Skip to main content

shunt/
proxy.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::forwarder::Forwarder;
16use crate::oauth::OAuthCredential;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24    config: Arc<Config>,
25    forwarder: Arc<Forwarder>,
26    state: StateStore,
27    /// Live credentials — can be refreshed at runtime without restarting.
28    credentials: Arc<RwLock<HashMap<String, OAuthCredential>>>,
29    /// Per-account mutex that serialises concurrent token-refresh attempts.
30    ///
31    /// When multiple in-flight requests hit a 401 for the same account at the
32    /// same time, only one should call the upstream OAuth endpoint; the others
33    /// should wait and then re-use the fresh token instead of each making their
34    /// own refresh call (which would rotate the refresh_token out from under the
35    /// others and cause cascading auth failures).
36    refresh_locks: Arc<std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
37    /// Epoch-ms when this proxy instance started.
38    started_ms: u64,
39    /// If set, /v1/chat/completions requests are translated and forwarded here
40    /// (the Anthropic proxy base URL, e.g. "http://127.0.0.1:8082").
41    anthropic_base_url: Option<String>,
42}
43
44pub fn create_app(config: Config) -> anyhow::Result<Router> {
45    let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
46    Ok(app)
47}
48
49/// Shared live credentials map — can be written to without restarting the proxy.
50pub type LiveCredentials = Arc<RwLock<HashMap<String, OAuthCredential>>>;
51
52pub fn create_app_with_state(
53    config: Config,
54    state: StateStore,
55    anthropic_base_url: Option<String>,
56) -> anyhow::Result<(Router, LiveCredentials)> {
57    let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
58
59    // Accounts with no credential are shown in status but skipped during routing.
60    // Mark them disabled immediately so the router ignores them.
61    for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
62        state.set_auth_failed(&a.name);
63    }
64
65    let credentials: LiveCredentials = Arc::new(RwLock::new(
66        config.accounts.iter()
67            .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
68            .collect::<HashMap<_, _>>(),
69    ));
70
71    let app_state = AppState {
72        config: Arc::new(config),
73        forwarder: Arc::new(forwarder),
74        state,
75        credentials: Arc::clone(&credentials),
76        refresh_locks: Arc::new(std::sync::Mutex::new(HashMap::new())),
77        started_ms: now_ms(),
78        anthropic_base_url,
79    };
80
81    // Always register both Anthropic and OpenAI routes so a single shunt
82    // instance can serve clients of either protocol and route to accounts of
83    // either provider, translating on the fly when needed.
84    let proxy_routes = Router::new()
85        .route("/v1/messages", post(proxy_handler))
86        .route("/v1/messages/count_tokens", post(proxy_handler))
87        .route("/v1/chat/completions", post(openai_compat_handler))
88        .route("/v1/models", get(openai_models_handler))
89        .fallback(proxy_handler);
90
91    let app = Router::new()
92        .route("/health", get(health))
93        .route("/status", get(status_handler))
94        .route("/use", post(use_handler))
95        .merge(proxy_routes)
96        .with_state(app_state);
97
98    Ok((app, credentials))
99}
100
101async fn health() -> impl IntoResponse {
102    axum::Json(json!({"status": "ok"}))
103}
104
105async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
106    let account_states = s.state.account_states();
107    let quotas = s.state.quota_snapshot();
108    let rate_limits = s.state.rate_limit_snapshot();
109
110    let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
111        let st = account_states.get(&a.name);
112        let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
113            "reauth_required"
114        } else if st.map(|s| s.disabled).unwrap_or(false) {
115            "disabled"
116        } else if s.state.is_available(&a.name) {
117            "available"
118        } else {
119            "cooling"
120        };
121
122        let quota = quotas.get(&a.name);
123        let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
124        let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
125        let tokens_used = quota.map(|q| json!({
126            "input": q.input_tokens,
127            "output": q.output_tokens,
128            "total": q.total_tokens(),
129        }));
130
131        let rl = rate_limits.get(&a.name);
132        let rate_limit = rl.map(|r| json!({
133            "utilization_5h": r.utilization_5h,
134            "reset_5h": r.reset_5h,
135            "status_5h": r.status_5h,
136            "utilization_7d": r.utilization_7d,
137            "reset_7d": r.reset_7d,
138            "status_7d": r.status_7d,
139            "representative_claim": r.representative_claim,
140            "updated_ms": r.updated_ms,
141        }));
142
143        let acc_state = account_states.get(&a.name);
144        let email = a.credential.as_ref().and_then(|c| c.email.as_deref()).map(|e| e.to_owned());
145        let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
146        let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
147        let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
148        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
149        let reset_5h = rl.and_then(|r| r.reset_5h);
150        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
151        let reset_7d = rl.and_then(|r| r.reset_7d);
152        let available = s.state.is_available(&a.name);
153
154        json!({
155            "name": a.name,
156            "email": email,
157            "plan_type": a.plan_type,
158            "status": avail_status,
159            "available": available,
160            "disabled": disabled,
161            "auth_failed": auth_failed,
162            "cooldown_until_ms": cooldown_until_ms,
163            "utilization_5h": utilization_5h,
164            "reset_5h": reset_5h,
165            "utilization_7d": utilization_7d,
166            "reset_7d": reset_7d,
167            "window_expires_ms": window_expires_ms,
168            "tokens_used": tokens_used,
169            "rate_limit": rate_limit,
170        })
171    }).collect();
172
173    let recent_requests = s.state.recent_requests_snapshot();
174    let savings = s.state.savings_snapshot();
175
176    axum::Json(json!({
177        "version": env!("CARGO_PKG_VERSION"),
178        "started_ms": s.started_ms,
179        "accounts": accounts,
180        "pinned_account": s.state.get_pinned(),
181        "last_used_account": s.state.get_last_used(),
182        "recent_requests": recent_requests,
183        "savings": savings,
184    }))
185}
186
187async fn use_handler(
188    State(s): State<AppState>,
189    axum::Json(body): axum::Json<serde_json::Value>,
190) -> impl IntoResponse {
191    let account = body["account"].as_str().map(|s| s.to_owned());
192    // Validate the account name exists (unless clearing to auto)
193    if let Some(ref name) = account {
194        if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
195            return axum::Json(json!({
196                "error": format!("unknown account '{name}'")
197            }));
198        }
199        let pinned = if name == "auto" { None } else { Some(name.clone()) };
200        s.state.set_pinned(pinned);
201        axum::Json(json!({ "pinned": name }))
202    } else {
203        s.state.set_pinned(None);
204        axum::Json(json!({ "pinned": null }))
205    }
206}
207
208fn now_ms() -> u64 {
209    use std::time::{SystemTime, UNIX_EPOCH};
210    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
211}
212
213async fn proxy_handler(
214    State(s): State<AppState>,
215    req: Request,
216) -> Result<Response, ProxyError> {
217    // Remote auth: if a remote_key is configured, the client must supply it as x-api-key.
218    if let Some(ref expected) = s.config.server.remote_key {
219        let provided = req.headers()
220            .get("x-api-key")
221            .and_then(|v| v.to_str().ok())
222            .unwrap_or("");
223        if provided != expected {
224            return Err(ProxyError::Unauthorized);
225        }
226    }
227
228    let method = req.method().as_str().to_owned();
229    let path = req.uri().path().to_owned();
230    let headers = req.headers().clone();
231
232    let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
233        .await
234        .map_err(|_| ProxyError::BodyRead)?;
235
236    let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
237        .ok()
238        .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
239        .unwrap_or_default();
240    let req_start_ms = now_ms();
241
242    let fp = router::fingerprint(&body_bytes);
243    let fp_ref = fp.as_deref();
244
245    let mut tried: HashSet<String> = HashSet::new();
246    // Track accounts we've already attempted a token refresh for this request.
247    let mut refreshed: HashSet<String> = HashSet::new();
248    // Total wait budget: up to 5 hours (Claude's rate-limit reset window).
249    let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
250
251    loop {
252        let account = match router::pick_account(
253            &s.config.accounts, &s.state, fp_ref, &tried,
254            s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
255        ) {
256            Some(a) => a,
257            None => {
258                // Check whether any accounts are just temporarily cooling down
259                // (429/529 backoff) rather than permanently disabled / auth_failed.
260                // If so, wait for the soonest one to recover and retry.
261                let account_states = s.state.account_states();
262                let now = now_ms();
263                let soonest_ms = s.config.accounts.iter()
264                    .filter_map(|a| {
265                        let st = account_states.get(&a.name)?;
266                        if st.disabled { return None; } // auth_failed or permanently off
267                        if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
268                    })
269                    .min();
270
271                match soonest_ms {
272                    Some(wake_ms) if wake_ms <= wait_deadline_ms => {
273                        let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; // +50 ms buffer
274                        warn!(wait_ms, "all accounts cooling — waiting for next available account");
275                        tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
276                        tried.clear(); // accounts may have recovered; try them again
277                    }
278                    _ => return Err(ProxyError::AllAccountsUnavailable),
279                }
280                continue;
281            }
282        };
283
284        let account_name = account.name.clone();
285
286        // Use the live (possibly refreshed) token rather than the one baked into config.
287        // For OpenAI/chatgpt.com accounts, use the id_token (short-lived OIDC JWT) as
288        // the bearer — chatgpt.com's API authenticates via id_token, not access_token.
289        let token = {
290            let creds = s.credentials.read().await;
291            let cred = creds.get(&account_name)
292                .cloned()
293                .or_else(|| account.credential.clone());
294            match cred {
295                Some(c) => c.access_token,
296                None => String::new(),
297            }
298        };
299
300        // Detect request and account protocols.  When they differ, translate
301        // the request body + path before forwarding and translate the response
302        // back so the client always sees its native wire format.
303        let req_is_anthropic = path.starts_with("/v1/messages");
304        let acct_is_anthropic = matches!(account.provider, Provider::Anthropic);
305
306        let (fwd_path, fwd_body, fwd_headers) = if req_is_anthropic == acct_is_anthropic {
307            (path.clone(), body_bytes.clone(), headers.clone())
308        } else if req_is_anthropic {
309            // Anthropic client → OpenAI account: translate A→O, strip Anthropic headers.
310            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
311            let translated = translate_anthropic_req_to_openai(val);
312            let mut h = headers.clone();
313            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
314                h.remove(*name);
315            }
316            (
317                "/v1/chat/completions".to_owned(),
318                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
319                h,
320            )
321        } else {
322            // OpenAI client → Anthropic account: translate O→A.
323            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
324            let translated = translate_to_anthropic(val);
325            (
326                "/v1/messages".to_owned(),
327                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
328                headers.clone(),
329            )
330        };
331
332        let upstream = account.provider.default_upstream_url();
333        let response = s.forwarder
334            .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
335            .await
336            .map_err(|e| {
337                error!("Forward error: {:#}", e);
338                ProxyError::Upstream
339            })?;
340
341        match response.status().as_u16() {
342            200..=299 => {
343                s.state.set_last_used(&account_name);
344                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
345                    s.state.update_rate_limits(&account_name, info);
346                }
347                // Translate response back to the client's expected protocol.
348                let response = if req_is_anthropic == acct_is_anthropic {
349                    response
350                } else if req_is_anthropic {
351                    // Got OpenAI response; client expects Anthropic.
352                    translate_response_openai_to_anthropic(response, &model).await
353                } else {
354                    // Got Anthropic response; client expects OpenAI.
355                    translate_response_anthropic_to_openai(response).await
356                };
357                return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
358            }
359            429 => {
360                let info = account.provider.parse_rate_limits(response.headers());
361                // Sleep until the actual reset time if the headers tell us when that is;
362                // otherwise fall back to 60s so we don't hammer the API.
363                let cooldown_ms = info.as_ref()
364                    .and_then(|i| i.reset_5h.or(i.reset_7d))
365                    .map(|reset_secs| {
366                        let reset_ms = reset_secs.saturating_mul(1_000);
367                        reset_ms.saturating_sub(now_ms()).saturating_add(500) // +500ms buffer
368                    })
369                    .unwrap_or(60_000);
370                warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
371                if let Some(info) = info {
372                    s.state.update_rate_limits(&account_name, info);
373                }
374                s.state.set_cooldown(&account_name, cooldown_ms);
375                if cooldown_ms >= 5 * 60_000 {
376                    let mins = cooldown_ms / 60_000;
377                    notify(
378                        "shunt: Rate Limited",
379                        &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
380                        "Ping",
381                    );
382                }
383                tried.insert(account_name);
384            }
385            529 => {
386                warn!(account = %account_name, "529 overloaded — cooling 30s");
387                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
388                    s.state.update_rate_limits(&account_name, info);
389                }
390                s.state.set_cooldown(&account_name, 30_000);
391                tried.insert(account_name);
392            }
393            401 => {
394                if !refreshed.contains(&account_name) {
395                    // Access token invalidated (e.g. user logged out) — try refresh.
396                    //
397                    // Acquire the per-account refresh lock so concurrent requests
398                    // for the same account serialise here. The first waiter to get
399                    // the lock does the actual OAuth refresh; subsequent waiters
400                    // re-check credentials and skip the refresh if the token was
401                    // already rotated while they were queued.
402                    let account_lock = {
403                        let mut locks = s.refresh_locks.lock().unwrap();
404                        locks.entry(account_name.clone())
405                            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
406                            .clone()
407                    };
408                    let _guard = account_lock.lock().await;
409
410                    // Re-read credentials after acquiring the lock — another task
411                    // may have already refreshed while we were waiting.
412                    let cred_before = {
413                        let creds = s.credentials.read().await;
414                        creds.get(&account_name).cloned()
415                            .or_else(|| account.credential.clone())
416                    };
417                    let Some(cred) = cred_before else {
418                        tried.insert(account_name);
419                        continue;
420                    };
421
422                    // Check if the token already changed while we were waiting.
423                    let token_before = cred.access_token.clone();
424                    let already_refreshed = {
425                        let creds = s.credentials.read().await;
426                        creds.get(&account_name)
427                            .map(|c| c.access_token != token_before)
428                            .unwrap_or(false)
429                    };
430
431                    if already_refreshed {
432                        // Another concurrent request already refreshed — just retry.
433                        warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
434                        refreshed.insert(account_name);
435                    } else {
436                        match tokio::time::timeout(
437                            std::time::Duration::from_secs(10),
438                            account.provider.refresh_token(&cred),
439                        ).await {
440                            Ok(Ok(fresh)) => {
441                                warn!(account = %account_name, "401 — token refreshed, retrying");
442                                {
443                                    let mut creds = s.credentials.write().await;
444                                    creds.insert(account_name.clone(), fresh.clone());
445                                }
446                                // Persist to disk so the refreshed token survives a restart.
447                                let name = account_name.clone();
448                                let fresh = fresh.clone();
449                                tokio::task::spawn_blocking(move || {
450                                    let mut store = CredentialsStore::load();
451                                    store.accounts.insert(name, fresh.clone());
452                                    store.save().ok();
453                                    if fresh.id_token.is_some() {
454                                        crate::oauth::write_codex_auth_file(&fresh);
455                                    }
456                                });
457                                // Mark as refreshed but don't add to tried — retry this account.
458                                refreshed.insert(account_name);
459                            }
460                            _ => {
461                                // Refresh failed/timed out — cool down, don't permanently disable.
462                                error!(account = %account_name, "401 — token refresh failed, cooling 5min");
463                                s.state.set_cooldown(&account_name, 5 * 60_000);
464                                tried.insert(account_name);
465                            }
466                        }
467                    }
468                } else {
469                    // Already refreshed once and still 401 — cool down this account.
470                    error!(account = %account_name, "401 after refresh — cooling 5min");
471                    s.state.set_cooldown(&account_name, 5 * 60_000);
472                    tried.insert(account_name);
473                }
474            }
475            403 => {
476                // Forbidden — subscription lapsed or org restriction; refreshing won't help.
477                error!(account = %account_name, "403 forbidden — cooling 30min");
478                s.state.set_cooldown(&account_name, 30 * 60_000);
479                notify(
480                    "shunt: Account Forbidden",
481                    &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
482                    "Basso",
483                );
484                tried.insert(account_name);
485            }
486            _ => {
487                // 400, 404, 500, etc. — return as-is, no retry
488                return Ok(response);
489            }
490        }
491    }
492}
493
494// ---------------------------------------------------------------------------
495// Usage extraction
496// ---------------------------------------------------------------------------
497
498/// Intercept a successful response to record token usage, then pass it through.
499///
500/// - Streaming: wraps the body stream with an SSE scanner (zero latency).
501/// - Non-streaming: buffers the body, parses usage, rebuilds the response.
502async fn tap_usage(
503    resp: Response,
504    state: &StateStore,
505    account: &str,
506    model: &str,
507    req_start_ms: u64,
508) -> Response {
509    use axum::body::Body;
510    use crate::state::RequestLog;
511
512    if quota::is_streaming_response(&resp) {
513        let state = state.clone();
514        let account = account.to_owned();
515        let model = model.to_owned();
516        let on_complete = Arc::new(move |input: u64, output: u64| {
517            state.record_usage(&account, input, output);
518            state.record_global(&model, input, output);
519            state.record_request(RequestLog {
520                ts_ms: req_start_ms,
521                account: account.clone(),
522                model: model.clone(),
523                status: 200,
524                input_tokens: input,
525                output_tokens: output,
526                duration_ms: now_ms().saturating_sub(req_start_ms),
527            });
528        });
529        let (parts, body) = resp.into_parts();
530        let wrapped = quota::wrap_streaming_body(body, on_complete);
531        return Response::from_parts(parts, wrapped);
532    }
533
534    // Non-streaming: buffer, extract, rebuild
535    let (parts, body) = resp.into_parts();
536    let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
537        Ok(b) => b,
538        Err(_) => return Response::from_parts(parts, Body::empty()),
539    };
540    let (input, output) = quota::extract_usage_from_json(&bytes);
541    state.record_usage(account, input, output);
542    state.record_global(model, input, output);
543    state.record_request(RequestLog {
544        ts_ms: req_start_ms,
545        account: account.to_owned(),
546        model: model.to_owned(),
547        status: 200,
548        input_tokens: input,
549        output_tokens: output,
550        duration_ms: now_ms().saturating_sub(req_start_ms),
551    });
552    Response::from_parts(parts, Body::from(bytes))
553}
554
555
556// ---------------------------------------------------------------------------
557// Rate limit prefetch
558// ---------------------------------------------------------------------------
559
560/// For any account with no rate-limit data yet, make a cheap request directly
561/// to the upstream API so we populate metrics without waiting for a real user
562/// request. Runs as a background task after startup.
563pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
564    let client = reqwest::Client::builder()
565        .timeout(std::time::Duration::from_secs(20))
566        .build()
567        .unwrap_or_default();
568
569    for account in &config.accounts {
570        // Skip if we already have data for this account.
571        let rl = state.rate_limit_snapshot();
572        if let Some(r) = rl.get(&account.name) {
573            if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
574                continue;
575            }
576        }
577
578        // Skip accounts with no credentials or no prefetch support.
579        let creds = match account.credential.clone() {
580            Some(c) => c,
581            None => continue,
582        };
583
584        let Some((path, body)) = account.provider.prefetch_request() else {
585            // No POST prefetch for this provider — do a lightweight GET auth check instead.
586            if let Some(probe_path) = account.provider.auth_probe_get_path() {
587                auth_probe_get(&client, probe_path, account, &state).await;
588            }
589            continue;
590        };
591        let url = format!("{}{}", config.server.upstream_url, path);
592
593        let resp = prefetch_send(&client, &url, &account.provider, &creds.access_token, &body).await;
594
595        let r = match resp {
596            Ok(r) => r,
597            Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
598        };
599
600        if r.status() == reqwest::StatusCode::UNAUTHORIZED {
601            tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
602            let fresh = match account.provider.refresh_token(&creds).await {
603                Ok(f) => f,
604                Err(e) => {
605                    tracing::warn!(account = %account.name, "token refresh failed: {e}");
606                    state.set_auth_failed(&account.name);
607                    continue;
608                }
609            };
610            let mut store = crate::config::CredentialsStore::load();
611            store.accounts.insert(account.name.clone(), fresh.clone());
612            store.save().ok();
613            if fresh.id_token.is_some() {
614                crate::oauth::write_codex_auth_file(&fresh);
615            }
616            // Update live credentials so the proxy uses the fresh token immediately.
617            live_creds.write().await.insert(account.name.clone(), fresh.clone());
618
619            match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
620                Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
621                    tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
622                    state.set_auth_failed(&account.name);
623                }
624                Ok(r2) => {
625                    if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
626                        state.update_rate_limits(&account.name, info);
627                    }
628                }
629                Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
630            }
631        } else {
632            tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
633            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
634                state.update_rate_limits(&account.name, info);
635            }
636        }
637    }
638}
639
640/// Build and send a prefetch request for the given provider + token.
641async fn prefetch_send(
642    client: &reqwest::Client,
643    url: &str,
644    provider: &crate::provider::Provider,
645    token: &str,
646    body: &serde_json::Value,
647) -> anyhow::Result<reqwest::Response> {
648    let mut headers = reqwest::header::HeaderMap::new();
649    provider.inject_auth_headers(&mut headers, token)?;
650    for (name, value) in provider.prefetch_extra_headers() {
651        headers.insert(
652            reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
653            reqwest::header::HeaderValue::from_static(value),
654        );
655    }
656    Ok(client.post(url).headers(headers).json(body).send().await?)
657}
658
659/// GET a cheap endpoint to verify credentials are still valid for providers that
660/// don't expose rate-limit headers (e.g. OpenAI). On 401, attempts a token refresh;
661/// marks the account as `reauth_required` if the refresh also fails.
662async fn auth_probe_get(
663    client: &reqwest::Client,
664    path: &str,
665    account: &crate::config::AccountConfig,
666    state: &StateStore,
667) {
668    let creds = match account.credential.clone() {
669        Some(c) => c,
670        None => return,
671    };
672    let upstream = match account.provider {
673        crate::provider::Provider::OpenAI => "https://chatgpt.com",
674        crate::provider::Provider::Anthropic => "https://api.anthropic.com",
675    };
676    let url = format!("{}{}", upstream, path);
677
678    let do_get = |token: &str| -> reqwest::RequestBuilder {
679        let mut headers = reqwest::header::HeaderMap::new();
680        let _ = account.provider.inject_auth_headers(&mut headers, token);
681        client.get(&url).headers(headers)
682    };
683
684    let probe_token = &creds.access_token;
685    let resp = match do_get(probe_token).send().await {
686        Ok(r) => r,
687        Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
688    };
689
690    if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
691        tracing::info!(account = %account.name, "auth probe: access token rejected, refreshing");
692        let fresh = match account.provider.refresh_token(&creds).await {
693            Ok(f) => f,
694            Err(e) => {
695                tracing::warn!(account = %account.name, "token refresh failed: {e}");
696                state.set_auth_failed(&account.name);
697                return;
698            }
699        };
700        let mut store = crate::config::CredentialsStore::load();
701        store.accounts.insert(account.name.clone(), fresh.clone());
702        store.save().ok();
703        if fresh.id_token.is_some() {
704            crate::oauth::write_codex_auth_file(&fresh);
705        }
706
707        let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
708        match do_get(fresh_token).send().await {
709            Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
710                tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
711                state.set_auth_failed(&account.name);
712            }
713            Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
714            Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
715        }
716    } else {
717        tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
718        // Access token is valid. Do NOT refresh here — rotating the refresh_token races
719        // with codex CLI, which also tries to refresh at startup using the same token.
720        // Proactive refreshing is handled solely by openai_token_refresh_loop.
721    }
722}
723
724// ---------------------------------------------------------------------------
725// Proactive OpenAI token refresh loop
726// ---------------------------------------------------------------------------
727
728/// Returns true if the access_token inside `cred` has fewer than `threshold_mins`
729/// minutes remaining. Falls back to the stored `expires_at` if the JWT cannot be decoded.
730fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
731    let now_ms = std::time::SystemTime::now()
732        .duration_since(std::time::UNIX_EPOCH)
733        .unwrap_or_default()
734        .as_millis() as u64;
735    let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
736        .unwrap_or(cred.expires_at);
737    exp_ms < now_ms + threshold_mins * 60 * 1_000
738}
739
740/// Sync live_creds from auth.json if auth.json has a newer token.
741///
742/// Codex CLI refreshes its own token and writes auth.json. Before we refresh,
743/// we pull that in so we don't use a stale refresh_token that codex already rotated.
744async fn sync_live_creds_from_auth_json(
745    account_name: &str,
746    live_creds: &LiveCredentials,
747) {
748    let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
749    let current_exp = live_creds.read().await
750        .get(account_name)
751        .map(|c| c.expires_at)
752        .unwrap_or(0);
753    if from_file.expires_at > current_exp {
754        tracing::info!(account = %account_name, "synced fresher token from auth.json");
755        live_creds.write().await.insert(account_name.to_owned(), from_file);
756    }
757}
758
759/// Perform a single proactive refresh for one account and persist the result.
760async fn do_proactive_refresh(
761    account: &crate::config::AccountConfig,
762    creds: &crate::oauth::OAuthCredential,
763    live_creds: &LiveCredentials,
764    state: &StateStore,
765) {
766    tracing::info!(account = %account.name, "proactive OpenAI token refresh");
767    match account.provider.refresh_token(creds).await {
768        Ok(fresh) => {
769            tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
770            {
771                let mut map = live_creds.write().await;
772                map.insert(account.name.clone(), fresh.clone());
773            }
774            let mut store = crate::config::CredentialsStore::load();
775            store.accounts.insert(account.name.clone(), fresh.clone());
776            store.save().ok();
777            if fresh.id_token.is_some() {
778                crate::oauth::write_codex_auth_file(&fresh);
779            }
780            state.clear_auth_failed(&account.name);
781        }
782        Err(e) => {
783            tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
784            state.set_auth_failed(&account.name);
785        }
786    }
787}
788
789
790/// Keeps shunt's live credentials in sync with Codex CLI's auth.json.
791///
792/// Strategy: never proactively rotate the refresh_token — that races with
793/// Codex CLI's own refresh logic and causes "invalid_grant" errors. Instead,
794/// just periodically sync from auth.json so shunt picks up whatever Codex wrote.
795/// On-demand refresh (401 handler) covers the case where Codex isn't running
796/// and the token has actually expired.
797pub async fn openai_token_refresh_loop(
798    config: Arc<Config>,
799    state: StateStore,
800    live_creds: LiveCredentials,
801) {
802    // Startup: sync from auth.json first (Codex may have refreshed since shunt last ran).
803    for account in config.accounts.iter()
804        .filter(|a| a.provider == crate::provider::Provider::OpenAI)
805    {
806        if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
807            continue;
808        }
809        sync_live_creds_from_auth_json(&account.name, &live_creds).await;
810
811        let creds = {
812            let map = live_creds.read().await;
813            map.get(&account.name).cloned().or_else(|| account.credential.clone())
814        };
815        if let Some(creds) = creds {
816            if access_token_expires_soon(&creds, 30) {
817                // access_token is nearly expired — refresh now so shunt can serve requests immediately.
818                do_proactive_refresh(account, &creds, &live_creds, &state).await;
819            } else {
820                tracing::info!(account = %account.name, "access_token fresh at startup");
821            }
822        }
823    }
824
825    // Periodic sync every 5 minutes — picks up any token Codex CLI has written.
826    // No proactive refresh: Codex owns the refresh lifecycle; shunt uses what Codex produces.
827    loop {
828        tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
829        for account in config.accounts.iter()
830            .filter(|a| a.provider == crate::provider::Provider::OpenAI)
831        {
832            sync_live_creds_from_auth_json(&account.name, &live_creds).await;
833        }
834    }
835}
836
837// ---------------------------------------------------------------------------
838// Error type
839// ---------------------------------------------------------------------------
840
841enum ProxyError {
842    BodyRead,
843    Upstream,
844    AllAccountsUnavailable,
845    Unauthorized,
846}
847
848impl IntoResponse for ProxyError {
849    fn into_response(self) -> Response {
850        let (status, msg) = match self {
851            ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
852            ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
853            ProxyError::AllAccountsUnavailable => {
854                (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
855            }
856            ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
857        };
858
859        (status, axum::Json(json!({
860            "type": "error",
861            "error": {"type": "api_error", "message": msg}
862        }))).into_response()
863    }
864}
865
866// ---------------------------------------------------------------------------
867// Recovery watcher — periodically retries token refresh for auth_failed accounts
868// ---------------------------------------------------------------------------
869
870/// Runs as a background task. Every 2 minutes, tries to refresh tokens for any
871/// auth_failed account. If refresh succeeds the account is brought back online
872/// without a process restart. If all accounts remain unrecoverable, fires a
873/// macOS notification (at most once per hour).
874pub async fn recovery_watcher(
875    config: Arc<Config>,
876    state: StateStore,
877    credentials: LiveCredentials,
878) {
879    use std::time::{Duration, Instant};
880    const CHECK_INTERVAL: Duration = Duration::from_secs(120);
881    const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
882
883    let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
884    let mut last_notified: Option<Instant> = None;
885
886    loop {
887        tokio::time::sleep(CHECK_INTERVAL).await;
888
889        let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
890        let failed = state.auth_failed_accounts(&name_refs);
891        if failed.is_empty() {
892            last_notified = None;
893            continue;
894        }
895
896        tracing::warn!(
897            accounts = ?failed,
898            "recovery: {} account(s) auth_failed, attempting token refresh",
899            failed.len()
900        );
901
902        let mut any_recovered = false;
903
904        for name in &failed {
905            let cred = {
906                let map = credentials.read().await;
907                map.get(*name).cloned()
908            };
909            let Some(cred) = cred else { continue };
910            if cred.refresh_token.is_empty() { continue; }
911
912            let provider = config.accounts.iter()
913                .find(|a| a.name == *name)
914                .map(|a| a.provider.clone())
915                .unwrap_or_default();
916
917            let result = tokio::time::timeout(
918                Duration::from_secs(20),
919                provider.refresh_token(&cred),
920            ).await;
921
922            match result {
923                Ok(Ok(fresh)) => {
924                    tracing::info!(account = %name, "recovery: token refreshed — account back online");
925                    {
926                        let mut map = credentials.write().await;
927                        map.insert(name.to_string(), fresh.clone());
928                    }
929                    let name_owned = name.to_string();
930                    let fresh_owned = fresh.clone();
931                    tokio::task::spawn_blocking(move || {
932                        let mut store = crate::config::CredentialsStore::load();
933                        store.accounts.insert(name_owned, fresh_owned.clone());
934                        store.save().ok();
935                        if fresh_owned.id_token.is_some() {
936                            crate::oauth::write_codex_auth_file(&fresh_owned);
937                        }
938                    });
939                    state.clear_auth_failed(name);
940                    any_recovered = true;
941                }
942                Ok(Err(e)) => {
943                    tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
944                    notify(
945                        "shunt: Reauth Required",
946                        &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
947                        "Basso",
948                    );
949                }
950                Err(_) => {
951                    tracing::error!(account = %name, "recovery: token refresh timed out");
952                    notify(
953                        "shunt: Reauth Required",
954                        &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
955                        "Basso",
956                    );
957                }
958            }
959        }
960
961        if any_recovered {
962            tracing::info!("recovery: at least one account is back online");
963            continue;
964        }
965
966        // All accounts still auth_failed after refresh attempts — notify.
967        let still_failed = state.auth_failed_accounts(&name_refs);
968        if still_failed.len() == account_names.len() {
969            let should_notify = last_notified
970                .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
971                .unwrap_or(true);
972            if should_notify {
973                error!(
974                    "ALL accounts are offline (auth failed). \
975                     Run `shunt add-account` to re-authorize."
976                );
977                notify(
978                    "shunt: All Accounts Offline",
979                    "All accounts need re-authorization. Run `shunt add-account`.",
980                    "Basso",
981                );
982                last_notified = Some(Instant::now());
983            }
984        }
985    }
986}
987
988/// Sends a single lightweight prefetch request for `account` immediately after its
989/// cooldown expires, so the router has fresh rate-limit headers before the next
990/// real request arrives.
991async fn post_cooldown_prefetch(
992    client: &reqwest::Client,
993    account: &crate::config::AccountConfig,
994    token: &str,
995    state: &StateStore,
996    upstream_url: &str,
997) {
998    let Some((path, body)) = account.provider.prefetch_request() else {
999        if let Some(probe_path) = account.provider.auth_probe_get_path() {
1000            auth_probe_get(client, probe_path, account, state).await;
1001        }
1002        return;
1003    };
1004    let url = format!("{upstream_url}{path}");
1005    match prefetch_send(client, &url, &account.provider, token, &body).await {
1006        Ok(r) => {
1007            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1008                state.update_rate_limits(&account.name, info);
1009                tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1010            }
1011        }
1012        Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1013    }
1014}
1015
1016/// Watches for account cooldowns expiring and triggers a post-cooldown prefetch
1017/// so each account re-enters rotation with fresh rate-limit metrics.
1018///
1019/// Analogous to `recovery_watcher` (which handles `auth_failed` accounts), but
1020/// for timed cooldowns (429 / 529 / 401 / 403 backoffs). Sleeps precisely until
1021/// the next cooldown deadline rather than polling at a fixed interval.
1022///
1023/// Also handles stale rate-limit data: if an account's rate-limit snapshot is
1024/// older than STALE_RL_MS and the account is available, a lightweight prefetch
1025/// is triggered so the router always has fresh utilization metrics.
1026pub async fn cooldown_watcher(
1027    config: Arc<Config>,
1028    state: StateStore,
1029    credentials: LiveCredentials,
1030) {
1031    /// Re-fetch rate-limit headers if data is older than 1 hour.
1032    const STALE_RL_MS: u64 = 60 * 60_000;
1033
1034    let client = reqwest::Client::builder()
1035        .timeout(std::time::Duration::from_secs(20))
1036        .build()
1037        .unwrap_or_default();
1038
1039    // In-memory: the cooldown_until_ms value we already ran a post-resume for.
1040    // Prevents re-triggering on every poll after expiry.
1041    let mut last_resumed: HashMap<String, u64> = HashMap::new();
1042    // Accounts whose cooldown was long enough (≥5 min) to deserve a "back online" notification.
1043    let mut notify_on_resume: HashSet<String> = HashSet::new();
1044    // Epoch-ms of the last successful stale-prefetch per account.
1045    let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1046
1047    loop {
1048        let states = state.account_states();
1049        let rl_snapshot = state.rate_limit_snapshot();
1050        let now = now_ms();
1051        let mut next_wake_ms: Option<u64> = None;
1052
1053        for account in &config.accounts {
1054            let Some(st) = states.get(&account.name) else { continue };
1055            if st.disabled { continue; } // auth_failed or permanently disabled
1056            let cdl = st.cooldown_until_ms;
1057
1058            if cdl > 0 && cdl <= now {
1059                // Cooldown expired — skip if we already handled this exact deadline
1060                let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1061                if !handled {
1062                    tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1063                    let token = {
1064                        let creds = credentials.read().await;
1065                        creds.get(&account.name).map(|c| c.access_token.clone())
1066                    };
1067                    if let Some(token) = token {
1068                        post_cooldown_prefetch(
1069                            &client, account, &token, &state,
1070                            &config.server.upstream_url,
1071                        ).await;
1072                    }
1073                    if notify_on_resume.remove(&account.name) {
1074                        notify(
1075                            "shunt: Account Resumed",
1076                            &format!("Account '{}' is back online.", account.name),
1077                            "Glass",
1078                        );
1079                    }
1080                    last_resumed.insert(account.name.clone(), cdl);
1081                    last_stale_prefetch.insert(account.name.clone(), now);
1082                }
1083            } else if cdl > now {
1084                // Still cooling — schedule wake at expiry; flag for notification if long
1085                let remaining = cdl - now;
1086                if remaining >= 5 * 60_000 {
1087                    notify_on_resume.insert(account.name.clone());
1088                }
1089                next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1090            } else {
1091                // Not in cooldown — check for stale rate-limit data
1092                let rl_age = rl_snapshot
1093                    .get(&account.name)
1094                    .map(|r| now.saturating_sub(r.updated_ms))
1095                    .unwrap_or(u64::MAX); // no data → treat as infinitely stale
1096                let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1097                let fetched_ago = now.saturating_sub(last_fetched);
1098
1099                if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1100                    tracing::debug!(
1101                        account = %account.name,
1102                        age_min = rl_age / 60_000,
1103                        "rate-limit data stale — refreshing"
1104                    );
1105                    let token = {
1106                        let creds = credentials.read().await;
1107                        creds.get(&account.name).map(|c| c.access_token.clone())
1108                    };
1109                    if let Some(token) = token {
1110                        post_cooldown_prefetch(
1111                            &client, account, &token, &state,
1112                            &config.server.upstream_url,
1113                        ).await;
1114                    }
1115                    last_stale_prefetch.insert(account.name.clone(), now);
1116                }
1117            }
1118        }
1119
1120        // Sleep exactly until the next cooldown expires; fall back to 30s poll
1121        let sleep_ms = next_wake_ms
1122            .map(|wake| wake.saturating_sub(now_ms()).max(50))
1123            .unwrap_or(30_000);
1124        tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1125    }
1126}
1127
1128use crate::notify::notify;
1129
1130// ---------------------------------------------------------------------------
1131// OpenAI-compatible API (translates to Anthropic Claude)
1132// ---------------------------------------------------------------------------
1133//
1134// When the OpenAI proxy receives a request at /v1/chat/completions, if an
1135// anthropic_base_url is configured, it translates the request to Anthropic
1136// Messages format and forwards it to the Anthropic proxy (which handles
1137// account selection, token management, and rate limiting).
1138// The response is translated back to OpenAI Chat Completions format.
1139
1140/// Map OpenAI model names → Claude model names.
1141/// Claude model names are passed through unchanged; only OpenAI aliases are remapped.
1142fn map_model(openai_model: &str) -> String {
1143    if openai_model.starts_with("claude-") {
1144        return openai_model.to_owned();
1145    }
1146    match openai_model {
1147        "gpt-4o" | "gpt-4.5" | "o1" | "o1-pro" | "o3" | "o3-pro" | "gpt-5" | "gpt-5.5" => {
1148            "claude-opus-4-6"
1149        }
1150        "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" | "o1-mini" | "o3-mini" => {
1151            "claude-haiku-4-5-20251001"
1152        }
1153        _ => "claude-sonnet-4-6",
1154    }.to_owned()
1155}
1156
1157/// Translate an OpenAI Chat Completions request body to an Anthropic Messages body.
1158fn translate_to_anthropic(body: serde_json::Value) -> serde_json::Value {
1159    let model = body["model"].as_str().unwrap_or("gpt-4o");
1160    let claude_model = map_model(model);
1161
1162    // Extract system message from messages array.
1163    let mut system: Option<String> = None;
1164    let mut messages = Vec::new();
1165    if let Some(arr) = body["messages"].as_array() {
1166        for msg in arr {
1167            let role = msg["role"].as_str().unwrap_or("");
1168            if role == "system" {
1169                // system can be a string or array of content parts
1170                let content = msg["content"].as_str()
1171                    .map(|s| s.to_owned())
1172                    .unwrap_or_else(|| serde_json::to_string(&msg["content"]).unwrap_or_default());
1173                system = Some(content);
1174            } else if role == "tool" {
1175                // OpenAI tool result → Anthropic tool_result content block
1176                let tool_use_id = msg["tool_call_id"].as_str().unwrap_or("").to_owned();
1177                let content = msg["content"].as_str().unwrap_or("").to_owned();
1178                messages.push(json!({
1179                    "role": "user",
1180                    "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}]
1181                }));
1182            } else {
1183                // Check for tool_calls in assistant messages
1184                if let Some(tool_calls) = msg["tool_calls"].as_array() {
1185                    let mut content_blocks: Vec<serde_json::Value> = Vec::new();
1186                    if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
1187                        content_blocks.push(json!({"type": "text", "text": text}));
1188                    }
1189                    for tc in tool_calls {
1190                        content_blocks.push(json!({
1191                            "type": "tool_use",
1192                            "id": tc["id"].as_str().unwrap_or(""),
1193                            "name": tc["function"]["name"].as_str().unwrap_or(""),
1194                            "input": serde_json::from_str::<serde_json::Value>(
1195                                tc["function"]["arguments"].as_str().unwrap_or("{}")
1196                            ).unwrap_or(json!({})),
1197                        }));
1198                    }
1199                    messages.push(json!({"role": "assistant", "content": content_blocks}));
1200                } else {
1201                    let content = msg["content"].as_str().unwrap_or("").to_owned();
1202                    messages.push(json!({ "role": role, "content": content }));
1203                }
1204            }
1205        }
1206    }
1207
1208    let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1209    let stream = body["stream"].as_bool().unwrap_or(false);
1210
1211    let mut req = json!({
1212        "model": claude_model,
1213        "messages": messages,
1214        "max_tokens": max_tokens,
1215        "stream": stream,
1216    });
1217
1218    if let Some(sys) = system {
1219        req["system"] = json!(sys);
1220    }
1221    if let Some(temp) = body.get("temperature") {
1222        req["temperature"] = temp.clone();
1223    }
1224    if let Some(sp) = body.get("stop") {
1225        req["stop_sequences"] = sp.clone();
1226    }
1227
1228    // Translate OpenAI tools → Anthropic tools format
1229    if let Some(tools) = body["tools"].as_array() {
1230        let claude_tools: Vec<serde_json::Value> = tools.iter().filter_map(|t| {
1231            let func = &t["function"];
1232            Some(json!({
1233                "name": func["name"].as_str()?,
1234                "description": func["description"].as_str().unwrap_or(""),
1235                "input_schema": func.get("parameters").cloned().unwrap_or(json!({"type": "object", "properties": {}})),
1236            }))
1237        }).collect();
1238        if !claude_tools.is_empty() {
1239            req["tools"] = json!(claude_tools);
1240        }
1241    }
1242
1243    req
1244}
1245
1246/// Translate a complete (non-streaming) Anthropic Messages response to OpenAI format.
1247fn translate_from_anthropic(body: serde_json::Value) -> serde_json::Value {
1248    let id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1249    let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6").to_owned();
1250
1251    // Extract text content and tool_use blocks.
1252    let mut text_content = String::new();
1253    let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1254    if let Some(blocks) = body["content"].as_array() {
1255        for (idx, block) in blocks.iter().enumerate() {
1256            match block["type"].as_str() {
1257                Some("text") => {
1258                    text_content.push_str(block["text"].as_str().unwrap_or(""));
1259                }
1260                Some("tool_use") => {
1261                    let args = match &block["input"] {
1262                        serde_json::Value::String(s) => s.clone(),
1263                        v => serde_json::to_string(v).unwrap_or_default(),
1264                    };
1265                    tool_calls.push(json!({
1266                        "id": block["id"].as_str().unwrap_or(""),
1267                        "type": "function",
1268                        "index": idx,
1269                        "function": {
1270                            "name": block["name"].as_str().unwrap_or(""),
1271                            "arguments": args,
1272                        }
1273                    }));
1274                }
1275                _ => {}
1276            }
1277        }
1278    }
1279
1280    let stop_reason = body["stop_reason"].as_str().unwrap_or("end_turn");
1281    let finish_reason = match stop_reason {
1282        "end_turn"   => "stop",
1283        "tool_use"   => "tool_calls",
1284        "max_tokens" => "length",
1285        other        => other,
1286    };
1287
1288    let input_tokens = body["usage"]["input_tokens"].as_u64().unwrap_or(0);
1289    let output_tokens = body["usage"]["output_tokens"].as_u64().unwrap_or(0);
1290
1291    let mut message = json!({"role": "assistant", "content": text_content});
1292    if !tool_calls.is_empty() {
1293        message["tool_calls"] = json!(tool_calls);
1294    }
1295
1296    json!({
1297        "id": id,
1298        "object": "chat.completion",
1299        "model": model,
1300        "choices": [{
1301            "index": 0,
1302            "message": message,
1303            "finish_reason": finish_reason,
1304        }],
1305        "usage": {
1306            "prompt_tokens": input_tokens,
1307            "completion_tokens": output_tokens,
1308            "total_tokens": input_tokens + output_tokens,
1309        }
1310    })
1311}
1312
1313fn uuid_v4() -> String {
1314    use crate::oauth::rand_bytes;
1315    let b: [u8; 16] = rand_bytes();
1316    format!("{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
1317        u32::from_be_bytes(b[0..4].try_into().unwrap()),
1318        u16::from_be_bytes(b[4..6].try_into().unwrap()),
1319        u16::from_be_bytes(b[6..8].try_into().unwrap()),
1320        u16::from_be_bytes(b[8..10].try_into().unwrap()),
1321        {
1322            let mut v = 0u64;
1323            for &x in &b[10..16] { v = (v << 8) | x as u64; }
1324            v
1325        }
1326    )
1327}
1328
1329/// GET /v1/models — return Claude models in OpenAI format.
1330async fn openai_models_handler() -> impl IntoResponse {
1331    axum::Json(json!({
1332        "object": "list",
1333        "data": [
1334            { "id": "claude-opus-4-6",           "object": "model", "owned_by": "anthropic" },
1335            { "id": "claude-sonnet-4-6",          "object": "model", "owned_by": "anthropic" },
1336            { "id": "claude-haiku-4-5-20251001",  "object": "model", "owned_by": "anthropic" },
1337        ]
1338    }))
1339}
1340
1341/// POST /v1/chat/completions — translate OpenAI request to Anthropic, proxy through Claude pool.
1342async fn openai_compat_handler(
1343    State(s): State<AppState>,
1344    req: Request,
1345) -> Result<Response, ProxyError> {
1346    let Some(ref anthropic_url) = s.anthropic_base_url else {
1347        // No Anthropic proxy configured — fall back to normal forwarding
1348        return proxy_handler(State(s), req).await;
1349    };
1350
1351    let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1352        .await
1353        .map_err(|_| ProxyError::BodyRead)?;
1354
1355    let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1356        .unwrap_or(json!({}));
1357
1358    let stream = openai_body["stream"].as_bool().unwrap_or(false);
1359    let anthropic_body = translate_to_anthropic(openai_body);
1360
1361    let client = reqwest::Client::builder()
1362        .timeout(std::time::Duration::from_secs(300))
1363        .build()
1364        .map_err(|_| ProxyError::Upstream)?;
1365
1366    let resp = client
1367        .post(format!("{anthropic_url}/v1/messages"))
1368        .header("content-type", "application/json")
1369        .header("anthropic-version", "2023-06-01")
1370        .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1371        .header("x-shunt-compat", "openai")
1372        .json(&anthropic_body)
1373        .send()
1374        .await
1375        .map_err(|_| ProxyError::Upstream)?;
1376
1377    if !resp.status().is_success() {
1378        let status = resp.status();
1379        let body = resp.text().await.unwrap_or_default();
1380        let code = status.as_u16();
1381        return Ok(axum::response::Response::builder()
1382            .status(code)
1383            .header("content-type", "application/json")
1384            .body(axum::body::Body::from(body))
1385            .unwrap());
1386    }
1387
1388    if stream {
1389        // Translate Anthropic SSE stream → OpenAI SSE stream
1390        let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1391        let stream = translate_anthropic_stream(resp, chat_id);
1392        Ok(axum::response::Response::builder()
1393            .status(200)
1394            .header("content-type", "text/event-stream")
1395            .header("cache-control", "no-cache")
1396            .body(axum::body::Body::from_stream(stream))
1397            .unwrap())
1398    } else {
1399        let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1400        let openai_resp = translate_from_anthropic(anthropic_resp);
1401        Ok(axum::Json(openai_resp).into_response())
1402    }
1403}
1404
1405/// Translate Anthropic SSE events to OpenAI SSE format, yielding raw bytes.
1406/// Handles text content, tool_use blocks, and finish reasons.
1407fn translate_anthropic_stream(
1408    resp: reqwest::Response,
1409    chat_id: String,
1410) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1411    use futures_util::StreamExt;
1412
1413    let id = chat_id;
1414    let byte_stream = resp.bytes_stream();
1415
1416    async_stream::stream! {
1417        let mut buf = String::new();
1418        // Per-block state: block_index -> (tool_call_oai_index, tool_id, tool_name)
1419        let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1420        let mut tool_call_count: usize = 0;
1421        futures_util::pin_mut!(byte_stream);
1422
1423        // Send initial role chunk
1424        let init = format!(
1425            "data: {}\n\n",
1426            serde_json::to_string(&json!({
1427                "id": id,
1428                "object": "chat.completion.chunk",
1429                "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1430            })).unwrap()
1431        );
1432        yield Ok(bytes::Bytes::from(init));
1433
1434        while let Some(chunk) = byte_stream.next().await {
1435            let chunk = match chunk {
1436                Ok(c) => c,
1437                Err(_) => break,
1438            };
1439            buf.push_str(&String::from_utf8_lossy(&chunk));
1440
1441            // Process complete SSE lines
1442            while let Some(nl) = buf.find('\n') {
1443                let line = buf[..nl].trim_end_matches('\r').to_owned();
1444                buf = buf[nl + 1..].to_owned();
1445
1446                if !line.starts_with("data: ") { continue; }
1447                let data = &line["data: ".len()..];
1448                if data == "[DONE]" { continue; }
1449
1450                let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1451                let event_type = event["type"].as_str().unwrap_or("");
1452
1453                let maybe_chunk = match event_type {
1454                    "content_block_start" => {
1455                        let block_idx = event["index"].as_u64().unwrap_or(0);
1456                        let cb = &event["content_block"];
1457                        if cb["type"].as_str() == Some("tool_use") {
1458                            let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
1459                            let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
1460                            let oai_idx = tool_call_count;
1461                            tool_call_count += 1;
1462                            tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
1463                            Some(json!({
1464                                "id": id,
1465                                "object": "chat.completion.chunk",
1466                                "choices": [{"index": 0, "delta": {
1467                                    "tool_calls": [{
1468                                        "index": oai_idx,
1469                                        "id": tool_id,
1470                                        "type": "function",
1471                                        "function": {"name": tool_name, "arguments": ""}
1472                                    }]
1473                                }, "finish_reason": null}]
1474                            }))
1475                        } else {
1476                            None
1477                        }
1478                    }
1479                    "content_block_delta" => {
1480                        let block_idx = event["index"].as_u64().unwrap_or(0);
1481                        let delta = &event["delta"];
1482                        match delta["type"].as_str() {
1483                            Some("text_delta") => {
1484                                let text = delta["text"].as_str().unwrap_or("");
1485                                if text.is_empty() { continue; }
1486                                Some(json!({
1487                                    "id": id,
1488                                    "object": "chat.completion.chunk",
1489                                    "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1490                                }))
1491                            }
1492                            Some("input_json_delta") => {
1493                                let args = delta["partial_json"].as_str().unwrap_or("");
1494                                if let Some((oai_idx, _, _)) = tool_blocks.get(&block_idx) {
1495                                    Some(json!({
1496                                        "id": id,
1497                                        "object": "chat.completion.chunk",
1498                                        "choices": [{"index": 0, "delta": {
1499                                            "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
1500                                        }, "finish_reason": null}]
1501                                    }))
1502                                } else {
1503                                    None
1504                                }
1505                            }
1506                            _ => None,
1507                        }
1508                    }
1509                    "message_delta" => {
1510                        let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1511                        let finish = match stop_reason {
1512                            "end_turn"  => "stop",
1513                            "tool_use"  => "tool_calls",
1514                            "max_tokens" => "length",
1515                            other       => other,
1516                        };
1517                        Some(json!({
1518                            "id": id,
1519                            "object": "chat.completion.chunk",
1520                            "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1521                        }))
1522                    }
1523                    _ => None,
1524                };
1525
1526                if let Some(c) = maybe_chunk {
1527                    let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
1528                    yield Ok(bytes::Bytes::from(out));
1529                }
1530            }
1531        }
1532
1533        yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
1534    }
1535}
1536
1537// ---------------------------------------------------------------------------
1538// Cross-protocol translation: Anthropic ↔ OpenAI
1539// ---------------------------------------------------------------------------
1540
1541/// Translate an Anthropic `/v1/messages` request body to OpenAI `/v1/chat/completions` format.
1542/// Used when routing an Anthropic-protocol request to an OpenAI/Codex account.
1543fn translate_anthropic_req_to_openai(body: serde_json::Value) -> serde_json::Value {
1544    let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6");
1545    let stream = body["stream"].as_bool().unwrap_or(false);
1546    let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1547
1548    let mut messages: Vec<serde_json::Value> = Vec::new();
1549
1550    // Prepend system prompt if present.
1551    if let Some(sys) = body["system"].as_str().filter(|s| !s.is_empty()) {
1552        messages.push(json!({"role": "system", "content": sys}));
1553    }
1554
1555    if let Some(arr) = body["messages"].as_array() {
1556        for msg in arr {
1557            let role = msg["role"].as_str().unwrap_or("user");
1558
1559            if let Some(blocks) = msg["content"].as_array() {
1560                // Check for tool_result blocks (user turn carrying tool results).
1561                let has_tool_result = blocks.iter().any(|b| b["type"] == "tool_result");
1562                if has_tool_result {
1563                    for b in blocks {
1564                        if b["type"] == "tool_result" {
1565                            let content = b["content"].as_str()
1566                                .map(|s| s.to_owned())
1567                                .unwrap_or_else(|| serde_json::to_string(&b["content"]).unwrap_or_default());
1568                            messages.push(json!({
1569                                "role": "tool",
1570                                "tool_call_id": b["tool_use_id"].as_str().unwrap_or(""),
1571                                "content": content,
1572                            }));
1573                        }
1574                    }
1575                    continue;
1576                }
1577
1578                // Regular content blocks — may include text and tool_use.
1579                let mut text = String::new();
1580                let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1581                for b in blocks {
1582                    match b["type"].as_str() {
1583                        Some("text") => text.push_str(b["text"].as_str().unwrap_or("")),
1584                        Some("tool_use") => {
1585                            let args = match &b["input"] {
1586                                serde_json::Value::String(s) => s.clone(),
1587                                v => serde_json::to_string(v).unwrap_or_default(),
1588                            };
1589                            tool_calls.push(json!({
1590                                "id": b["id"].as_str().unwrap_or(""),
1591                                "type": "function",
1592                                "function": {"name": b["name"].as_str().unwrap_or(""), "arguments": args},
1593                            }));
1594                        }
1595                        _ => {}
1596                    }
1597                }
1598                let mut m = json!({"role": role, "content": text});
1599                if !tool_calls.is_empty() {
1600                    m["tool_calls"] = json!(tool_calls);
1601                }
1602                messages.push(m);
1603            } else if let Some(s) = msg["content"].as_str() {
1604                messages.push(json!({"role": role, "content": s}));
1605            }
1606        }
1607    }
1608
1609    let mut req = json!({
1610        "model": model,
1611        "messages": messages,
1612        "max_tokens": max_tokens,
1613        "stream": stream,
1614    });
1615
1616    // Request usage data in stream final chunk.
1617    if stream {
1618        req["stream_options"] = json!({"include_usage": true});
1619    }
1620    if let Some(t) = body.get("temperature") { req["temperature"] = t.clone(); }
1621    if let Some(sp) = body.get("stop_sequences") { req["stop"] = sp.clone(); }
1622
1623    // Anthropic tools → OpenAI tools.
1624    if let Some(tools) = body["tools"].as_array() {
1625        let oai: Vec<serde_json::Value> = tools.iter().map(|t| json!({
1626            "type": "function",
1627            "function": {
1628                "name": t["name"].as_str().unwrap_or(""),
1629                "description": t["description"].as_str().unwrap_or(""),
1630                "parameters": t.get("input_schema").cloned()
1631                    .unwrap_or(json!({"type": "object", "properties": {}})),
1632            }
1633        })).collect();
1634        if !oai.is_empty() { req["tools"] = json!(oai); }
1635    }
1636
1637    if let Some(tc) = body.get("tool_choice") {
1638        req["tool_choice"] = match tc["type"].as_str() {
1639            Some("any")  => json!({"type": "required"}),
1640            Some("tool") => json!({"type": "function", "function": {"name": tc["name"]}}),
1641            _            => json!("auto"),
1642        };
1643    }
1644
1645    req
1646}
1647
1648/// Translate an OpenAI `/v1/chat/completions` non-streaming response to Anthropic format.
1649fn translate_openai_resp_to_anthropic(body: serde_json::Value, model: &str) -> serde_json::Value {
1650    let id = format!("msg_{}", &uuid_v4()[..8]);
1651    let choice = &body["choices"][0];
1652    let msg = &choice["message"];
1653
1654    let mut content: Vec<serde_json::Value> = Vec::new();
1655    if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
1656        content.push(json!({"type": "text", "text": text}));
1657    }
1658    if let Some(tcs) = msg["tool_calls"].as_array() {
1659        for tc in tcs {
1660            content.push(json!({
1661                "type": "tool_use",
1662                "id": tc["id"].as_str().unwrap_or(""),
1663                "name": tc["function"]["name"].as_str().unwrap_or(""),
1664                "input": serde_json::from_str::<serde_json::Value>(
1665                    tc["function"]["arguments"].as_str().unwrap_or("{}")
1666                ).unwrap_or(json!({})),
1667            }));
1668        }
1669    }
1670
1671    let stop_reason = match choice["finish_reason"].as_str().unwrap_or("stop") {
1672        "stop"       => "end_turn",
1673        "tool_calls" => "tool_use",
1674        "length"     => "max_tokens",
1675        other        => other,
1676    };
1677
1678    json!({
1679        "id": id,
1680        "type": "message",
1681        "role": "assistant",
1682        "model": model,
1683        "content": content,
1684        "stop_reason": stop_reason,
1685        "stop_sequence": null,
1686        "usage": {
1687            "input_tokens":  body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
1688            "output_tokens": body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
1689        }
1690    })
1691}
1692
1693/// Translate the response back from OpenAI format to Anthropic format.
1694/// Handles both streaming and non-streaming responses.
1695async fn translate_response_openai_to_anthropic(resp: Response, model: &str) -> Response {
1696    use axum::body::Body;
1697    let msg_id = format!("msg_{}", &uuid_v4()[..8]);
1698    let model = model.to_owned();
1699
1700    if quota::is_streaming_response(&resp) {
1701        let (mut parts, body) = resp.into_parts();
1702        parts.headers.insert(
1703            axum::http::header::CONTENT_TYPE,
1704            axum::http::HeaderValue::from_static("text/event-stream"),
1705        );
1706        let stream = translate_openai_stream_to_anthropic(body, model, msg_id);
1707        Response::from_parts(parts, Body::from_stream(stream))
1708    } else {
1709        let (mut parts, body) = resp.into_parts();
1710        let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
1711        let openai_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
1712        let anthropic_val = translate_openai_resp_to_anthropic(openai_val, &model);
1713        let out = serde_json::to_vec(&anthropic_val).unwrap_or_default();
1714        parts.headers.insert(
1715            axum::http::header::CONTENT_TYPE,
1716            axum::http::HeaderValue::from_static("application/json"),
1717        );
1718        Response::from_parts(parts, Body::from(out))
1719    }
1720}
1721
1722/// Translate the response back from Anthropic format to OpenAI format.
1723async fn translate_response_anthropic_to_openai(resp: Response) -> Response {
1724    use axum::body::Body;
1725    let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1726
1727    if quota::is_streaming_response(&resp) {
1728        let (parts, body) = resp.into_parts();
1729        let stream = translate_body_anthropic_to_openai(body, chat_id);
1730        Response::from_parts(parts, Body::from_stream(stream))
1731    } else {
1732        let (mut parts, body) = resp.into_parts();
1733        let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
1734        let anthropic_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
1735        let openai_val = translate_from_anthropic(anthropic_val);
1736        let out = serde_json::to_vec(&openai_val).unwrap_or_default();
1737        parts.headers.insert(
1738            axum::http::header::CONTENT_TYPE,
1739            axum::http::HeaderValue::from_static("application/json"),
1740        );
1741        Response::from_parts(parts, Body::from(out))
1742    }
1743}
1744
1745/// Stream-translate an OpenAI SSE response body into Anthropic SSE events.
1746///
1747/// Emits: `message_start` → `content_block_start` → N×`content_block_delta`
1748///       → `content_block_stop` → `message_delta` → `message_stop`
1749fn translate_openai_stream_to_anthropic(
1750    body: axum::body::Body,
1751    model: String,
1752    msg_id: String,
1753) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1754    use futures_util::StreamExt;
1755
1756    async_stream::stream! {
1757        // Send message_start immediately (input_tokens unknown yet, use 0).
1758        let start_evt = format!(
1759            "event: message_start\ndata: {}\n\nevent: ping\ndata: {{\"type\":\"ping\"}}\n\n",
1760            serde_json::to_string(&json!({
1761                "type": "message_start",
1762                "message": {
1763                    "id": msg_id, "type": "message", "role": "assistant",
1764                    "content": [], "model": model, "stop_reason": null,
1765                    "usage": {"input_tokens": 0, "output_tokens": 0}
1766                }
1767            })).unwrap()
1768        );
1769        yield Ok(bytes::Bytes::from(start_evt));
1770
1771        let mut buf = String::new();
1772        let mut content_block_open = false;
1773        let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1774        let mut tool_call_count: usize = 0;
1775        let mut output_tokens: u64 = 0;
1776        let mut input_tokens: u64 = 0;
1777        let byte_stream = body.into_data_stream();
1778        futures_util::pin_mut!(byte_stream);
1779
1780        while let Some(chunk) = byte_stream.next().await {
1781            let chunk = match chunk { Ok(c) => c, Err(_) => break };
1782            buf.push_str(&String::from_utf8_lossy(&chunk));
1783
1784            while let Some(nl) = buf.find('\n') {
1785                let line = buf[..nl].trim_end_matches('\r').to_owned();
1786                buf = buf[nl + 1..].to_owned();
1787                if !line.starts_with("data: ") { continue; }
1788                let data = &line["data: ".len()..];
1789                if data == "[DONE]" { continue; }
1790                let Ok(ev) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1791
1792                // Collect usage from final chunk (stream_options.include_usage).
1793                if let Some(u) = ev.get("usage") {
1794                    input_tokens  = u["prompt_tokens"].as_u64().unwrap_or(input_tokens);
1795                    output_tokens = u["completion_tokens"].as_u64().unwrap_or(output_tokens);
1796                }
1797
1798                let choice = &ev["choices"][0];
1799                let delta = &choice["delta"];
1800                let finish = choice["finish_reason"].as_str();
1801
1802                // Text delta.
1803                if let Some(text) = delta["content"].as_str().filter(|s| !s.is_empty()) {
1804                    if !content_block_open {
1805                        content_block_open = true;
1806                        let cb = format!(
1807                            "event: content_block_start\ndata: {}\n\n",
1808                            serde_json::to_string(&json!({
1809                                "type": "content_block_start", "index": 0,
1810                                "content_block": {"type": "text", "text": ""}
1811                            })).unwrap()
1812                        );
1813                        yield Ok(bytes::Bytes::from(cb));
1814                    }
1815                    let d = format!(
1816                        "event: content_block_delta\ndata: {}\n\n",
1817                        serde_json::to_string(&json!({
1818                            "type": "content_block_delta", "index": 0,
1819                            "delta": {"type": "text_delta", "text": text}
1820                        })).unwrap()
1821                    );
1822                    yield Ok(bytes::Bytes::from(d));
1823                }
1824
1825                // Tool call deltas.
1826                if let Some(tcs) = delta["tool_calls"].as_array() {
1827                    for tc in tcs {
1828                        let oai_idx = tc["index"].as_u64().unwrap_or(0);
1829                        // New tool call: emit content_block_start for tool_use.
1830                        if let Some(id) = tc["id"].as_str() {
1831                            let name = tc["function"]["name"].as_str().unwrap_or("").to_owned();
1832                            let my_idx = tool_call_count;
1833                            tool_call_count += 1;
1834                            tool_blocks.insert(oai_idx, (my_idx, id.to_owned(), name.clone()));
1835                            let cb = format!(
1836                                "event: content_block_start\ndata: {}\n\n",
1837                                serde_json::to_string(&json!({
1838                                    "type": "content_block_start",
1839                                    "index": my_idx + 1, // +1: text block at 0
1840                                    "content_block": {"type": "tool_use", "id": id, "name": name, "input": {}}
1841                                })).unwrap()
1842                            );
1843                            yield Ok(bytes::Bytes::from(cb));
1844                        }
1845                        // Streaming arguments.
1846                        if let Some(args_chunk) = tc["function"]["arguments"].as_str() {
1847                            if let Some(&(my_idx, _, _)) = tool_blocks.get(&oai_idx) {
1848                                let d = format!(
1849                                    "event: content_block_delta\ndata: {}\n\n",
1850                                    serde_json::to_string(&json!({
1851                                        "type": "content_block_delta",
1852                                        "index": my_idx + 1,
1853                                        "delta": {"type": "input_json_delta", "partial_json": args_chunk}
1854                                    })).unwrap()
1855                                );
1856                                yield Ok(bytes::Bytes::from(d));
1857                            }
1858                        }
1859                    }
1860                }
1861
1862                // Finish reason → close blocks + message_delta + message_stop.
1863                if let Some(fr) = finish {
1864                    let stop_reason = match fr {
1865                        "stop"       => "end_turn",
1866                        "tool_calls" => "tool_use",
1867                        "length"     => "max_tokens",
1868                        other        => other,
1869                    };
1870
1871                    // Close open content/tool blocks.
1872                    if content_block_open {
1873                        yield Ok(bytes::Bytes::from(format!(
1874                            "event: content_block_stop\ndata: {}\n\n",
1875                            serde_json::to_string(&json!({"type":"content_block_stop","index":0})).unwrap()
1876                        )));
1877                    }
1878                    for (_, (my_idx, _, _)) in &tool_blocks {
1879                        yield Ok(bytes::Bytes::from(format!(
1880                            "event: content_block_stop\ndata: {}\n\n",
1881                            serde_json::to_string(&json!({"type":"content_block_stop","index": my_idx + 1})).unwrap()
1882                        )));
1883                    }
1884
1885                    yield Ok(bytes::Bytes::from(format!(
1886                        "event: message_delta\ndata: {}\n\n",
1887                        serde_json::to_string(&json!({
1888                            "type": "message_delta",
1889                            "delta": {"stop_reason": stop_reason, "stop_sequence": null},
1890                            "usage": {"output_tokens": output_tokens}
1891                        })).unwrap()
1892                    )));
1893                    yield Ok(bytes::Bytes::from(
1894                        "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
1895                    ));
1896                }
1897            }
1898        }
1899    }
1900}
1901
1902/// Stream-translate an Anthropic SSE response body (from axum `Body`) into OpenAI SSE format.
1903/// Equivalent to `translate_anthropic_stream` but consumes an axum `Body` instead of a
1904/// `reqwest::Response`, so it can be used after the forwarder returns.
1905fn translate_body_anthropic_to_openai(
1906    body: axum::body::Body,
1907    chat_id: String,
1908) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1909    use futures_util::StreamExt;
1910
1911    async_stream::stream! {
1912        let id = chat_id;
1913
1914        // Initial role chunk.
1915        let init = format!(
1916            "data: {}\n\n",
1917            serde_json::to_string(&json!({
1918                "id": id, "object": "chat.completion.chunk",
1919                "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1920            })).unwrap()
1921        );
1922        yield Ok(bytes::Bytes::from(init));
1923
1924        let mut buf = String::new();
1925        let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1926        let mut tool_call_count: usize = 0;
1927        let byte_stream = body.into_data_stream();
1928        futures_util::pin_mut!(byte_stream);
1929
1930        while let Some(chunk) = byte_stream.next().await {
1931            let chunk = match chunk { Ok(c) => c, Err(_) => break };
1932            buf.push_str(&String::from_utf8_lossy(&chunk));
1933
1934            while let Some(nl) = buf.find('\n') {
1935                let line = buf[..nl].trim_end_matches('\r').to_owned();
1936                buf = buf[nl + 1..].to_owned();
1937                if !line.starts_with("data: ") { continue; }
1938                let data = &line["data: ".len()..];
1939                if data == "[DONE]" { continue; }
1940                let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1941                let event_type = event["type"].as_str().unwrap_or("");
1942
1943                let maybe_chunk = match event_type {
1944                    "content_block_start" => {
1945                        let block_idx = event["index"].as_u64().unwrap_or(0);
1946                        let cb = &event["content_block"];
1947                        if cb["type"].as_str() == Some("tool_use") {
1948                            let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
1949                            let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
1950                            let oai_idx = tool_call_count;
1951                            tool_call_count += 1;
1952                            tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
1953                            Some(json!({
1954                                "id": id, "object": "chat.completion.chunk",
1955                                "choices": [{"index": 0, "delta": {
1956                                    "tool_calls": [{"index": oai_idx, "id": tool_id, "type": "function",
1957                                        "function": {"name": tool_name, "arguments": ""}}]
1958                                }, "finish_reason": null}]
1959                            }))
1960                        } else { None }
1961                    }
1962                    "content_block_delta" => {
1963                        let block_idx = event["index"].as_u64().unwrap_or(0);
1964                        let delta = &event["delta"];
1965                        match delta["type"].as_str() {
1966                            Some("text_delta") => {
1967                                let text = delta["text"].as_str().unwrap_or("");
1968                                if text.is_empty() { continue; }
1969                                Some(json!({
1970                                    "id": id, "object": "chat.completion.chunk",
1971                                    "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1972                                }))
1973                            }
1974                            Some("input_json_delta") => {
1975                                let args = delta["partial_json"].as_str().unwrap_or("");
1976                                tool_blocks.get(&block_idx).map(|(oai_idx, _, _)| json!({
1977                                    "id": id, "object": "chat.completion.chunk",
1978                                    "choices": [{"index": 0, "delta": {
1979                                        "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
1980                                    }, "finish_reason": null}]
1981                                }))
1982                            }
1983                            _ => None,
1984                        }
1985                    }
1986                    "message_delta" => {
1987                        let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1988                        let finish = match stop_reason {
1989                            "end_turn"   => "stop",
1990                            "tool_use"   => "tool_calls",
1991                            "max_tokens" => "length",
1992                            other        => other,
1993                        };
1994                        Some(json!({
1995                            "id": id, "object": "chat.completion.chunk",
1996                            "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1997                        }))
1998                    }
1999                    _ => None,
2000                };
2001
2002                if let Some(c) = maybe_chunk {
2003                    let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
2004                    yield Ok(bytes::Bytes::from(out));
2005                }
2006            }
2007        }
2008        yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
2009    }
2010}