Skip to main content

shunt/
proxy.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::forwarder::Forwarder;
16use crate::oauth::OAuthCredential;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24    config: Arc<Config>,
25    forwarder: Arc<Forwarder>,
26    state: StateStore,
27    /// Live credentials — can be refreshed at runtime without restarting.
28    credentials: Arc<RwLock<HashMap<String, OAuthCredential>>>,
29    /// Epoch-ms when this proxy instance started.
30    started_ms: u64,
31    /// If set, /v1/chat/completions requests are translated and forwarded here
32    /// (the Anthropic proxy base URL, e.g. "http://127.0.0.1:8082").
33    anthropic_base_url: Option<String>,
34}
35
36pub fn create_app(config: Config) -> anyhow::Result<Router> {
37    let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
38    Ok(app)
39}
40
41/// Shared live credentials map — can be written to without restarting the proxy.
42pub type LiveCredentials = Arc<RwLock<HashMap<String, OAuthCredential>>>;
43
44pub fn create_app_with_state(
45    config: Config,
46    state: StateStore,
47    anthropic_base_url: Option<String>,
48) -> anyhow::Result<(Router, LiveCredentials)> {
49    let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
50
51    // Accounts with no credential are shown in status but skipped during routing.
52    // Mark them disabled immediately so the router ignores them.
53    for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
54        state.set_auth_failed(&a.name);
55    }
56
57    let credentials: LiveCredentials = Arc::new(RwLock::new(
58        config.accounts.iter()
59            .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
60            .collect::<HashMap<_, _>>(),
61    ));
62
63    let app_state = AppState {
64        config: Arc::new(config),
65        forwarder: Arc::new(forwarder),
66        state,
67        credentials: Arc::clone(&credentials),
68        started_ms: now_ms(),
69        anthropic_base_url,
70    };
71
72    // Register proxy routes appropriate for the provider.
73    // Anthropic: explicit paths only (maintains existing behaviour).
74    // OpenAI/others: wildcard catches all paths; also expose OpenAI-compat
75    //   endpoints that translate to Claude when anthropic_base_url is set.
76    let provider = app_state.config.accounts.first()
77        .map(|a| &a.provider)
78        .cloned()
79        .unwrap_or_default();
80
81    let proxy_routes = match provider {
82        Provider::Anthropic => Router::new()
83            .route("/v1/messages", post(proxy_handler))
84            .route("/v1/messages/count_tokens", post(proxy_handler)),
85        Provider::OpenAI => Router::new()
86            .route("/v1/chat/completions", post(openai_compat_handler))
87            .route("/v1/models", get(openai_models_handler))
88            .fallback(proxy_handler),
89    };
90
91    let app = Router::new()
92        .route("/health", get(health))
93        .route("/status", get(status_handler))
94        .route("/use", post(use_handler))
95        .merge(proxy_routes)
96        .with_state(app_state);
97
98    Ok((app, credentials))
99}
100
101async fn health() -> impl IntoResponse {
102    axum::Json(json!({"status": "ok"}))
103}
104
105async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
106    let account_states = s.state.account_states();
107    let quotas = s.state.quota_snapshot();
108    let rate_limits = s.state.rate_limit_snapshot();
109
110    let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
111        let st = account_states.get(&a.name);
112        let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
113            "reauth_required"
114        } else if st.map(|s| s.disabled).unwrap_or(false) {
115            "disabled"
116        } else if s.state.is_available(&a.name) {
117            "available"
118        } else {
119            "cooling"
120        };
121
122        let quota = quotas.get(&a.name);
123        let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
124        let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
125        let tokens_used = quota.map(|q| json!({
126            "input": q.input_tokens,
127            "output": q.output_tokens,
128            "total": q.total_tokens(),
129        }));
130
131        let rl = rate_limits.get(&a.name);
132        let rate_limit = rl.map(|r| json!({
133            "utilization_5h": r.utilization_5h,
134            "reset_5h": r.reset_5h,
135            "status_5h": r.status_5h,
136            "utilization_7d": r.utilization_7d,
137            "reset_7d": r.reset_7d,
138            "status_7d": r.status_7d,
139            "representative_claim": r.representative_claim,
140            "updated_ms": r.updated_ms,
141        }));
142
143        let acc_state = account_states.get(&a.name);
144        let email = a.credential.as_ref().and_then(|c| c.email.as_deref()).map(|e| e.to_owned());
145        let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
146        let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
147        let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
148        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
149        let reset_5h = rl.and_then(|r| r.reset_5h);
150        let total_tokens = quota.map(|q| q.total_tokens()).unwrap_or(0);
151        let available = s.state.is_available(&a.name);
152
153        json!({
154            "name": a.name,
155            "email": email,
156            "plan_type": a.plan_type,
157            "status": avail_status,
158            "available": available,
159            "disabled": disabled,
160            "auth_failed": auth_failed,
161            "cooldown_until_ms": cooldown_until_ms,
162            "utilization_5h": utilization_5h,
163            "reset_5h": reset_5h,
164            "total_tokens": total_tokens,
165            "window_expires_ms": window_expires_ms,
166            "tokens_used": tokens_used,
167            "rate_limit": rate_limit,
168        })
169    }).collect();
170
171    let recent_requests = s.state.recent_requests_snapshot();
172    let savings = s.state.savings_snapshot();
173
174    axum::Json(json!({
175        "version": env!("CARGO_PKG_VERSION"),
176        "started_ms": s.started_ms,
177        "accounts": accounts,
178        "pinned_account": s.state.get_pinned(),
179        "last_used_account": s.state.get_last_used(),
180        "recent_requests": recent_requests,
181        "savings": savings,
182    }))
183}
184
185async fn use_handler(
186    State(s): State<AppState>,
187    axum::Json(body): axum::Json<serde_json::Value>,
188) -> impl IntoResponse {
189    let account = body["account"].as_str().map(|s| s.to_owned());
190    // Validate the account name exists (unless clearing to auto)
191    if let Some(ref name) = account {
192        if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
193            return axum::Json(json!({
194                "error": format!("unknown account '{name}'")
195            }));
196        }
197        let pinned = if name == "auto" { None } else { Some(name.clone()) };
198        s.state.set_pinned(pinned);
199        axum::Json(json!({ "pinned": name }))
200    } else {
201        s.state.set_pinned(None);
202        axum::Json(json!({ "pinned": null }))
203    }
204}
205
206fn now_ms() -> u64 {
207    use std::time::{SystemTime, UNIX_EPOCH};
208    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
209}
210
211async fn proxy_handler(
212    State(s): State<AppState>,
213    req: Request,
214) -> Result<Response, ProxyError> {
215    // Remote auth: if a remote_key is configured, the client must supply it as x-api-key.
216    if let Some(ref expected) = s.config.server.remote_key {
217        let provided = req.headers()
218            .get("x-api-key")
219            .and_then(|v| v.to_str().ok())
220            .unwrap_or("");
221        if provided != expected {
222            return Err(ProxyError::Unauthorized);
223        }
224    }
225
226    let method = req.method().as_str().to_owned();
227    let path = req.uri().path().to_owned();
228    let headers = req.headers().clone();
229
230    let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
231        .await
232        .map_err(|_| ProxyError::BodyRead)?;
233
234    let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
235        .ok()
236        .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
237        .unwrap_or_default();
238    let req_start_ms = now_ms();
239
240    let fp = router::fingerprint(&body_bytes);
241    let fp_ref = fp.as_deref();
242
243    let mut tried: HashSet<String> = HashSet::new();
244    // Track accounts we've already attempted a token refresh for this request.
245    let mut refreshed: HashSet<String> = HashSet::new();
246
247    loop {
248        let account = match router::pick_account(
249            &s.config.accounts, &s.state, fp_ref, &tried,
250            s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
251        ) {
252            Some(a) => a,
253            None => return Err(ProxyError::AllAccountsUnavailable),
254        };
255
256        let account_name = account.name.clone();
257
258        // Use the live (possibly refreshed) token rather than the one baked into config.
259        // For OpenAI/chatgpt.com accounts, use the id_token (short-lived OIDC JWT) as
260        // the bearer — chatgpt.com's API authenticates via id_token, not access_token.
261        let token = {
262            let creds = s.credentials.read().await;
263            let cred = creds.get(&account_name)
264                .cloned()
265                .or_else(|| account.credential.clone());
266            match cred {
267                Some(c) if account.provider == crate::provider::Provider::OpenAI => {
268                    c.id_token.unwrap_or(c.access_token)
269                }
270                Some(c) => c.access_token,
271                None => String::new(),
272            }
273        };
274
275        let response = s.forwarder
276            .forward(&method, &path, body_bytes.clone(), &headers, account, &token)
277            .await
278            .map_err(|e| {
279                error!("Forward error: {:#}", e);
280                ProxyError::Upstream
281            })?;
282
283        match response.status().as_u16() {
284            200..=299 => {
285                s.state.set_last_used(&account_name);
286                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
287                    s.state.update_rate_limits(&account_name, info);
288                }
289                return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
290            }
291            429 => {
292                warn!(account = %account_name, "429 rate-limited — cooling 60s");
293                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
294                    s.state.update_rate_limits(&account_name, info);
295                }
296                s.state.set_cooldown(&account_name, 60_000);
297                tried.insert(account_name);
298            }
299            529 => {
300                warn!(account = %account_name, "529 overloaded — cooling 30s");
301                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
302                    s.state.update_rate_limits(&account_name, info);
303                }
304                s.state.set_cooldown(&account_name, 30_000);
305                tried.insert(account_name);
306            }
307            401 => {
308                if !refreshed.contains(&account_name) {
309                    // Access token invalidated (e.g. user logged out) — try refresh.
310                    let cred = {
311                        let creds = s.credentials.read().await;
312                        creds.get(&account_name).cloned()
313                            .or_else(|| account.credential.clone())
314                    };
315                    let Some(cred) = cred else {
316                        tried.insert(account_name);
317                        continue;
318                    };
319                    match tokio::time::timeout(
320                        std::time::Duration::from_secs(10),
321                        account.provider.refresh_token(&cred),
322                    ).await {
323                        Ok(Ok(fresh)) => {
324                            warn!(account = %account_name, "401 — token refreshed, retrying");
325                            {
326                                let mut creds = s.credentials.write().await;
327                                creds.insert(account_name.clone(), fresh.clone());
328                            }
329                            // Persist to disk so the refreshed token survives a restart.
330                            let name = account_name.clone();
331                            let fresh = fresh.clone();
332                            tokio::task::spawn_blocking(move || {
333                                let mut store = CredentialsStore::load();
334                                store.accounts.insert(name, fresh.clone());
335                                store.save().ok();
336                                if fresh.id_token.is_some() {
337                                    crate::oauth::write_codex_auth_file(&fresh);
338                                }
339                            });
340                            // Mark as refreshed but don't add to tried — retry this account.
341                            refreshed.insert(account_name);
342                        }
343                        _ => {
344                            // Refresh failed/timed out — cool down, don't permanently disable.
345                            error!(account = %account_name, "401 — token refresh failed, cooling 5min");
346                            s.state.set_cooldown(&account_name, 5 * 60_000);
347                            tried.insert(account_name);
348                        }
349                    }
350                } else {
351                    // Already refreshed once and still 401 — cool down this account.
352                    error!(account = %account_name, "401 after refresh — cooling 5min");
353                    s.state.set_cooldown(&account_name, 5 * 60_000);
354                    tried.insert(account_name);
355                }
356            }
357            403 => {
358                // Forbidden — subscription lapsed or org restriction; refreshing won't help.
359                error!(account = %account_name, "403 forbidden — cooling 30min");
360                s.state.set_cooldown(&account_name, 30 * 60_000);
361                tried.insert(account_name);
362            }
363            _ => {
364                // 400, 404, 500, etc. — return as-is, no retry
365                return Ok(response);
366            }
367        }
368    }
369}
370
371// ---------------------------------------------------------------------------
372// Usage extraction
373// ---------------------------------------------------------------------------
374
375/// Intercept a successful response to record token usage, then pass it through.
376///
377/// - Streaming: wraps the body stream with an SSE scanner (zero latency).
378/// - Non-streaming: buffers the body, parses usage, rebuilds the response.
379async fn tap_usage(
380    resp: Response,
381    state: &StateStore,
382    account: &str,
383    model: &str,
384    req_start_ms: u64,
385) -> Response {
386    use axum::body::Body;
387    use crate::state::RequestLog;
388
389    if quota::is_streaming_response(&resp) {
390        let state = state.clone();
391        let account = account.to_owned();
392        let model = model.to_owned();
393        let on_complete = Arc::new(move |input: u64, output: u64| {
394            state.record_usage(&account, input, output);
395            state.record_global(&model, input, output);
396            state.record_request(RequestLog {
397                ts_ms: req_start_ms,
398                account: account.clone(),
399                model: model.clone(),
400                status: 200,
401                input_tokens: input,
402                output_tokens: output,
403                duration_ms: now_ms().saturating_sub(req_start_ms),
404            });
405        });
406        let (parts, body) = resp.into_parts();
407        let wrapped = quota::wrap_streaming_body(body, on_complete);
408        return Response::from_parts(parts, wrapped);
409    }
410
411    // Non-streaming: buffer, extract, rebuild
412    let (parts, body) = resp.into_parts();
413    let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
414        Ok(b) => b,
415        Err(_) => return Response::from_parts(parts, Body::empty()),
416    };
417    let (input, output) = quota::extract_usage_from_json(&bytes);
418    state.record_usage(account, input, output);
419    state.record_global(model, input, output);
420    state.record_request(RequestLog {
421        ts_ms: req_start_ms,
422        account: account.to_owned(),
423        model: model.to_owned(),
424        status: 200,
425        input_tokens: input,
426        output_tokens: output,
427        duration_ms: now_ms().saturating_sub(req_start_ms),
428    });
429    Response::from_parts(parts, Body::from(bytes))
430}
431
432
433// ---------------------------------------------------------------------------
434// Rate limit prefetch
435// ---------------------------------------------------------------------------
436
437/// For any account with no rate-limit data yet, make a cheap request directly
438/// to the upstream API so we populate metrics without waiting for a real user
439/// request. Runs as a background task after startup.
440pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore) {
441    let client = reqwest::Client::builder()
442        .timeout(std::time::Duration::from_secs(20))
443        .build()
444        .unwrap_or_default();
445
446    for account in &config.accounts {
447        // Skip if we already have data for this account.
448        let rl = state.rate_limit_snapshot();
449        if let Some(r) = rl.get(&account.name) {
450            if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
451                continue;
452            }
453        }
454
455        // Skip accounts with no credentials or no prefetch support.
456        let creds = match account.credential.clone() {
457            Some(c) => c,
458            None => continue,
459        };
460
461        let Some((path, body)) = account.provider.prefetch_request() else {
462            // No POST prefetch for this provider — do a lightweight GET auth check instead.
463            if let Some(probe_path) = account.provider.auth_probe_get_path() {
464                auth_probe_get(&client, probe_path, account, &state).await;
465            }
466            continue;
467        };
468        let url = format!("{}{}", config.server.upstream_url, path);
469
470        let resp = prefetch_send(&client, &url, &account.provider, &creds.access_token, &body).await;
471
472        let r = match resp {
473            Ok(r) => r,
474            Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
475        };
476
477        if r.status() == reqwest::StatusCode::UNAUTHORIZED {
478            tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
479            let fresh = match account.provider.refresh_token(&creds).await {
480                Ok(f) => f,
481                Err(e) => {
482                    tracing::warn!(account = %account.name, "token refresh failed: {e}");
483                    state.set_auth_failed(&account.name);
484                    continue;
485                }
486            };
487            let mut store = crate::config::CredentialsStore::load();
488            store.accounts.insert(account.name.clone(), fresh.clone());
489            store.save().ok();
490            if fresh.id_token.is_some() {
491                crate::oauth::write_codex_auth_file(&fresh);
492            }
493
494            match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
495                Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
496                    tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
497                    state.set_auth_failed(&account.name);
498                }
499                Ok(r2) => {
500                    if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
501                        state.update_rate_limits(&account.name, info);
502                    }
503                }
504                Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
505            }
506        } else {
507            tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
508            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
509                state.update_rate_limits(&account.name, info);
510            }
511        }
512    }
513}
514
515/// Build and send a prefetch request for the given provider + token.
516async fn prefetch_send(
517    client: &reqwest::Client,
518    url: &str,
519    provider: &crate::provider::Provider,
520    token: &str,
521    body: &serde_json::Value,
522) -> anyhow::Result<reqwest::Response> {
523    let mut headers = reqwest::header::HeaderMap::new();
524    provider.inject_auth_headers(&mut headers, token)?;
525    for (name, value) in provider.prefetch_extra_headers() {
526        headers.insert(
527            reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
528            reqwest::header::HeaderValue::from_static(value),
529        );
530    }
531    Ok(client.post(url).headers(headers).json(body).send().await?)
532}
533
534/// GET a cheap endpoint to verify credentials are still valid for providers that
535/// don't expose rate-limit headers (e.g. OpenAI). On 401, attempts a token refresh;
536/// marks the account as `reauth_required` if the refresh also fails.
537async fn auth_probe_get(
538    client: &reqwest::Client,
539    path: &str,
540    account: &crate::config::AccountConfig,
541    state: &StateStore,
542) {
543    let creds = match account.credential.clone() {
544        Some(c) => c,
545        None => return,
546    };
547    let upstream = match account.provider {
548        crate::provider::Provider::OpenAI => "https://chatgpt.com",
549        crate::provider::Provider::Anthropic => "https://api.anthropic.com",
550    };
551    let url = format!("{}{}", upstream, path);
552
553    let do_get = |token: &str| -> reqwest::RequestBuilder {
554        let mut headers = reqwest::header::HeaderMap::new();
555        let _ = account.provider.inject_auth_headers(&mut headers, token);
556        client.get(&url).headers(headers)
557    };
558
559    // Use id_token for chatgpt.com (same as the proxy handler).
560    let probe_token = creds.id_token.as_deref().unwrap_or(&creds.access_token);
561    let resp = match do_get(probe_token).send().await {
562        Ok(r) => r,
563        Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
564    };
565
566    if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
567        tracing::info!(account = %account.name, "auth probe: access token rejected, refreshing");
568        let fresh = match account.provider.refresh_token(&creds).await {
569            Ok(f) => f,
570            Err(e) => {
571                tracing::warn!(account = %account.name, "token refresh failed: {e}");
572                state.set_auth_failed(&account.name);
573                return;
574            }
575        };
576        let mut store = crate::config::CredentialsStore::load();
577        store.accounts.insert(account.name.clone(), fresh.clone());
578        store.save().ok();
579        if fresh.id_token.is_some() {
580            crate::oauth::write_codex_auth_file(&fresh);
581        }
582
583        let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
584        match do_get(fresh_token).send().await {
585            Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
586                tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
587                state.set_auth_failed(&account.name);
588            }
589            Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
590            Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
591        }
592    } else {
593        tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
594        // Access token is valid. Do NOT refresh here — rotating the refresh_token races
595        // with codex CLI, which also tries to refresh at startup using the same token.
596        // Proactive refreshing is handled solely by openai_token_refresh_loop.
597    }
598}
599
600// ---------------------------------------------------------------------------
601// Proactive OpenAI token refresh loop
602// ---------------------------------------------------------------------------
603
604/// Returns true if the id_token inside `cred` has fewer than `threshold_mins`
605/// minutes remaining, or if there is no id_token / it cannot be parsed.
606fn id_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
607    let Some(ref id_tok) = cred.id_token else { return true };
608    let Some(exp_ms) = crate::oauth::jwt_exp_ms(id_tok) else { return true };
609    let now_ms = std::time::SystemTime::now()
610        .duration_since(std::time::UNIX_EPOCH)
611        .unwrap_or_default()
612        .as_millis() as u64;
613    exp_ms < now_ms + threshold_mins * 60 * 1_000
614}
615
616/// Sync live_creds from auth.json if auth.json has a newer token.
617///
618/// Codex CLI refreshes its own token and writes auth.json. Before we refresh,
619/// we pull that in so we don't use a stale refresh_token that codex already rotated.
620async fn sync_live_creds_from_auth_json(
621    account_name: &str,
622    live_creds: &LiveCredentials,
623) {
624    let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
625    let current_exp = live_creds.read().await
626        .get(account_name)
627        .map(|c| c.expires_at)
628        .unwrap_or(0);
629    if from_file.expires_at > current_exp {
630        tracing::info!(account = %account_name, "synced fresher token from auth.json");
631        live_creds.write().await.insert(account_name.to_owned(), from_file);
632    }
633}
634
635/// Perform a single proactive refresh for one account and persist the result.
636async fn do_proactive_refresh(
637    account: &crate::config::AccountConfig,
638    creds: &crate::oauth::OAuthCredential,
639    live_creds: &LiveCredentials,
640    state: &StateStore,
641) {
642    tracing::info!(account = %account.name, "proactive OpenAI token refresh");
643    match account.provider.refresh_token(creds).await {
644        Ok(fresh) => {
645            tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
646            {
647                let mut map = live_creds.write().await;
648                map.insert(account.name.clone(), fresh.clone());
649            }
650            let mut store = crate::config::CredentialsStore::load();
651            store.accounts.insert(account.name.clone(), fresh.clone());
652            store.save().ok();
653            if fresh.id_token.is_some() {
654                crate::oauth::write_codex_auth_file(&fresh);
655            }
656            state.clear_auth_failed(&account.name);
657        }
658        Err(e) => {
659            tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
660            state.set_auth_failed(&account.name);
661        }
662    }
663}
664
665/// Compute how many seconds to sleep until 15 minutes before the soonest
666/// OpenAI id_token expiry. Capped at 45 minutes; at least 60 seconds.
667async fn secs_until_next_refresh(config: &Config, live_creds: &LiveCredentials) -> u64 {
668    let now_ms = std::time::SystemTime::now()
669        .duration_since(std::time::UNIX_EPOCH)
670        .unwrap_or_default()
671        .as_millis() as u64;
672    const WAKE_BEFORE_MS: u64 = 15 * 60 * 1_000; // wake 15 min before expiry
673    const MAX_SLEEP_SECS: u64 = 45 * 60;
674    const MIN_SLEEP_SECS: u64 = 60;
675
676    let mut min_sleep_secs = MAX_SLEEP_SECS;
677
678    for account in config.accounts.iter()
679        .filter(|a| a.provider == crate::provider::Provider::OpenAI)
680    {
681        let creds = live_creds.read().await.get(&account.name).cloned();
682        let Some(creds) = creds else { continue };
683        let Some(ref id_tok) = creds.id_token else { continue };
684        let Some(exp_ms) = crate::oauth::jwt_exp_ms(id_tok) else { continue };
685
686        // How many ms until we want to wake (15 min before expiry)?
687        let wake_ms = exp_ms.saturating_sub(WAKE_BEFORE_MS);
688        let sleep_ms = wake_ms.saturating_sub(now_ms);
689        let sleep_secs = (sleep_ms / 1_000).clamp(MIN_SLEEP_SECS, MAX_SLEEP_SECS);
690        min_sleep_secs = min_sleep_secs.min(sleep_secs);
691    }
692
693    min_sleep_secs
694}
695
696/// Keeps OpenAI `id_token`s perpetually fresh.
697///
698/// Strategy:
699/// - At startup: only refresh if the id_token is already expired (< 2 min left).
700///   If the token is fresh, we skip — codex CLI does its own startup refresh and
701///   rotating the token from two places simultaneously causes "invalid_grant".
702/// - Dynamically scheduled: sleep until 15 min before the id_token expires,
703///   then re-sync from auth.json (codex may have refreshed it) and refresh.
704///
705/// This avoids racing with codex CLI's startup refresh while still ensuring
706/// the token is always fresh when codex needs it.
707pub async fn openai_token_refresh_loop(
708    config: Arc<Config>,
709    state: StateStore,
710    live_creds: LiveCredentials,
711) {
712    // Startup pass: only refresh if token is already expired.
713    for account in config.accounts.iter()
714        .filter(|a| a.provider == crate::provider::Provider::OpenAI)
715    {
716        if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
717            continue;
718        }
719        let creds = {
720            let map = live_creds.read().await;
721            map.get(&account.name).cloned().or_else(|| account.credential.clone())
722        };
723        if let Some(creds) = creds {
724            if id_token_expires_soon(&creds, 2) {
725                // Token already expired (or expiring within 2 min) at startup — refresh now.
726                do_proactive_refresh(account, &creds, &live_creds, &state).await;
727            } else {
728                tracing::info!(account = %account.name, "id_token fresh at startup — skipping immediate refresh");
729            }
730        }
731    }
732
733    loop {
734        // Sleep until ~15 min before the soonest id_token expiry (max 45 min).
735        // This ensures we always refresh before codex CLI needs to, regardless
736        // of when shunt started relative to the last token refresh.
737        let sleep_secs = secs_until_next_refresh(&config, &live_creds).await;
738        tracing::debug!("next OpenAI token refresh check in {}m {}s",
739            sleep_secs / 60, sleep_secs % 60);
740        tokio::time::sleep(std::time::Duration::from_secs(sleep_secs)).await;
741
742        for account in config.accounts.iter()
743            .filter(|a| a.provider == crate::provider::Provider::OpenAI)
744        {
745            if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
746                continue;
747            }
748
749            // Pull in any token rotation codex CLI has done since our last refresh.
750            sync_live_creds_from_auth_json(&account.name, &live_creds).await;
751
752            let creds = {
753                let map = live_creds.read().await;
754                map.get(&account.name).cloned().or_else(|| account.credential.clone())
755            };
756            let Some(creds) = creds else { continue };
757
758            // Only refresh if id_token has < 15 minutes left.
759            if !id_token_expires_soon(&creds, 15) {
760                tracing::debug!(account = %account.name, "id_token still fresh, skipping refresh");
761                continue;
762            }
763
764            do_proactive_refresh(account, &creds, &live_creds, &state).await;
765        }
766    }
767}
768
769// ---------------------------------------------------------------------------
770// Error type
771// ---------------------------------------------------------------------------
772
773enum ProxyError {
774    BodyRead,
775    Upstream,
776    AllAccountsUnavailable,
777    Unauthorized,
778}
779
780impl IntoResponse for ProxyError {
781    fn into_response(self) -> Response {
782        let (status, msg) = match self {
783            ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
784            ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
785            ProxyError::AllAccountsUnavailable => {
786                (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
787            }
788            ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
789        };
790
791        (status, axum::Json(json!({
792            "type": "error",
793            "error": {"type": "api_error", "message": msg}
794        }))).into_response()
795    }
796}
797
798// ---------------------------------------------------------------------------
799// Recovery watcher — periodically retries token refresh for auth_failed accounts
800// ---------------------------------------------------------------------------
801
802/// Runs as a background task. Every 2 minutes, tries to refresh tokens for any
803/// auth_failed account. If refresh succeeds the account is brought back online
804/// without a process restart. If all accounts remain unrecoverable, fires a
805/// macOS notification (at most once per hour).
806pub async fn recovery_watcher(
807    config: Arc<Config>,
808    state: StateStore,
809    credentials: LiveCredentials,
810) {
811    use std::time::{Duration, Instant};
812    const CHECK_INTERVAL: Duration = Duration::from_secs(120);
813    const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
814
815    let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
816    let mut last_notified: Option<Instant> = None;
817
818    loop {
819        tokio::time::sleep(CHECK_INTERVAL).await;
820
821        let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
822        let failed = state.auth_failed_accounts(&name_refs);
823        if failed.is_empty() {
824            last_notified = None;
825            continue;
826        }
827
828        tracing::warn!(
829            accounts = ?failed,
830            "recovery: {} account(s) auth_failed, attempting token refresh",
831            failed.len()
832        );
833
834        let mut any_recovered = false;
835
836        for name in &failed {
837            let cred = {
838                let map = credentials.read().await;
839                map.get(*name).cloned()
840            };
841            let Some(cred) = cred else { continue };
842            if cred.refresh_token.is_empty() { continue; }
843
844            let provider = config.accounts.iter()
845                .find(|a| a.name == *name)
846                .map(|a| a.provider.clone())
847                .unwrap_or_default();
848
849            let result = tokio::time::timeout(
850                Duration::from_secs(20),
851                provider.refresh_token(&cred),
852            ).await;
853
854            match result {
855                Ok(Ok(fresh)) => {
856                    tracing::info!(account = %name, "recovery: token refreshed — account back online");
857                    {
858                        let mut map = credentials.write().await;
859                        map.insert(name.to_string(), fresh.clone());
860                    }
861                    let name_owned = name.to_string();
862                    let fresh_owned = fresh.clone();
863                    tokio::task::spawn_blocking(move || {
864                        let mut store = crate::config::CredentialsStore::load();
865                        store.accounts.insert(name_owned, fresh_owned.clone());
866                        store.save().ok();
867                        if fresh_owned.id_token.is_some() {
868                            crate::oauth::write_codex_auth_file(&fresh_owned);
869                        }
870                    });
871                    state.clear_auth_failed(name);
872                    any_recovered = true;
873                }
874                Ok(Err(e)) => {
875                    tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
876                }
877                Err(_) => {
878                    tracing::error!(account = %name, "recovery: token refresh timed out");
879                }
880            }
881        }
882
883        if any_recovered {
884            tracing::info!("recovery: at least one account is back online");
885            continue;
886        }
887
888        // All accounts still auth_failed after refresh attempts — notify.
889        let still_failed = state.auth_failed_accounts(&name_refs);
890        if still_failed.len() == account_names.len() {
891            let should_notify = last_notified
892                .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
893                .unwrap_or(true);
894            if should_notify {
895                error!(
896                    "ALL accounts are offline (auth failed). \
897                     Run `shunt add-account` to re-authorize."
898                );
899                notify_all_accounts_offline();
900                last_notified = Some(Instant::now());
901            }
902        }
903    }
904}
905
906fn notify_all_accounts_offline() {
907    #[cfg(target_os = "macos")]
908    {
909        let _ = std::process::Command::new("osascript")
910            .args(["-e", concat!(
911                r#"display notification "#,
912                r#""All accounts have lost authentication. Run `shunt add-account` to re-authorize." "#,
913                r#"with title "shunt: All Accounts Offline" sound name "Basso""#
914            )])
915            .status();
916    }
917}
918
919// ---------------------------------------------------------------------------
920// OpenAI-compatible API (translates to Anthropic Claude)
921// ---------------------------------------------------------------------------
922//
923// When the OpenAI proxy receives a request at /v1/chat/completions, if an
924// anthropic_base_url is configured, it translates the request to Anthropic
925// Messages format and forwards it to the Anthropic proxy (which handles
926// account selection, token management, and rate limiting).
927// The response is translated back to OpenAI Chat Completions format.
928
929/// Map OpenAI model names → Claude model names.
930fn map_model(openai_model: &str) -> &'static str {
931    match openai_model {
932        m if m.starts_with("claude-") => {
933            // Already a Claude model name — but we need a &'static str, so match known ones
934            // or fall through to default
935            if m.contains("opus")   { "claude-opus-4-6" }
936            else if m.contains("haiku") { "claude-haiku-4-5-20251001" }
937            else                    { "claude-sonnet-4-6" }
938        }
939        "gpt-4o" | "gpt-4.5" | "o1" | "o1-pro" | "o3" | "o3-pro" | "gpt-5" | "gpt-5.5" => {
940            "claude-opus-4-6"
941        }
942        "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" | "o1-mini" | "o3-mini" => {
943            "claude-haiku-4-5-20251001"
944        }
945        _ => "claude-sonnet-4-6",
946    }
947}
948
949/// Translate an OpenAI Chat Completions request body to an Anthropic Messages body.
950fn translate_to_anthropic(body: serde_json::Value) -> serde_json::Value {
951    let model = body["model"].as_str().unwrap_or("gpt-4o");
952    let claude_model = map_model(model).to_owned();
953
954    // Extract system message from messages array.
955    let mut system: Option<String> = None;
956    let mut messages = Vec::new();
957    if let Some(arr) = body["messages"].as_array() {
958        for msg in arr {
959            let role = msg["role"].as_str().unwrap_or("");
960            let content = msg["content"].as_str().unwrap_or("").to_owned();
961            if role == "system" {
962                system = Some(content);
963            } else {
964                messages.push(json!({ "role": role, "content": content }));
965            }
966        }
967    }
968
969    let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
970    let stream = body["stream"].as_bool().unwrap_or(false);
971
972    let mut req = json!({
973        "model": claude_model,
974        "messages": messages,
975        "max_tokens": max_tokens,
976        "stream": stream,
977    });
978
979    if let Some(sys) = system {
980        req["system"] = json!(sys);
981    }
982    if let Some(temp) = body.get("temperature") {
983        req["temperature"] = temp.clone();
984    }
985    if let Some(sp) = body.get("stop") {
986        req["stop_sequences"] = sp.clone();
987    }
988
989    req
990}
991
992/// Translate a complete (non-streaming) Anthropic Messages response to OpenAI format.
993fn translate_from_anthropic(body: serde_json::Value) -> serde_json::Value {
994    let id = format!("chatcmpl-{}", &uuid_v4()[..8]);
995    let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6").to_owned();
996    let content = body["content"]
997        .as_array()
998        .and_then(|arr| arr.iter().find_map(|b| b["text"].as_str()))
999        .unwrap_or("")
1000        .to_owned();
1001    let stop_reason = body["stop_reason"].as_str().unwrap_or("end_turn");
1002    let finish_reason = if stop_reason == "end_turn" { "stop" } else { stop_reason };
1003    let input_tokens = body["usage"]["input_tokens"].as_u64().unwrap_or(0);
1004    let output_tokens = body["usage"]["output_tokens"].as_u64().unwrap_or(0);
1005
1006    json!({
1007        "id": id,
1008        "object": "chat.completion",
1009        "model": model,
1010        "choices": [{
1011            "index": 0,
1012            "message": { "role": "assistant", "content": content },
1013            "finish_reason": finish_reason,
1014        }],
1015        "usage": {
1016            "prompt_tokens": input_tokens,
1017            "completion_tokens": output_tokens,
1018            "total_tokens": input_tokens + output_tokens,
1019        }
1020    })
1021}
1022
1023fn uuid_v4() -> String {
1024    use crate::oauth::rand_bytes;
1025    let b: [u8; 16] = rand_bytes();
1026    format!("{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
1027        u32::from_be_bytes(b[0..4].try_into().unwrap()),
1028        u16::from_be_bytes(b[4..6].try_into().unwrap()),
1029        u16::from_be_bytes(b[6..8].try_into().unwrap()),
1030        u16::from_be_bytes(b[8..10].try_into().unwrap()),
1031        {
1032            let mut v = 0u64;
1033            for &x in &b[10..16] { v = (v << 8) | x as u64; }
1034            v
1035        }
1036    )
1037}
1038
1039/// GET /v1/models — return Claude models in OpenAI format.
1040async fn openai_models_handler() -> impl IntoResponse {
1041    axum::Json(json!({
1042        "object": "list",
1043        "data": [
1044            { "id": "claude-opus-4-6",           "object": "model", "owned_by": "anthropic" },
1045            { "id": "claude-sonnet-4-6",          "object": "model", "owned_by": "anthropic" },
1046            { "id": "claude-haiku-4-5-20251001",  "object": "model", "owned_by": "anthropic" },
1047        ]
1048    }))
1049}
1050
1051/// POST /v1/chat/completions — translate OpenAI request to Anthropic, proxy through Claude pool.
1052async fn openai_compat_handler(
1053    State(s): State<AppState>,
1054    req: Request,
1055) -> Result<Response, ProxyError> {
1056    let Some(ref anthropic_url) = s.anthropic_base_url else {
1057        // No Anthropic proxy configured — fall back to normal forwarding
1058        return proxy_handler(State(s), req).await;
1059    };
1060
1061    let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1062        .await
1063        .map_err(|_| ProxyError::BodyRead)?;
1064
1065    let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1066        .unwrap_or(json!({}));
1067
1068    let stream = openai_body["stream"].as_bool().unwrap_or(false);
1069    let anthropic_body = translate_to_anthropic(openai_body);
1070
1071    let client = reqwest::Client::builder()
1072        .timeout(std::time::Duration::from_secs(300))
1073        .build()
1074        .map_err(|_| ProxyError::Upstream)?;
1075
1076    let resp = client
1077        .post(format!("{anthropic_url}/v1/messages"))
1078        .header("content-type", "application/json")
1079        .header("anthropic-version", "2023-06-01")
1080        .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1081        .header("x-shunt-compat", "openai")
1082        .json(&anthropic_body)
1083        .send()
1084        .await
1085        .map_err(|_| ProxyError::Upstream)?;
1086
1087    if !resp.status().is_success() {
1088        let status = resp.status();
1089        let body = resp.text().await.unwrap_or_default();
1090        let code = status.as_u16();
1091        return Ok(axum::response::Response::builder()
1092            .status(code)
1093            .header("content-type", "application/json")
1094            .body(axum::body::Body::from(body))
1095            .unwrap());
1096    }
1097
1098    if stream {
1099        // Translate Anthropic SSE stream → OpenAI SSE stream
1100        let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1101        let stream = translate_anthropic_stream(resp, chat_id);
1102        Ok(axum::response::Response::builder()
1103            .status(200)
1104            .header("content-type", "text/event-stream")
1105            .header("cache-control", "no-cache")
1106            .body(axum::body::Body::from_stream(stream))
1107            .unwrap())
1108    } else {
1109        let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1110        let openai_resp = translate_from_anthropic(anthropic_resp);
1111        Ok(axum::Json(openai_resp).into_response())
1112    }
1113}
1114
1115/// Translate Anthropic SSE events to OpenAI SSE format, yielding raw bytes.
1116fn translate_anthropic_stream(
1117    resp: reqwest::Response,
1118    chat_id: String,
1119) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1120    use futures_util::StreamExt;
1121
1122    let id = chat_id;
1123    let byte_stream = resp.bytes_stream();
1124
1125    async_stream::stream! {
1126        let mut buf = String::new();
1127        futures_util::pin_mut!(byte_stream);
1128
1129        // Send initial role chunk
1130        let init = format!(
1131            "data: {}\n\n",
1132            serde_json::to_string(&json!({
1133                "id": id,
1134                "object": "chat.completion.chunk",
1135                "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1136            })).unwrap()
1137        );
1138        yield Ok(bytes::Bytes::from(init));
1139
1140        while let Some(chunk) = byte_stream.next().await {
1141            let chunk = match chunk {
1142                Ok(c) => c,
1143                Err(_) => break,
1144            };
1145            buf.push_str(&String::from_utf8_lossy(&chunk));
1146
1147            // Process complete SSE lines
1148            while let Some(nl) = buf.find('\n') {
1149                let line = buf[..nl].trim_end_matches('\r').to_owned();
1150                buf = buf[nl + 1..].to_owned();
1151
1152                if !line.starts_with("data: ") { continue; }
1153                let data = &line["data: ".len()..];
1154                if data == "[DONE]" { continue; }
1155
1156                let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1157                let event_type = event["type"].as_str().unwrap_or("");
1158
1159                let maybe_chunk = match event_type {
1160                    "content_block_delta" => {
1161                        let text = event["delta"]["text"].as_str().unwrap_or("");
1162                        if text.is_empty() { continue; }
1163                        Some(json!({
1164                            "id": id,
1165                            "object": "chat.completion.chunk",
1166                            "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1167                        }))
1168                    }
1169                    "message_delta" => {
1170                        let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1171                        let finish = if stop_reason == "end_turn" { "stop" } else { stop_reason };
1172                        Some(json!({
1173                            "id": id,
1174                            "object": "chat.completion.chunk",
1175                            "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1176                        }))
1177                    }
1178                    _ => None,
1179                };
1180
1181                if let Some(c) = maybe_chunk {
1182                    let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
1183                    yield Ok(bytes::Bytes::from(out));
1184                }
1185            }
1186        }
1187
1188        yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
1189    }
1190}