ai_agent/bridge/
jwt_utils.rs1use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
6use std::collections::HashMap;
7use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
8
9const TOKEN_REFRESH_BUFFER_MS: u64 = 5 * 60 * 1000;
15
16const FALLBACK_REFRESH_INTERVAL_MS: u64 = 30 * 60 * 1000;
18
19const MAX_REFRESH_FAILURES: u32 = 3;
21
22const REFRESH_RETRY_DELAY_MS: u64 = 60_000;
24
25fn current_timestamp_millis() -> u64 {
31 SystemTime::now()
32 .duration_since(UNIX_EPOCH)
33 .unwrap_or_default()
34 .as_millis() as u64
35}
36
37pub fn format_duration(ms: u64) -> String {
43 if ms < 60_000 {
44 return format!("{}s", ms / 1000);
45 }
46 let m = ms / 60_000;
47 let s = (ms % 60_000) / 1000;
48 if s > 0 {
49 format!("{}m {}s", m, s)
50 } else {
51 format!("{}m", m)
52 }
53}
54
55pub fn decode_jwt_payload(token: &str) -> Option<serde_json::Value> {
60 let jwt = if token.starts_with("sk-ant-si-") {
61 &token["sk-ant-si-".len()..]
62 } else {
63 token
64 };
65
66 let parts: Vec<&str> = jwt.split('.').collect();
67 if parts.len() != 3 || parts[1].is_empty() {
68 return None;
69 }
70
71 let payload_str = match URL_SAFE_NO_PAD.decode(parts[1]) {
73 Ok(bytes) => String::from_utf8(bytes).ok()?,
74 Err(_) => return None,
75 };
76
77 serde_json::from_str(&payload_str).ok()
79}
80
81pub fn decode_jwt_expiry(token: &str) -> Option<i64> {
84 let payload = decode_jwt_payload(token)?;
85 if let Some(exp) = payload.get("exp").and_then(|v| v.as_i64()) {
86 Some(exp)
87 } else {
88 None
89 }
90}
91
92pub struct TokenRefreshScheduler {
98 timers: HashMap<String, TimerState>,
99 failure_counts: HashMap<String, u32>,
100 generations: HashMap<String, u32>,
101 get_access_token: Box<dyn Fn() -> Option<String> + Send + Sync>,
102 on_refresh: Box<dyn Fn(&str, &str) + Send + Sync>,
103 label: String,
104 refresh_buffer_ms: u64,
105}
106
107#[derive(Debug)]
108struct TimerState {
109 timer: Option<tokio::time::Sleep>,
110 expiry: Option<Instant>,
111}
112
113pub struct TokenRefreshSchedulerHandle {
115 scheduler: std::sync::Arc<std::sync::Mutex<TokenRefreshScheduler>>,
116}
117
118impl TokenRefreshSchedulerHandle {
119 pub fn schedule(&self, session_id: &str, token: &str) {
121 if let Ok(mut scheduler) = self.scheduler.lock() {
122 scheduler.schedule(session_id, token);
123 }
124 }
125
126 pub fn schedule_from_expires_in(&self, session_id: &str, expires_in_seconds: u64) {
128 if let Ok(mut scheduler) = self.scheduler.lock() {
129 scheduler.schedule_from_expires_in(session_id, expires_in_seconds);
130 }
131 }
132
133 pub fn cancel(&self, session_id: &str) {
135 if let Ok(mut scheduler) = self.scheduler.lock() {
136 scheduler.cancel(session_id);
137 }
138 }
139
140 pub fn cancel_all(&self) {
142 if let Ok(mut scheduler) = self.scheduler.lock() {
143 scheduler.cancel_all();
144 }
145 }
146}
147
148pub fn create_token_refresh_scheduler<G, R, L>(
154 get_access_token: G,
155 on_refresh: R,
156 label: L,
157) -> TokenRefreshSchedulerHandle
158where
159 G: Fn() -> Option<String> + Send + Sync + 'static,
160 R: Fn(&str, &str) + Send + Sync + 'static,
161 L: Into<String>,
162{
163 let scheduler = TokenRefreshScheduler {
164 timers: HashMap::new(),
165 failure_counts: HashMap::new(),
166 generations: HashMap::new(),
167 get_access_token: Box::new(get_access_token),
168 on_refresh: Box::new(on_refresh),
169 label: label.into(),
170 refresh_buffer_ms: TOKEN_REFRESH_BUFFER_MS,
171 };
172
173 TokenRefreshSchedulerHandle {
174 scheduler: std::sync::Arc::new(std::sync::Mutex::new(scheduler)),
175 }
176}
177
178impl TokenRefreshScheduler {
179 fn next_generation(&mut self, session_id: &str) -> u32 {
180 let r#gen = self.generations.get(session_id).copied().unwrap_or(0) + 1;
181 self.generations.insert(session_id.to_string(), r#gen);
182 r#gen
183 }
184
185 fn schedule(&mut self, session_id: &str, token: &str) {
186 let expiry = decode_jwt_expiry(token);
187 if expiry.is_none() {
188 eprintln!(
190 "[{}:token] Could not decode JWT expiry for sessionId={}, token prefix={}..., keeping existing timer",
191 self.label,
192 session_id,
193 &token[..15.min(token.len())]
194 );
195 return;
196 }
197
198 let expiry = expiry.unwrap();
199
200 self.timers.remove(session_id);
202
203 let r#gen = self.next_generation(session_id);
205
206 let expiry_date = chrono::DateTime::from_timestamp(expiry, 0)
207 .map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string())
208 .unwrap_or_else(|| "unknown".to_string());
209
210 let delay_ms = (expiry * 1000) as i64
211 - current_timestamp_millis() as i64
212 - self.refresh_buffer_ms as i64;
213
214 if delay_ms <= 0 {
215 eprintln!(
216 "[{}:token] Token for sessionId={} expires={} (past or within buffer), refreshing immediately",
217 self.label, session_id, expiry_date
218 );
219 return;
221 }
222
223 eprintln!(
224 "[{}:token] Scheduled token refresh for sessionId={} in {} (expires={}, buffer={}s)",
225 self.label,
226 session_id,
227 format_duration(delay_ms as u64),
228 expiry_date,
229 self.refresh_buffer_ms / 1000
230 );
231
232 }
234
235 fn schedule_from_expires_in(&mut self, session_id: &str, expires_in_seconds: u64) {
236 self.timers.remove(session_id);
238
239 let r#gen = self.next_generation(session_id);
240
241 let delay_ms = (expires_in_seconds * 1000)
243 .saturating_sub(self.refresh_buffer_ms)
244 .max(30_000);
245
246 eprintln!(
247 "[{}:token] Scheduled token refresh for sessionId={} in {} (expires_in={}s, buffer={}s)",
248 self.label,
249 session_id,
250 format_duration(delay_ms),
251 expires_in_seconds,
252 self.refresh_buffer_ms / 1000
253 );
254
255 }
257
258 fn cancel(&mut self, session_id: &str) {
259 self.next_generation(session_id);
261 self.timers.remove(session_id);
262 self.failure_counts.remove(session_id);
263 }
264
265 fn cancel_all(&mut self) {
266 let session_ids: Vec<String> = self.generations.keys().cloned().collect();
268 for session_id in session_ids {
269 self.next_generation(&session_id);
270 }
271 self.timers.clear();
272 self.failure_counts.clear();
273 }
274}
275
276pub fn is_token_expired(token: &str, buffer_ms: u64) -> bool {
282 if let Some(expiry) = decode_jwt_expiry(token) {
283 let expiry_ms = expiry * 1000;
284 let now = current_timestamp_millis();
285 expiry_ms + buffer_ms as i64 <= now as i64
286 } else {
287 false
289 }
290}
291
292pub fn get_token_remaining_time(token: &str) -> Option<Duration> {
294 let expiry = decode_jwt_expiry(token)?;
295 let expiry_ms = expiry * 1000;
296 let now = current_timestamp_millis() as i64;
297 let remaining = expiry_ms - now;
298 if remaining > 0 {
299 Some(Duration::from_millis(remaining as u64))
300 } else {
301 Some(Duration::ZERO)
302 }
303}