Skip to main content

shunt/
proxy.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::credential::Credential;
16use crate::forwarder::Forwarder;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24    config: Arc<Config>,
25    forwarder: Arc<Forwarder>,
26    state: StateStore,
27    /// Live credentials — can be refreshed at runtime without restarting.
28    credentials: Arc<RwLock<HashMap<String, Credential>>>,
29    /// Per-account mutex that serialises concurrent token-refresh attempts.
30    ///
31    /// When multiple in-flight requests hit a 401 for the same account at the
32    /// same time, only one should call the upstream OAuth endpoint; the others
33    /// should wait and then re-use the fresh token instead of each making their
34    /// own refresh call (which would rotate the refresh_token out from under the
35    /// others and cause cascading auth failures).
36    refresh_locks: Arc<std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
37    /// Epoch-ms when this proxy instance started.
38    started_ms: u64,
39    /// If set, /v1/chat/completions requests are translated and forwarded here
40    /// (the Anthropic proxy base URL, e.g. "http://127.0.0.1:8082").
41    anthropic_base_url: Option<String>,
42}
43
44pub fn create_app(config: Config) -> anyhow::Result<Router> {
45    let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
46    Ok(app)
47}
48
49/// Shared live credentials map — can be written to without restarting the proxy.
50pub type LiveCredentials = Arc<RwLock<HashMap<String, Credential>>>;
51
52/// Create a pure proxy app (no management routes).
53/// Registers /v1/messages, /v1/chat/completions, /v1/models, and a fallback.
54/// Build a shared `AppState` and the `LiveCredentials` handle it references.
55fn build_app_state(
56    config: Config,
57    state: StateStore,
58    anthropic_base_url: Option<String>,
59) -> anyhow::Result<(AppState, LiveCredentials)> {
60    let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
61
62    for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
63        state.set_auth_failed(&a.name);
64    }
65
66    let credentials: LiveCredentials = Arc::new(RwLock::new(
67        config.accounts.iter()
68            .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
69            .collect::<HashMap<_, _>>(),
70    ));
71
72    let app_state = AppState {
73        config: Arc::new(config),
74        forwarder: Arc::new(forwarder),
75        state,
76        credentials: Arc::clone(&credentials),
77        refresh_locks: Arc::new(std::sync::Mutex::new(HashMap::new())),
78        started_ms: now_ms(),
79        anthropic_base_url,
80    };
81
82    Ok((app_state, credentials))
83}
84
85pub fn create_proxy_app(
86    config: Config,
87    state: StateStore,
88    anthropic_base_url: Option<String>,
89) -> anyhow::Result<(Router, LiveCredentials)> {
90    let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
91
92    let app = Router::new()
93        .route("/v1/messages", post(proxy_handler))
94        .route("/v1/messages/count_tokens", post(proxy_handler))
95        .route("/v1/chat/completions", post(openai_compat_handler))
96        .route("/v1/models", get(openai_models_handler))
97        .fallback(proxy_handler)
98        .with_state(app_state);
99
100    Ok((app, credentials))
101}
102
103/// Create a control plane app (management routes only — sees ALL accounts).
104/// Registers /health, /status, /use.
105pub fn create_control_app(
106    config: Config,
107    state: StateStore,
108) -> anyhow::Result<Router> {
109    let (app_state, _) = build_app_state(config, state, None)?;
110
111    let app = Router::new()
112        .route("/health", get(health))
113        .route("/status", get(status_handler))
114        .route("/use", post(use_handler))
115        .with_state(app_state);
116
117    Ok(app)
118}
119
120/// Combined app used by tests and the single-port fallback mode.
121/// Includes both proxy routes and management routes (/health, /status, /use)
122/// sharing a single AppState so state changes are visible across all routes.
123pub fn create_app_with_state(
124    config: Config,
125    state: StateStore,
126    anthropic_base_url: Option<String>,
127) -> anyhow::Result<(Router, LiveCredentials)> {
128    let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
129
130    let app = Router::new()
131        // Management routes
132        .route("/health", get(health))
133        .route("/status", get(status_handler))
134        .route("/use", post(use_handler))
135        // Proxy routes
136        .route("/v1/messages", post(proxy_handler))
137        .route("/v1/messages/count_tokens", post(proxy_handler))
138        .route("/v1/chat/completions", post(openai_compat_handler))
139        .route("/v1/models", get(openai_models_handler))
140        .fallback(proxy_handler)
141        .with_state(app_state);
142
143    Ok((app, credentials))
144}
145
146async fn health() -> impl IntoResponse {
147    axum::Json(json!({"status": "ok"}))
148}
149
150async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
151    let account_states = s.state.account_states();
152    let quotas = s.state.quota_snapshot();
153    let rate_limits = s.state.rate_limit_snapshot();
154
155    let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
156        let st = account_states.get(&a.name);
157        let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
158            "reauth_required"
159        } else if st.map(|s| s.disabled).unwrap_or(false) {
160            "disabled"
161        } else if s.state.is_available(&a.name) {
162            "available"
163        } else {
164            "cooling"
165        };
166
167        let quota = quotas.get(&a.name);
168        let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
169        let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
170        let tokens_used = quota.map(|q| json!({
171            "input": q.input_tokens,
172            "output": q.output_tokens,
173            "total": q.total_tokens(),
174        }));
175
176        let rl = rate_limits.get(&a.name);
177        let rate_limit = rl.map(|r| json!({
178            "utilization_5h": r.utilization_5h,
179            "reset_5h": r.reset_5h,
180            "status_5h": r.status_5h,
181            "utilization_7d": r.utilization_7d,
182            "reset_7d": r.reset_7d,
183            "status_7d": r.status_7d,
184            "representative_claim": r.representative_claim,
185            "updated_ms": r.updated_ms,
186        }));
187
188        let acc_state = account_states.get(&a.name);
189        let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
190        let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
191        let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
192        let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
193        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
194        let reset_5h = rl.and_then(|r| r.reset_5h);
195        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
196        let reset_7d = rl.and_then(|r| r.reset_7d);
197        let available = s.state.is_available(&a.name);
198
199        json!({
200            "name": a.name,
201            "email": email,
202            "plan_type": a.plan_type,
203            "provider": a.provider.to_string(),
204            "status": avail_status,
205            "available": available,
206            "disabled": disabled,
207            "auth_failed": auth_failed,
208            "cooldown_until_ms": cooldown_until_ms,
209            "utilization_5h": utilization_5h,
210            "reset_5h": reset_5h,
211            "utilization_7d": utilization_7d,
212            "reset_7d": reset_7d,
213            "window_expires_ms": window_expires_ms,
214            "tokens_used": tokens_used,
215            "rate_limit": rate_limit,
216        })
217    }).collect();
218
219    let recent_requests = s.state.recent_requests_snapshot();
220    let savings = s.state.savings_snapshot();
221
222    axum::Json(json!({
223        "version": env!("CARGO_PKG_VERSION"),
224        "started_ms": s.started_ms,
225        "accounts": accounts,
226        "pinned_account": s.state.get_pinned(),
227        "last_used_account": s.state.get_last_used(),
228        "recent_requests": recent_requests,
229        "savings": savings,
230    }))
231}
232
233async fn use_handler(
234    State(s): State<AppState>,
235    axum::Json(body): axum::Json<serde_json::Value>,
236) -> Response {
237    let account = body["account"].as_str().map(|s| s.to_owned());
238    // Validate the account name exists (unless clearing to auto)
239    if let Some(ref name) = account {
240        if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
241            return (StatusCode::BAD_REQUEST, axum::Json(json!({
242                "error": format!("unknown account '{name}'")
243            }))).into_response();
244        }
245        let pinned = if name == "auto" { None } else { Some(name.clone()) };
246        s.state.set_pinned(pinned);
247        axum::Json(json!({ "pinned": name })).into_response()
248    } else {
249        s.state.set_pinned(None);
250        axum::Json(json!({ "pinned": null })).into_response()
251    }
252}
253
254fn now_ms() -> u64 {
255    use std::time::{SystemTime, UNIX_EPOCH};
256    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
257}
258
259async fn proxy_handler(
260    State(s): State<AppState>,
261    req: Request,
262) -> Result<Response, ProxyError> {
263    // Remote auth: if a remote_key is configured, the client must supply it as x-api-key.
264    if let Some(ref expected) = s.config.server.remote_key {
265        let provided = req.headers()
266            .get("x-api-key")
267            .and_then(|v| v.to_str().ok())
268            .unwrap_or("");
269        if provided != expected {
270            return Err(ProxyError::Unauthorized);
271        }
272    }
273
274    let method = req.method().as_str().to_owned();
275    let path = req.uri().path().to_owned();
276    let headers = req.headers().clone();
277
278    let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
279        .await
280        .map_err(|_| ProxyError::BodyRead)?;
281
282    let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
283        .ok()
284        .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
285        .unwrap_or_default();
286    let req_start_ms = now_ms();
287
288    let fp = router::fingerprint(&body_bytes);
289    let fp_ref = fp.as_deref();
290
291    let mut tried: HashSet<String> = HashSet::new();
292    // Track accounts we've already attempted a token refresh for this request.
293    let mut refreshed: HashSet<String> = HashSet::new();
294    // Total wait budget: up to 5 hours (Claude's rate-limit reset window).
295    let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
296
297    loop {
298        let account = match router::pick_account(
299            &s.config.accounts, &s.state, fp_ref, &tried,
300            s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
301        ) {
302            Some(a) => a,
303            None => {
304                // Check whether any accounts are just temporarily cooling down
305                // (429/529 backoff) rather than permanently disabled / auth_failed.
306                // If so, wait for the soonest one to recover and retry.
307                let account_states = s.state.account_states();
308                let now = now_ms();
309                let soonest_ms = s.config.accounts.iter()
310                    .filter_map(|a| {
311                        let st = account_states.get(&a.name)?;
312                        if st.disabled { return None; } // auth_failed or permanently off
313                        if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
314                    })
315                    .min();
316
317                match soonest_ms {
318                    Some(wake_ms) if wake_ms <= wait_deadline_ms => {
319                        let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; // +50 ms buffer
320                        warn!(wait_ms, "all accounts cooling — waiting for next available account");
321                        tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
322                        tried.clear(); // accounts may have recovered; try them again
323                    }
324                    _ => return Err(ProxyError::AllAccountsUnavailable),
325                }
326                continue;
327            }
328        };
329
330        let account_name = account.name.clone();
331
332        // Use the live (possibly refreshed) token rather than the one baked into config.
333        // For OpenAI/chatgpt.com accounts, Credential::bearer_token() returns id_token
334        // (short-lived OIDC JWT) which chatgpt.com requires. For all other providers it
335        // returns access_token. API-key accounts return the key directly.
336        let token = {
337            let creds = s.credentials.read().await;
338            let cred = creds.get(&account_name)
339                .cloned()
340                .or_else(|| account.credential.clone());
341            match cred {
342                Some(c) => c.bearer_token().to_owned(),
343                None => String::new(),
344            }
345        };
346
347        // Detect request and account protocols.  When they differ, translate
348        // the request body + path before forwarding and translate the response
349        // back so the client always sees its native wire format.
350        let req_is_anthropic = path.starts_with("/v1/messages");
351        let acct_is_anthropic = account.provider.wire_protocol()
352            == crate::provider::WireProtocol::Anthropic;
353        // chatgpt.com (Provider::OpenAI) uses a proprietary backend-api path + sentinel token.
354        // All other OpenAI-compat providers (OpenAIApi, Groq, Mistral, …) use /v1/chat/completions.
355        let acct_is_chatgpt = matches!(account.provider, Provider::OpenAI);
356
357        let (fwd_path, fwd_body, mut fwd_headers) = if req_is_anthropic == acct_is_anthropic {
358            // Same wire protocol — pass through unchanged.
359            (path.clone(), body_bytes.clone(), headers.clone())
360        } else if req_is_anthropic && acct_is_chatgpt {
361            // Anthropic client → chatgpt.com account: translate to backend-api format.
362            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
363            let translated = translate_anthropic_req_to_chatgpt(&val);
364            let mut h = headers.clone();
365            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
366                h.remove(*name);
367            }
368            (
369                "/backend-api/conversation".to_owned(),
370                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
371                h,
372            )
373        } else if req_is_anthropic {
374            // Anthropic client → standard OpenAI-compat account (OpenAIApi, Groq, Mistral, …).
375            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
376            let translated = translate_anthropic_req_to_openai(val);
377            let mut h = headers.clone();
378            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
379                h.remove(*name);
380            }
381            (
382                "/v1/chat/completions".to_owned(),
383                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
384                h,
385            )
386        } else {
387            // OpenAI client → Anthropic account: translate O→A.
388            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
389            let translated = translate_to_anthropic(val);
390            (
391                "/v1/messages".to_owned(),
392                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
393                headers.clone(),
394            )
395        };
396
397        // Resolve upstream URL: per-account override (set at load time for non-primary
398        // providers, or explicitly in tests) → config server URL.
399        let upstream = account.upstream_url.as_deref()
400            .unwrap_or(&s.config.server.upstream_url);
401
402        // Inject chatgpt.com sentinel token — only for the chatgpt.com proprietary path.
403        // Wrap in tokio::time::timeout (3s) to guarantee we don't block on Cloudflare challenges.
404        if req_is_anthropic && acct_is_chatgpt {
405            tracing::info!(account = %account_name, upstream = %upstream, "routing to chatgpt.com — fetching sentinel");
406            let sentinel_client = reqwest::Client::builder()
407                .timeout(std::time::Duration::from_secs(3))
408                .build()
409                .unwrap_or_default();
410            let sentinel_opt = tokio::time::timeout(
411                std::time::Duration::from_secs(3),
412                fetch_sentinel_token(&sentinel_client, upstream, &token),
413            ).await.ok().flatten();
414            if let Some(sentinel) = sentinel_opt {
415                if let Ok(name) = axum::http::header::HeaderName::from_bytes(
416                    b"openai-sentinel-chat-requirements-token",
417                ) {
418                    if let Ok(val) = axum::http::HeaderValue::from_str(&sentinel) {
419                        fwd_headers.insert(name, val);
420                    }
421                }
422            }
423        }
424
425        // Apply a hard 15s cap only for chatgpt.com: Cloudflare may hold the TCP connection
426        // open indefinitely for certain TLS fingerprints.  Standard API providers don't need this.
427        let response = if acct_is_chatgpt {
428            tracing::info!(account = %account_name, path = %fwd_path, "forwarding to chatgpt.com (15s cap)");
429            match tokio::time::timeout(
430                std::time::Duration::from_secs(15),
431                s.forwarder.forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token),
432            ).await {
433                Ok(Ok(r)) => r,
434                Ok(Err(e)) => {
435                    error!(account = %account_name, "chatgpt.com forward error: {:#}", e);
436                    s.state.set_cooldown(&account_name, 5 * 60_000);
437                    tried.insert(account_name);
438                    continue;
439                }
440                Err(_) => {
441                    warn!(account = %account_name, "chatgpt.com request timed out (Cloudflare) — cooling 5min");
442                    s.state.set_cooldown(&account_name, 5 * 60_000);
443                    tried.insert(account_name);
444                    continue;
445                }
446            }
447        } else {
448            s.forwarder
449                .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
450                .await
451                .map_err(|e| {
452                    error!("Forward error: {:#}", e);
453                    ProxyError::Upstream
454                })?
455        };
456
457        match response.status().as_u16() {
458            200..=299 => {
459                s.state.set_last_used(&account_name);
460                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
461                    s.state.update_rate_limits(&account_name, info);
462                }
463                // Translate response back to the client's expected protocol.
464                let response = if req_is_anthropic == acct_is_anthropic {
465                    response
466                } else if req_is_anthropic && acct_is_chatgpt {
467                    // Got chatgpt.com response; client expects Anthropic.
468                    translate_response_chatgpt_to_anthropic(response, &model).await
469                } else if req_is_anthropic {
470                    // Got standard OpenAI-compat response; client expects Anthropic.
471                    translate_response_openai_to_anthropic(response, &model).await
472                } else {
473                    // Got Anthropic response; client expects OpenAI.
474                    translate_response_anthropic_to_openai(response).await
475                };
476                return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
477            }
478            429 => {
479                let info = account.provider.parse_rate_limits(response.headers());
480                // Sleep until the actual reset time if the headers tell us when that is;
481                // otherwise fall back to 60s so we don't hammer the API.
482                let cooldown_ms = info.as_ref()
483                    .and_then(|i| i.reset_5h.or(i.reset_7d))
484                    .map(|reset_secs| {
485                        let reset_ms = reset_secs.saturating_mul(1_000);
486                        reset_ms.saturating_sub(now_ms()).saturating_add(500) // +500ms buffer
487                    })
488                    .unwrap_or(60_000);
489                warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
490                if let Some(info) = info {
491                    s.state.update_rate_limits(&account_name, info);
492                }
493                s.state.set_cooldown(&account_name, cooldown_ms);
494                if cooldown_ms >= 5 * 60_000 {
495                    let mins = cooldown_ms / 60_000;
496                    notify(
497                        "shunt: Rate Limited",
498                        &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
499                        "Ping",
500                    );
501                }
502                tried.insert(account_name);
503            }
504            529 => {
505                warn!(account = %account_name, "529 overloaded — cooling 30s");
506                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
507                    s.state.update_rate_limits(&account_name, info);
508                }
509                s.state.set_cooldown(&account_name, 30_000);
510                tried.insert(account_name);
511            }
512            401 => {
513                if !refreshed.contains(&account_name) {
514                    // Access token invalidated (e.g. user logged out) — try refresh.
515                    //
516                    // Acquire the per-account refresh lock so concurrent requests
517                    // for the same account serialise here. The first waiter to get
518                    // the lock does the actual OAuth refresh; subsequent waiters
519                    // re-check credentials and skip the refresh if the token was
520                    // already rotated while they were queued.
521                    let account_lock = {
522                        let mut locks = s.refresh_locks.lock().unwrap();
523                        locks.entry(account_name.clone())
524                            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
525                            .clone()
526                    };
527                    let _guard = account_lock.lock().await;
528
529                    // Re-read credentials after acquiring the lock — another task
530                    // may have already refreshed while we were waiting.
531                    let cred_before = {
532                        let creds = s.credentials.read().await;
533                        creds.get(&account_name).cloned()
534                            .or_else(|| account.credential.clone())
535                    };
536                    let Some(cred) = cred_before else {
537                        tried.insert(account_name);
538                        continue;
539                    };
540
541                    // Check if the token already changed while we were waiting.
542                    let token_before = cred.access_token().to_owned();
543                    let already_refreshed = {
544                        let creds = s.credentials.read().await;
545                        creds.get(&account_name)
546                            .map(|c| c.access_token() != token_before)
547                            .unwrap_or(false)
548                    };
549
550                    if already_refreshed {
551                        // Another concurrent request already refreshed — just retry.
552                        warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
553                        refreshed.insert(account_name);
554                    } else if let Some(oauth_cred) = cred.as_oauth() {
555                        // OAuth account — attempt token refresh.
556                        match tokio::time::timeout(
557                            std::time::Duration::from_secs(10),
558                            account.provider.refresh_token(oauth_cred),
559                        ).await {
560                            Ok(Ok(fresh)) => {
561                                warn!(account = %account_name, "401 — token refreshed, retrying");
562                                {
563                                    let mut creds = s.credentials.write().await;
564                                    creds.insert(account_name.clone(), Credential::Oauth(fresh.clone()));
565                                }
566                                // Persist to disk so the refreshed token survives a restart.
567                                let name = account_name.clone();
568                                let fresh = fresh.clone();
569                                tokio::task::spawn_blocking(move || {
570                                    let mut store = CredentialsStore::load();
571                                    store.accounts.insert(name, Credential::Oauth(fresh.clone()));
572                                    store.save().ok();
573                                    if fresh.id_token.is_some() {
574                                        crate::oauth::write_codex_auth_file(&fresh);
575                                    }
576                                });
577                                // Mark as refreshed but don't add to tried — retry this account.
578                                refreshed.insert(account_name);
579                            }
580                            _ => {
581                                // Refresh failed/timed out — cool down, don't permanently disable.
582                                error!(account = %account_name, "401 — token refresh failed, cooling 5min");
583                                s.state.set_cooldown(&account_name, 5 * 60_000);
584                                tried.insert(account_name);
585                            }
586                        }
587                    } else {
588                        // API-key account — 401 means the key is invalid; no refresh possible.
589                        error!(account = %account_name, "401 — API key rejected, cooling 5min");
590                        s.state.set_cooldown(&account_name, 5 * 60_000);
591                        tried.insert(account_name);
592                    }
593                } else {
594                    // Already refreshed once and still 401 — cool down this account.
595                    error!(account = %account_name, "401 after refresh — cooling 5min");
596                    s.state.set_cooldown(&account_name, 5 * 60_000);
597                    tried.insert(account_name);
598                }
599            }
600            403 => {
601                // Forbidden — could be a Cloudflare challenge (non-Anthropic providers)
602                // or a genuine subscription/org block (Anthropic). Use a short cooldown
603                // for non-Anthropic accounts so a CF block doesn't lock them out for 30m.
604                if acct_is_anthropic {
605                    error!(account = %account_name, "403 forbidden — cooling 30min");
606                    s.state.set_cooldown(&account_name, 30 * 60_000);
607                    notify(
608                        "shunt: Account Forbidden",
609                        &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
610                        "Basso",
611                    );
612                } else {
613                    warn!(account = %account_name, "403 from chatgpt.com (Cloudflare) — cooling 5min");
614                    s.state.set_cooldown(&account_name, 5 * 60_000);
615                }
616                tried.insert(account_name);
617            }
618            _ => {
619                // 400, 404, 500, etc. — return as-is, no retry
620                return Ok(response);
621            }
622        }
623    }
624}
625
626// ---------------------------------------------------------------------------
627// Usage extraction
628// ---------------------------------------------------------------------------
629
630/// Intercept a successful response to record token usage, then pass it through.
631///
632/// - Streaming: wraps the body stream with an SSE scanner (zero latency).
633/// - Non-streaming: buffers the body, parses usage, rebuilds the response.
634async fn tap_usage(
635    resp: Response,
636    state: &StateStore,
637    account: &str,
638    model: &str,
639    req_start_ms: u64,
640) -> Response {
641    use axum::body::Body;
642    use crate::state::RequestLog;
643
644    if quota::is_streaming_response(&resp) {
645        let state = state.clone();
646        let account = account.to_owned();
647        let model = model.to_owned();
648        let on_complete = Arc::new(move |input: u64, output: u64| {
649            state.record_usage(&account, input, output);
650            state.record_global(&model, input, output);
651            state.record_request(RequestLog {
652                ts_ms: req_start_ms,
653                account: account.clone(),
654                model: model.clone(),
655                status: 200,
656                input_tokens: input,
657                output_tokens: output,
658                duration_ms: now_ms().saturating_sub(req_start_ms),
659            });
660        });
661        let (parts, body) = resp.into_parts();
662        let wrapped = quota::wrap_streaming_body(body, on_complete);
663        return Response::from_parts(parts, wrapped);
664    }
665
666    // Non-streaming: buffer, extract, rebuild
667    let (parts, body) = resp.into_parts();
668    let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
669        Ok(b) => b,
670        Err(_) => return Response::from_parts(parts, Body::empty()),
671    };
672    let (input, output) = quota::extract_usage_from_json(&bytes);
673    state.record_usage(account, input, output);
674    state.record_global(model, input, output);
675    state.record_request(RequestLog {
676        ts_ms: req_start_ms,
677        account: account.to_owned(),
678        model: model.to_owned(),
679        status: 200,
680        input_tokens: input,
681        output_tokens: output,
682        duration_ms: now_ms().saturating_sub(req_start_ms),
683    });
684    Response::from_parts(parts, Body::from(bytes))
685}
686
687
688// ---------------------------------------------------------------------------
689// Rate limit prefetch
690// ---------------------------------------------------------------------------
691
692/// For any account with no rate-limit data yet, make a cheap request directly
693/// to the upstream API so we populate metrics without waiting for a real user
694/// request. Runs as a background task after startup.
695pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
696    let client = reqwest::Client::builder()
697        .timeout(std::time::Duration::from_secs(20))
698        .build()
699        .unwrap_or_default();
700
701    for account in &config.accounts {
702        // Skip if we already have data for this account.
703        let rl = state.rate_limit_snapshot();
704        if let Some(r) = rl.get(&account.name) {
705            if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
706                continue;
707            }
708        }
709
710        // Skip accounts with no credentials or no prefetch support.
711        let cred = match account.credential.clone() {
712            Some(c) => c,
713            None => continue,
714        };
715
716        let Some((path, body)) = account.provider.prefetch_request() else {
717            // No POST prefetch for this provider — do a lightweight GET auth check instead.
718            if let Some(probe_path) = account.provider.auth_probe_get_path() {
719                auth_probe_get(&client, probe_path, account, &state).await;
720            }
721            continue;
722        };
723        let url = format!("{}{}", config.server.upstream_url, path);
724
725        let resp = prefetch_send(&client, &url, &account.provider, cred.bearer_token(), &body).await;
726
727        let r = match resp {
728            Ok(r) => r,
729            Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
730        };
731
732        if r.status() == reqwest::StatusCode::UNAUTHORIZED {
733            tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
734            let Some(oauth_cred) = cred.as_oauth() else {
735                // API-key account — 401 during prefetch means the key is invalid.
736                tracing::error!(account = %account.name, "prefetch 401 — API key rejected");
737                state.set_auth_failed(&account.name);
738                continue;
739            };
740            let fresh = match account.provider.refresh_token(oauth_cred).await {
741                Ok(f) => f,
742                Err(e) => {
743                    tracing::warn!(account = %account.name, "token refresh failed: {e}");
744                    state.set_auth_failed(&account.name);
745                    continue;
746                }
747            };
748            let mut store = crate::config::CredentialsStore::load();
749            store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
750            store.save().ok();
751            if fresh.id_token.is_some() {
752                crate::oauth::write_codex_auth_file(&fresh);
753            }
754            // Update live credentials so the proxy uses the fresh token immediately.
755            live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
756
757            match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
758                Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
759                    tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
760                    state.set_auth_failed(&account.name);
761                }
762                Ok(r2) => {
763                    if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
764                        state.update_rate_limits(&account.name, info);
765                    }
766                }
767                Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
768            }
769        } else {
770            tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
771            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
772                state.update_rate_limits(&account.name, info);
773            }
774        }
775    }
776}
777
778/// Build and send a prefetch request for the given provider + token.
779async fn prefetch_send(
780    client: &reqwest::Client,
781    url: &str,
782    provider: &crate::provider::Provider,
783    token: &str,
784    body: &serde_json::Value,
785) -> anyhow::Result<reqwest::Response> {
786    let mut headers = reqwest::header::HeaderMap::new();
787    provider.inject_auth_headers(&mut headers, token)?;
788    for (name, value) in provider.prefetch_extra_headers() {
789        headers.insert(
790            reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
791            reqwest::header::HeaderValue::from_static(value),
792        );
793    }
794    Ok(client.post(url).headers(headers).json(body).send().await?)
795}
796
797/// GET a cheap endpoint to verify credentials are still valid for providers that
798/// don't expose rate-limit headers (e.g. OpenAI). On 401, attempts a token refresh;
799/// marks the account as `reauth_required` if the refresh also fails.
800async fn auth_probe_get(
801    client: &reqwest::Client,
802    path: &str,
803    account: &crate::config::AccountConfig,
804    state: &StateStore,
805) {
806    let cred = match account.credential.clone() {
807        Some(c) => c,
808        None => return,
809    };
810    let upstream = account.upstream_url.as_deref()
811        .unwrap_or_else(|| account.provider.default_upstream_url());
812    let url = format!("{}{}", upstream, path);
813
814    let do_get = |token: &str| -> reqwest::RequestBuilder {
815        let mut headers = reqwest::header::HeaderMap::new();
816        let _ = account.provider.inject_auth_headers(&mut headers, token);
817        client.get(&url).headers(headers)
818    };
819
820    let resp = match do_get(cred.bearer_token()).send().await {
821        Ok(r) => r,
822        Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
823    };
824
825    if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
826        tracing::info!(account = %account.name, "auth probe: token rejected, refreshing");
827        let Some(oauth_cred) = cred.as_oauth() else {
828            // API-key account — key is invalid; no refresh possible.
829            tracing::error!(account = %account.name, "auth probe 401 — API key rejected");
830            state.set_auth_failed(&account.name);
831            return;
832        };
833        let fresh = match account.provider.refresh_token(oauth_cred).await {
834            Ok(f) => f,
835            Err(e) => {
836                tracing::warn!(account = %account.name, "token refresh failed: {e}");
837                state.set_auth_failed(&account.name);
838                return;
839            }
840        };
841        let mut store = crate::config::CredentialsStore::load();
842        store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
843        store.save().ok();
844        if fresh.id_token.is_some() {
845            crate::oauth::write_codex_auth_file(&fresh);
846        }
847
848        let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
849        match do_get(fresh_token).send().await {
850            Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
851                tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
852                state.set_auth_failed(&account.name);
853            }
854            Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
855            Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
856        }
857    } else {
858        tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
859        // Access token is valid. Do NOT refresh here — rotating the refresh_token races
860        // with codex CLI, which also tries to refresh at startup using the same token.
861        // Proactive refreshing is handled solely by openai_token_refresh_loop.
862    }
863}
864
865// ---------------------------------------------------------------------------
866// Proactive OpenAI token refresh loop
867// ---------------------------------------------------------------------------
868
869/// Returns true if the access_token inside `cred` has fewer than `threshold_mins`
870/// minutes remaining. Falls back to the stored `expires_at` if the JWT cannot be decoded.
871fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
872    let now_ms = std::time::SystemTime::now()
873        .duration_since(std::time::UNIX_EPOCH)
874        .unwrap_or_default()
875        .as_millis() as u64;
876    let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
877        .unwrap_or(cred.expires_at);
878    exp_ms < now_ms + threshold_mins * 60 * 1_000
879}
880
881/// Sync live_creds from auth.json if auth.json has a newer token.
882///
883/// Codex CLI refreshes its own token and writes auth.json. Before we refresh,
884/// we pull that in so we don't use a stale refresh_token that codex already rotated.
885async fn sync_live_creds_from_auth_json(
886    account_name: &str,
887    live_creds: &LiveCredentials,
888) {
889    let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
890    let current_exp = live_creds.read().await
891        .get(account_name)
892        .and_then(|c| c.as_oauth())
893        .map(|c| c.expires_at)
894        .unwrap_or(0);
895    if from_file.expires_at > current_exp {
896        tracing::info!(account = %account_name, "synced fresher token from auth.json");
897        live_creds.write().await.insert(account_name.to_owned(), Credential::Oauth(from_file));
898    }
899}
900
901/// Perform a single proactive refresh for one account and persist the result.
902async fn do_proactive_refresh(
903    account: &crate::config::AccountConfig,
904    creds: &crate::oauth::OAuthCredential,
905    live_creds: &LiveCredentials,
906    state: &StateStore,
907) {
908    tracing::info!(account = %account.name, "proactive OpenAI token refresh");
909    match account.provider.refresh_token(creds).await {
910        Ok(fresh) => {
911            tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
912            {
913                let mut map = live_creds.write().await;
914                map.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
915            }
916            let mut store = crate::config::CredentialsStore::load();
917            store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
918            store.save().ok();
919            if fresh.id_token.is_some() {
920                crate::oauth::write_codex_auth_file(&fresh);
921            }
922            state.clear_auth_failed(&account.name);
923        }
924        Err(e) => {
925            tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
926            state.set_auth_failed(&account.name);
927        }
928    }
929}
930
931
932/// Keeps shunt's live credentials in sync with Codex CLI's auth.json.
933///
934/// Strategy: never proactively rotate the refresh_token — that races with
935/// Codex CLI's own refresh logic and causes "invalid_grant" errors. Instead,
936/// just periodically sync from auth.json so shunt picks up whatever Codex wrote.
937/// On-demand refresh (401 handler) covers the case where Codex isn't running
938/// and the token has actually expired.
939pub async fn openai_token_refresh_loop(
940    config: Arc<Config>,
941    state: StateStore,
942    live_creds: LiveCredentials,
943) {
944    // Startup: sync from auth.json first (Codex may have refreshed since shunt last ran).
945    for account in config.accounts.iter()
946        .filter(|a| a.provider == crate::provider::Provider::OpenAI)
947    {
948        if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
949            continue;
950        }
951        sync_live_creds_from_auth_json(&account.name, &live_creds).await;
952
953        let creds = {
954            let map = live_creds.read().await;
955            map.get(&account.name).cloned().or_else(|| account.credential.clone())
956        };
957        if let Some(creds) = creds {
958            if let Some(oauth) = creds.as_oauth() {
959                if access_token_expires_soon(oauth, 30) {
960                    // access_token is nearly expired — refresh now so shunt can serve requests immediately.
961                    do_proactive_refresh(account, oauth, &live_creds, &state).await;
962                } else {
963                    tracing::info!(account = %account.name, "access_token fresh at startup");
964                }
965            }
966        }
967    }
968
969    // Periodic sync every 5 minutes — picks up any token Codex CLI has written.
970    // No proactive refresh: Codex owns the refresh lifecycle; shunt uses what Codex produces.
971    loop {
972        tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
973        for account in config.accounts.iter()
974            .filter(|a| a.provider == crate::provider::Provider::OpenAI)
975        {
976            sync_live_creds_from_auth_json(&account.name, &live_creds).await;
977        }
978    }
979}
980
981// ---------------------------------------------------------------------------
982// Error type
983// ---------------------------------------------------------------------------
984
985enum ProxyError {
986    BodyRead,
987    Upstream,
988    AllAccountsUnavailable,
989    Unauthorized,
990}
991
992impl IntoResponse for ProxyError {
993    fn into_response(self) -> Response {
994        let (status, msg) = match self {
995            ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
996            ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
997            ProxyError::AllAccountsUnavailable => {
998                (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
999            }
1000            ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
1001        };
1002
1003        (status, axum::Json(json!({
1004            "type": "error",
1005            "error": {"type": "api_error", "message": msg}
1006        }))).into_response()
1007    }
1008}
1009
1010// ---------------------------------------------------------------------------
1011// Recovery watcher — periodically retries token refresh for auth_failed accounts
1012// ---------------------------------------------------------------------------
1013
1014/// Runs as a background task. Every 2 minutes, tries to refresh tokens for any
1015/// auth_failed account. If refresh succeeds the account is brought back online
1016/// without a process restart. If all accounts remain unrecoverable, fires a
1017/// macOS notification (at most once per hour).
1018pub async fn recovery_watcher(
1019    config: Arc<Config>,
1020    state: StateStore,
1021    credentials: LiveCredentials,
1022) {
1023    use std::time::{Duration, Instant};
1024    const CHECK_INTERVAL: Duration = Duration::from_secs(120);
1025    const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
1026
1027    let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
1028    let mut last_notified: Option<Instant> = None;
1029
1030    loop {
1031        tokio::time::sleep(CHECK_INTERVAL).await;
1032
1033        let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
1034        let failed = state.auth_failed_accounts(&name_refs);
1035        if failed.is_empty() {
1036            last_notified = None;
1037            continue;
1038        }
1039
1040        tracing::warn!(
1041            accounts = ?failed,
1042            "recovery: {} account(s) auth_failed, attempting token refresh",
1043            failed.len()
1044        );
1045
1046        let mut any_recovered = false;
1047
1048        for name in &failed {
1049            let cred = {
1050                let map = credentials.read().await;
1051                map.get(*name).cloned()
1052            };
1053            let Some(cred) = cred else { continue };
1054            if !cred.has_refresh_token() { continue; }
1055            let Some(oauth_cred) = cred.as_oauth().cloned() else { continue };
1056
1057            let provider = config.accounts.iter()
1058                .find(|a| a.name == *name)
1059                .map(|a| a.provider.clone())
1060                .unwrap_or_default();
1061
1062            let result = tokio::time::timeout(
1063                Duration::from_secs(20),
1064                provider.refresh_token(&oauth_cred),
1065            ).await;
1066
1067            match result {
1068                Ok(Ok(fresh)) => {
1069                    tracing::info!(account = %name, "recovery: token refreshed — account back online");
1070                    {
1071                        let mut map = credentials.write().await;
1072                        map.insert(name.to_string(), Credential::Oauth(fresh.clone()));
1073                    }
1074                    let name_owned = name.to_string();
1075                    let fresh_owned = fresh.clone();
1076                    tokio::task::spawn_blocking(move || {
1077                        let mut store = crate::config::CredentialsStore::load();
1078                        store.accounts.insert(name_owned, Credential::Oauth(fresh_owned.clone()));
1079                        store.save().ok();
1080                        if fresh_owned.id_token.is_some() {
1081                            crate::oauth::write_codex_auth_file(&fresh_owned);
1082                        }
1083                    });
1084                    state.clear_auth_failed(name);
1085                    any_recovered = true;
1086                }
1087                Ok(Err(e)) => {
1088                    tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
1089                    notify(
1090                        "shunt: Reauth Required",
1091                        &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
1092                        "Basso",
1093                    );
1094                }
1095                Err(_) => {
1096                    tracing::error!(account = %name, "recovery: token refresh timed out");
1097                    notify(
1098                        "shunt: Reauth Required",
1099                        &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
1100                        "Basso",
1101                    );
1102                }
1103            }
1104        }
1105
1106        if any_recovered {
1107            tracing::info!("recovery: at least one account is back online");
1108            continue;
1109        }
1110
1111        // All accounts still auth_failed after refresh attempts — notify.
1112        let still_failed = state.auth_failed_accounts(&name_refs);
1113        if still_failed.len() == account_names.len() {
1114            let should_notify = last_notified
1115                .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
1116                .unwrap_or(true);
1117            if should_notify {
1118                error!(
1119                    "ALL accounts are offline (auth failed). \
1120                     Run `shunt add-account` to re-authorize."
1121                );
1122                notify(
1123                    "shunt: All Accounts Offline",
1124                    "All accounts need re-authorization. Run `shunt add-account`.",
1125                    "Basso",
1126                );
1127                last_notified = Some(Instant::now());
1128            }
1129        }
1130    }
1131}
1132
1133/// Sends a single lightweight prefetch request for `account` immediately after its
1134/// cooldown expires, so the router has fresh rate-limit headers before the next
1135/// real request arrives.
1136async fn post_cooldown_prefetch(
1137    client: &reqwest::Client,
1138    account: &crate::config::AccountConfig,
1139    token: &str,
1140    state: &StateStore,
1141    upstream_url: &str,
1142) {
1143    let Some((path, body)) = account.provider.prefetch_request() else {
1144        if let Some(probe_path) = account.provider.auth_probe_get_path() {
1145            auth_probe_get(client, probe_path, account, state).await;
1146        }
1147        return;
1148    };
1149    let url = format!("{upstream_url}{path}");
1150    match prefetch_send(client, &url, &account.provider, token, &body).await {
1151        Ok(r) => {
1152            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1153                state.update_rate_limits(&account.name, info);
1154                tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1155            }
1156        }
1157        Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1158    }
1159}
1160
1161/// Watches for account cooldowns expiring and triggers a post-cooldown prefetch
1162/// so each account re-enters rotation with fresh rate-limit metrics.
1163///
1164/// Analogous to `recovery_watcher` (which handles `auth_failed` accounts), but
1165/// for timed cooldowns (429 / 529 / 401 / 403 backoffs). Sleeps precisely until
1166/// the next cooldown deadline rather than polling at a fixed interval.
1167///
1168/// Also handles stale rate-limit data: if an account's rate-limit snapshot is
1169/// older than STALE_RL_MS and the account is available, a lightweight prefetch
1170/// is triggered so the router always has fresh utilization metrics.
1171pub async fn cooldown_watcher(
1172    config: Arc<Config>,
1173    state: StateStore,
1174    credentials: LiveCredentials,
1175) {
1176    /// Re-fetch rate-limit headers if data is older than 1 hour.
1177    const STALE_RL_MS: u64 = 60 * 60_000;
1178
1179    let client = reqwest::Client::builder()
1180        .timeout(std::time::Duration::from_secs(20))
1181        .build()
1182        .unwrap_or_default();
1183
1184    // In-memory: the cooldown_until_ms value we already ran a post-resume for.
1185    // Prevents re-triggering on every poll after expiry.
1186    let mut last_resumed: HashMap<String, u64> = HashMap::new();
1187    // Accounts whose cooldown was long enough (≥5 min) to deserve a "back online" notification.
1188    let mut notify_on_resume: HashSet<String> = HashSet::new();
1189    // Epoch-ms of the last successful stale-prefetch per account.
1190    let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1191
1192    loop {
1193        let states = state.account_states();
1194        let rl_snapshot = state.rate_limit_snapshot();
1195        let now = now_ms();
1196        let mut next_wake_ms: Option<u64> = None;
1197
1198        for account in &config.accounts {
1199            let Some(st) = states.get(&account.name) else { continue };
1200            if st.disabled { continue; } // auth_failed or permanently disabled
1201            let cdl = st.cooldown_until_ms;
1202
1203            if cdl > 0 && cdl <= now {
1204                // Cooldown expired — skip if we already handled this exact deadline
1205                let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1206                if !handled {
1207                    tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1208                    let token = {
1209                        let creds = credentials.read().await;
1210                        creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1211                    };
1212                    if let Some(token) = token {
1213                        post_cooldown_prefetch(
1214                            &client, account, &token, &state,
1215                            &config.server.upstream_url,
1216                        ).await;
1217                    }
1218                    if notify_on_resume.remove(&account.name) {
1219                        notify(
1220                            "shunt: Account Resumed",
1221                            &format!("Account '{}' is back online.", account.name),
1222                            "Glass",
1223                        );
1224                    }
1225                    last_resumed.insert(account.name.clone(), cdl);
1226                    last_stale_prefetch.insert(account.name.clone(), now);
1227                }
1228            } else if cdl > now {
1229                // Still cooling — schedule wake at expiry; flag for notification if long
1230                let remaining = cdl - now;
1231                if remaining >= 5 * 60_000 {
1232                    notify_on_resume.insert(account.name.clone());
1233                }
1234                next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1235            } else {
1236                // Not in cooldown — check for stale rate-limit data
1237                let rl_age = rl_snapshot
1238                    .get(&account.name)
1239                    .map(|r| now.saturating_sub(r.updated_ms))
1240                    .unwrap_or(u64::MAX); // no data → treat as infinitely stale
1241                let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1242                let fetched_ago = now.saturating_sub(last_fetched);
1243
1244                if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1245                    tracing::debug!(
1246                        account = %account.name,
1247                        age_min = rl_age / 60_000,
1248                        "rate-limit data stale — refreshing"
1249                    );
1250                    let token = {
1251                        let creds = credentials.read().await;
1252                        creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1253                    };
1254                    if let Some(token) = token {
1255                        post_cooldown_prefetch(
1256                            &client, account, &token, &state,
1257                            &config.server.upstream_url,
1258                        ).await;
1259                    }
1260                    last_stale_prefetch.insert(account.name.clone(), now);
1261                }
1262            }
1263        }
1264
1265        // Sleep exactly until the next cooldown expires; fall back to 30s poll
1266        let sleep_ms = next_wake_ms
1267            .map(|wake| wake.saturating_sub(now_ms()).max(50))
1268            .unwrap_or(30_000);
1269        tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1270    }
1271}
1272
1273use crate::notify::notify;
1274
1275// ---------------------------------------------------------------------------
1276// OpenAI-compatible API (translates to Anthropic Claude)
1277// ---------------------------------------------------------------------------
1278//
1279// When the OpenAI proxy receives a request at /v1/chat/completions, if an
1280// anthropic_base_url is configured, it translates the request to Anthropic
1281// Messages format and forwards it to the Anthropic proxy (which handles
1282// account selection, token management, and rate limiting).
1283// The response is translated back to OpenAI Chat Completions format.
1284
1285/// Map OpenAI model names → Claude model names.
1286/// Claude model names are passed through unchanged; only OpenAI aliases are remapped.
1287fn map_model(openai_model: &str) -> String {
1288    if openai_model.starts_with("claude-") {
1289        return openai_model.to_owned();
1290    }
1291    match openai_model {
1292        "gpt-4o" | "gpt-4.5" | "o1" | "o1-pro" | "o3" | "o3-pro" | "gpt-5" | "gpt-5.5" => {
1293            "claude-opus-4-6"
1294        }
1295        "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" | "o1-mini" | "o3-mini" => {
1296            "claude-haiku-4-5-20251001"
1297        }
1298        _ => "claude-sonnet-4-6",
1299    }.to_owned()
1300}
1301
1302/// Translate an OpenAI Chat Completions request body to an Anthropic Messages body.
1303fn translate_to_anthropic(body: serde_json::Value) -> serde_json::Value {
1304    let model = body["model"].as_str().unwrap_or("gpt-4o");
1305    let claude_model = map_model(model);
1306
1307    // Extract system message from messages array.
1308    let mut system: Option<String> = None;
1309    let mut messages = Vec::new();
1310    if let Some(arr) = body["messages"].as_array() {
1311        for msg in arr {
1312            let role = msg["role"].as_str().unwrap_or("");
1313            if role == "system" {
1314                // system can be a string or array of content parts
1315                let content = msg["content"].as_str()
1316                    .map(|s| s.to_owned())
1317                    .unwrap_or_else(|| serde_json::to_string(&msg["content"]).unwrap_or_default());
1318                system = Some(content);
1319            } else if role == "tool" {
1320                // OpenAI tool result → Anthropic tool_result content block
1321                let tool_use_id = msg["tool_call_id"].as_str().unwrap_or("").to_owned();
1322                let content = msg["content"].as_str().unwrap_or("").to_owned();
1323                messages.push(json!({
1324                    "role": "user",
1325                    "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}]
1326                }));
1327            } else {
1328                // Check for tool_calls in assistant messages
1329                if let Some(tool_calls) = msg["tool_calls"].as_array() {
1330                    let mut content_blocks: Vec<serde_json::Value> = Vec::new();
1331                    if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
1332                        content_blocks.push(json!({"type": "text", "text": text}));
1333                    }
1334                    for tc in tool_calls {
1335                        content_blocks.push(json!({
1336                            "type": "tool_use",
1337                            "id": tc["id"].as_str().unwrap_or(""),
1338                            "name": tc["function"]["name"].as_str().unwrap_or(""),
1339                            "input": serde_json::from_str::<serde_json::Value>(
1340                                tc["function"]["arguments"].as_str().unwrap_or("{}")
1341                            ).unwrap_or(json!({})),
1342                        }));
1343                    }
1344                    messages.push(json!({"role": "assistant", "content": content_blocks}));
1345                } else {
1346                    let content = msg["content"].as_str().unwrap_or("").to_owned();
1347                    messages.push(json!({ "role": role, "content": content }));
1348                }
1349            }
1350        }
1351    }
1352
1353    let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1354    let stream = body["stream"].as_bool().unwrap_or(false);
1355
1356    let mut req = json!({
1357        "model": claude_model,
1358        "messages": messages,
1359        "max_tokens": max_tokens,
1360        "stream": stream,
1361    });
1362
1363    if let Some(sys) = system {
1364        req["system"] = json!(sys);
1365    }
1366    if let Some(temp) = body.get("temperature") {
1367        req["temperature"] = temp.clone();
1368    }
1369    if let Some(sp) = body.get("stop") {
1370        req["stop_sequences"] = sp.clone();
1371    }
1372
1373    // Translate OpenAI tools → Anthropic tools format
1374    if let Some(tools) = body["tools"].as_array() {
1375        let claude_tools: Vec<serde_json::Value> = tools.iter().filter_map(|t| {
1376            let func = &t["function"];
1377            Some(json!({
1378                "name": func["name"].as_str()?,
1379                "description": func["description"].as_str().unwrap_or(""),
1380                "input_schema": func.get("parameters").cloned().unwrap_or(json!({"type": "object", "properties": {}})),
1381            }))
1382        }).collect();
1383        if !claude_tools.is_empty() {
1384            req["tools"] = json!(claude_tools);
1385        }
1386    }
1387
1388    req
1389}
1390
1391/// Translate a complete (non-streaming) Anthropic Messages response to OpenAI format.
1392fn translate_from_anthropic(body: serde_json::Value) -> serde_json::Value {
1393    let id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1394    let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6").to_owned();
1395
1396    // Extract text content and tool_use blocks.
1397    let mut text_content = String::new();
1398    let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1399    if let Some(blocks) = body["content"].as_array() {
1400        for (idx, block) in blocks.iter().enumerate() {
1401            match block["type"].as_str() {
1402                Some("text") => {
1403                    text_content.push_str(block["text"].as_str().unwrap_or(""));
1404                }
1405                Some("tool_use") => {
1406                    let args = match &block["input"] {
1407                        serde_json::Value::String(s) => s.clone(),
1408                        v => serde_json::to_string(v).unwrap_or_default(),
1409                    };
1410                    tool_calls.push(json!({
1411                        "id": block["id"].as_str().unwrap_or(""),
1412                        "type": "function",
1413                        "index": idx,
1414                        "function": {
1415                            "name": block["name"].as_str().unwrap_or(""),
1416                            "arguments": args,
1417                        }
1418                    }));
1419                }
1420                _ => {}
1421            }
1422        }
1423    }
1424
1425    let stop_reason = body["stop_reason"].as_str().unwrap_or("end_turn");
1426    let finish_reason = match stop_reason {
1427        "end_turn"   => "stop",
1428        "tool_use"   => "tool_calls",
1429        "max_tokens" => "length",
1430        other        => other,
1431    };
1432
1433    let input_tokens = body["usage"]["input_tokens"].as_u64().unwrap_or(0);
1434    let output_tokens = body["usage"]["output_tokens"].as_u64().unwrap_or(0);
1435
1436    let mut message = json!({"role": "assistant", "content": text_content});
1437    if !tool_calls.is_empty() {
1438        message["tool_calls"] = json!(tool_calls);
1439    }
1440
1441    json!({
1442        "id": id,
1443        "object": "chat.completion",
1444        "model": model,
1445        "choices": [{
1446            "index": 0,
1447            "message": message,
1448            "finish_reason": finish_reason,
1449        }],
1450        "usage": {
1451            "prompt_tokens": input_tokens,
1452            "completion_tokens": output_tokens,
1453            "total_tokens": input_tokens + output_tokens,
1454        }
1455    })
1456}
1457
1458fn uuid_v4() -> String {
1459    use crate::oauth::rand_bytes;
1460    let b: [u8; 16] = rand_bytes();
1461    format!("{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
1462        u32::from_be_bytes(b[0..4].try_into().unwrap()),
1463        u16::from_be_bytes(b[4..6].try_into().unwrap()),
1464        u16::from_be_bytes(b[6..8].try_into().unwrap()),
1465        u16::from_be_bytes(b[8..10].try_into().unwrap()),
1466        {
1467            let mut v = 0u64;
1468            for &x in &b[10..16] { v = (v << 8) | x as u64; }
1469            v
1470        }
1471    )
1472}
1473
1474/// GET /v1/models — return Claude models in OpenAI format.
1475async fn openai_models_handler() -> impl IntoResponse {
1476    axum::Json(json!({
1477        "object": "list",
1478        "data": [
1479            { "id": "claude-opus-4-6",           "object": "model", "owned_by": "anthropic" },
1480            { "id": "claude-sonnet-4-6",          "object": "model", "owned_by": "anthropic" },
1481            { "id": "claude-haiku-4-5-20251001",  "object": "model", "owned_by": "anthropic" },
1482        ]
1483    }))
1484}
1485
1486/// POST /v1/chat/completions — translate OpenAI request to Anthropic, proxy through Claude pool.
1487async fn openai_compat_handler(
1488    State(s): State<AppState>,
1489    req: Request,
1490) -> Result<Response, ProxyError> {
1491    let Some(ref anthropic_url) = s.anthropic_base_url else {
1492        // No Anthropic proxy configured — fall back to normal forwarding
1493        return proxy_handler(State(s), req).await;
1494    };
1495
1496    let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1497        .await
1498        .map_err(|_| ProxyError::BodyRead)?;
1499
1500    let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1501        .unwrap_or(json!({}));
1502
1503    let stream = openai_body["stream"].as_bool().unwrap_or(false);
1504    let anthropic_body = translate_to_anthropic(openai_body);
1505
1506    let client = reqwest::Client::builder()
1507        .timeout(std::time::Duration::from_secs(300))
1508        .build()
1509        .map_err(|_| ProxyError::Upstream)?;
1510
1511    let resp = client
1512        .post(format!("{anthropic_url}/v1/messages"))
1513        .header("content-type", "application/json")
1514        .header("anthropic-version", "2023-06-01")
1515        .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1516        .header("x-shunt-compat", "openai")
1517        .json(&anthropic_body)
1518        .send()
1519        .await
1520        .map_err(|_| ProxyError::Upstream)?;
1521
1522    if !resp.status().is_success() {
1523        let status = resp.status();
1524        let body = resp.text().await.unwrap_or_default();
1525        let code = status.as_u16();
1526        return Ok(axum::response::Response::builder()
1527            .status(code)
1528            .header("content-type", "application/json")
1529            .body(axum::body::Body::from(body))
1530            .unwrap());
1531    }
1532
1533    if stream {
1534        // Translate Anthropic SSE stream → OpenAI SSE stream
1535        let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1536        let stream = translate_anthropic_stream(resp, chat_id);
1537        Ok(axum::response::Response::builder()
1538            .status(200)
1539            .header("content-type", "text/event-stream")
1540            .header("cache-control", "no-cache")
1541            .body(axum::body::Body::from_stream(stream))
1542            .unwrap())
1543    } else {
1544        let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1545        let openai_resp = translate_from_anthropic(anthropic_resp);
1546        Ok(axum::Json(openai_resp).into_response())
1547    }
1548}
1549
1550/// Translate Anthropic SSE events to OpenAI SSE format, yielding raw bytes.
1551/// Handles text content, tool_use blocks, and finish reasons.
1552fn translate_anthropic_stream(
1553    resp: reqwest::Response,
1554    chat_id: String,
1555) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1556    use futures_util::StreamExt;
1557
1558    let id = chat_id;
1559    let byte_stream = resp.bytes_stream();
1560
1561    async_stream::stream! {
1562        let mut buf = String::new();
1563        // Per-block state: block_index -> (tool_call_oai_index, tool_id, tool_name)
1564        let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1565        let mut tool_call_count: usize = 0;
1566        futures_util::pin_mut!(byte_stream);
1567
1568        // Send initial role chunk
1569        let init = format!(
1570            "data: {}\n\n",
1571            serde_json::to_string(&json!({
1572                "id": id,
1573                "object": "chat.completion.chunk",
1574                "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1575            })).unwrap()
1576        );
1577        yield Ok(bytes::Bytes::from(init));
1578
1579        while let Some(chunk) = byte_stream.next().await {
1580            let chunk = match chunk {
1581                Ok(c) => c,
1582                Err(_) => break,
1583            };
1584            buf.push_str(&String::from_utf8_lossy(&chunk));
1585
1586            // Process complete SSE lines
1587            while let Some(nl) = buf.find('\n') {
1588                let line = buf[..nl].trim_end_matches('\r').to_owned();
1589                buf = buf[nl + 1..].to_owned();
1590
1591                if !line.starts_with("data: ") { continue; }
1592                let data = &line["data: ".len()..];
1593                if data == "[DONE]" { continue; }
1594
1595                let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1596                let event_type = event["type"].as_str().unwrap_or("");
1597
1598                let maybe_chunk = match event_type {
1599                    "content_block_start" => {
1600                        let block_idx = event["index"].as_u64().unwrap_or(0);
1601                        let cb = &event["content_block"];
1602                        if cb["type"].as_str() == Some("tool_use") {
1603                            let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
1604                            let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
1605                            let oai_idx = tool_call_count;
1606                            tool_call_count += 1;
1607                            tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
1608                            Some(json!({
1609                                "id": id,
1610                                "object": "chat.completion.chunk",
1611                                "choices": [{"index": 0, "delta": {
1612                                    "tool_calls": [{
1613                                        "index": oai_idx,
1614                                        "id": tool_id,
1615                                        "type": "function",
1616                                        "function": {"name": tool_name, "arguments": ""}
1617                                    }]
1618                                }, "finish_reason": null}]
1619                            }))
1620                        } else {
1621                            None
1622                        }
1623                    }
1624                    "content_block_delta" => {
1625                        let block_idx = event["index"].as_u64().unwrap_or(0);
1626                        let delta = &event["delta"];
1627                        match delta["type"].as_str() {
1628                            Some("text_delta") => {
1629                                let text = delta["text"].as_str().unwrap_or("");
1630                                if text.is_empty() { continue; }
1631                                Some(json!({
1632                                    "id": id,
1633                                    "object": "chat.completion.chunk",
1634                                    "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1635                                }))
1636                            }
1637                            Some("input_json_delta") => {
1638                                let args = delta["partial_json"].as_str().unwrap_or("");
1639                                if let Some((oai_idx, _, _)) = tool_blocks.get(&block_idx) {
1640                                    Some(json!({
1641                                        "id": id,
1642                                        "object": "chat.completion.chunk",
1643                                        "choices": [{"index": 0, "delta": {
1644                                            "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
1645                                        }, "finish_reason": null}]
1646                                    }))
1647                                } else {
1648                                    None
1649                                }
1650                            }
1651                            _ => None,
1652                        }
1653                    }
1654                    "message_delta" => {
1655                        let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1656                        let finish = match stop_reason {
1657                            "end_turn"  => "stop",
1658                            "tool_use"  => "tool_calls",
1659                            "max_tokens" => "length",
1660                            other       => other,
1661                        };
1662                        Some(json!({
1663                            "id": id,
1664                            "object": "chat.completion.chunk",
1665                            "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1666                        }))
1667                    }
1668                    _ => None,
1669                };
1670
1671                if let Some(c) = maybe_chunk {
1672                    let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
1673                    yield Ok(bytes::Bytes::from(out));
1674                }
1675            }
1676        }
1677
1678        yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
1679    }
1680}
1681
1682// ---------------------------------------------------------------------------
1683// Cross-protocol translation: Anthropic ↔ OpenAI
1684// ---------------------------------------------------------------------------
1685
1686/// Map Claude model names → OpenAI model names (mirror of `map_model`).
1687fn map_model_to_openai(claude_model: &str) -> &str {
1688    match claude_model {
1689        m if m.contains("opus")  => "gpt-4o",
1690        m if m.contains("haiku") => "gpt-4o-mini",
1691        _                        => "gpt-4o", // sonnet and everything else
1692    }
1693}
1694
1695// ---------------------------------------------------------------------------
1696// ChatGPT backend API translation (chatgpt.com /backend-api/conversation)
1697// ---------------------------------------------------------------------------
1698
1699/// Fetch the sentinel token required by chatgpt.com's backend API.
1700/// Returns None if the request fails or proof-of-work is required.
1701async fn fetch_sentinel_token(client: &reqwest::Client, upstream: &str, token: &str) -> Option<String> {
1702    let url = format!("{}/backend-api/sentinel/chat-requirements", upstream);
1703    let resp = client
1704        .get(&url)
1705        .header("Authorization", format!("Bearer {}", token))
1706        .send()
1707        .await
1708        .ok()?;
1709    if !resp.status().is_success() {
1710        return None;
1711    }
1712    let json: serde_json::Value = resp.json().await.ok()?;
1713    if json["proofofwork"]["required"].as_bool() == Some(true) {
1714        return None;
1715    }
1716    json["token"].as_str().map(ToOwned::to_owned)
1717}
1718
1719/// Map Claude model names → chatgpt.com model names.
1720fn map_model_to_chatgpt(model: &str) -> &str {
1721    if model.contains("opus") {
1722        "gpt-4o"
1723    } else if model.contains("haiku") {
1724        "gpt-4o-mini"
1725    } else {
1726        "gpt-4o"
1727    }
1728}
1729
1730/// Extract flat text from an Anthropic content value (string or content block array).
1731/// Tool-use blocks are rendered as `[Tool: name(args)]`; tool_result blocks are flattened.
1732fn extract_text_from_anthropic_content(content: &serde_json::Value) -> String {
1733    if let Some(s) = content.as_str() {
1734        return s.to_owned();
1735    }
1736    if let Some(arr) = content.as_array() {
1737        let mut text = String::new();
1738        for block in arr {
1739            match block["type"].as_str() {
1740                Some("text") => text.push_str(block["text"].as_str().unwrap_or("")),
1741                Some("tool_use") => {
1742                    let name = block["name"].as_str().unwrap_or("tool");
1743                    let args = serde_json::to_string(&block["input"]).unwrap_or_default();
1744                    text.push_str(&format!("[Tool: {}({})]", name, args));
1745                }
1746                Some("tool_result") => {
1747                    let result = block["content"].as_str()
1748                        .map(|s| s.to_owned())
1749                        .unwrap_or_else(|| serde_json::to_string(&block["content"]).unwrap_or_default());
1750                    text.push_str(&result);
1751                }
1752                _ => {}
1753            }
1754        }
1755        return text;
1756    }
1757    String::new()
1758}
1759
1760/// Translate an Anthropic `/v1/messages` request body to chatgpt.com `/backend-api/conversation` format.
1761/// Tools are stripped — chatgpt.com's backend API does not support tool use.
1762fn translate_anthropic_req_to_chatgpt(body: &serde_json::Value) -> serde_json::Value {
1763    let claude_model = body["model"].as_str().unwrap_or("claude-sonnet-4-6");
1764    let model = map_model_to_chatgpt(claude_model);
1765    let system_prompt = body["system"].as_str().unwrap_or("").to_owned();
1766
1767    let mut messages: Vec<serde_json::Value> = Vec::new();
1768    if let Some(arr) = body["messages"].as_array() {
1769        for msg in arr {
1770            let role = msg["role"].as_str().unwrap_or("user");
1771            let text = extract_text_from_anthropic_content(&msg["content"]);
1772            messages.push(json!({
1773                "id": uuid_v4(),
1774                "author": {"role": role},
1775                "content": {"content_type": "text", "parts": [text]},
1776                "metadata": {}
1777            }));
1778        }
1779    }
1780
1781    json!({
1782        "action": "next",
1783        "messages": messages,
1784        "model": model,
1785        "parent_message_id": uuid_v4(),
1786        "system_prompt": system_prompt,
1787        "history_and_training_disabled": true,
1788        "supports_modapi": false,
1789    })
1790}
1791
1792/// Translate a chatgpt.com non-streaming response to Anthropic format.
1793fn translate_chatgpt_resp_to_anthropic(body: serde_json::Value, model: &str) -> serde_json::Value {
1794    let id = format!("msg_{}", &uuid_v4()[..8]);
1795    let text = body["message"]["content"]["parts"][0]
1796        .as_str()
1797        .unwrap_or("")
1798        .to_owned();
1799    json!({
1800        "id": id,
1801        "type": "message",
1802        "role": "assistant",
1803        "model": model,
1804        "content": [{"type": "text", "text": text}],
1805        "stop_reason": "end_turn",
1806        "stop_sequence": null,
1807        "usage": {"input_tokens": 0, "output_tokens": 0}
1808    })
1809}
1810
1811/// Translate the response back from chatgpt.com format to Anthropic format.
1812/// Handles both streaming and non-streaming responses.
1813async fn translate_response_chatgpt_to_anthropic(resp: Response, model: &str) -> Response {
1814    use axum::body::Body;
1815    let msg_id = format!("msg_{}", &uuid_v4()[..8]);
1816    let model = model.to_owned();
1817
1818    if quota::is_streaming_response(&resp) {
1819        let (mut parts, body) = resp.into_parts();
1820        parts.headers.insert(
1821            axum::http::header::CONTENT_TYPE,
1822            axum::http::HeaderValue::from_static("text/event-stream"),
1823        );
1824        let stream = translate_chatgpt_stream_to_anthropic(body, model, msg_id);
1825        Response::from_parts(parts, Body::from_stream(stream))
1826    } else {
1827        let (mut parts, body) = resp.into_parts();
1828        let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
1829        let chatgpt_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
1830        let anthropic_val = translate_chatgpt_resp_to_anthropic(chatgpt_val, &model);
1831        let out = serde_json::to_vec(&anthropic_val).unwrap_or_default();
1832        parts.headers.insert(
1833            axum::http::header::CONTENT_TYPE,
1834            axum::http::HeaderValue::from_static("application/json"),
1835        );
1836        Response::from_parts(parts, Body::from(out))
1837    }
1838}
1839
1840/// Stream-translate a chatgpt.com SSE response body into Anthropic SSE events.
1841///
1842/// chatgpt.com sends the **full accumulated text** in each chunk (not a delta),
1843/// so we track `prev_len` and compute deltas ourselves.
1844fn translate_chatgpt_stream_to_anthropic(
1845    body: axum::body::Body,
1846    model: String,
1847    msg_id: String,
1848) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1849    use futures_util::StreamExt;
1850
1851    async_stream::stream! {
1852        let start_evt = format!(
1853            "event: message_start\ndata: {}\n\nevent: ping\ndata: {{\"type\":\"ping\"}}\n\n",
1854            serde_json::to_string(&json!({
1855                "type": "message_start",
1856                "message": {
1857                    "id": msg_id, "type": "message", "role": "assistant",
1858                    "content": [], "model": model, "stop_reason": null,
1859                    "usage": {"input_tokens": 0, "output_tokens": 0}
1860                }
1861            })).unwrap()
1862        );
1863        yield Ok(bytes::Bytes::from(start_evt));
1864
1865        let mut buf = String::new();
1866        let mut content_block_open = false;
1867        let mut prev_len: usize = 0;
1868        let byte_stream = body.into_data_stream();
1869        futures_util::pin_mut!(byte_stream);
1870
1871        'outer: while let Some(chunk) = byte_stream.next().await {
1872            let chunk = match chunk { Ok(c) => c, Err(_) => break };
1873            buf.push_str(&String::from_utf8_lossy(&chunk));
1874
1875            while let Some(nl) = buf.find('\n') {
1876                let line = buf[..nl].trim_end_matches('\r').to_owned();
1877                buf = buf[nl + 1..].to_owned();
1878                if !line.starts_with("data: ") { continue; }
1879                let data = &line["data: ".len()..];
1880                if data == "[DONE]" { break 'outer; }
1881                let Ok(val) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1882
1883                let text = match val["message"]["content"]["parts"][0].as_str() {
1884                    Some(t) => t.to_owned(),
1885                    None => continue,
1886                };
1887
1888                let delta = text[prev_len..].to_owned();
1889                if !delta.is_empty() {
1890                    if !content_block_open {
1891                        content_block_open = true;
1892                        yield Ok(bytes::Bytes::from(format!(
1893                            "event: content_block_start\ndata: {}\n\n",
1894                            serde_json::to_string(&json!({
1895                                "type": "content_block_start", "index": 0,
1896                                "content_block": {"type": "text", "text": ""}
1897                            })).unwrap()
1898                        )));
1899                    }
1900                    yield Ok(bytes::Bytes::from(format!(
1901                        "event: content_block_delta\ndata: {}\n\n",
1902                        serde_json::to_string(&json!({
1903                            "type": "content_block_delta", "index": 0,
1904                            "delta": {"type": "text_delta", "text": delta}
1905                        })).unwrap()
1906                    )));
1907                    prev_len = text.len();
1908                }
1909
1910                if val["message"]["end_turn"].as_bool() == Some(true) {
1911                    break 'outer;
1912                }
1913            }
1914        }
1915
1916        if content_block_open {
1917            yield Ok(bytes::Bytes::from(format!(
1918                "event: content_block_stop\ndata: {}\n\n",
1919                serde_json::to_string(&json!({"type": "content_block_stop", "index": 0})).unwrap()
1920            )));
1921        }
1922        yield Ok(bytes::Bytes::from(format!(
1923            "event: message_delta\ndata: {}\n\n",
1924            serde_json::to_string(&json!({
1925                "type": "message_delta",
1926                "delta": {"stop_reason": "end_turn", "stop_sequence": null},
1927                "usage": {"output_tokens": 0}
1928            })).unwrap()
1929        )));
1930        yield Ok(bytes::Bytes::from(
1931            "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
1932        ));
1933    }
1934}
1935
1936/// Translate an Anthropic `/v1/messages` request body to OpenAI `/v1/chat/completions` format.
1937/// Used when routing an Anthropic-protocol request to an OpenAI/Codex account.
1938fn translate_anthropic_req_to_openai(body: serde_json::Value) -> serde_json::Value {
1939    let claude_model = body["model"].as_str().unwrap_or("claude-sonnet-4-6");
1940    let model = map_model_to_openai(claude_model);
1941    let stream = body["stream"].as_bool().unwrap_or(false);
1942    let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1943
1944    let mut messages: Vec<serde_json::Value> = Vec::new();
1945
1946    // Prepend system prompt if present.
1947    if let Some(sys) = body["system"].as_str().filter(|s| !s.is_empty()) {
1948        messages.push(json!({"role": "system", "content": sys}));
1949    }
1950
1951    if let Some(arr) = body["messages"].as_array() {
1952        for msg in arr {
1953            let role = msg["role"].as_str().unwrap_or("user");
1954
1955            if let Some(blocks) = msg["content"].as_array() {
1956                // Check for tool_result blocks (user turn carrying tool results).
1957                let has_tool_result = blocks.iter().any(|b| b["type"] == "tool_result");
1958                if has_tool_result {
1959                    for b in blocks {
1960                        if b["type"] == "tool_result" {
1961                            let content = b["content"].as_str()
1962                                .map(|s| s.to_owned())
1963                                .unwrap_or_else(|| serde_json::to_string(&b["content"]).unwrap_or_default());
1964                            messages.push(json!({
1965                                "role": "tool",
1966                                "tool_call_id": b["tool_use_id"].as_str().unwrap_or(""),
1967                                "content": content,
1968                            }));
1969                        }
1970                    }
1971                    continue;
1972                }
1973
1974                // Regular content blocks — may include text and tool_use.
1975                let mut text = String::new();
1976                let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1977                for b in blocks {
1978                    match b["type"].as_str() {
1979                        Some("text") => text.push_str(b["text"].as_str().unwrap_or("")),
1980                        Some("tool_use") => {
1981                            let args = match &b["input"] {
1982                                serde_json::Value::String(s) => s.clone(),
1983                                v => serde_json::to_string(v).unwrap_or_default(),
1984                            };
1985                            tool_calls.push(json!({
1986                                "id": b["id"].as_str().unwrap_or(""),
1987                                "type": "function",
1988                                "function": {"name": b["name"].as_str().unwrap_or(""), "arguments": args},
1989                            }));
1990                        }
1991                        _ => {}
1992                    }
1993                }
1994                let mut m = json!({"role": role, "content": text});
1995                if !tool_calls.is_empty() {
1996                    m["tool_calls"] = json!(tool_calls);
1997                }
1998                messages.push(m);
1999            } else if let Some(s) = msg["content"].as_str() {
2000                messages.push(json!({"role": role, "content": s}));
2001            }
2002        }
2003    }
2004
2005    let mut req = json!({
2006        "model": model,
2007        "messages": messages,
2008        "max_tokens": max_tokens,
2009        "stream": stream,
2010    });
2011
2012    // Request usage data in stream final chunk.
2013    if stream {
2014        req["stream_options"] = json!({"include_usage": true});
2015    }
2016    if let Some(t) = body.get("temperature") { req["temperature"] = t.clone(); }
2017    if let Some(sp) = body.get("stop_sequences") { req["stop"] = sp.clone(); }
2018
2019    // Anthropic tools → OpenAI tools.
2020    if let Some(tools) = body["tools"].as_array() {
2021        let oai: Vec<serde_json::Value> = tools.iter().map(|t| json!({
2022            "type": "function",
2023            "function": {
2024                "name": t["name"].as_str().unwrap_or(""),
2025                "description": t["description"].as_str().unwrap_or(""),
2026                "parameters": t.get("input_schema").cloned()
2027                    .unwrap_or(json!({"type": "object", "properties": {}})),
2028            }
2029        })).collect();
2030        if !oai.is_empty() { req["tools"] = json!(oai); }
2031    }
2032
2033    if let Some(tc) = body.get("tool_choice") {
2034        req["tool_choice"] = match tc["type"].as_str() {
2035            Some("any")  => json!({"type": "required"}),
2036            Some("tool") => json!({"type": "function", "function": {"name": tc["name"]}}),
2037            _            => json!("auto"),
2038        };
2039    }
2040
2041    req
2042}
2043
2044/// Translate an OpenAI `/v1/chat/completions` non-streaming response to Anthropic format.
2045fn translate_openai_resp_to_anthropic(body: serde_json::Value, model: &str) -> serde_json::Value {
2046    let id = format!("msg_{}", &uuid_v4()[..8]);
2047    let choice = &body["choices"][0];
2048    let msg = &choice["message"];
2049
2050    let mut content: Vec<serde_json::Value> = Vec::new();
2051    if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
2052        content.push(json!({"type": "text", "text": text}));
2053    }
2054    if let Some(tcs) = msg["tool_calls"].as_array() {
2055        for tc in tcs {
2056            content.push(json!({
2057                "type": "tool_use",
2058                "id": tc["id"].as_str().unwrap_or(""),
2059                "name": tc["function"]["name"].as_str().unwrap_or(""),
2060                "input": serde_json::from_str::<serde_json::Value>(
2061                    tc["function"]["arguments"].as_str().unwrap_or("{}")
2062                ).unwrap_or(json!({})),
2063            }));
2064        }
2065    }
2066
2067    let stop_reason = match choice["finish_reason"].as_str().unwrap_or("stop") {
2068        "stop"       => "end_turn",
2069        "tool_calls" => "tool_use",
2070        "length"     => "max_tokens",
2071        other        => other,
2072    };
2073
2074    json!({
2075        "id": id,
2076        "type": "message",
2077        "role": "assistant",
2078        "model": model,
2079        "content": content,
2080        "stop_reason": stop_reason,
2081        "stop_sequence": null,
2082        "usage": {
2083            "input_tokens":  body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
2084            "output_tokens": body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
2085        }
2086    })
2087}
2088
2089/// Translate the response back from OpenAI format to Anthropic format.
2090/// Handles both streaming and non-streaming responses.
2091async fn translate_response_openai_to_anthropic(resp: Response, model: &str) -> Response {
2092    use axum::body::Body;
2093    let msg_id = format!("msg_{}", &uuid_v4()[..8]);
2094    let model = model.to_owned();
2095
2096    if quota::is_streaming_response(&resp) {
2097        let (mut parts, body) = resp.into_parts();
2098        parts.headers.insert(
2099            axum::http::header::CONTENT_TYPE,
2100            axum::http::HeaderValue::from_static("text/event-stream"),
2101        );
2102        let stream = translate_openai_stream_to_anthropic(body, model, msg_id);
2103        Response::from_parts(parts, Body::from_stream(stream))
2104    } else {
2105        let (mut parts, body) = resp.into_parts();
2106        let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
2107        let openai_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
2108        let anthropic_val = translate_openai_resp_to_anthropic(openai_val, &model);
2109        let out = serde_json::to_vec(&anthropic_val).unwrap_or_default();
2110        parts.headers.insert(
2111            axum::http::header::CONTENT_TYPE,
2112            axum::http::HeaderValue::from_static("application/json"),
2113        );
2114        Response::from_parts(parts, Body::from(out))
2115    }
2116}
2117
2118/// Translate the response back from Anthropic format to OpenAI format.
2119async fn translate_response_anthropic_to_openai(resp: Response) -> Response {
2120    use axum::body::Body;
2121    let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
2122
2123    if quota::is_streaming_response(&resp) {
2124        let (parts, body) = resp.into_parts();
2125        let stream = translate_body_anthropic_to_openai(body, chat_id);
2126        Response::from_parts(parts, Body::from_stream(stream))
2127    } else {
2128        let (mut parts, body) = resp.into_parts();
2129        let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
2130        let anthropic_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
2131        let openai_val = translate_from_anthropic(anthropic_val);
2132        let out = serde_json::to_vec(&openai_val).unwrap_or_default();
2133        parts.headers.insert(
2134            axum::http::header::CONTENT_TYPE,
2135            axum::http::HeaderValue::from_static("application/json"),
2136        );
2137        Response::from_parts(parts, Body::from(out))
2138    }
2139}
2140
2141/// Stream-translate an OpenAI SSE response body into Anthropic SSE events.
2142///
2143/// Emits: `message_start` → `content_block_start` → N×`content_block_delta`
2144///       → `content_block_stop` → `message_delta` → `message_stop`
2145fn translate_openai_stream_to_anthropic(
2146    body: axum::body::Body,
2147    model: String,
2148    msg_id: String,
2149) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
2150    use futures_util::StreamExt;
2151
2152    async_stream::stream! {
2153        // Send message_start immediately (input_tokens unknown yet, use 0).
2154        let start_evt = format!(
2155            "event: message_start\ndata: {}\n\nevent: ping\ndata: {{\"type\":\"ping\"}}\n\n",
2156            serde_json::to_string(&json!({
2157                "type": "message_start",
2158                "message": {
2159                    "id": msg_id, "type": "message", "role": "assistant",
2160                    "content": [], "model": model, "stop_reason": null,
2161                    "usage": {"input_tokens": 0, "output_tokens": 0}
2162                }
2163            })).unwrap()
2164        );
2165        yield Ok(bytes::Bytes::from(start_evt));
2166
2167        let mut buf = String::new();
2168        let mut content_block_open = false;
2169        let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
2170        let mut tool_call_count: usize = 0;
2171        let mut output_tokens: u64 = 0;
2172        let mut input_tokens: u64 = 0;
2173        let byte_stream = body.into_data_stream();
2174        futures_util::pin_mut!(byte_stream);
2175
2176        while let Some(chunk) = byte_stream.next().await {
2177            let chunk = match chunk { Ok(c) => c, Err(_) => break };
2178            buf.push_str(&String::from_utf8_lossy(&chunk));
2179
2180            while let Some(nl) = buf.find('\n') {
2181                let line = buf[..nl].trim_end_matches('\r').to_owned();
2182                buf = buf[nl + 1..].to_owned();
2183                if !line.starts_with("data: ") { continue; }
2184                let data = &line["data: ".len()..];
2185                if data == "[DONE]" { continue; }
2186                let Ok(ev) = serde_json::from_str::<serde_json::Value>(data) else { continue };
2187
2188                // Collect usage from final chunk (stream_options.include_usage).
2189                if let Some(u) = ev.get("usage") {
2190                    input_tokens  = u["prompt_tokens"].as_u64().unwrap_or(input_tokens);
2191                    output_tokens = u["completion_tokens"].as_u64().unwrap_or(output_tokens);
2192                }
2193
2194                let choice = &ev["choices"][0];
2195                let delta = &choice["delta"];
2196                let finish = choice["finish_reason"].as_str();
2197
2198                // Text delta.
2199                if let Some(text) = delta["content"].as_str().filter(|s| !s.is_empty()) {
2200                    if !content_block_open {
2201                        content_block_open = true;
2202                        let cb = format!(
2203                            "event: content_block_start\ndata: {}\n\n",
2204                            serde_json::to_string(&json!({
2205                                "type": "content_block_start", "index": 0,
2206                                "content_block": {"type": "text", "text": ""}
2207                            })).unwrap()
2208                        );
2209                        yield Ok(bytes::Bytes::from(cb));
2210                    }
2211                    let d = format!(
2212                        "event: content_block_delta\ndata: {}\n\n",
2213                        serde_json::to_string(&json!({
2214                            "type": "content_block_delta", "index": 0,
2215                            "delta": {"type": "text_delta", "text": text}
2216                        })).unwrap()
2217                    );
2218                    yield Ok(bytes::Bytes::from(d));
2219                }
2220
2221                // Tool call deltas.
2222                if let Some(tcs) = delta["tool_calls"].as_array() {
2223                    for tc in tcs {
2224                        let oai_idx = tc["index"].as_u64().unwrap_or(0);
2225                        // New tool call: emit content_block_start for tool_use.
2226                        if let Some(id) = tc["id"].as_str() {
2227                            let name = tc["function"]["name"].as_str().unwrap_or("").to_owned();
2228                            let my_idx = tool_call_count;
2229                            tool_call_count += 1;
2230                            tool_blocks.insert(oai_idx, (my_idx, id.to_owned(), name.clone()));
2231                            let cb = format!(
2232                                "event: content_block_start\ndata: {}\n\n",
2233                                serde_json::to_string(&json!({
2234                                    "type": "content_block_start",
2235                                    "index": my_idx + 1, // +1: text block at 0
2236                                    "content_block": {"type": "tool_use", "id": id, "name": name, "input": {}}
2237                                })).unwrap()
2238                            );
2239                            yield Ok(bytes::Bytes::from(cb));
2240                        }
2241                        // Streaming arguments.
2242                        if let Some(args_chunk) = tc["function"]["arguments"].as_str() {
2243                            if let Some(&(my_idx, _, _)) = tool_blocks.get(&oai_idx) {
2244                                let d = format!(
2245                                    "event: content_block_delta\ndata: {}\n\n",
2246                                    serde_json::to_string(&json!({
2247                                        "type": "content_block_delta",
2248                                        "index": my_idx + 1,
2249                                        "delta": {"type": "input_json_delta", "partial_json": args_chunk}
2250                                    })).unwrap()
2251                                );
2252                                yield Ok(bytes::Bytes::from(d));
2253                            }
2254                        }
2255                    }
2256                }
2257
2258                // Finish reason → close blocks + message_delta + message_stop.
2259                if let Some(fr) = finish {
2260                    let stop_reason = match fr {
2261                        "stop"       => "end_turn",
2262                        "tool_calls" => "tool_use",
2263                        "length"     => "max_tokens",
2264                        other        => other,
2265                    };
2266
2267                    // Close open content/tool blocks.
2268                    if content_block_open {
2269                        yield Ok(bytes::Bytes::from(format!(
2270                            "event: content_block_stop\ndata: {}\n\n",
2271                            serde_json::to_string(&json!({"type":"content_block_stop","index":0})).unwrap()
2272                        )));
2273                    }
2274                    for (_, (my_idx, _, _)) in &tool_blocks {
2275                        yield Ok(bytes::Bytes::from(format!(
2276                            "event: content_block_stop\ndata: {}\n\n",
2277                            serde_json::to_string(&json!({"type":"content_block_stop","index": my_idx + 1})).unwrap()
2278                        )));
2279                    }
2280
2281                    yield Ok(bytes::Bytes::from(format!(
2282                        "event: message_delta\ndata: {}\n\n",
2283                        serde_json::to_string(&json!({
2284                            "type": "message_delta",
2285                            "delta": {"stop_reason": stop_reason, "stop_sequence": null},
2286                            "usage": {"output_tokens": output_tokens}
2287                        })).unwrap()
2288                    )));
2289                    yield Ok(bytes::Bytes::from(
2290                        "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
2291                    ));
2292                }
2293            }
2294        }
2295    }
2296}
2297
2298/// Stream-translate an Anthropic SSE response body (from axum `Body`) into OpenAI SSE format.
2299/// Equivalent to `translate_anthropic_stream` but consumes an axum `Body` instead of a
2300/// `reqwest::Response`, so it can be used after the forwarder returns.
2301fn translate_body_anthropic_to_openai(
2302    body: axum::body::Body,
2303    chat_id: String,
2304) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
2305    use futures_util::StreamExt;
2306
2307    async_stream::stream! {
2308        let id = chat_id;
2309
2310        // Initial role chunk.
2311        let init = format!(
2312            "data: {}\n\n",
2313            serde_json::to_string(&json!({
2314                "id": id, "object": "chat.completion.chunk",
2315                "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
2316            })).unwrap()
2317        );
2318        yield Ok(bytes::Bytes::from(init));
2319
2320        let mut buf = String::new();
2321        let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
2322        let mut tool_call_count: usize = 0;
2323        let byte_stream = body.into_data_stream();
2324        futures_util::pin_mut!(byte_stream);
2325
2326        while let Some(chunk) = byte_stream.next().await {
2327            let chunk = match chunk { Ok(c) => c, Err(_) => break };
2328            buf.push_str(&String::from_utf8_lossy(&chunk));
2329
2330            while let Some(nl) = buf.find('\n') {
2331                let line = buf[..nl].trim_end_matches('\r').to_owned();
2332                buf = buf[nl + 1..].to_owned();
2333                if !line.starts_with("data: ") { continue; }
2334                let data = &line["data: ".len()..];
2335                if data == "[DONE]" { continue; }
2336                let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
2337                let event_type = event["type"].as_str().unwrap_or("");
2338
2339                let maybe_chunk = match event_type {
2340                    "content_block_start" => {
2341                        let block_idx = event["index"].as_u64().unwrap_or(0);
2342                        let cb = &event["content_block"];
2343                        if cb["type"].as_str() == Some("tool_use") {
2344                            let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
2345                            let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
2346                            let oai_idx = tool_call_count;
2347                            tool_call_count += 1;
2348                            tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
2349                            Some(json!({
2350                                "id": id, "object": "chat.completion.chunk",
2351                                "choices": [{"index": 0, "delta": {
2352                                    "tool_calls": [{"index": oai_idx, "id": tool_id, "type": "function",
2353                                        "function": {"name": tool_name, "arguments": ""}}]
2354                                }, "finish_reason": null}]
2355                            }))
2356                        } else { None }
2357                    }
2358                    "content_block_delta" => {
2359                        let block_idx = event["index"].as_u64().unwrap_or(0);
2360                        let delta = &event["delta"];
2361                        match delta["type"].as_str() {
2362                            Some("text_delta") => {
2363                                let text = delta["text"].as_str().unwrap_or("");
2364                                if text.is_empty() { continue; }
2365                                Some(json!({
2366                                    "id": id, "object": "chat.completion.chunk",
2367                                    "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
2368                                }))
2369                            }
2370                            Some("input_json_delta") => {
2371                                let args = delta["partial_json"].as_str().unwrap_or("");
2372                                tool_blocks.get(&block_idx).map(|(oai_idx, _, _)| json!({
2373                                    "id": id, "object": "chat.completion.chunk",
2374                                    "choices": [{"index": 0, "delta": {
2375                                        "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
2376                                    }, "finish_reason": null}]
2377                                }))
2378                            }
2379                            _ => None,
2380                        }
2381                    }
2382                    "message_delta" => {
2383                        let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
2384                        let finish = match stop_reason {
2385                            "end_turn"   => "stop",
2386                            "tool_use"   => "tool_calls",
2387                            "max_tokens" => "length",
2388                            other        => other,
2389                        };
2390                        Some(json!({
2391                            "id": id, "object": "chat.completion.chunk",
2392                            "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
2393                        }))
2394                    }
2395                    _ => None,
2396                };
2397
2398                if let Some(c) = maybe_chunk {
2399                    let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
2400                    yield Ok(bytes::Bytes::from(out));
2401                }
2402            }
2403        }
2404        yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
2405    }
2406}