use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use std::collections::HashMap;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
const TOKEN_REFRESH_BUFFER_MS: u64 = 5 * 60 * 1000;
const FALLBACK_REFRESH_INTERVAL_MS: u64 = 30 * 60 * 1000;
const MAX_REFRESH_FAILURES: u32 = 3;
const REFRESH_RETRY_DELAY_MS: u64 = 60_000;
fn current_timestamp_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
pub fn format_duration(ms: u64) -> String {
if ms < 60_000 {
return format!("{}s", ms / 1000);
}
let m = ms / 60_000;
let s = (ms % 60_000) / 1000;
if s > 0 {
format!("{}m {}s", m, s)
} else {
format!("{}m", m)
}
}
pub fn decode_jwt_payload(token: &str) -> Option<serde_json::Value> {
let jwt = if token.starts_with("sk-ant-si-") {
&token["sk-ant-si-".len()..]
} else {
token
};
let parts: Vec<&str> = jwt.split('.').collect();
if parts.len() != 3 || parts[1].is_empty() {
return None;
}
let payload_str = match URL_SAFE_NO_PAD.decode(parts[1]) {
Ok(bytes) => String::from_utf8(bytes).ok()?,
Err(_) => return None,
};
serde_json::from_str(&payload_str).ok()
}
pub fn decode_jwt_expiry(token: &str) -> Option<i64> {
let payload = decode_jwt_payload(token)?;
if let Some(exp) = payload.get("exp").and_then(|v| v.as_i64()) {
Some(exp)
} else {
None
}
}
pub struct TokenRefreshScheduler {
timers: HashMap<String, TimerState>,
failure_counts: HashMap<String, u32>,
generations: HashMap<String, u32>,
get_access_token: Box<dyn Fn() -> Option<String> + Send + Sync>,
on_refresh: Box<dyn Fn(&str, &str) + Send + Sync>,
label: String,
refresh_buffer_ms: u64,
}
#[derive(Debug)]
struct TimerState {
timer: Option<tokio::time::Sleep>,
expiry: Option<Instant>,
}
pub struct TokenRefreshSchedulerHandle {
scheduler: std::sync::Arc<std::sync::Mutex<TokenRefreshScheduler>>,
}
impl TokenRefreshSchedulerHandle {
pub fn schedule(&self, session_id: &str, token: &str) {
if let Ok(mut scheduler) = self.scheduler.lock() {
scheduler.schedule(session_id, token);
}
}
pub fn schedule_from_expires_in(&self, session_id: &str, expires_in_seconds: u64) {
if let Ok(mut scheduler) = self.scheduler.lock() {
scheduler.schedule_from_expires_in(session_id, expires_in_seconds);
}
}
pub fn cancel(&self, session_id: &str) {
if let Ok(mut scheduler) = self.scheduler.lock() {
scheduler.cancel(session_id);
}
}
pub fn cancel_all(&self) {
if let Ok(mut scheduler) = self.scheduler.lock() {
scheduler.cancel_all();
}
}
}
pub fn create_token_refresh_scheduler<G, R, L>(
get_access_token: G,
on_refresh: R,
label: L,
) -> TokenRefreshSchedulerHandle
where
G: Fn() -> Option<String> + Send + Sync + 'static,
R: Fn(&str, &str) + Send + Sync + 'static,
L: Into<String>,
{
let scheduler = TokenRefreshScheduler {
timers: HashMap::new(),
failure_counts: HashMap::new(),
generations: HashMap::new(),
get_access_token: Box::new(get_access_token),
on_refresh: Box::new(on_refresh),
label: label.into(),
refresh_buffer_ms: TOKEN_REFRESH_BUFFER_MS,
};
TokenRefreshSchedulerHandle {
scheduler: std::sync::Arc::new(std::sync::Mutex::new(scheduler)),
}
}
impl TokenRefreshScheduler {
fn next_generation(&mut self, session_id: &str) -> u32 {
let r#gen = self.generations.get(session_id).copied().unwrap_or(0) + 1;
self.generations.insert(session_id.to_string(), r#gen);
r#gen
}
fn schedule(&mut self, session_id: &str, token: &str) {
let expiry = decode_jwt_expiry(token);
if expiry.is_none() {
eprintln!(
"[{}:token] Could not decode JWT expiry for sessionId={}, token prefix={}..., keeping existing timer",
self.label,
session_id,
&token[..15.min(token.len())]
);
return;
}
let expiry = expiry.unwrap();
self.timers.remove(session_id);
let r#gen = self.next_generation(session_id);
let expiry_date = chrono::DateTime::from_timestamp(expiry, 0)
.map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string())
.unwrap_or_else(|| "unknown".to_string());
let delay_ms = (expiry * 1000) as i64
- current_timestamp_millis() as i64
- self.refresh_buffer_ms as i64;
if delay_ms <= 0 {
eprintln!(
"[{}:token] Token for sessionId={} expires={} (past or within buffer), refreshing immediately",
self.label, session_id, expiry_date
);
return;
}
eprintln!(
"[{}:token] Scheduled token refresh for sessionId={} in {} (expires={}, buffer={}s)",
self.label,
session_id,
format_duration(delay_ms as u64),
expiry_date,
self.refresh_buffer_ms / 1000
);
}
fn schedule_from_expires_in(&mut self, session_id: &str, expires_in_seconds: u64) {
self.timers.remove(session_id);
let r#gen = self.next_generation(session_id);
let delay_ms = (expires_in_seconds * 1000)
.saturating_sub(self.refresh_buffer_ms)
.max(30_000);
eprintln!(
"[{}:token] Scheduled token refresh for sessionId={} in {} (expires_in={}s, buffer={}s)",
self.label,
session_id,
format_duration(delay_ms),
expires_in_seconds,
self.refresh_buffer_ms / 1000
);
}
fn cancel(&mut self, session_id: &str) {
self.next_generation(session_id);
self.timers.remove(session_id);
self.failure_counts.remove(session_id);
}
fn cancel_all(&mut self) {
let session_ids: Vec<String> = self.generations.keys().cloned().collect();
for session_id in session_ids {
self.next_generation(&session_id);
}
self.timers.clear();
self.failure_counts.clear();
}
}
pub fn is_token_expired(token: &str, buffer_ms: u64) -> bool {
if let Some(expiry) = decode_jwt_expiry(token) {
let expiry_ms = expiry * 1000;
let now = current_timestamp_millis();
expiry_ms + buffer_ms as i64 <= now as i64
} else {
false
}
}
pub fn get_token_remaining_time(token: &str) -> Option<Duration> {
let expiry = decode_jwt_expiry(token)?;
let expiry_ms = expiry * 1000;
let now = current_timestamp_millis() as i64;
let remaining = expiry_ms - now;
if remaining > 0 {
Some(Duration::from_millis(remaining as u64))
} else {
Some(Duration::ZERO)
}
}