Skip to main content

ai_agent/bridge/
jwt_utils.rs

1//! JWT utilities for bridge token handling.
2//!
3//! Translated from openclaudecode/src/bridge/jwtUtils.ts
4
5use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
6use std::collections::HashMap;
7use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
8
9// =============================================================================
10// CONSTANTS
11// =============================================================================
12
13/// Refresh buffer: request a new token before expiry (5 minutes)
14const TOKEN_REFRESH_BUFFER_MS: u64 = 5 * 60 * 1000;
15
16/// Fallback refresh interval when the new token's expiry is unknown (30 minutes)
17const FALLBACK_REFRESH_INTERVAL_MS: u64 = 30 * 60 * 1000;
18
19/// Max consecutive failures before giving up on the refresh chain.
20const MAX_REFRESH_FAILURES: u32 = 3;
21
22/// Retry delay when getAccessToken returns undefined.
23const REFRESH_RETRY_DELAY_MS: u64 = 60_000;
24
25// =============================================================================
26// TIME HELPERS
27// =============================================================================
28
29/// Get current timestamp in milliseconds.
30fn current_timestamp_millis() -> u64 {
31    SystemTime::now()
32        .duration_since(UNIX_EPOCH)
33        .unwrap_or_default()
34        .as_millis() as u64
35}
36
37// =============================================================================
38// JWT DECODING
39// =============================================================================
40
41/// Format a millisecond duration as a human-readable string (e.g. "5m 30s").
42pub fn format_duration(ms: u64) -> String {
43    if ms < 60_000 {
44        return format!("{}s", ms / 1000);
45    }
46    let m = ms / 60_000;
47    let s = (ms % 60_000) / 1000;
48    if s > 0 {
49        format!("{}m {}s", m, s)
50    } else {
51        format!("{}m", m)
52    }
53}
54
55/// Decode a JWT's payload segment without verifying the signature.
56/// Strips the `sk-ant-si-` session-ingress prefix if present.
57/// Returns the parsed JSON payload as a Value, or None if the
58/// token is malformed or the payload is not valid JSON.
59pub fn decode_jwt_payload(token: &str) -> Option<serde_json::Value> {
60    let jwt = if token.starts_with("sk-ant-si-") {
61        &token["sk-ant-si-".len()..]
62    } else {
63        token
64    };
65
66    let parts: Vec<&str> = jwt.split('.').collect();
67    if parts.len() != 3 || parts[1].is_empty() {
68        return None;
69    }
70
71    // Decode base64url
72    let payload_str = match URL_SAFE_NO_PAD.decode(parts[1]) {
73        Ok(bytes) => String::from_utf8(bytes).ok()?,
74        Err(_) => return None,
75    };
76
77    // Parse JSON
78    serde_json::from_str(&payload_str).ok()
79}
80
81/// Decode the `exp` (expiry) claim from a JWT without verifying the signature.
82/// Returns the `exp` value in Unix seconds, or None if unparseable.
83pub fn decode_jwt_expiry(token: &str) -> Option<i64> {
84    let payload = decode_jwt_payload(token)?;
85    if let Some(exp) = payload.get("exp").and_then(|v| v.as_i64()) {
86        Some(exp)
87    } else {
88        None
89    }
90}
91
92// =============================================================================
93// TOKEN REFRESH SCHEDULER
94// =============================================================================
95
96/// Token refresh scheduler state.
97pub struct TokenRefreshScheduler {
98    timers: HashMap<String, TimerState>,
99    failure_counts: HashMap<String, u32>,
100    generations: HashMap<String, u32>,
101    get_access_token: Box<dyn Fn() -> Option<String> + Send + Sync>,
102    on_refresh: Box<dyn Fn(&str, &str) + Send + Sync>,
103    label: String,
104    refresh_buffer_ms: u64,
105}
106
107#[derive(Debug)]
108struct TimerState {
109    timer: Option<tokio::time::Sleep>,
110    expiry: Option<Instant>,
111}
112
113/// Token refresh scheduler handle for controlling the scheduler.
114pub struct TokenRefreshSchedulerHandle {
115    scheduler: std::sync::Arc<std::sync::Mutex<TokenRefreshScheduler>>,
116}
117
118impl TokenRefreshSchedulerHandle {
119    /// Schedule refresh for a token with a given session ID.
120    pub fn schedule(&self, session_id: &str, token: &str) {
121        if let Ok(mut scheduler) = self.scheduler.lock() {
122            scheduler.schedule(session_id, token);
123        }
124    }
125
126    /// Schedule refresh using an explicit TTL (seconds until expiry).
127    pub fn schedule_from_expires_in(&self, session_id: &str, expires_in_seconds: u64) {
128        if let Ok(mut scheduler) = self.scheduler.lock() {
129            scheduler.schedule_from_expires_in(session_id, expires_in_seconds);
130        }
131    }
132
133    /// Cancel refresh for a session.
134    pub fn cancel(&self, session_id: &str) {
135        if let Ok(mut scheduler) = self.scheduler.lock() {
136            scheduler.cancel(session_id);
137        }
138    }
139
140    /// Cancel all scheduled refreshes.
141    pub fn cancel_all(&self) {
142        if let Ok(mut scheduler) = self.scheduler.lock() {
143            scheduler.cancel_all();
144        }
145    }
146}
147
148/// Create a token refresh scheduler that proactively refreshes session tokens
149/// before they expire. Used by both the standalone bridge and the REPL bridge.
150///
151/// When a token is about to expire, the scheduler calls `on_refresh` with the
152/// session ID and the bridge's OAuth access token.
153pub fn create_token_refresh_scheduler<G, R, L>(
154    get_access_token: G,
155    on_refresh: R,
156    label: L,
157) -> TokenRefreshSchedulerHandle
158where
159    G: Fn() -> Option<String> + Send + Sync + 'static,
160    R: Fn(&str, &str) + Send + Sync + 'static,
161    L: Into<String>,
162{
163    let scheduler = TokenRefreshScheduler {
164        timers: HashMap::new(),
165        failure_counts: HashMap::new(),
166        generations: HashMap::new(),
167        get_access_token: Box::new(get_access_token),
168        on_refresh: Box::new(on_refresh),
169        label: label.into(),
170        refresh_buffer_ms: TOKEN_REFRESH_BUFFER_MS,
171    };
172
173    TokenRefreshSchedulerHandle {
174        scheduler: std::sync::Arc::new(std::sync::Mutex::new(scheduler)),
175    }
176}
177
178impl TokenRefreshScheduler {
179    fn next_generation(&mut self, session_id: &str) -> u32 {
180        let r#gen = self.generations.get(session_id).copied().unwrap_or(0) + 1;
181        self.generations.insert(session_id.to_string(), r#gen);
182        r#gen
183    }
184
185    fn schedule(&mut self, session_id: &str, token: &str) {
186        let expiry = decode_jwt_expiry(token);
187        if expiry.is_none() {
188            // Token is not a decodable JWT
189            eprintln!(
190                "[{}:token] Could not decode JWT expiry for sessionId={}, token prefix={}..., keeping existing timer",
191                self.label,
192                session_id,
193                &token[..15.min(token.len())]
194            );
195            return;
196        }
197
198        let expiry = expiry.unwrap();
199
200        // Clear any existing timer
201        self.timers.remove(session_id);
202
203        // Bump generation
204        let r#gen = self.next_generation(session_id);
205
206        let expiry_date = chrono::DateTime::from_timestamp(expiry, 0)
207            .map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string())
208            .unwrap_or_else(|| "unknown".to_string());
209
210        let delay_ms = (expiry * 1000) as i64
211            - current_timestamp_millis() as i64
212            - self.refresh_buffer_ms as i64;
213
214        if delay_ms <= 0 {
215            eprintln!(
216                "[{}:token] Token for sessionId={} expires={} (past or within buffer), refreshing immediately",
217                self.label, session_id, expiry_date
218            );
219            // Would trigger refresh here in async context
220            return;
221        }
222
223        eprintln!(
224            "[{}:token] Scheduled token refresh for sessionId={} in {} (expires={}, buffer={}s)",
225            self.label,
226            session_id,
227            format_duration(delay_ms as u64),
228            expiry_date,
229            self.refresh_buffer_ms / 1000
230        );
231
232        // Timer would be scheduled here
233    }
234
235    fn schedule_from_expires_in(&mut self, session_id: &str, expires_in_seconds: u64) {
236        // Clear any existing timer
237        self.timers.remove(session_id);
238
239        let r#gen = self.next_generation(session_id);
240
241        // Clamp to 30s floor
242        let delay_ms = (expires_in_seconds * 1000)
243            .saturating_sub(self.refresh_buffer_ms)
244            .max(30_000);
245
246        eprintln!(
247            "[{}:token] Scheduled token refresh for sessionId={} in {} (expires_in={}s, buffer={}s)",
248            self.label,
249            session_id,
250            format_duration(delay_ms),
251            expires_in_seconds,
252            self.refresh_buffer_ms / 1000
253        );
254
255        // Timer would be scheduled here
256    }
257
258    fn cancel(&mut self, session_id: &str) {
259        // Bump generation to invalidate any in-flight refresh
260        self.next_generation(session_id);
261        self.timers.remove(session_id);
262        self.failure_counts.remove(session_id);
263    }
264
265    fn cancel_all(&mut self) {
266        // Bump all generations
267        let session_ids: Vec<String> = self.generations.keys().cloned().collect();
268        for session_id in session_ids {
269            self.next_generation(&session_id);
270        }
271        self.timers.clear();
272        self.failure_counts.clear();
273    }
274}
275
276// =============================================================================
277// STANDALONE FUNCTIONS
278// =============================================================================
279
280/// Check if a token is expired or about to expire.
281pub fn is_token_expired(token: &str, buffer_ms: u64) -> bool {
282    if let Some(expiry) = decode_jwt_expiry(token) {
283        let expiry_ms = expiry * 1000;
284        let now = current_timestamp_millis();
285        expiry_ms + buffer_ms as i64 <= now as i64
286    } else {
287        // Can't decode - assume not expired
288        false
289    }
290}
291
292/// Get remaining time until token expires.
293pub fn get_token_remaining_time(token: &str) -> Option<Duration> {
294    let expiry = decode_jwt_expiry(token)?;
295    let expiry_ms = expiry * 1000;
296    let now = current_timestamp_millis() as i64;
297    let remaining = expiry_ms - now;
298    if remaining > 0 {
299        Some(Duration::from_millis(remaining as u64))
300    } else {
301        Some(Duration::ZERO)
302    }
303}