Skip to main content

shunt/
proxy.rs

1use std::collections::{HashMap, HashSet};
2use std::net::IpAddr;
3use std::sync::Arc;
4use std::time::Instant;
5
6use parking_lot::Mutex as ParkingMutex;
7
8use axum::extract::{Request, State};
9use axum::http::StatusCode;
10use axum::response::{IntoResponse, Response};
11use axum::routing::{get, post};
12use axum::Router;
13use bytes::Bytes;
14use serde_json::json;
15use tokio::sync::RwLock;
16use tracing::{error, info, warn};
17
18use crate::config::{state_path, Config, CredentialsStore};
19use crate::credential::Credential;
20use crate::forwarder::Forwarder;
21use crate::provider::Provider;
22use crate::quota;
23use crate::router;
24use crate::state::StateStore;
25use crate::telemetry::TelemetryClient;
26
27/// 100 MB limit — sufficient for any LLM request including large context windows.
28const MAX_REQUEST_BODY: usize = 100 * 1024 * 1024;
29
30#[derive(Clone)]
31struct AppState {
32    config: Arc<Config>,
33    forwarder: Arc<Forwarder>,
34    state: StateStore,
35    /// Live credentials — can be refreshed at runtime without restarting.
36    credentials: Arc<RwLock<HashMap<String, Credential>>>,
37    /// Per-account mutex that serialises concurrent token-refresh attempts.
38    ///
39    /// When multiple in-flight requests hit a 401 for the same account at the
40    /// same time, only one should call the upstream OAuth endpoint; the others
41    /// should wait and then re-use the fresh token instead of each making their
42    /// own refresh call (which would rotate the refresh_token out from under the
43    /// others and cause cascading auth failures).
44    refresh_locks: Arc<ParkingMutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
45    /// Epoch-ms when this proxy instance started.
46    started_ms: u64,
47    /// If set, /v1/chat/completions requests are translated and forwarded here
48    /// (the Anthropic proxy base URL, e.g. "http://127.0.0.1:8082").
49    anthropic_base_url: Option<String>,
50    /// Optional relay-server telemetry client.
51    telemetry: Option<TelemetryClient>,
52    /// Per-IP token-bucket rate limiter (#16). None when rate_limit_rpm == 0.
53    rate_limiter: Option<Arc<ParkingMutex<HashMap<IpAddr, TokenBucket>>>>,
54}
55
56/// Simple token-bucket for per-IP rate limiting.
57struct TokenBucket {
58    tokens: f64,
59    last_refill: Instant,
60}
61
62impl TokenBucket {
63    fn new(capacity: f64) -> Self {
64        Self { tokens: capacity, last_refill: Instant::now() }
65    }
66
67    /// Refill tokens proportional to elapsed time, then consume one.
68    /// Returns true if the request is allowed.
69    fn check_and_consume(&mut self, rpm: f64) -> bool {
70        let elapsed = self.last_refill.elapsed().as_secs_f64();
71        self.last_refill = Instant::now();
72        // Refill at rpm/60 tokens per second; cap at burst (10 tokens).
73        let burst = (rpm / 6.0).max(10.0);
74        self.tokens = (self.tokens + elapsed * rpm / 60.0).min(burst);
75        if self.tokens >= 1.0 {
76            self.tokens -= 1.0;
77            true
78        } else {
79            false
80        }
81    }
82}
83
84pub fn create_app(config: Config) -> anyhow::Result<Router> {
85    let (app, _, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
86    Ok(app)
87}
88
89/// Shared live credentials map — can be written to without restarting the proxy.
90pub type LiveCredentials = Arc<RwLock<HashMap<String, Credential>>>;
91
92/// Create a pure proxy app (no management routes).
93/// Registers /v1/messages, /v1/chat/completions, /v1/models, and a fallback.
94/// Build a shared `AppState` and the `LiveCredentials` handle it references.
95fn build_app_state(
96    config: Config,
97    state: StateStore,
98    anthropic_base_url: Option<String>,
99) -> anyhow::Result<(AppState, LiveCredentials)> {
100    let forwarder = Forwarder::new(config.server.request_timeout_secs)?;
101
102    for a in &config.accounts {
103        if a.provider.auth_kind() == crate::provider::AuthKind::None {
104            // Local providers never need credentials — clear any stale auth_failed from disk.
105            state.clear_auth_failed(&a.name);
106        } else if a.credential.is_none() {
107            state.set_auth_failed(&a.name);
108        }
109    }
110
111    let credentials: LiveCredentials = Arc::new(RwLock::new(
112        config.accounts.iter()
113            .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
114            .collect::<HashMap<_, _>>(),
115    ));
116
117    let telemetry = config.server.telemetry_url.as_deref().map(|url| {
118        TelemetryClient::new(url, config.server.telemetry_token.clone(), config.server.instance_name.clone())
119    });
120
121    let rate_limiter = if config.server.rate_limit_rpm > 0 {
122        Some(Arc::new(ParkingMutex::new(HashMap::<IpAddr, TokenBucket>::new())))
123    } else {
124        None
125    };
126
127    let app_state = AppState {
128        config: Arc::new(config),
129        forwarder: Arc::new(forwarder),
130        state,
131        credentials: Arc::clone(&credentials),
132        refresh_locks: Arc::new(ParkingMutex::new(HashMap::new())),
133        started_ms: now_ms(),
134        anthropic_base_url,
135        telemetry,
136        rate_limiter,
137    };
138
139    Ok((app_state, credentials))
140}
141
142pub fn create_proxy_app(
143    config: Config,
144    state: StateStore,
145    anthropic_base_url: Option<String>,
146) -> anyhow::Result<(Router, LiveCredentials)> {
147    let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
148
149    let app = Router::new()
150        .route("/v1/messages", post(proxy_handler))
151        .route("/v1/messages/count_tokens", post(proxy_handler))
152        .route("/v1/chat/completions", post(openai_compat_handler))
153        .route("/v1/models", get(openai_models_handler))
154        .fallback(proxy_handler)
155        .with_state(app_state);
156
157    Ok((app, credentials))
158}
159
160/// Create a control plane app (management routes only — sees ALL accounts).
161/// Registers /health, /status, /use.
162pub fn create_control_app(
163    config: Config,
164    state: StateStore,
165) -> anyhow::Result<Router> {
166    let (app_state, _) = build_app_state(config, state, None)?;
167
168    let app = Router::new()
169        .route("/health", get(health))
170        .route("/status", get(status_handler))
171        .route("/use", post(use_handler))
172        .route("/model", get(model_get_handler).post(model_set_handler).delete(model_clear_handler))
173        .route("/strategy", get(strategy_get_handler).post(strategy_set_handler).delete(strategy_clear_handler))
174        .with_state(app_state);
175
176    Ok(app)
177}
178
179/// Combined app used by tests and the single-port fallback mode.
180/// Includes both proxy routes and management routes (/health, /status, /use)
181/// sharing a single AppState so state changes are visible across all routes.
182pub fn create_app_with_state(
183    config: Config,
184    state: StateStore,
185    anthropic_base_url: Option<String>,
186) -> anyhow::Result<(Router, LiveCredentials, Option<TelemetryClient>)> {
187    let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
188    let telemetry = app_state.telemetry.clone();
189
190    let app = Router::new()
191        // Management routes
192        .route("/health", get(health))
193        .route("/status", get(status_handler))
194        .route("/use", post(use_handler))
195        .route("/model", get(model_get_handler).post(model_set_handler).delete(model_clear_handler))
196        .route("/strategy", get(strategy_get_handler).post(strategy_set_handler).delete(strategy_clear_handler))
197        // Proxy routes
198        .route("/v1/messages", post(proxy_handler))
199        .route("/v1/messages/count_tokens", post(proxy_handler))
200        .route("/v1/chat/completions", post(openai_compat_handler))
201        .route("/v1/models", get(openai_models_handler))
202        .fallback(proxy_handler)
203        .with_state(app_state);
204
205    Ok((app, credentials, telemetry))
206}
207
208/// Build a status JSON snapshot from config + state — used by the heartbeat loop.
209pub fn build_status_snapshot(config: &Config, state: &StateStore, started_ms: u64) -> serde_json::Value {
210    let account_states = state.account_states();
211    let rate_limits    = state.rate_limit_snapshot();
212
213    let accounts: Vec<_> = config.accounts.iter().map(|a| {
214        let st            = account_states.get(&a.name);
215        let rl            = rate_limits.get(&a.name);
216        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
217        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
218        let reset_5h       = rl.and_then(|r| r.reset_5h);
219        let reset_7d       = rl.and_then(|r| r.reset_7d);
220        let disabled       = st.map(|s| s.disabled).unwrap_or(false);
221        let auth_failed    = st.map(|s| s.auth_failed).unwrap_or(false);
222        let health_check_failed = st.map(|s| s.health_check_failed).unwrap_or(false);
223        let cooldown_until_ms = st.map(|s| s.cooldown_until_ms).unwrap_or(0);
224        let available      = state.is_available(&a.name);
225        let email          = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
226
227        json!({
228            "name": a.name,
229            "email": email,
230            "provider": a.provider.to_string(),
231            "available": available,
232            "disabled": disabled,
233            "auth_failed": auth_failed,
234            "health_check_failed": health_check_failed,
235            "cooldown_until_ms": cooldown_until_ms,
236            "utilization_5h": utilization_5h,
237            "reset_5h": reset_5h,
238            "utilization_7d": utilization_7d,
239            "reset_7d": reset_7d,
240        })
241    }).collect();
242
243    json!({
244        "started_ms": started_ms,
245        "accounts": accounts,
246        "pinned_account": state.get_pinned(),
247        "last_used_account": state.get_last_used(),
248    })
249}
250
251async fn health() -> impl IntoResponse {
252    axum::Json(json!({"status": "ok"}))
253}
254
255async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
256    let account_states = s.state.account_states();
257    let quotas = s.state.quota_snapshot();
258    let rate_limits = s.state.rate_limit_snapshot();
259
260    let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
261        let st = account_states.get(&a.name);
262        let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
263            "reauth_required"
264        } else if st.map(|s| s.disabled).unwrap_or(false) {
265            "disabled"
266        } else if st.map(|s| s.health_check_failed).unwrap_or(false) {
267            "unhealthy"
268        } else if s.state.is_available(&a.name) {
269            "available"
270        } else {
271            "cooling"
272        };
273
274        let quota = quotas.get(&a.name);
275        let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
276        let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
277        let tokens_used = quota.map(|q| json!({
278            "input": q.input_tokens,
279            "output": q.output_tokens,
280            "total": q.total_tokens(),
281        }));
282
283        let rl = rate_limits.get(&a.name);
284        let rate_limit = rl.map(|r| json!({
285            "utilization_5h": r.utilization_5h,
286            "reset_5h": r.reset_5h,
287            "status_5h": r.status_5h,
288            "utilization_7d": r.utilization_7d,
289            "reset_7d": r.reset_7d,
290            "status_7d": r.status_7d,
291            "representative_claim": r.representative_claim,
292            "updated_ms": r.updated_ms,
293        }));
294
295        let acc_state = account_states.get(&a.name);
296        let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
297        let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
298        let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
299        let health_check_failed = acc_state.map(|s| s.health_check_failed).unwrap_or(false);
300        let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
301        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
302        let reset_5h = rl.and_then(|r| r.reset_5h);
303        let status_5h = rl.and_then(|r| r.status_5h.clone());
304        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
305        let reset_7d = rl.and_then(|r| r.reset_7d);
306        let status_7d = rl.and_then(|r| r.status_7d.clone());
307        let available = s.state.is_available(&a.name);
308
309        json!({
310            "name": a.name,
311            "email": email,
312            "plan_type": a.plan_type,
313            "provider": a.provider.to_string(),
314            "status": avail_status,
315            "available": available,
316            "disabled": disabled,
317            "auth_failed": auth_failed,
318            "health_check_failed": health_check_failed,
319            "cooldown_until_ms": cooldown_until_ms,
320            "utilization_5h": utilization_5h,
321            "reset_5h": reset_5h,
322            "status_5h": status_5h,
323            "utilization_7d": utilization_7d,
324            "reset_7d": reset_7d,
325            "status_7d": status_7d,
326            "window_expires_ms": window_expires_ms,
327            "tokens_used": tokens_used,
328            "rate_limit": rate_limit,
329        })
330    }).collect();
331
332    let recent_requests = s.state.recent_requests_snapshot();
333    let savings = s.state.savings_snapshot();
334
335    axum::Json(json!({
336        "version": env!("CARGO_PKG_VERSION"),
337        "started_ms": s.started_ms,
338        "accounts": accounts,
339        "pinned_account": s.state.get_pinned(),
340        "last_used_account": s.state.get_last_used(),
341        "recent_requests": recent_requests,
342        "savings": savings,
343    }))
344}
345
346async fn use_handler(
347    State(s): State<AppState>,
348    axum::Json(body): axum::Json<serde_json::Value>,
349) -> Response {
350    let account = body["account"].as_str().map(|s| s.to_owned());
351    // Validate the account name exists (unless clearing to auto)
352    if let Some(ref name) = account {
353        if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
354            return (StatusCode::BAD_REQUEST, axum::Json(json!({
355                "error": format!("unknown account '{name}'")
356            }))).into_response();
357        }
358        let pinned = if name == "auto" { None } else { Some(name.clone()) };
359        s.state.set_pinned(pinned);
360        axum::Json(json!({ "pinned": name })).into_response()
361    } else {
362        s.state.set_pinned(None);
363        axum::Json(json!({ "pinned": null })).into_response()
364    }
365}
366
367async fn model_get_handler(State(s): State<AppState>) -> impl IntoResponse {
368    let model = s.state.get_model_override();
369    axum::Json(json!({ "model": model }))
370}
371
372async fn model_set_handler(
373    State(s): State<AppState>,
374    axum::Json(body): axum::Json<serde_json::Value>,
375) -> Response {
376    let Some(model) = body["model"].as_str() else {
377        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing model field" }))).into_response();
378    };
379    s.state.set_model_override(model.to_owned());
380    info!(model, "model override set");
381    axum::Json(json!({ "model": model })).into_response()
382}
383
384async fn model_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
385    s.state.clear_model_override();
386    info!("model override cleared");
387    axum::Json(json!({ "model": null }))
388}
389
390async fn strategy_get_handler(State(s): State<AppState>) -> impl IntoResponse {
391    let (strategy_str, source) = match s.state.get_routing_strategy() {
392        Some(st) => (st.as_str(), "override"),
393        None => (s.config.server.routing_strategy.as_str(), "config"),
394    };
395    axum::Json(json!({ "strategy": strategy_str, "source": source }))
396}
397
398async fn strategy_set_handler(
399    State(s): State<AppState>,
400    axum::Json(body): axum::Json<serde_json::Value>,
401) -> Response {
402    let Some(name) = body["strategy"].as_str() else {
403        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing strategy field" }))).into_response();
404    };
405    let Some(strategy) = crate::config::RoutingStrategy::from_str(name) else {
406        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": format!("unknown strategy '{name}'") }))).into_response();
407    };
408    s.state.set_routing_strategy(strategy);
409    info!(strategy = name, "routing strategy override set");
410    axum::Json(json!({ "strategy": strategy.as_str(), "source": "override" })).into_response()
411}
412
413async fn strategy_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
414    s.state.clear_routing_strategy();
415    info!("routing strategy override cleared");
416    let strategy_str = s.config.server.routing_strategy.as_str();
417    axum::Json(json!({ "strategy": strategy_str, "source": "config" }))
418}
419
420fn now_ms() -> u64 {
421    use std::time::{SystemTime, UNIX_EPOCH};
422    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
423}
424
425/// Extract client IP for rate limiting.
426///
427/// `X-Real-IP` is only trusted when `trust_proxy_headers` is explicitly enabled
428/// in config — otherwise any client could spoof the header to rotate its bucket.
429/// When not trusted (the default), all requests share a single loopback bucket,
430/// giving a global RPM cap rather than a per-IP one.
431fn extract_client_ip(req: &Request, trust_proxy_headers: bool) -> IpAddr {
432    if trust_proxy_headers {
433        if let Some(ip) = req.headers()
434            .get("x-real-ip")
435            .and_then(|v| v.to_str().ok())
436            .and_then(|s| s.parse().ok())
437        {
438            return ip;
439        }
440    }
441    IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
442}
443
444async fn proxy_handler(
445    State(s): State<AppState>,
446    req: Request,
447) -> Result<Response, ProxyError> {
448    // Remote auth: if a remote_key is configured, the client must supply it as x-api-key.
449    if let Some(ref expected) = s.config.server.remote_key {
450        let provided = req.headers()
451            .get("x-api-key")
452            .and_then(|v| v.to_str().ok())
453            .unwrap_or("");
454        if provided != expected {
455            return Err(ProxyError::Unauthorized);
456        }
457    }
458
459    // #16: per-IP rate limiting (token bucket, configurable via rate_limit_rpm).
460    if let Some(ref rl) = s.rate_limiter {
461        let ip = extract_client_ip(&req, s.config.server.trust_proxy_headers);
462        let rpm = s.config.server.rate_limit_rpm as f64;
463        let allowed = rl.lock().entry(ip).or_insert_with(|| TokenBucket::new(rpm)).check_and_consume(rpm);
464        if !allowed {
465            return Err(ProxyError::RateLimited);
466        }
467    }
468
469    let method = req.method().as_str().to_owned();
470    let path = req.uri().path().to_owned();
471    let headers = req.headers().clone();
472
473    let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY)
474        .await
475        .map_err(|_| ProxyError::BodyRead)?;
476
477    // Apply model override: if set, patch the `model` field in the JSON body before forwarding.
478    // Also strip unsupported params for models that don't support them (e.g. Haiku).
479    let body_bytes = if let Ok(mut val) = serde_json::from_slice::<serde_json::Value>(&body_bytes) {
480        let mut changed = false;
481        if let Some(override_model) = s.state.get_model_override() {
482            if val.get("model").is_some() {
483                val["model"] = serde_json::Value::String(override_model);
484                changed = true;
485            }
486        }
487        let resolved_model = val["model"].as_str().unwrap_or("").to_owned();
488        if is_simple_model(&resolved_model) {
489            if let Some(obj) = val.as_object_mut() {
490                // Strip features unsupported by simpler models (Haiku).
491                for key in &["thinking", "effort", "reasoning_effort"] {
492                    if obj.remove(*key).is_some() { changed = true; }
493                }
494                // Strip effort from output_config if present
495                if let Some(serde_json::Value::Object(oc)) = obj.get_mut("output_config") {
496                    if oc.remove("effort").is_some() { changed = true; }
497                    // Remove output_config entirely if empty
498                    if oc.is_empty() { obj.remove("output_config"); }
499                }
500                // Strip context_management (thinking-related edit rules)
501                if obj.remove("context_management").is_some() { changed = true; }
502                // Remove extended-thinking beta flag
503                if let Some(serde_json::Value::Array(betas)) = obj.get_mut("betas") {
504                    let before = betas.len();
505                    betas.retain(|b| b.as_str() != Some("interleaved-thinking-2025-05-14"));
506                    if betas.len() != before { changed = true; }
507                }
508            }
509        }
510        if changed {
511            Bytes::from(serde_json::to_vec(&val).unwrap_or_else(|_| body_bytes.to_vec()))
512        } else {
513            body_bytes
514        }
515    } else {
516        body_bytes
517    };
518
519    let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
520        .ok()
521        .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
522        .unwrap_or_default();
523
524    // Strip thinking/effort-related beta flags from the anthropic-beta header for simple models.
525    let mut headers = headers;
526    if is_simple_model(&model) {
527        if let Some(beta_val) = headers.get("anthropic-beta").and_then(|v| v.to_str().ok().map(|s| s.to_owned())) {
528            let filtered: Vec<&str> = beta_val.split(',')
529                .map(|s| s.trim())
530                .filter(|b| !b.contains("thinking") && !b.contains("effort"))
531                .collect();
532            let new_beta = filtered.join(",");
533            if filtered.is_empty() {
534                headers.remove("anthropic-beta");
535            } else if let Ok(v) = axum::http::HeaderValue::from_str(&new_beta) {
536                headers.insert("anthropic-beta", v);
537            }
538        }
539    }
540
541    let req_start_ms = now_ms();
542    let request_id = uuid::Uuid::new_v4().to_string()[..8].to_owned();
543
544    let fp = router::fingerprint(&body_bytes);
545    let fp_ref = fp.as_deref();
546
547    let mut tried: HashSet<String> = HashSet::new();
548    // Track accounts we've already attempted a token refresh for this request.
549    let mut refreshed: HashSet<String> = HashSet::new();
550    // Cap wait to the configured request timeout — the client's TCP connection
551    // won't survive 5 hours anyway; return 503 so the client can retry.
552    let wait_deadline_ms = now_ms() + s.config.server.request_timeout_secs.saturating_mul(1_000);
553
554    loop {
555        let effective_strategy = s.state.get_routing_strategy()
556            .unwrap_or(s.config.server.routing_strategy);
557        let account = match router::pick_account(
558            &s.config.accounts, &s.state, fp_ref, &tried,
559            s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
560            effective_strategy,
561        ) {
562            Some(a) => a,
563            None => {
564                // Check whether any accounts are just temporarily cooling down
565                // (429/529 backoff) rather than permanently disabled / auth_failed.
566                // If so, wait for the soonest one to recover and retry.
567                let account_states = s.state.account_states();
568                let now = now_ms();
569                let soonest_ms = s.config.accounts.iter()
570                    .filter_map(|a| {
571                        let st = account_states.get(&a.name)?;
572                        if st.disabled { return None; } // auth_failed or permanently off
573                        if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
574                    })
575                    .min();
576
577                match soonest_ms {
578                    Some(wake_ms) if wake_ms <= wait_deadline_ms => {
579                        let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; // +50 ms buffer
580                        warn!(wait_ms, "all accounts cooling — waiting for next available account");
581                        tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
582                        tried.clear(); // accounts may have recovered; try them again
583                    }
584                    _ => return Err(ProxyError::AllAccountsUnavailable),
585                }
586                continue;
587            }
588        };
589
590        let account_name = account.name.clone();
591
592        // Use the live (possibly refreshed) token rather than the one baked into config.
593        // For OpenAI/chatgpt.com accounts, Credential::bearer_token() returns id_token
594        // (short-lived OIDC JWT) which chatgpt.com requires. For all other providers it
595        // returns access_token. API-key accounts return the key directly.
596        let token = {
597            let creds = s.credentials.read().await;
598            let cred = creds.get(&account_name)
599                .cloned()
600                .or_else(|| account.credential.clone());
601            match cred {
602                Some(c) => c.bearer_token().to_owned(),
603                None => String::new(),
604            }
605        };
606
607        // Detect request and account protocols.  When they differ, translate
608        // the request body + path before forwarding and translate the response
609        // back so the client always sees its native wire format.
610        let req_is_anthropic = path.starts_with("/v1/messages");
611        let acct_is_anthropic = account.provider.wire_protocol()
612            == crate::provider::WireProtocol::Anthropic;
613        // chatgpt.com (Provider::OpenAI) uses a proprietary backend-api path + sentinel token.
614        // All other OpenAI-compat providers (OpenAIApi, Groq, Mistral, …) use /v1/chat/completions.
615        let acct_is_chatgpt = matches!(account.provider, Provider::OpenAI);
616
617        // log_model: what we actually send to the upstream (after resolve_model).
618        // Defaults to the incoming model; overridden in the OpenAI-compat branch.
619        let mut log_model = model.clone();
620
621        let (fwd_path, fwd_body, mut fwd_headers) = if req_is_anthropic == acct_is_anthropic {
622            // Same wire protocol — pass through unchanged.
623            (path.clone(), body_bytes.clone(), headers.clone())
624        } else if req_is_anthropic && acct_is_chatgpt {
625            // Anthropic client → chatgpt.com account: translate to backend-api format.
626            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
627            let translated = translate_anthropic_req_to_chatgpt(&val);
628            let mut h = headers.clone();
629            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
630                h.remove(*name);
631            }
632            (
633                "/backend-api/conversation".to_owned(),
634                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
635                h,
636            )
637        } else if req_is_anthropic {
638            // Anthropic client → standard OpenAI-compat account (OpenAIApi, Groq, Mistral, …).
639            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
640            // Resolve the target model: account pin → global mapping → provider default.
641            let target_model = resolve_model(&model, account, &s.config.model_mapping);
642            log_model = target_model.clone();
643            let translated = translate_anthropic_req_to_openai(val, &target_model);
644            let mut h = headers.clone();
645            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
646                h.remove(*name);
647            }
648            (
649                "/v1/chat/completions".to_owned(),
650                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
651                h,
652            )
653        } else {
654            // OpenAI client → Anthropic account: translate O→A.
655            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
656            let translated = translate_to_anthropic(val);
657            (
658                "/v1/messages".to_owned(),
659                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
660                headers.clone(),
661            )
662        };
663
664        // Resolve upstream URL: per-account override (set at load time for non-primary
665        // providers, or explicitly in tests) → config server URL.
666        let upstream = account.upstream_url.as_deref()
667            .unwrap_or(&s.config.server.upstream_url);
668
669        // Inject chatgpt.com sentinel token — only for the chatgpt.com proprietary path.
670        // Wrap in tokio::time::timeout (3s) to guarantee we don't block on Cloudflare challenges.
671        if req_is_anthropic && acct_is_chatgpt {
672            tracing::info!(account = %account_name, upstream = %upstream, "routing to chatgpt.com — fetching sentinel");
673            let sentinel_client = reqwest::Client::builder()
674                .timeout(std::time::Duration::from_secs(3))
675                .build()
676                .unwrap_or_default();
677            let sentinel_opt = tokio::time::timeout(
678                std::time::Duration::from_secs(3),
679                fetch_sentinel_token(&sentinel_client, upstream, &token),
680            ).await.ok().flatten();
681            if let Some(sentinel) = sentinel_opt {
682                if let Ok(name) = axum::http::header::HeaderName::from_bytes(
683                    b"openai-sentinel-chat-requirements-token",
684                ) {
685                    if let Ok(val) = axum::http::HeaderValue::from_str(&sentinel) {
686                        fwd_headers.insert(name, val);
687                    }
688                }
689            }
690        }
691
692        // Apply a hard 15s cap only for chatgpt.com: Cloudflare may hold the TCP connection
693        // open indefinitely for certain TLS fingerprints.  Standard API providers don't need this.
694        let response = if acct_is_chatgpt {
695            tracing::info!(account = %account_name, path = %fwd_path, "forwarding to chatgpt.com (15s cap)");
696            match tokio::time::timeout(
697                std::time::Duration::from_secs(15),
698                s.forwarder.forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token),
699            ).await {
700                Ok(Ok(r)) => r,
701                Ok(Err(e)) => {
702                    error!(account = %account_name, "chatgpt.com forward error: {:#}", e);
703                    s.state.set_cooldown(&account_name, 5 * 60_000);
704                    tried.insert(account_name);
705                    continue;
706                }
707                Err(_) => {
708                    warn!(account = %account_name, "chatgpt.com request timed out (Cloudflare) — cooling 5min");
709                    s.state.set_cooldown(&account_name, 5 * 60_000);
710                    tried.insert(account_name);
711                    continue;
712                }
713            }
714        } else {
715            s.forwarder
716                .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
717                .await
718                .map_err(|e| {
719                    error!("Forward error: {:#}", e);
720                    ProxyError::Upstream
721                })?
722        };
723
724        match response.status().as_u16() {
725            200..=299 => {
726                s.state.set_last_used(&account_name);
727                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
728                    s.state.update_rate_limits(&account_name, info);
729                }
730                // Translate response back to the client's expected protocol.
731                let response = if req_is_anthropic == acct_is_anthropic {
732                    response
733                } else if req_is_anthropic && acct_is_chatgpt {
734                    // Got chatgpt.com response; client expects Anthropic.
735                    translate_response_chatgpt_to_anthropic(response, &model).await
736                } else if req_is_anthropic {
737                    // Got standard OpenAI-compat response; client expects Anthropic.
738                    translate_response_openai_to_anthropic(response, &model).await
739                } else {
740                    // Got Anthropic response; client expects OpenAI.
741                    translate_response_anthropic_to_openai(response).await
742                };
743                return Ok(tap_usage(response, &s.state, s.telemetry.as_ref(), &account_name, &log_model, req_start_ms, &request_id, &path, tried.len()).await);
744            }
745            429 => {
746                let info = account.provider.parse_rate_limits(response.headers());
747                // Sleep until the actual reset time if the headers tell us when that is.
748                // Fall back to Retry-After (per-minute rate limits), then 60s.
749                let retry_after_ms = response.headers()
750                    .get("retry-after")
751                    .and_then(|v| v.to_str().ok())
752                    .and_then(|s| s.parse::<u64>().ok())
753                    .map(|secs| secs.saturating_mul(1_000).max(500));
754                let cooldown_ms = info.as_ref()
755                    .and_then(|i| i.reset_5h.or(i.reset_7d))
756                    .map(|reset_secs| {
757                        let reset_ms = reset_secs.saturating_mul(1_000);
758                        reset_ms.saturating_sub(now_ms()).saturating_add(500) // +500ms buffer
759                    })
760                    .or(retry_after_ms)
761                    .unwrap_or(60_000);
762                warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
763                if let Some(info) = info {
764                    s.state.update_rate_limits(&account_name, info);
765                }
766                s.state.set_cooldown(&account_name, cooldown_ms);
767                if cooldown_ms >= 5 * 60_000 {
768                    let mins = cooldown_ms / 60_000;
769                    notify(
770                        "shunt: Rate Limited",
771                        &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
772                        "Ping",
773                    );
774                }
775                tried.insert(account_name);
776            }
777            529 => {
778                warn!(account = %account_name, "529 overloaded — cooling 30s");
779                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
780                    s.state.update_rate_limits(&account_name, info);
781                }
782                s.state.set_cooldown(&account_name, 30_000);
783                tried.insert(account_name);
784            }
785            401 => {
786                if !refreshed.contains(&account_name) {
787                    // Access token invalidated (e.g. user logged out) — try refresh.
788                    //
789                    // Acquire the per-account refresh lock so concurrent requests
790                    // for the same account serialise here. The first waiter to get
791                    // the lock does the actual OAuth refresh; subsequent waiters
792                    // re-check credentials and skip the refresh if the token was
793                    // already rotated while they were queued.
794                    let account_lock = {
795                        let mut locks = s.refresh_locks.lock();
796                        locks.entry(account_name.clone())
797                            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
798                            .clone()
799                    };
800                    let _guard = account_lock.lock().await;
801
802                    // Re-read credentials after acquiring the lock — another task
803                    // may have already refreshed while we were waiting.
804                    let cred_before = {
805                        let creds = s.credentials.read().await;
806                        creds.get(&account_name).cloned()
807                            .or_else(|| account.credential.clone())
808                    };
809                    let Some(cred) = cred_before else {
810                        tried.insert(account_name);
811                        continue;
812                    };
813
814                    // Check if the token already changed while we were waiting.
815                    let token_before = cred.access_token().to_owned();
816                    let already_refreshed = {
817                        let creds = s.credentials.read().await;
818                        creds.get(&account_name)
819                            .map(|c| c.access_token() != token_before)
820                            .unwrap_or(false)
821                    };
822
823                    if already_refreshed {
824                        // Another concurrent request already refreshed — just retry.
825                        warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
826                        refreshed.insert(account_name);
827                    } else if let Some(oauth_cred) = cred.as_oauth() {
828                        // OAuth account — attempt token refresh.
829                        match tokio::time::timeout(
830                            std::time::Duration::from_secs(10),
831                            account.provider.refresh_token(oauth_cred),
832                        ).await {
833                            Ok(Ok(fresh)) => {
834                                warn!(account = %account_name, "401 — token refreshed, retrying");
835                                {
836                                    let mut creds = s.credentials.write().await;
837                                    creds.insert(account_name.clone(), Credential::Oauth(fresh.clone()));
838                                }
839                                // Persist to disk so the refreshed token survives a restart.
840                                let name = account_name.clone();
841                                let fresh = fresh.clone();
842                                tokio::task::spawn_blocking(move || {
843                                    let mut store = CredentialsStore::load();
844                                    store.accounts.insert(name, Credential::Oauth(fresh.clone()));
845                                    store.save().ok();
846                                    if fresh.id_token.is_some() {
847                                        crate::oauth::write_codex_auth_file(&fresh);
848                                    }
849                                });
850                                // Mark as refreshed but don't add to tried — retry this account.
851                                refreshed.insert(account_name);
852                            }
853                            _ => {
854                                // Refresh failed/timed out — cool down, don't permanently disable.
855                                error!(account = %account_name, "401 — token refresh failed, cooling 5min");
856                                s.state.set_cooldown(&account_name, 5 * 60_000);
857                                tried.insert(account_name);
858                            }
859                        }
860                    } else {
861                        // API-key account — 401 means the key is invalid; no refresh possible.
862                        error!(account = %account_name, "401 — API key rejected, cooling 5min");
863                        s.state.set_cooldown(&account_name, 5 * 60_000);
864                        tried.insert(account_name);
865                    }
866                } else {
867                    // Already refreshed once and still 401 — cool down this account.
868                    error!(account = %account_name, "401 after refresh — cooling 5min");
869                    s.state.set_cooldown(&account_name, 5 * 60_000);
870                    tried.insert(account_name);
871                }
872            }
873            403 => {
874                // Forbidden — could be a Cloudflare challenge (non-Anthropic providers)
875                // or a genuine subscription/org block (Anthropic). Use a short cooldown
876                // for non-Anthropic accounts so a CF block doesn't lock them out for 30m.
877                if acct_is_anthropic {
878                    error!(account = %account_name, "403 forbidden — cooling 30min");
879                    s.state.set_cooldown(&account_name, 30 * 60_000);
880                    notify(
881                        "shunt: Account Forbidden",
882                        &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
883                        "Basso",
884                    );
885                } else {
886                    warn!(account = %account_name, "403 from chatgpt.com (Cloudflare) — cooling 5min");
887                    s.state.set_cooldown(&account_name, 5 * 60_000);
888                }
889                tried.insert(account_name);
890            }
891            _ => {
892                // 400, 404, 500, etc. — return as-is, no retry
893                return Ok(response);
894            }
895        }
896    }
897}
898
899// ---------------------------------------------------------------------------
900// Usage extraction
901// ---------------------------------------------------------------------------
902
903/// Intercept a successful response to record token usage, then pass it through.
904///
905/// - Streaming: wraps the body stream with an SSE scanner (zero latency).
906/// - Non-streaming: buffers the body, parses usage, rebuilds the response.
907async fn tap_usage(
908    resp: Response,
909    state: &StateStore,
910    telemetry: Option<&TelemetryClient>,
911    account: &str,
912    model: &str,
913    req_start_ms: u64,
914    request_id: &str,
915    path: &str,
916    retries: usize,
917) -> Response {
918    use axum::body::Body;
919    use crate::state::RequestLog;
920
921    let streaming = quota::is_streaming_response(&resp);
922
923    if streaming {
924        let state      = state.clone();
925        let telem      = telemetry.cloned();
926        let account    = account.to_owned();
927        let model      = model.to_owned();
928        let request_id = request_id.to_owned();
929        let path       = path.to_owned();
930        let on_complete = Arc::new(move |input: u64, output: u64| {
931            let duration_ms = now_ms().saturating_sub(req_start_ms);
932            info!(
933                request_id = %request_id,
934                account    = %account,
935                model      = %model,
936                status     = 200,
937                latency_ms = duration_ms,
938                path       = %path,
939                stream     = true,
940                input_tokens  = input,
941                output_tokens = output,
942                retries    = retries,
943                "request complete"
944            );
945            let log = RequestLog {
946                ts_ms: req_start_ms,
947                account: account.clone(),
948                model: model.clone(),
949                status: 200,
950                input_tokens: input,
951                output_tokens: output,
952                duration_ms,
953            };
954            state.record_usage(&account, input, output);
955            state.record_global(&model, input, output);
956            if let Some(ref t) = telem { t.push_event(&log); }
957            state.record_request(log);
958        });
959        let (parts, body) = resp.into_parts();
960        let wrapped = quota::wrap_streaming_body(body, on_complete);
961        return Response::from_parts(parts, wrapped);
962    }
963
964    // Non-streaming: buffer, extract, rebuild
965    let (parts, body) = resp.into_parts();
966    let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
967        Ok(b) => b,
968        Err(_) => return Response::from_parts(parts, Body::empty()),
969    };
970    let (input, output) = quota::extract_usage_from_json(&bytes);
971    let duration_ms = now_ms().saturating_sub(req_start_ms);
972    info!(
973        request_id    = %request_id,
974        account       = %account,
975        model         = %model,
976        status        = 200,
977        latency_ms    = duration_ms,
978        path          = %path,
979        stream        = false,
980        input_tokens  = input,
981        output_tokens = output,
982        retries       = retries,
983        "request complete"
984    );
985    let log = RequestLog {
986        ts_ms: req_start_ms,
987        account: account.to_owned(),
988        model: model.to_owned(),
989        status: 200,
990        input_tokens: input,
991        output_tokens: output,
992        duration_ms,
993    };
994    state.record_usage(account, input, output);
995    state.record_global(model, input, output);
996    if let Some(t) = telemetry { t.push_event(&log); }
997    state.record_request(log);
998    Response::from_parts(parts, Body::from(bytes))
999}
1000
1001
1002// ---------------------------------------------------------------------------
1003// Rate limit prefetch
1004// ---------------------------------------------------------------------------
1005
1006/// For any account with no rate-limit data yet, make a cheap request directly
1007/// to the upstream API so we populate metrics without waiting for a real user
1008/// request. Runs as a background task after startup.
1009pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
1010    let client = reqwest::Client::builder()
1011        .timeout(std::time::Duration::from_secs(20))
1012        .build()
1013        .unwrap_or_default();
1014
1015    for account in &config.accounts {
1016        // Skip if we already have data for this account.
1017        let rl = state.rate_limit_snapshot();
1018        if let Some(r) = rl.get(&account.name) {
1019            if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
1020                continue;
1021            }
1022        }
1023
1024        // Skip accounts with no credentials or no prefetch support.
1025        let cred = match account.credential.clone() {
1026            Some(c) => c,
1027            None => continue,
1028        };
1029
1030        let Some((path, body)) = account.provider.prefetch_request() else {
1031            // No POST prefetch for this provider — do a lightweight GET auth check instead.
1032            if let Some(probe_path) = account.provider.auth_probe_get_path() {
1033                auth_probe_get(&client, probe_path, account, &state).await;
1034            }
1035            continue;
1036        };
1037        let url = format!("{}{}", config.server.upstream_url, path);
1038
1039        let resp = prefetch_send(&client, &url, &account.provider, cred.bearer_token(), &body).await;
1040
1041        let r = match resp {
1042            Ok(r) => r,
1043            Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
1044        };
1045
1046        if r.status() == reqwest::StatusCode::UNAUTHORIZED {
1047            tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
1048            let Some(oauth_cred) = cred.as_oauth() else {
1049                // API-key account — 401 during prefetch means the key is invalid.
1050                tracing::error!(account = %account.name, "prefetch 401 — API key rejected");
1051                state.set_auth_failed(&account.name);
1052                continue;
1053            };
1054            let fresh = match account.provider.refresh_token(oauth_cred).await {
1055                Ok(f) => f,
1056                Err(e) => {
1057                    tracing::warn!(account = %account.name, "token refresh failed: {e}");
1058                    state.set_auth_failed(&account.name);
1059                    continue;
1060                }
1061            };
1062            let mut store = crate::config::CredentialsStore::load();
1063            store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1064            store.save().ok();
1065            if fresh.id_token.is_some() {
1066                crate::oauth::write_codex_auth_file(&fresh);
1067            }
1068            // Update live credentials so the proxy uses the fresh token immediately.
1069            live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1070
1071            match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
1072                Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
1073                    tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
1074                    state.set_auth_failed(&account.name);
1075                }
1076                Ok(r2) => {
1077                    if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
1078                        state.update_rate_limits(&account.name, info);
1079                    }
1080                }
1081                Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
1082            }
1083        } else {
1084            tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
1085            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1086                state.update_rate_limits(&account.name, info);
1087            }
1088        }
1089    }
1090}
1091
1092/// Build and send a prefetch request for the given provider + token.
1093async fn prefetch_send(
1094    client: &reqwest::Client,
1095    url: &str,
1096    provider: &crate::provider::Provider,
1097    token: &str,
1098    body: &serde_json::Value,
1099) -> anyhow::Result<reqwest::Response> {
1100    let mut headers = reqwest::header::HeaderMap::new();
1101    provider.inject_auth_headers(&mut headers, token)?;
1102    for (name, value) in provider.prefetch_extra_headers() {
1103        headers.insert(
1104            reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
1105            reqwest::header::HeaderValue::from_static(value),
1106        );
1107    }
1108    Ok(client.post(url).headers(headers).json(body).send().await?)
1109}
1110
1111/// GET a cheap endpoint to verify credentials are still valid for providers that
1112/// don't expose rate-limit headers (e.g. OpenAI). On 401, attempts a token refresh;
1113/// marks the account as `reauth_required` if the refresh also fails.
1114async fn auth_probe_get(
1115    client: &reqwest::Client,
1116    path: &str,
1117    account: &crate::config::AccountConfig,
1118    state: &StateStore,
1119) {
1120    let cred = match account.credential.clone() {
1121        Some(c) => c,
1122        None => return,
1123    };
1124    let upstream = account.upstream_url.as_deref()
1125        .unwrap_or_else(|| account.provider.default_upstream_url());
1126    let url = format!("{}{}", upstream, path);
1127
1128    let do_get = |token: &str| -> reqwest::RequestBuilder {
1129        let mut headers = reqwest::header::HeaderMap::new();
1130        let _ = account.provider.inject_auth_headers(&mut headers, token);
1131        client.get(&url).headers(headers)
1132    };
1133
1134    let resp = match do_get(cred.bearer_token()).send().await {
1135        Ok(r) => r,
1136        Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
1137    };
1138
1139    if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
1140        tracing::info!(account = %account.name, "auth probe: token rejected, refreshing");
1141        let Some(oauth_cred) = cred.as_oauth() else {
1142            // API-key account — key is invalid; no refresh possible.
1143            tracing::error!(account = %account.name, "auth probe 401 — API key rejected");
1144            state.set_auth_failed(&account.name);
1145            return;
1146        };
1147        let fresh = match account.provider.refresh_token(oauth_cred).await {
1148            Ok(f) => f,
1149            Err(e) => {
1150                tracing::warn!(account = %account.name, "token refresh failed: {e}");
1151                state.set_auth_failed(&account.name);
1152                return;
1153            }
1154        };
1155        let mut store = crate::config::CredentialsStore::load();
1156        store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1157        store.save().ok();
1158        if fresh.id_token.is_some() {
1159            crate::oauth::write_codex_auth_file(&fresh);
1160        }
1161
1162        let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
1163        match do_get(fresh_token).send().await {
1164            Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
1165                tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
1166                state.set_auth_failed(&account.name);
1167            }
1168            Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
1169            Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
1170        }
1171    } else {
1172        tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
1173        // Access token is valid. Do NOT refresh here — rotating the refresh_token races
1174        // with codex CLI, which also tries to refresh at startup using the same token.
1175        // Proactive refreshing is handled solely by openai_token_refresh_loop.
1176    }
1177}
1178
1179// ---------------------------------------------------------------------------
1180// Proactive OpenAI token refresh loop
1181// ---------------------------------------------------------------------------
1182
1183/// Returns true if the access_token inside `cred` has fewer than `threshold_mins`
1184/// minutes remaining. Falls back to the stored `expires_at` if the JWT cannot be decoded.
1185fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
1186    let now_ms = std::time::SystemTime::now()
1187        .duration_since(std::time::UNIX_EPOCH)
1188        .unwrap_or_default()
1189        .as_millis() as u64;
1190    let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
1191        .unwrap_or(cred.expires_at);
1192    exp_ms < now_ms + threshold_mins * 60 * 1_000
1193}
1194
1195/// Sync live_creds from auth.json if auth.json has a newer token.
1196///
1197/// Codex CLI refreshes its own token and writes auth.json. Before we refresh,
1198/// we pull that in so we don't use a stale refresh_token that codex already rotated.
1199async fn sync_live_creds_from_auth_json(
1200    account_name: &str,
1201    live_creds: &LiveCredentials,
1202) {
1203    let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
1204    let current_exp = live_creds.read().await
1205        .get(account_name)
1206        .and_then(|c| c.as_oauth())
1207        .map(|c| c.expires_at)
1208        .unwrap_or(0);
1209    if from_file.expires_at > current_exp {
1210        tracing::info!(account = %account_name, "synced fresher token from auth.json");
1211        live_creds.write().await.insert(account_name.to_owned(), Credential::Oauth(from_file));
1212    }
1213}
1214
1215/// Perform a single proactive refresh for one account and persist the result.
1216async fn do_proactive_refresh(
1217    account: &crate::config::AccountConfig,
1218    creds: &crate::oauth::OAuthCredential,
1219    live_creds: &LiveCredentials,
1220    state: &StateStore,
1221) {
1222    tracing::info!(account = %account.name, "proactive OpenAI token refresh");
1223    match account.provider.refresh_token(creds).await {
1224        Ok(fresh) => {
1225            tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
1226            {
1227                let mut map = live_creds.write().await;
1228                map.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1229            }
1230            let mut store = crate::config::CredentialsStore::load();
1231            store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1232            store.save().ok();
1233            if fresh.id_token.is_some() {
1234                crate::oauth::write_codex_auth_file(&fresh);
1235            }
1236            state.clear_auth_failed(&account.name);
1237        }
1238        Err(e) => {
1239            tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
1240            state.set_auth_failed(&account.name);
1241        }
1242    }
1243}
1244
1245
1246/// Keeps shunt's live credentials in sync with Codex CLI's auth.json.
1247///
1248/// Strategy: never proactively rotate the refresh_token — that races with
1249/// Codex CLI's own refresh logic and causes "invalid_grant" errors. Instead,
1250/// just periodically sync from auth.json so shunt picks up whatever Codex wrote.
1251/// On-demand refresh (401 handler) covers the case where Codex isn't running
1252/// and the token has actually expired.
1253pub async fn openai_token_refresh_loop(
1254    config: Arc<Config>,
1255    state: StateStore,
1256    live_creds: LiveCredentials,
1257) {
1258    // Startup: sync from auth.json first (Codex may have refreshed since shunt last ran).
1259    for account in config.accounts.iter()
1260        .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1261    {
1262        if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
1263            continue;
1264        }
1265        sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1266
1267        let creds = {
1268            let map = live_creds.read().await;
1269            map.get(&account.name).cloned().or_else(|| account.credential.clone())
1270        };
1271        if let Some(creds) = creds {
1272            if let Some(oauth) = creds.as_oauth() {
1273                if access_token_expires_soon(oauth, 30) {
1274                    // access_token is nearly expired — refresh now so shunt can serve requests immediately.
1275                    do_proactive_refresh(account, oauth, &live_creds, &state).await;
1276                } else {
1277                    tracing::info!(account = %account.name, "access_token fresh at startup");
1278                }
1279            }
1280        }
1281    }
1282
1283    // Periodic sync every 5 minutes — picks up any token Codex CLI has written.
1284    // No proactive refresh: Codex owns the refresh lifecycle; shunt uses what Codex produces.
1285    loop {
1286        tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
1287        for account in config.accounts.iter()
1288            .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1289        {
1290            sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1291        }
1292    }
1293}
1294
1295// ---------------------------------------------------------------------------
1296// Error type
1297// ---------------------------------------------------------------------------
1298
1299enum ProxyError {
1300    BodyRead,
1301    Upstream,
1302    AllAccountsUnavailable,
1303    Unauthorized,
1304    RateLimited,
1305}
1306
1307impl IntoResponse for ProxyError {
1308    fn into_response(self) -> Response {
1309        match self {
1310            ProxyError::RateLimited => {
1311                let mut resp = (
1312                    StatusCode::TOO_MANY_REQUESTS,
1313                    axum::Json(json!({
1314                        "type": "error",
1315                        "error": {"type": "rate_limit_error", "message": "too many requests — slow down"}
1316                    })),
1317                ).into_response();
1318                resp.headers_mut().insert(
1319                    axum::http::header::RETRY_AFTER,
1320                    axum::http::HeaderValue::from_static("60"),
1321                );
1322                resp
1323            }
1324            other => {
1325                let (status, msg) = match other {
1326                    ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
1327                    ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
1328                    ProxyError::AllAccountsUnavailable => {
1329                        (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
1330                    }
1331                    ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
1332                    ProxyError::RateLimited => unreachable!(),
1333                };
1334                (status, axum::Json(json!({
1335                    "type": "error",
1336                    "error": {"type": "api_error", "message": msg}
1337                }))).into_response()
1338            }
1339        }
1340    }
1341}
1342
1343// ---------------------------------------------------------------------------
1344// Recovery watcher — periodically retries token refresh for auth_failed accounts
1345// ---------------------------------------------------------------------------
1346
1347/// Runs as a background task. Every 2 minutes, tries to refresh tokens for any
1348/// auth_failed account. If refresh succeeds the account is brought back online
1349/// without a process restart. If all accounts remain unrecoverable, fires a
1350/// macOS notification (at most once per hour).
1351pub async fn recovery_watcher(
1352    config: Arc<Config>,
1353    state: StateStore,
1354    credentials: LiveCredentials,
1355) {
1356    use std::time::{Duration, Instant};
1357    const CHECK_INTERVAL: Duration = Duration::from_secs(120);
1358    const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
1359
1360    let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
1361    let mut last_notified: Option<Instant> = None;
1362
1363    loop {
1364        tokio::time::sleep(CHECK_INTERVAL).await;
1365
1366        let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
1367        let failed = state.auth_failed_accounts(&name_refs);
1368        if failed.is_empty() {
1369            last_notified = None;
1370            continue;
1371        }
1372
1373        tracing::warn!(
1374            accounts = ?failed,
1375            "recovery: {} account(s) auth_failed, attempting token refresh",
1376            failed.len()
1377        );
1378
1379        let mut any_recovered = false;
1380
1381        for name in &failed {
1382            let cred = {
1383                let map = credentials.read().await;
1384                map.get(*name).cloned()
1385            };
1386            let Some(cred) = cred else { continue };
1387            if !cred.has_refresh_token() { continue; }
1388            let Some(oauth_cred) = cred.as_oauth().cloned() else { continue };
1389
1390            let provider = config.accounts.iter()
1391                .find(|a| a.name == *name)
1392                .map(|a| a.provider.clone())
1393                .unwrap_or_default();
1394
1395            let result = tokio::time::timeout(
1396                Duration::from_secs(20),
1397                provider.refresh_token(&oauth_cred),
1398            ).await;
1399
1400            match result {
1401                Ok(Ok(fresh)) => {
1402                    tracing::info!(account = %name, "recovery: token refreshed — account back online");
1403                    {
1404                        let mut map = credentials.write().await;
1405                        map.insert(name.to_string(), Credential::Oauth(fresh.clone()));
1406                    }
1407                    let name_owned = name.to_string();
1408                    let fresh_owned = fresh.clone();
1409                    tokio::task::spawn_blocking(move || {
1410                        let mut store = crate::config::CredentialsStore::load();
1411                        store.accounts.insert(name_owned, Credential::Oauth(fresh_owned.clone()));
1412                        store.save().ok();
1413                        if fresh_owned.id_token.is_some() {
1414                            crate::oauth::write_codex_auth_file(&fresh_owned);
1415                        }
1416                    });
1417                    state.clear_auth_failed(name);
1418                    any_recovered = true;
1419                }
1420                Ok(Err(e)) => {
1421                    tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
1422                    notify(
1423                        "shunt: Reauth Required",
1424                        &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
1425                        "Basso",
1426                    );
1427                }
1428                Err(_) => {
1429                    tracing::error!(account = %name, "recovery: token refresh timed out");
1430                    notify(
1431                        "shunt: Reauth Required",
1432                        &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
1433                        "Basso",
1434                    );
1435                }
1436            }
1437        }
1438
1439        if any_recovered {
1440            tracing::info!("recovery: at least one account is back online");
1441            continue;
1442        }
1443
1444        // All accounts still auth_failed after refresh attempts — notify.
1445        let still_failed = state.auth_failed_accounts(&name_refs);
1446        if still_failed.len() == account_names.len() {
1447            let should_notify = last_notified
1448                .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
1449                .unwrap_or(true);
1450            if should_notify {
1451                error!(
1452                    "ALL accounts are offline (auth failed). \
1453                     Run `shunt add-account` to re-authorize."
1454                );
1455                notify(
1456                    "shunt: All Accounts Offline",
1457                    "All accounts need re-authorization. Run `shunt add-account`.",
1458                    "Basso",
1459                );
1460                last_notified = Some(Instant::now());
1461            }
1462        }
1463    }
1464}
1465
1466/// Sends a single lightweight prefetch request for `account` immediately after its
1467/// cooldown expires, so the router has fresh rate-limit headers before the next
1468/// real request arrives.
1469async fn post_cooldown_prefetch(
1470    client: &reqwest::Client,
1471    account: &crate::config::AccountConfig,
1472    token: &str,
1473    state: &StateStore,
1474    upstream_url: &str,
1475) {
1476    let Some((path, body)) = account.provider.prefetch_request() else {
1477        if let Some(probe_path) = account.provider.auth_probe_get_path() {
1478            auth_probe_get(client, probe_path, account, state).await;
1479        }
1480        return;
1481    };
1482    let url = format!("{upstream_url}{path}");
1483    match prefetch_send(client, &url, &account.provider, token, &body).await {
1484        Ok(r) => {
1485            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1486                state.update_rate_limits(&account.name, info);
1487                tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1488            }
1489        }
1490        Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1491    }
1492}
1493
1494/// Periodic health-check loop: probes every account at a configurable interval
1495/// to detect dead/invalid accounts before real traffic hits them.
1496///
1497/// Uses exponential backoff (base_interval * 2^min(failures, 3)) per account,
1498/// capped at ~40 min. Marks accounts as `health_check_failed` after 2 consecutive
1499/// failures (tolerates one transient blip). On 401, delegates to `set_auth_failed`.
1500pub async fn health_check_loop(
1501    config: Arc<Config>,
1502    state: StateStore,
1503    live_creds: LiveCredentials,
1504) {
1505    if !config.server.health_check_enabled {
1506        return;
1507    }
1508
1509    // Let prefetch_rate_limits finish first.
1510    tokio::time::sleep(std::time::Duration::from_secs(15)).await;
1511
1512    let base_interval_ms = config.server.health_check_interval_secs * 1000;
1513    let timeout = std::time::Duration::from_secs(config.server.health_check_timeout_secs);
1514    let client = reqwest::Client::builder()
1515        .timeout(timeout)
1516        .build()
1517        .unwrap_or_default();
1518
1519    const FAILURE_THRESHOLD: u32 = 2;
1520    const MAX_BACKOFF_EXP: u32 = 3; // 2^3 = 8x → 40 min at 5-min base
1521
1522    loop {
1523        for account in &config.accounts {
1524            // Skip accounts already handled by recovery_watcher.
1525            {
1526                let states = state.account_states();
1527                if let Some(acc_state) = states.get(&account.name) {
1528                    if acc_state.disabled || acc_state.auth_failed {
1529                        continue;
1530                    }
1531                }
1532            }
1533
1534            // Exponential backoff based on consecutive failure count.
1535            let (last_check_ms, failures) = state.health_check_info(&account.name);
1536            let backoff_factor = 1u64 << failures.min(MAX_BACKOFF_EXP);
1537            let effective_interval_ms = base_interval_ms.saturating_mul(backoff_factor);
1538            let now = crate::state::now_ms_pub();
1539            if last_check_ms > 0 && now.saturating_sub(last_check_ms) < effective_interval_ms {
1540                continue;
1541            }
1542
1543            state.update_last_health_check(&account.name);
1544
1545            // Resolve current credential from live_creds (may have been refreshed).
1546            let cred = {
1547                let creds = live_creds.read().await;
1548                creds.get(&account.name).cloned()
1549            }.or_else(|| account.credential.clone());
1550
1551            let cred = match cred {
1552                Some(c) => c,
1553                None => {
1554                    // Local providers have no cred — probe reachability via GET /v1/models.
1555                    if let Some(probe_path) = account.provider.auth_probe_get_path() {
1556                        let upstream = account.upstream_url.as_deref()
1557                            .unwrap_or_else(|| account.provider.default_upstream_url());
1558                        let url = format!("{upstream}{probe_path}");
1559                        match client.get(&url).send().await {
1560                            Ok(r) if r.status().is_success() => {
1561                                if state.is_health_check_failed(&account.name) {
1562                                    tracing::info!(account = %account.name, "health check recovered");
1563                                }
1564                                state.clear_health_check_failed(&account.name);
1565                            }
1566                            Ok(r) => {
1567                                let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1568                                tracing::warn!(account = %account.name, status = %r.status(),
1569                                    failures = count, "health check failed");
1570                            }
1571                            Err(e) => {
1572                                let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1573                                tracing::warn!(account = %account.name, failures = count,
1574                                    "health check unreachable: {e}");
1575                            }
1576                        }
1577                    }
1578                    continue;
1579                }
1580            };
1581
1582            let token = cred.bearer_token();
1583            let upstream = account.upstream_url.as_deref()
1584                .unwrap_or(&config.server.upstream_url);
1585
1586            // Try POST prefetch (Anthropic) or GET auth probe (other providers).
1587            if let Some((path, body)) = account.provider.prefetch_request() {
1588                let url = format!("{upstream}{path}");
1589                match prefetch_send(&client, &url, &account.provider, token, &body).await {
1590                    Ok(r) => {
1591                        let status = r.status();
1592                        if status == reqwest::StatusCode::UNAUTHORIZED {
1593                            // Attempt refresh for OAuth accounts.
1594                            if let Some(oauth_cred) = cred.as_oauth() {
1595                                match account.provider.refresh_token(oauth_cred).await {
1596                                    Ok(fresh) => {
1597                                        let mut store = crate::config::CredentialsStore::load();
1598                                        store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1599                                        store.save().ok();
1600                                        live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh));
1601                                        state.clear_auth_failed(&account.name);
1602                                        if state.is_health_check_failed(&account.name) {
1603                                            state.clear_health_check_failed(&account.name);
1604                                        }
1605                                        tracing::info!(account = %account.name, "health check: token refreshed");
1606                                    }
1607                                    Err(e) => {
1608                                        tracing::error!(account = %account.name, "health check: refresh failed: {e}");
1609                                        state.set_auth_failed(&account.name);
1610                                    }
1611                                }
1612                            } else {
1613                                tracing::error!(account = %account.name, "health check: 401 — API key rejected");
1614                                state.set_auth_failed(&account.name);
1615                            }
1616                        } else if status.is_server_error() {
1617                            let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1618                            tracing::warn!(account = %account.name, status = %status,
1619                                failures = count, "health check: server error");
1620                        } else {
1621                            // Success — update rate limits if available.
1622                            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1623                                state.update_rate_limits(&account.name, info);
1624                            }
1625                            if state.is_health_check_failed(&account.name) {
1626                                tracing::info!(account = %account.name, "health check recovered");
1627                            }
1628                            state.clear_health_check_failed(&account.name);
1629                        }
1630                    }
1631                    Err(e) => {
1632                        let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1633                        tracing::warn!(account = %account.name, failures = count,
1634                            "health check probe failed: {e}");
1635                    }
1636                }
1637            } else if let Some(probe_path) = account.provider.auth_probe_get_path() {
1638                let probe_upstream = account.upstream_url.as_deref()
1639                    .unwrap_or_else(|| account.provider.default_upstream_url());
1640                let url = format!("{probe_upstream}{probe_path}");
1641                let mut headers = reqwest::header::HeaderMap::new();
1642                let _ = account.provider.inject_auth_headers(&mut headers, token);
1643                match client.get(&url).headers(headers).send().await {
1644                    Ok(r) => {
1645                        let status = r.status();
1646                        if status == reqwest::StatusCode::UNAUTHORIZED {
1647                            if let Some(oauth_cred) = cred.as_oauth() {
1648                                match account.provider.refresh_token(oauth_cred).await {
1649                                    Ok(fresh) => {
1650                                        let mut store = crate::config::CredentialsStore::load();
1651                                        store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1652                                        store.save().ok();
1653                                        live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh));
1654                                        state.clear_auth_failed(&account.name);
1655                                        state.clear_health_check_failed(&account.name);
1656                                        tracing::info!(account = %account.name, "health check: token refreshed (GET probe)");
1657                                    }
1658                                    Err(e) => {
1659                                        tracing::error!(account = %account.name, "health check: refresh failed: {e}");
1660                                        state.set_auth_failed(&account.name);
1661                                    }
1662                                }
1663                            } else {
1664                                tracing::error!(account = %account.name, "health check: 401 — API key rejected");
1665                                state.set_auth_failed(&account.name);
1666                            }
1667                        } else if status.is_server_error() {
1668                            let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1669                            tracing::warn!(account = %account.name, status = %status,
1670                                failures = count, "health check: server error (GET probe)");
1671                        } else {
1672                            if state.is_health_check_failed(&account.name) {
1673                                tracing::info!(account = %account.name, "health check recovered");
1674                            }
1675                            state.clear_health_check_failed(&account.name);
1676                        }
1677                    }
1678                    Err(e) => {
1679                        let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1680                        tracing::warn!(account = %account.name, failures = count,
1681                            "health check probe failed: {e}");
1682                    }
1683                }
1684            }
1685        }
1686
1687        tokio::time::sleep(std::time::Duration::from_secs(config.server.health_check_interval_secs)).await;
1688    }
1689}
1690
1691/// Watches for account cooldowns expiring and triggers a post-cooldown prefetch
1692/// so each account re-enters rotation with fresh rate-limit metrics.
1693///
1694/// Analogous to `recovery_watcher` (which handles `auth_failed` accounts), but
1695/// for timed cooldowns (429 / 529 / 401 / 403 backoffs). Sleeps precisely until
1696/// the next cooldown deadline rather than polling at a fixed interval.
1697///
1698/// Also handles stale rate-limit data: if an account's rate-limit snapshot is
1699/// older than STALE_RL_MS and the account is available, a lightweight prefetch
1700/// is triggered so the router always has fresh utilization metrics.
1701pub async fn cooldown_watcher(
1702    config: Arc<Config>,
1703    state: StateStore,
1704    credentials: LiveCredentials,
1705) {
1706    /// Re-fetch rate-limit headers if data is older than 1 hour.
1707    const STALE_RL_MS: u64 = 60 * 60_000;
1708
1709    let client = reqwest::Client::builder()
1710        .timeout(std::time::Duration::from_secs(20))
1711        .build()
1712        .unwrap_or_default();
1713
1714    // In-memory: the cooldown_until_ms value we already ran a post-resume for.
1715    // Prevents re-triggering on every poll after expiry.
1716    let mut last_resumed: HashMap<String, u64> = HashMap::new();
1717    // Accounts whose cooldown was long enough (≥5 min) to deserve a "back online" notification.
1718    let mut notify_on_resume: HashSet<String> = HashSet::new();
1719    // Epoch-ms of the last successful stale-prefetch per account.
1720    let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1721
1722    loop {
1723        let states = state.account_states();
1724        let rl_snapshot = state.rate_limit_snapshot();
1725        let now = now_ms();
1726        let mut next_wake_ms: Option<u64> = None;
1727
1728        for account in &config.accounts {
1729            let Some(st) = states.get(&account.name) else { continue };
1730            if st.disabled { continue; } // auth_failed or permanently disabled
1731            let cdl = st.cooldown_until_ms;
1732
1733            if cdl > 0 && cdl <= now {
1734                // Cooldown expired — skip if we already handled this exact deadline
1735                let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1736                if !handled {
1737                    tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1738                    let token = {
1739                        let creds = credentials.read().await;
1740                        creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1741                    };
1742                    if let Some(token) = token {
1743                        post_cooldown_prefetch(
1744                            &client, account, &token, &state,
1745                            &config.server.upstream_url,
1746                        ).await;
1747                    }
1748                    if notify_on_resume.remove(&account.name) {
1749                        notify(
1750                            "shunt: Account Resumed",
1751                            &format!("Account '{}' is back online.", account.name),
1752                            "Glass",
1753                        );
1754                    }
1755                    last_resumed.insert(account.name.clone(), cdl);
1756                    last_stale_prefetch.insert(account.name.clone(), now);
1757                }
1758            } else if cdl > now {
1759                // Still cooling — schedule wake at expiry; flag for notification if long
1760                let remaining = cdl - now;
1761                if remaining >= 5 * 60_000 {
1762                    notify_on_resume.insert(account.name.clone());
1763                }
1764                next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1765            } else {
1766                // Not in cooldown — check for stale rate-limit data
1767                let rl_age = rl_snapshot
1768                    .get(&account.name)
1769                    .map(|r| now.saturating_sub(r.updated_ms))
1770                    .unwrap_or(u64::MAX); // no data → treat as infinitely stale
1771                let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1772                let fetched_ago = now.saturating_sub(last_fetched);
1773
1774                if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1775                    tracing::debug!(
1776                        account = %account.name,
1777                        age_min = rl_age / 60_000,
1778                        "rate-limit data stale — refreshing"
1779                    );
1780                    let token = {
1781                        let creds = credentials.read().await;
1782                        creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1783                    };
1784                    if let Some(token) = token {
1785                        post_cooldown_prefetch(
1786                            &client, account, &token, &state,
1787                            &config.server.upstream_url,
1788                        ).await;
1789                    }
1790                    last_stale_prefetch.insert(account.name.clone(), now);
1791                }
1792            }
1793        }
1794
1795        // Sleep exactly until the next cooldown expires; fall back to 30s poll
1796        let sleep_ms = next_wake_ms
1797            .map(|wake| wake.saturating_sub(now_ms()).max(50))
1798            .unwrap_or(30_000);
1799        tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1800    }
1801}
1802
1803use crate::notify::notify;
1804use crate::translate::{
1805    translate_to_anthropic,
1806    translate_from_anthropic,
1807    uuid_v4,
1808    translate_anthropic_stream,
1809    translate_anthropic_req_to_chatgpt,
1810    translate_response_chatgpt_to_anthropic,
1811    translate_anthropic_req_to_openai,
1812    translate_response_openai_to_anthropic,
1813    translate_response_anthropic_to_openai,
1814};
1815
1816// ---------------------------------------------------------------------------
1817// OpenAI-compatible API (translates to Anthropic Claude)
1818// ---------------------------------------------------------------------------
1819//
1820// When the OpenAI proxy receives a request at /v1/chat/completions, if an
1821// anthropic_base_url is configured, it translates the request to Anthropic
1822// Messages format and forwards it to the Anthropic proxy (which handles
1823// account selection, token management, and rate limiting).
1824// The response is translated back to OpenAI Chat Completions format.
1825
1826
1827
1828
1829/// GET /v1/models — return Claude models in OpenAI format.
1830async fn openai_models_handler() -> impl IntoResponse {
1831    axum::Json(json!({
1832        "object": "list",
1833        "data": [
1834            { "id": "claude-opus-4-6",           "object": "model", "owned_by": "anthropic" },
1835            { "id": "claude-sonnet-4-6",          "object": "model", "owned_by": "anthropic" },
1836            { "id": "claude-haiku-4-5-20251001",  "object": "model", "owned_by": "anthropic" },
1837        ]
1838    }))
1839}
1840
1841/// POST /v1/chat/completions — translate OpenAI request to Anthropic, proxy through Claude pool.
1842async fn openai_compat_handler(
1843    State(s): State<AppState>,
1844    req: Request,
1845) -> Result<Response, ProxyError> {
1846    let Some(ref anthropic_url) = s.anthropic_base_url else {
1847        // No Anthropic proxy configured — fall back to normal forwarding
1848        return proxy_handler(State(s), req).await;
1849    };
1850
1851    let body_bytes = axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY)
1852        .await
1853        .map_err(|_| ProxyError::BodyRead)?;
1854
1855    let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1856        .unwrap_or(json!({}));
1857
1858    let stream = openai_body["stream"].as_bool().unwrap_or(false);
1859    let anthropic_body = translate_to_anthropic(openai_body);
1860
1861    let client = reqwest::Client::builder()
1862        .timeout(std::time::Duration::from_secs(300))
1863        .build()
1864        .map_err(|_| ProxyError::Upstream)?;
1865
1866    let mut req_builder = client
1867        .post(format!("{anthropic_url}/v1/messages"))
1868        .header("content-type", "application/json")
1869        .header("anthropic-version", "2023-06-01")
1870        .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1871        .header("x-shunt-compat", "openai");
1872    if let Some(ref key) = s.config.server.remote_key {
1873        req_builder = req_builder.header("x-api-key", key.as_str());
1874    }
1875    let resp = req_builder
1876        .json(&anthropic_body)
1877        .send()
1878        .await
1879        .map_err(|_| ProxyError::Upstream)?;
1880
1881    if !resp.status().is_success() {
1882        let status = resp.status();
1883        let body = resp.text().await.unwrap_or_default();
1884        let code = status.as_u16();
1885        return Ok(axum::response::Response::builder()
1886            .status(code)
1887            .header("content-type", "application/json")
1888            .body(axum::body::Body::from(body))
1889            .unwrap());
1890    }
1891
1892    if stream {
1893        // Translate Anthropic SSE stream → OpenAI SSE stream
1894        let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1895        let stream = translate_anthropic_stream(resp, chat_id);
1896        Ok(axum::response::Response::builder()
1897            .status(200)
1898            .header("content-type", "text/event-stream")
1899            .header("cache-control", "no-cache")
1900            .body(axum::body::Body::from_stream(stream))
1901            .unwrap())
1902    } else {
1903        let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1904        let openai_resp = translate_from_anthropic(anthropic_resp);
1905        Ok(axum::Json(openai_resp).into_response())
1906    }
1907}
1908
1909// ---------------------------------------------------------------------------
1910// ChatGPT backend API translation (chatgpt.com /backend-api/conversation)
1911// ---------------------------------------------------------------------------
1912
1913/// Fetch the sentinel token required by chatgpt.com's backend API.
1914/// Returns None if the request fails or proof-of-work is required.
1915async fn fetch_sentinel_token(client: &reqwest::Client, upstream: &str, token: &str) -> Option<String> {
1916    let url = format!("{}/backend-api/sentinel/chat-requirements", upstream);
1917    let resp = client
1918        .get(&url)
1919        .header("Authorization", format!("Bearer {}", token))
1920        .send()
1921        .await
1922        .ok()?;
1923    if !resp.status().is_success() {
1924        return None;
1925    }
1926    let json: serde_json::Value = resp.json().await.ok()?;
1927    if json["proofofwork"]["required"].as_bool() == Some(true) {
1928        return None;
1929    }
1930    json["token"].as_str().map(ToOwned::to_owned)
1931}
1932
1933
1934/// Returns true if the model lacks support for extended thinking / effort.
1935/// These params must be stripped before forwarding.
1936fn is_simple_model(model: &str) -> bool {
1937    model.contains("haiku")
1938}
1939
1940/// Resolve the target model name for a non-Anthropic account.
1941///
1942/// Priority: per-account `model` pin → global `model_mapping` → provider `default_model()`.
1943/// If the provider is `Local` (default_model = ""), the incoming model name is passed through.
1944fn resolve_model(
1945    incoming: &str,
1946    account: &crate::config::AccountConfig,
1947    mapping: &std::collections::HashMap<String, String>,
1948) -> String {
1949    // 1. Per-account pin (highest priority).
1950    if let Some(m) = &account.model {
1951        return m.clone();
1952    }
1953    // 2. Global mapping for this specific incoming model name.
1954    if let Some(m) = mapping.get(incoming) {
1955        return m.clone();
1956    }
1957    // 3. Provider default.
1958    let default = account.provider.default_model();
1959    if !default.is_empty() {
1960        return default.to_owned();
1961    }
1962    // 4. Pass through (Local provider — model name is server-defined).
1963    incoming.to_owned()
1964}
1965