Skip to main content

shunt/
proxy.rs

1use std::collections::{HashMap, HashSet};
2use std::net::IpAddr;
3use std::sync::Arc;
4use std::time::Instant;
5
6use parking_lot::Mutex as ParkingMutex;
7
8use axum::extract::{Request, State};
9use axum::http::StatusCode;
10use axum::response::{IntoResponse, Response};
11use axum::routing::{get, post};
12use axum::Router;
13use bytes::Bytes;
14use serde_json::json;
15use tokio::sync::RwLock;
16use tracing::{error, info, warn};
17
18use crate::config::{state_path, Config, CredentialsStore};
19use crate::credential::Credential;
20use crate::forwarder::Forwarder;
21use crate::provider::Provider;
22use crate::quota;
23use crate::router;
24use crate::state::StateStore;
25use crate::telemetry::TelemetryClient;
26
27/// 100 MB limit — sufficient for any LLM request including large context windows.
28const MAX_REQUEST_BODY: usize = 100 * 1024 * 1024;
29
30#[derive(Clone)]
31struct AppState {
32    config: Arc<Config>,
33    forwarder: Arc<Forwarder>,
34    state: StateStore,
35    /// Live credentials — can be refreshed at runtime without restarting.
36    credentials: Arc<RwLock<HashMap<String, Credential>>>,
37    /// Per-account mutex that serialises concurrent token-refresh attempts.
38    ///
39    /// When multiple in-flight requests hit a 401 for the same account at the
40    /// same time, only one should call the upstream OAuth endpoint; the others
41    /// should wait and then re-use the fresh token instead of each making their
42    /// own refresh call (which would rotate the refresh_token out from under the
43    /// others and cause cascading auth failures).
44    refresh_locks: Arc<ParkingMutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
45    /// Epoch-ms when this proxy instance started.
46    started_ms: u64,
47    /// If set, /v1/chat/completions requests are translated and forwarded here
48    /// (the Anthropic proxy base URL, e.g. "http://127.0.0.1:8082").
49    anthropic_base_url: Option<String>,
50    /// Optional relay-server telemetry client.
51    telemetry: Option<TelemetryClient>,
52    /// Per-IP token-bucket rate limiter (#16). None when rate_limit_rpm == 0.
53    rate_limiter: Option<Arc<ParkingMutex<HashMap<IpAddr, TokenBucket>>>>,
54}
55
56/// Simple token-bucket for per-IP rate limiting.
57struct TokenBucket {
58    tokens: f64,
59    last_refill: Instant,
60}
61
62impl TokenBucket {
63    fn new(capacity: f64) -> Self {
64        Self { tokens: capacity, last_refill: Instant::now() }
65    }
66
67    /// Refill tokens proportional to elapsed time, then consume one.
68    /// Returns true if the request is allowed.
69    fn check_and_consume(&mut self, rpm: f64) -> bool {
70        let elapsed = self.last_refill.elapsed().as_secs_f64();
71        self.last_refill = Instant::now();
72        // Refill at rpm/60 tokens per second; cap at burst (10 tokens).
73        let burst = (rpm / 6.0).max(10.0);
74        self.tokens = (self.tokens + elapsed * rpm / 60.0).min(burst);
75        if self.tokens >= 1.0 {
76            self.tokens -= 1.0;
77            true
78        } else {
79            false
80        }
81    }
82}
83
84pub fn create_app(config: Config) -> anyhow::Result<Router> {
85    let (app, _, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
86    Ok(app)
87}
88
89/// Shared live credentials map — can be written to without restarting the proxy.
90pub type LiveCredentials = Arc<RwLock<HashMap<String, Credential>>>;
91
92/// Create a pure proxy app (no management routes).
93/// Registers /v1/messages, /v1/chat/completions, /v1/models, and a fallback.
94/// Build a shared `AppState` and the `LiveCredentials` handle it references.
95fn build_app_state(
96    config: Config,
97    state: StateStore,
98    anthropic_base_url: Option<String>,
99) -> anyhow::Result<(AppState, LiveCredentials)> {
100    let forwarder = Forwarder::new(config.server.request_timeout_secs)?;
101
102    for a in &config.accounts {
103        if a.provider.auth_kind() == crate::provider::AuthKind::None {
104            // Local providers never need credentials — clear any stale auth_failed from disk.
105            state.clear_auth_failed(&a.name);
106        } else if a.credential.is_none() {
107            state.set_auth_failed(&a.name);
108        }
109    }
110
111    let credentials: LiveCredentials = Arc::new(RwLock::new(
112        config.accounts.iter()
113            .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
114            .collect::<HashMap<_, _>>(),
115    ));
116
117    let telemetry = config.server.telemetry_url.as_deref().map(|url| {
118        TelemetryClient::new(url, config.server.telemetry_token.clone(), config.server.instance_name.clone())
119    });
120
121    let rate_limiter = if config.server.rate_limit_rpm > 0 {
122        Some(Arc::new(ParkingMutex::new(HashMap::<IpAddr, TokenBucket>::new())))
123    } else {
124        None
125    };
126
127    let app_state = AppState {
128        config: Arc::new(config),
129        forwarder: Arc::new(forwarder),
130        state,
131        credentials: Arc::clone(&credentials),
132        refresh_locks: Arc::new(ParkingMutex::new(HashMap::new())),
133        started_ms: now_ms(),
134        anthropic_base_url,
135        telemetry,
136        rate_limiter,
137    };
138
139    Ok((app_state, credentials))
140}
141
142pub fn create_proxy_app(
143    config: Config,
144    state: StateStore,
145    anthropic_base_url: Option<String>,
146) -> anyhow::Result<(Router, LiveCredentials)> {
147    let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
148
149    let app = Router::new()
150        .route("/v1/messages", post(proxy_handler))
151        .route("/v1/messages/count_tokens", post(proxy_handler))
152        .route("/v1/chat/completions", post(openai_compat_handler))
153        .route("/v1/models", get(openai_models_handler))
154        .fallback(proxy_handler)
155        .with_state(app_state);
156
157    Ok((app, credentials))
158}
159
160/// Create a control plane app (management routes only — sees ALL accounts).
161/// Registers /health, /status, /use.
162pub fn create_control_app(
163    config: Config,
164    state: StateStore,
165) -> anyhow::Result<Router> {
166    let (app_state, _) = build_app_state(config, state, None)?;
167
168    let app = Router::new()
169        .route("/health", get(health))
170        .route("/status", get(status_handler))
171        .route("/use", post(use_handler))
172        .route("/model", get(model_get_handler).post(model_set_handler).delete(model_clear_handler))
173        .route("/strategy", get(strategy_get_handler).post(strategy_set_handler).delete(strategy_clear_handler))
174        .route("/burst-limit", get(burst_limit_get_handler).post(burst_limit_set_handler).delete(burst_limit_clear_handler))
175        .route("/fallback", get(fallback_get_handler).post(fallback_set_handler).delete(fallback_clear_handler))
176        .route("/effort", get(effort_get_handler).post(effort_set_handler).delete(effort_clear_handler))
177        .route("/thinking", get(thinking_get_handler).post(thinking_set_handler).delete(thinking_clear_handler))
178        .route("/alerts", get(alerts_get_handler).post(alerts_set_handler))
179        .with_state(app_state);
180
181    Ok(app)
182}
183
184/// Combined app used by tests and the single-port fallback mode.
185/// Includes both proxy routes and management routes (/health, /status, /use)
186/// sharing a single AppState so state changes are visible across all routes.
187pub fn create_app_with_state(
188    config: Config,
189    state: StateStore,
190    anthropic_base_url: Option<String>,
191) -> anyhow::Result<(Router, LiveCredentials, Option<TelemetryClient>)> {
192    let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
193    let telemetry = app_state.telemetry.clone();
194
195    let app = Router::new()
196        // Management routes
197        .route("/health", get(health))
198        .route("/status", get(status_handler))
199        .route("/use", post(use_handler))
200        .route("/model", get(model_get_handler).post(model_set_handler).delete(model_clear_handler))
201        .route("/strategy", get(strategy_get_handler).post(strategy_set_handler).delete(strategy_clear_handler))
202        .route("/burst-limit", get(burst_limit_get_handler).post(burst_limit_set_handler).delete(burst_limit_clear_handler))
203        .route("/fallback", get(fallback_get_handler).post(fallback_set_handler).delete(fallback_clear_handler))
204        .route("/effort", get(effort_get_handler).post(effort_set_handler).delete(effort_clear_handler))
205        .route("/thinking", get(thinking_get_handler).post(thinking_set_handler).delete(thinking_clear_handler))
206        .route("/alerts", get(alerts_get_handler).post(alerts_set_handler))
207        // Proxy routes
208        .route("/v1/messages", post(proxy_handler))
209        .route("/v1/messages/count_tokens", post(proxy_handler))
210        .route("/v1/chat/completions", post(openai_compat_handler))
211        .route("/v1/models", get(openai_models_handler))
212        .fallback(proxy_handler)
213        .with_state(app_state);
214
215    Ok((app, credentials, telemetry))
216}
217
218/// Build a status JSON snapshot from config + state — used by the heartbeat loop.
219pub fn build_status_snapshot(config: &Config, state: &StateStore, started_ms: u64) -> serde_json::Value {
220    let account_states = state.account_states();
221    let rate_limits    = state.rate_limit_snapshot();
222
223    let accounts: Vec<_> = config.accounts.iter().map(|a| {
224        let st            = account_states.get(&a.name);
225        let rl            = rate_limits.get(&a.name);
226        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
227        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
228        let reset_5h       = rl.and_then(|r| r.reset_5h);
229        let reset_7d       = rl.and_then(|r| r.reset_7d);
230        let disabled       = st.map(|s| s.disabled).unwrap_or(false);
231        let auth_failed    = st.map(|s| s.auth_failed).unwrap_or(false);
232        let health_check_failed = st.map(|s| s.health_check_failed).unwrap_or(false);
233        let cooldown_until_ms = st.map(|s| s.cooldown_until_ms).unwrap_or(0);
234        let available      = state.is_available(&a.name);
235        let email          = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
236
237        json!({
238            "name": a.name,
239            "email": email,
240            "provider": a.provider.to_string(),
241            "available": available,
242            "disabled": disabled,
243            "auth_failed": auth_failed,
244            "health_check_failed": health_check_failed,
245            "cooldown_until_ms": cooldown_until_ms,
246            "utilization_5h": utilization_5h,
247            "reset_5h": reset_5h,
248            "utilization_7d": utilization_7d,
249            "reset_7d": reset_7d,
250        })
251    }).collect();
252
253    json!({
254        "started_ms": started_ms,
255        "accounts": accounts,
256        "pinned_account": state.get_pinned(),
257        "last_used_account": state.get_last_used(),
258    })
259}
260
261async fn health() -> impl IntoResponse {
262    axum::Json(json!({ "status": "ok", "version": env!("CARGO_PKG_VERSION") }))
263}
264
265async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
266    let account_states = s.state.account_states();
267    let quotas = s.state.quota_snapshot();
268    let rate_limits = s.state.rate_limit_snapshot();
269
270    let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
271        let st = account_states.get(&a.name);
272        let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
273            "reauth_required"
274        } else if st.map(|s| s.disabled).unwrap_or(false) {
275            "disabled"
276        } else if st.map(|s| s.health_check_failed).unwrap_or(false) {
277            "unhealthy"
278        } else if s.state.is_available(&a.name) {
279            "available"
280        } else {
281            "cooling"
282        };
283
284        let quota = quotas.get(&a.name);
285        let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
286        let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
287        let tokens_used = quota.map(|q| json!({
288            "input": q.input_tokens,
289            "output": q.output_tokens,
290            "total": q.total_tokens(),
291        }));
292
293        let rl = rate_limits.get(&a.name);
294        let rate_limit = rl.map(|r| json!({
295            "utilization_5h": r.utilization_5h,
296            "reset_5h": r.reset_5h,
297            "status_5h": r.status_5h,
298            "utilization_7d": r.utilization_7d,
299            "reset_7d": r.reset_7d,
300            "status_7d": r.status_7d,
301            "representative_claim": r.representative_claim,
302            "updated_ms": r.updated_ms,
303        }));
304
305        let acc_state = account_states.get(&a.name);
306        let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
307        let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
308        let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
309        let health_check_failed = acc_state.map(|s| s.health_check_failed).unwrap_or(false);
310        let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
311        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
312        let reset_5h = rl.and_then(|r| r.reset_5h);
313        let status_5h = rl.and_then(|r| r.status_5h.clone());
314        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
315        let reset_7d = rl.and_then(|r| r.reset_7d);
316        let status_7d = rl.and_then(|r| r.status_7d.clone());
317        let available = s.state.is_available(&a.name);
318
319        json!({
320            "name": a.name,
321            "email": email,
322            "plan_type": a.plan_type,
323            "provider": a.provider.to_string(),
324            "status": avail_status,
325            "available": available,
326            "disabled": disabled,
327            "auth_failed": auth_failed,
328            "health_check_failed": health_check_failed,
329            "cooldown_until_ms": cooldown_until_ms,
330            "utilization_5h": utilization_5h,
331            "reset_5h": reset_5h,
332            "status_5h": status_5h,
333            "utilization_7d": utilization_7d,
334            "reset_7d": reset_7d,
335            "status_7d": status_7d,
336            "window_expires_ms": window_expires_ms,
337            "tokens_used": tokens_used,
338            "rate_limit": rate_limit,
339        })
340    }).collect();
341
342    let recent_requests = s.state.recent_requests_snapshot();
343    let savings = s.state.savings_snapshot();
344
345    axum::Json(json!({
346        "version": env!("CARGO_PKG_VERSION"),
347        "started_ms": s.started_ms,
348        "accounts": accounts,
349        "pinned_account": s.state.get_pinned(),
350        "last_used_account": s.state.get_last_used(),
351        "recent_requests": recent_requests,
352        "savings": savings,
353    }))
354}
355
356async fn use_handler(
357    State(s): State<AppState>,
358    axum::Json(body): axum::Json<serde_json::Value>,
359) -> Response {
360    let account = body["account"].as_str().map(|s| s.to_owned());
361    // Validate the account name exists (unless clearing to auto)
362    if let Some(ref name) = account {
363        if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
364            return (StatusCode::BAD_REQUEST, axum::Json(json!({
365                "error": format!("unknown account '{name}'")
366            }))).into_response();
367        }
368        let pinned = if name == "auto" { None } else { Some(name.clone()) };
369        s.state.set_pinned(pinned);
370        axum::Json(json!({ "pinned": name })).into_response()
371    } else {
372        s.state.set_pinned(None);
373        axum::Json(json!({ "pinned": null })).into_response()
374    }
375}
376
377async fn model_get_handler(State(s): State<AppState>) -> impl IntoResponse {
378    let model = s.state.get_model_override();
379    axum::Json(json!({ "model": model }))
380}
381
382async fn model_set_handler(
383    State(s): State<AppState>,
384    axum::Json(body): axum::Json<serde_json::Value>,
385) -> Response {
386    let Some(model) = body["model"].as_str() else {
387        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing model field" }))).into_response();
388    };
389    s.state.set_model_override(model.to_owned());
390    info!(model, "model override set");
391    axum::Json(json!({ "model": model })).into_response()
392}
393
394async fn model_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
395    s.state.clear_model_override();
396    info!("model override cleared");
397    axum::Json(json!({ "model": null }))
398}
399
400async fn strategy_get_handler(State(s): State<AppState>) -> impl IntoResponse {
401    let (strategy_str, source) = match s.state.get_routing_strategy() {
402        Some(st) => (st.as_str(), "override"),
403        None => (s.config.server.routing_strategy.as_str(), "config"),
404    };
405    axum::Json(json!({ "strategy": strategy_str, "source": source }))
406}
407
408async fn strategy_set_handler(
409    State(s): State<AppState>,
410    axum::Json(body): axum::Json<serde_json::Value>,
411) -> Response {
412    let Some(name) = body["strategy"].as_str() else {
413        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing strategy field" }))).into_response();
414    };
415    let Some(strategy) = crate::config::RoutingStrategy::from_str(name) else {
416        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": format!("unknown strategy '{name}'") }))).into_response();
417    };
418    s.state.set_routing_strategy(strategy);
419    info!(strategy = name, "routing strategy override set");
420    axum::Json(json!({ "strategy": strategy.as_str(), "source": "override" })).into_response()
421}
422
423async fn strategy_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
424    s.state.clear_routing_strategy();
425    info!("routing strategy override cleared");
426    let strategy_str = s.config.server.routing_strategy.as_str();
427    axum::Json(json!({ "strategy": strategy_str, "source": "config" }))
428}
429
430// ── Burst RPM limit ────────────────────────────────────────────────────────
431
432async fn burst_limit_get_handler(State(s): State<AppState>) -> impl IntoResponse {
433    let (limit, source) = match s.state.get_burst_rpm_limit_override() {
434        Some(l) => (l, "override"),
435        None => (s.config.server.burst_rpm_limit, if s.config.server.burst_rpm_limit == 10 { "default" } else { "config" }),
436    };
437    axum::Json(json!({ "burst_rpm_limit": limit, "source": source }))
438}
439
440async fn burst_limit_set_handler(
441    State(s): State<AppState>,
442    axum::Json(body): axum::Json<serde_json::Value>,
443) -> Response {
444    let Some(limit) = body["burst_rpm_limit"].as_u64() else {
445        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing burst_rpm_limit field (integer)" }))).into_response();
446    };
447    let limit = limit as u32;
448    s.state.set_burst_rpm_limit_override(limit);
449    info!(limit, "burst RPM limit override set");
450    axum::Json(json!({ "burst_rpm_limit": limit, "source": "override" })).into_response()
451}
452
453async fn burst_limit_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
454    s.state.clear_burst_rpm_limit_override();
455    info!("burst RPM limit override cleared");
456    let limit = s.config.server.burst_rpm_limit;
457    axum::Json(json!({ "burst_rpm_limit": limit, "source": if limit == 10 { "default" } else { "config" } }))
458}
459
460// ── Fallback model ─────────────────────────────────────────────────────────
461
462async fn fallback_get_handler(State(s): State<AppState>) -> impl IntoResponse {
463    match s.state.get_fallback_model_override() {
464        Some(Some(model)) => axum::Json(json!({ "fallback_model": model, "source": "override" })),
465        Some(None) => axum::Json(json!({ "fallback_model": null, "source": "override", "disabled": true })),
466        None => match &s.config.server.fallback_model {
467            Some(model) => axum::Json(json!({ "fallback_model": model, "source": "config" })),
468            None => axum::Json(json!({ "fallback_model": "auto", "source": "auto" })),
469        },
470    }
471}
472
473async fn fallback_set_handler(
474    State(s): State<AppState>,
475    axum::Json(body): axum::Json<serde_json::Value>,
476) -> Response {
477    if body["fallback_model"].is_null() || body.get("disabled").and_then(|v| v.as_bool()) == Some(true) {
478        s.state.set_fallback_model_override(None);
479        info!("fallback model explicitly disabled");
480        return axum::Json(json!({ "fallback_model": null, "source": "override", "disabled": true })).into_response();
481    }
482    let Some(model) = body["fallback_model"].as_str() else {
483        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing fallback_model field" }))).into_response();
484    };
485    let model = model.to_owned();
486    s.state.set_fallback_model_override(Some(model.clone()));
487    info!(model = %model, "fallback model override set");
488    axum::Json(json!({ "fallback_model": model, "source": "override" })).into_response()
489}
490
491async fn fallback_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
492    s.state.clear_fallback_model_override();
493    info!("fallback model override cleared");
494    match &s.config.server.fallback_model {
495        Some(model) => axum::Json(json!({ "fallback_model": model, "source": "config" })),
496        None => axum::Json(json!({ "fallback_model": "auto", "source": "auto" })),
497    }
498}
499
500async fn effort_get_handler(State(s): State<AppState>) -> impl IntoResponse {
501    match s.state.get_effort_override() {
502        Some(effort) => axum::Json(json!({ "effort": effort, "source": "override" })),
503        None => axum::Json(json!({ "effort": null, "source": "passthrough" })),
504    }
505}
506
507async fn effort_set_handler(
508    State(s): State<AppState>,
509    axum::Json(body): axum::Json<serde_json::Value>,
510) -> Response {
511    let Some(effort) = body["effort"].as_str() else {
512        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing effort string field" }))).into_response();
513    };
514    let valid = ["low", "medium", "high", "xhigh", "max"];
515    if !valid.contains(&effort) {
516        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "effort must be one of: low, medium, high, max" }))).into_response();
517    }
518    s.state.set_effort_override(effort.to_owned());
519    info!(effort, "effort override set");
520    axum::Json(json!({ "effort": effort, "source": "override" })).into_response()
521}
522
523async fn effort_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
524    s.state.clear_effort_override();
525    info!("effort override cleared");
526    axum::Json(json!({ "effort": null, "source": "passthrough" }))
527}
528
529async fn thinking_get_handler(State(s): State<AppState>) -> impl IntoResponse {
530    match s.state.get_thinking_override() {
531        Some(mode) => axum::Json(json!({ "thinking": mode, "source": "override" })),
532        None => axum::Json(json!({ "thinking": null, "source": "passthrough" })),
533    }
534}
535
536async fn thinking_set_handler(
537    State(s): State<AppState>,
538    axum::Json(body): axum::Json<serde_json::Value>,
539) -> Response {
540    let Some(mode) = body["thinking"].as_str() else {
541        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing thinking string field" }))).into_response();
542    };
543    let valid = ["adaptive", "disabled"];
544    if !valid.contains(&mode) {
545        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "thinking must be one of: adaptive, disabled" }))).into_response();
546    }
547    s.state.set_thinking_override(mode.to_owned());
548    info!(mode, "thinking override set");
549    axum::Json(json!({ "thinking": mode, "source": "override" })).into_response()
550}
551
552async fn thinking_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
553    s.state.clear_thinking_override();
554    info!("thinking override cleared");
555    axum::Json(json!({ "thinking": null, "source": "passthrough" }))
556}
557
558async fn alerts_get_handler(State(s): State<AppState>) -> impl IntoResponse {
559    let muted = s.state.get_alerts_muted();
560    axum::Json(json!({ "muted": muted }))
561}
562
563async fn alerts_set_handler(
564    State(s): State<AppState>,
565    axum::Json(body): axum::Json<serde_json::Value>,
566) -> Response {
567    let Some(muted) = body["muted"].as_bool() else {
568        return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing muted bool field" }))).into_response();
569    };
570    s.state.set_alerts_muted(muted);
571    info!(muted, "alerts mute state changed");
572    axum::Json(json!({ "muted": muted })).into_response()
573}
574
575use crate::state::now_ms_pub as now_ms;
576
577/// Extract client IP for rate limiting.
578///
579/// `X-Real-IP` is only trusted when `trust_proxy_headers` is explicitly enabled
580/// in config — otherwise any client could spoof the header to rotate its bucket.
581/// When not trusted (the default), all requests share a single loopback bucket,
582/// giving a global RPM cap rather than a per-IP one.
583fn extract_client_ip(req: &Request, trust_proxy_headers: bool) -> IpAddr {
584    if trust_proxy_headers {
585        if let Some(ip) = req.headers()
586            .get("x-real-ip")
587            .and_then(|v| v.to_str().ok())
588            .and_then(|s| s.parse().ok())
589        {
590            return ip;
591        }
592    }
593    IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
594}
595
596async fn proxy_handler(
597    State(s): State<AppState>,
598    req: Request,
599) -> Result<Response, ProxyError> {
600    // Remote auth: if a remote_key is configured, the client must supply it as x-api-key.
601    if let Some(ref expected) = s.config.server.remote_key {
602        let provided = req.headers()
603            .get("x-api-key")
604            .and_then(|v| v.to_str().ok())
605            .unwrap_or("");
606        if provided != expected {
607            return Err(ProxyError::Unauthorized);
608        }
609    }
610
611    // #16: per-IP rate limiting (token bucket, configurable via rate_limit_rpm).
612    if let Some(ref rl) = s.rate_limiter {
613        let ip = extract_client_ip(&req, s.config.server.trust_proxy_headers);
614        let rpm = s.config.server.rate_limit_rpm as f64;
615        let allowed = rl.lock().entry(ip).or_insert_with(|| TokenBucket::new(rpm)).check_and_consume(rpm);
616        if !allowed {
617            return Err(ProxyError::RateLimited);
618        }
619    }
620
621    let method = req.method().as_str().to_owned();
622    let path = req.uri().path().to_owned();
623    let headers = req.headers().clone();
624
625    let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY)
626        .await
627        .map_err(|_| ProxyError::BodyRead)?;
628
629    // Apply model override: if set, patch the `model` field in the JSON body before forwarding.
630    // Also strip unsupported params for models that don't support them (e.g. Haiku).
631    // Parse once and reuse the value to extract the model name (avoids double-parse).
632    let (mut body_bytes, model) = if let Ok(mut val) = serde_json::from_slice::<serde_json::Value>(&body_bytes) {
633        let mut changed = false;
634        if let Some(override_model) = s.state.get_model_override() {
635            if val.get("model").is_some() {
636                val["model"] = serde_json::Value::String(override_model);
637                changed = true;
638            }
639        }
640        // Apply effort override: inject output_config.effort before simple-model stripping.
641        if let Some(effort) = s.state.get_effort_override() {
642            if val.get("output_config").is_none() {
643                val["output_config"] = serde_json::json!({});
644            }
645            val["output_config"]["effort"] = serde_json::Value::String(effort);
646            changed = true;
647        }
648        // Apply thinking mode override: inject thinking object before simple-model stripping.
649        if let Some(thinking_mode) = s.state.get_thinking_override() {
650            val["thinking"] = serde_json::json!({ "type": thinking_mode });
651            changed = true;
652        }
653        let resolved_model = val["model"].as_str().unwrap_or("").to_owned();
654        if is_simple_model(&resolved_model) {
655            if let Some(obj) = val.as_object_mut() {
656                // Strip features unsupported by simpler models (Haiku).
657                for key in &["thinking", "effort", "reasoning_effort"] {
658                    if obj.remove(*key).is_some() { changed = true; }
659                }
660                // Strip effort from output_config if present
661                if let Some(serde_json::Value::Object(oc)) = obj.get_mut("output_config") {
662                    if oc.remove("effort").is_some() { changed = true; }
663                    // Remove output_config entirely if empty
664                    if oc.is_empty() { obj.remove("output_config"); }
665                }
666                // Strip context_management (thinking-related edit rules)
667                if obj.remove("context_management").is_some() { changed = true; }
668                // Remove extended-thinking beta flag
669                if let Some(serde_json::Value::Array(betas)) = obj.get_mut("betas") {
670                    let before = betas.len();
671                    betas.retain(|b| b.as_str() != Some("interleaved-thinking-2025-05-14"));
672                    if betas.len() != before { changed = true; }
673                }
674            }
675        }
676        let model = val["model"].as_str().unwrap_or("").to_owned();
677        let bytes = if changed {
678            Bytes::from(serde_json::to_vec(&val).unwrap_or_else(|_| body_bytes.to_vec()))
679        } else {
680            body_bytes
681        };
682        (bytes, model)
683    } else {
684        (body_bytes, String::new())
685    };
686
687    // Strip thinking/effort-related beta flags from the anthropic-beta header for simple models.
688    let mut headers = headers;
689    if is_simple_model(&model) {
690        if let Some(beta_val) = headers.get("anthropic-beta").and_then(|v| v.to_str().ok().map(|s| s.to_owned())) {
691            let filtered: Vec<&str> = beta_val.split(',')
692                .map(|s| s.trim())
693                .filter(|b| !b.contains("thinking") && !b.contains("effort"))
694                .collect();
695            let new_beta = filtered.join(",");
696            if filtered.is_empty() {
697                headers.remove("anthropic-beta");
698            } else if let Ok(v) = axum::http::HeaderValue::from_str(&new_beta) {
699                headers.insert("anthropic-beta", v);
700            }
701        }
702    }
703
704    let req_start_ms = now_ms();
705    let request_id = uuid::Uuid::new_v4().to_string()[..8].to_owned();
706
707    let fp = router::fingerprint(&body_bytes);
708    let fp_ref = fp.as_deref();
709
710    let mut tried: HashSet<String> = HashSet::new();
711    // Track accounts we've already attempted a token refresh for this request.
712    let mut refreshed: HashSet<String> = HashSet::new();
713    // Guard: only attempt model fallback once per request.
714    let mut fell_back = false;
715    // Cap wait to the configured request timeout — the client's TCP connection
716    // won't survive 5 hours anyway; return 503 so the client can retry.
717    let wait_deadline_ms = now_ms() + s.config.server.request_timeout_secs.saturating_mul(1_000);
718
719    loop {
720        let effective_strategy = s.state.get_routing_strategy()
721            .unwrap_or(s.config.server.routing_strategy);
722        let snap = s.state.routing_snapshot();
723        let effective_burst_rpm = s.state.get_burst_rpm_limit_override()
724            .unwrap_or(s.config.server.burst_rpm_limit);
725        let account = match router::pick_account(
726            &s.config.accounts, &s.state, &snap, fp_ref, &tried,
727            s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
728            effective_strategy, effective_burst_rpm,
729        ) {
730            Some(a) => a,
731            None => {
732                // Model fallback: before waiting, try switching to a cheaper model.
733                // Rate limits are often per-model, so the fallback may succeed immediately.
734                // Override chain: runtime override → config → auto-detect from model name.
735                if !fell_back && !model.is_empty() {
736                    let fallback: Option<String> = match s.state.get_fallback_model_override() {
737                        Some(Some(m)) => Some(m),              // explicit override
738                        Some(None) => None,                     // explicitly disabled
739                        None => s.config.server.fallback_model.clone()
740                            .or_else(|| auto_fallback_model(&model).map(|s| s.to_owned())),
741                    };
742                    if let Some(ref fb) = fallback {
743                        if model != *fb {
744                            if let Ok(mut val) = serde_json::from_slice::<serde_json::Value>(&body_bytes) {
745                                warn!(from = %model, to = %fb, "all accounts cooling — falling back to cheaper model");
746                                val["model"] = serde_json::Value::String(fb.clone());
747                                body_bytes = Bytes::from(serde_json::to_vec(&val).unwrap_or_else(|_| body_bytes.to_vec()));
748                                fell_back = true;
749                                tried.clear();
750                                continue;
751                            }
752                        }
753                    }
754                }
755
756                // Check whether any accounts are just temporarily cooling down
757                // (429/529 backoff) rather than permanently disabled / auth_failed.
758                // If so, wait for the soonest one to recover and retry.
759                let account_states = s.state.account_states();
760                let now = now_ms();
761                let soonest_ms = s.config.accounts.iter()
762                    .filter_map(|a| {
763                        let st = account_states.get(&a.name)?;
764                        if st.disabled { return None; } // auth_failed or permanently off
765                        if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
766                    })
767                    .min();
768
769                match soonest_ms {
770                    Some(wake_ms) if wake_ms <= wait_deadline_ms => {
771                        let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; // +50 ms buffer
772                        warn!(wait_ms, "all accounts cooling — waiting for next available account");
773                        tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
774                        tried.clear(); // accounts may have recovered; try them again
775                    }
776                    _ => return Err(ProxyError::AllAccountsUnavailable),
777                }
778                continue;
779            }
780        };
781
782        let account_name = account.name.clone();
783        s.state.record_request_burst(&account_name);
784
785        // Use the live (possibly refreshed) token rather than the one baked into config.
786        // For OpenAI/chatgpt.com accounts, Credential::bearer_token() returns id_token
787        // (short-lived OIDC JWT) which chatgpt.com requires. For all other providers it
788        // returns access_token. API-key accounts return the key directly.
789        let token = {
790            let creds = s.credentials.read().await;
791            let cred = creds.get(&account_name)
792                .cloned()
793                .or_else(|| account.credential.clone());
794            match cred {
795                Some(c) => c.bearer_token().to_owned(),
796                None => String::new(),
797            }
798        };
799
800        // Detect request and account protocols.  When they differ, translate
801        // the request body + path before forwarding and translate the response
802        // back so the client always sees its native wire format.
803        let req_is_anthropic = path.starts_with("/v1/messages");
804        let acct_is_anthropic = account.provider.wire_protocol()
805            == crate::provider::WireProtocol::Anthropic;
806        // chatgpt.com (Provider::OpenAI) uses a proprietary backend-api path + sentinel token.
807        // All other OpenAI-compat providers (OpenAIApi, Groq, Mistral, …) use /v1/chat/completions.
808        let acct_is_chatgpt = matches!(account.provider, Provider::OpenAI);
809
810        // log_model: what we actually send to the upstream (after resolve_model).
811        // Defaults to the incoming model; overridden in the OpenAI-compat branch.
812        let mut log_model = model.clone();
813
814        let (fwd_path, fwd_body, mut fwd_headers) = if req_is_anthropic == acct_is_anthropic {
815            // Same wire protocol — pass through unchanged.
816            (path.clone(), body_bytes.clone(), headers.clone())
817        } else if req_is_anthropic && acct_is_chatgpt {
818            // Anthropic client → chatgpt.com account: translate to backend-api format.
819            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
820            let translated = translate_anthropic_req_to_chatgpt(&val);
821            let mut h = headers.clone();
822            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
823                h.remove(*name);
824            }
825            (
826                "/backend-api/conversation".to_owned(),
827                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
828                h,
829            )
830        } else if req_is_anthropic {
831            // Anthropic client → standard OpenAI-compat account (OpenAIApi, Groq, Mistral, …).
832            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
833            // Resolve the target model: account pin → global mapping → provider default.
834            let target_model = resolve_model(&model, account, &s.config.model_mapping);
835            log_model = target_model.clone();
836            let translated = translate_anthropic_req_to_openai(val, &target_model);
837            let mut h = headers.clone();
838            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
839                h.remove(*name);
840            }
841            (
842                "/v1/chat/completions".to_owned(),
843                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
844                h,
845            )
846        } else {
847            // OpenAI client → Anthropic account: translate O→A.
848            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
849            let translated = translate_to_anthropic(val);
850            (
851                "/v1/messages".to_owned(),
852                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
853                headers.clone(),
854            )
855        };
856
857        // Resolve upstream URL: per-account override (set at load time for non-primary
858        // providers, or explicitly in tests) → config server URL.
859        let upstream = account.upstream_url.as_deref()
860            .unwrap_or(&s.config.server.upstream_url);
861
862        // Inject chatgpt.com sentinel token — only for the chatgpt.com proprietary path.
863        // Wrap in tokio::time::timeout (3s) to guarantee we don't block on Cloudflare challenges.
864        if req_is_anthropic && acct_is_chatgpt {
865            tracing::info!(account = %account_name, upstream = %upstream, "routing to chatgpt.com — fetching sentinel");
866            let sentinel_client = reqwest::Client::builder()
867                .timeout(std::time::Duration::from_secs(3))
868                .build()
869                .unwrap_or_default();
870            let sentinel_opt = tokio::time::timeout(
871                std::time::Duration::from_secs(3),
872                fetch_sentinel_token(&sentinel_client, upstream, &token),
873            ).await.ok().flatten();
874            if let Some(sentinel) = sentinel_opt {
875                if let Ok(name) = axum::http::header::HeaderName::from_bytes(
876                    b"openai-sentinel-chat-requirements-token",
877                ) {
878                    if let Ok(val) = axum::http::HeaderValue::from_str(&sentinel) {
879                        fwd_headers.insert(name, val);
880                    }
881                }
882            }
883        }
884
885        // Apply a hard 15s cap only for chatgpt.com: Cloudflare may hold the TCP connection
886        // open indefinitely for certain TLS fingerprints.  Standard API providers don't need this.
887        let response = if acct_is_chatgpt {
888            tracing::info!(account = %account_name, path = %fwd_path, "forwarding to chatgpt.com (15s cap)");
889            match tokio::time::timeout(
890                std::time::Duration::from_secs(15),
891                s.forwarder.forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token),
892            ).await {
893                Ok(Ok(r)) => r,
894                Ok(Err(e)) => {
895                    error!(account = %account_name, "chatgpt.com forward error: {:#}", e);
896                    s.state.set_cooldown(&account_name, 5 * 60_000);
897                    tried.insert(account_name);
898                    continue;
899                }
900                Err(_) => {
901                    warn!(account = %account_name, "chatgpt.com request timed out (Cloudflare) — cooling 5min");
902                    s.state.set_cooldown(&account_name, 5 * 60_000);
903                    tried.insert(account_name);
904                    continue;
905                }
906            }
907        } else {
908            s.forwarder
909                .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
910                .await
911                .map_err(|e| {
912                    error!("Forward error: {:#}", e);
913                    ProxyError::Upstream
914                })?
915        };
916
917        match response.status().as_u16() {
918            200..=299 => {
919                s.state.set_last_used(&account_name);
920                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
921                    s.state.update_rate_limits(&account_name, info);
922                }
923                // Translate response back to the client's expected protocol.
924                let response = if req_is_anthropic == acct_is_anthropic {
925                    response
926                } else if req_is_anthropic && acct_is_chatgpt {
927                    // Got chatgpt.com response; client expects Anthropic.
928                    translate_response_chatgpt_to_anthropic(response, &model).await
929                } else if req_is_anthropic {
930                    // Got standard OpenAI-compat response; client expects Anthropic.
931                    translate_response_openai_to_anthropic(response, &model).await
932                } else {
933                    // Got Anthropic response; client expects OpenAI.
934                    translate_response_anthropic_to_openai(response).await
935                };
936                return Ok(tap_usage(response, &s.state, s.telemetry.as_ref(), &account_name, &log_model, req_start_ms, &request_id, &path, tried.len()).await);
937            }
938            429 => {
939                let info = account.provider.parse_rate_limits(response.headers());
940                // Use Retry-After header for burst limits, or derive from reset window for
941                // quota exhaustion. Cap at 5 min — the `exhausted` flag (set by
942                // update_rate_limits below) already excludes the account from routing until
943                // reset_5h/7d passes, so a long cooldown is redundant and just causes the
944                // "all accounts cooling" wait to block requests for hours.
945                // Add 30s to burst-limit retry-after to prevent immediate re-selection:
946                // after a burst 429, maximus would pick the same account again (its 5h/7d
947                // quota looks fine), immediately hitting the limit again. The 30s buffer
948                // gives other accounts time in the window before this one recovers.
949                let retry_after_ms = response.headers()
950                    .get("retry-after")
951                    .and_then(|v| v.to_str().ok())
952                    .and_then(|s| s.parse::<u64>().ok())
953                    .map(|secs| secs.saturating_mul(1_000).max(500).saturating_add(30_000));
954                const MAX_429_COOLDOWN_MS: u64 = 5 * 60_000;
955                let cooldown_ms = info.as_ref()
956                    .and_then(|i| i.reset_5h.or(i.reset_7d))
957                    .map(|reset_secs| {
958                        let reset_ms = reset_secs.saturating_mul(1_000);
959                        reset_ms.saturating_sub(now_ms()).saturating_add(500) // +500ms buffer
960                    })
961                    .or(retry_after_ms)
962                    .unwrap_or(90_000) // 90s default (was 60s) — avoids immediate re-cycle
963                    .min(MAX_429_COOLDOWN_MS);
964                warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling");
965                if let Some(info) = info {
966                    s.state.update_rate_limits(&account_name, info);
967                }
968                s.state.set_cooldown_staggered(&account_name, cooldown_ms);
969                if cooldown_ms >= 5 * 60_000 && !s.state.get_alerts_muted() {
970                    let mins = cooldown_ms / 60_000;
971                    notify(
972                        "shunt: Rate Limited",
973                        &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
974                        "Ping",
975                    );
976                }
977                tried.insert(account_name);
978            }
979            529 => {
980                warn!(account = %account_name, "529 overloaded — cooling 30s");
981                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
982                    s.state.update_rate_limits(&account_name, info);
983                }
984                s.state.set_cooldown(&account_name, 30_000);
985                tried.insert(account_name);
986            }
987            401 => {
988                if !refreshed.contains(&account_name) {
989                    // Access token invalidated (e.g. user logged out) — try refresh.
990                    //
991                    // Acquire the per-account refresh lock so concurrent requests
992                    // for the same account serialise here. The first waiter to get
993                    // the lock does the actual OAuth refresh; subsequent waiters
994                    // re-check credentials and skip the refresh if the token was
995                    // already rotated while they were queued.
996                    let account_lock = {
997                        let mut locks = s.refresh_locks.lock();
998                        locks.entry(account_name.clone())
999                            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
1000                            .clone()
1001                    };
1002                    let _guard = account_lock.lock().await;
1003
1004                    // Re-read credentials after acquiring the lock — another task
1005                    // may have already refreshed while we were waiting.
1006                    let cred_before = {
1007                        let creds = s.credentials.read().await;
1008                        creds.get(&account_name).cloned()
1009                            .or_else(|| account.credential.clone())
1010                    };
1011                    let Some(cred) = cred_before else {
1012                        tried.insert(account_name);
1013                        continue;
1014                    };
1015
1016                    // Check if the token already changed while we were waiting.
1017                    let token_before = cred.access_token().to_owned();
1018                    let already_refreshed = {
1019                        let creds = s.credentials.read().await;
1020                        creds.get(&account_name)
1021                            .map(|c| c.access_token() != token_before)
1022                            .unwrap_or(false)
1023                    };
1024
1025                    if already_refreshed {
1026                        // Another concurrent request already refreshed — just retry.
1027                        warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
1028                        refreshed.insert(account_name);
1029                    } else if let Some(oauth_cred) = cred.as_oauth() {
1030                        // OAuth account — attempt token refresh.
1031                        match tokio::time::timeout(
1032                            std::time::Duration::from_secs(10),
1033                            account.provider.refresh_token(oauth_cred),
1034                        ).await {
1035                            Ok(Ok(fresh)) => {
1036                                warn!(account = %account_name, "401 — token refreshed, retrying");
1037                                {
1038                                    let mut creds = s.credentials.write().await;
1039                                    creds.insert(account_name.clone(), Credential::Oauth(fresh.clone()));
1040                                }
1041                                // Persist to disk so the refreshed token survives a restart.
1042                                let name = account_name.clone();
1043                                let fresh = fresh.clone();
1044                                tokio::task::spawn_blocking(move || {
1045                                    let mut store = CredentialsStore::load();
1046                                    store.accounts.insert(name, Credential::Oauth(fresh.clone()));
1047                                    store.save().ok();
1048                                    if fresh.id_token.is_some() {
1049                                        crate::oauth::write_codex_auth_file(&fresh);
1050                                    }
1051                                });
1052                                // Mark as refreshed but don't add to tried — retry this account.
1053                                refreshed.insert(account_name);
1054                            }
1055                            _ => {
1056                                // Refresh failed/timed out — cool down, don't permanently disable.
1057                                error!(account = %account_name, "401 — token refresh failed, cooling 5min");
1058                                s.state.set_cooldown(&account_name, 5 * 60_000);
1059                                tried.insert(account_name);
1060                            }
1061                        }
1062                    } else {
1063                        // API-key account — 401 means the key is invalid; no refresh possible.
1064                        error!(account = %account_name, "401 — API key rejected, cooling 5min");
1065                        s.state.set_cooldown(&account_name, 5 * 60_000);
1066                        tried.insert(account_name);
1067                    }
1068                } else {
1069                    // Already refreshed once and still 401 — cool down this account.
1070                    error!(account = %account_name, "401 after refresh — cooling 5min");
1071                    s.state.set_cooldown(&account_name, 5 * 60_000);
1072                    tried.insert(account_name);
1073                }
1074            }
1075            403 => {
1076                // Forbidden — could be a Cloudflare challenge (non-Anthropic providers)
1077                // or a genuine subscription/org block (Anthropic). Use a short cooldown
1078                // for non-Anthropic accounts so a CF block doesn't lock them out for 30m.
1079                if acct_is_anthropic {
1080                    error!(account = %account_name, "403 forbidden — cooling 30min");
1081                    s.state.set_cooldown(&account_name, 30 * 60_000);
1082                    if !s.state.get_alerts_muted() {
1083                        notify(
1084                            "shunt: Account Forbidden",
1085                            &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
1086                            "Basso",
1087                        );
1088                    }
1089                } else {
1090                    warn!(account = %account_name, "403 from chatgpt.com (Cloudflare) — cooling 5min");
1091                    s.state.set_cooldown(&account_name, 5 * 60_000);
1092                }
1093                tried.insert(account_name);
1094            }
1095            _ => {
1096                // 400, 404, 500, etc. — return as-is, no retry
1097                return Ok(response);
1098            }
1099        }
1100    }
1101}
1102
1103// ---------------------------------------------------------------------------
1104// Usage extraction
1105// ---------------------------------------------------------------------------
1106
1107/// Intercept a successful response to record token usage, then pass it through.
1108///
1109/// - Streaming: wraps the body stream with an SSE scanner (zero latency).
1110/// - Non-streaming: buffers the body, parses usage, rebuilds the response.
1111async fn tap_usage(
1112    resp: Response,
1113    state: &StateStore,
1114    telemetry: Option<&TelemetryClient>,
1115    account: &str,
1116    model: &str,
1117    req_start_ms: u64,
1118    request_id: &str,
1119    path: &str,
1120    retries: usize,
1121) -> Response {
1122    use axum::body::Body;
1123    use crate::state::RequestLog;
1124
1125    let streaming = quota::is_streaming_response(&resp);
1126
1127    if streaming {
1128        let state      = state.clone();
1129        let telem      = telemetry.cloned();
1130        let account    = account.to_owned();
1131        let model      = model.to_owned();
1132        let request_id = request_id.to_owned();
1133        let path       = path.to_owned();
1134        let on_complete = Arc::new(move |input: u64, output: u64| {
1135            let duration_ms = now_ms().saturating_sub(req_start_ms);
1136            info!(
1137                request_id = %request_id,
1138                account    = %account,
1139                model      = %model,
1140                status     = 200,
1141                latency_ms = duration_ms,
1142                path       = %path,
1143                stream     = true,
1144                input_tokens  = input,
1145                output_tokens = output,
1146                retries    = retries,
1147                "request complete"
1148            );
1149            let log = RequestLog {
1150                ts_ms: req_start_ms,
1151                account: account.clone(),
1152                model: model.clone(),
1153                status: 200,
1154                input_tokens: input,
1155                output_tokens: output,
1156                duration_ms,
1157            };
1158            state.record_usage(&account, input, output);
1159            state.record_global(&model, input, output);
1160            if let Some(ref t) = telem { t.push_event(&log); }
1161            state.record_request(log);
1162        });
1163        let (parts, body) = resp.into_parts();
1164        let wrapped = quota::wrap_streaming_body(body, on_complete);
1165        return Response::from_parts(parts, wrapped);
1166    }
1167
1168    // Non-streaming: buffer, extract, rebuild
1169    let (parts, body) = resp.into_parts();
1170    let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
1171        Ok(b) => b,
1172        Err(_) => return Response::from_parts(parts, Body::empty()),
1173    };
1174    let (input, output) = quota::extract_usage_from_json(&bytes);
1175    let duration_ms = now_ms().saturating_sub(req_start_ms);
1176    info!(
1177        request_id    = %request_id,
1178        account       = %account,
1179        model         = %model,
1180        status        = 200,
1181        latency_ms    = duration_ms,
1182        path          = %path,
1183        stream        = false,
1184        input_tokens  = input,
1185        output_tokens = output,
1186        retries       = retries,
1187        "request complete"
1188    );
1189    let log = RequestLog {
1190        ts_ms: req_start_ms,
1191        account: account.to_owned(),
1192        model: model.to_owned(),
1193        status: 200,
1194        input_tokens: input,
1195        output_tokens: output,
1196        duration_ms,
1197    };
1198    state.record_usage(account, input, output);
1199    state.record_global(model, input, output);
1200    if let Some(t) = telemetry { t.push_event(&log); }
1201    state.record_request(log);
1202    Response::from_parts(parts, Body::from(bytes))
1203}
1204
1205
1206// ---------------------------------------------------------------------------
1207// Rate limit prefetch
1208// ---------------------------------------------------------------------------
1209
1210/// For any account with no rate-limit data yet, make a cheap request directly
1211/// to the upstream API so we populate metrics without waiting for a real user
1212/// request. Runs as a background task after startup.
1213pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
1214    let client = reqwest::Client::builder()
1215        .timeout(std::time::Duration::from_secs(20))
1216        .build()
1217        .unwrap_or_default();
1218
1219    let existing_rl = state.rate_limit_snapshot();
1220    for account in &config.accounts {
1221        // Skip if we already have data for this account.
1222        if let Some(r) = existing_rl.get(&account.name) {
1223            if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
1224                continue;
1225            }
1226        }
1227
1228        // Skip accounts with no credentials or no prefetch support.
1229        let cred = match account.credential.clone() {
1230            Some(c) => c,
1231            None => continue,
1232        };
1233
1234        let Some((path, body)) = account.provider.prefetch_request() else {
1235            // No POST prefetch for this provider — do a lightweight GET auth check instead.
1236            if let Some(probe_path) = account.provider.auth_probe_get_path() {
1237                auth_probe_get(&client, probe_path, account, &state).await;
1238            }
1239            continue;
1240        };
1241        let url = format!("{}{}", config.server.upstream_url, path);
1242
1243        let resp = prefetch_send(&client, &url, &account.provider, cred.bearer_token(), &body).await;
1244
1245        let r = match resp {
1246            Ok(r) => r,
1247            Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
1248        };
1249
1250        if r.status() == reqwest::StatusCode::UNAUTHORIZED {
1251            tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
1252            let Some(oauth_cred) = cred.as_oauth() else {
1253                // API-key account — 401 during prefetch means the key is invalid.
1254                tracing::error!(account = %account.name, "prefetch 401 — API key rejected");
1255                state.set_auth_failed(&account.name);
1256                continue;
1257            };
1258            let fresh = match account.provider.refresh_token(oauth_cred).await {
1259                Ok(f) => f,
1260                Err(e) => {
1261                    tracing::warn!(account = %account.name, "token refresh failed: {e}");
1262                    state.set_auth_failed(&account.name);
1263                    continue;
1264                }
1265            };
1266            let mut store = crate::config::CredentialsStore::load();
1267            store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1268            store.save().ok();
1269            if fresh.id_token.is_some() {
1270                crate::oauth::write_codex_auth_file(&fresh);
1271            }
1272            // Update live credentials so the proxy uses the fresh token immediately.
1273            live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1274
1275            match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
1276                Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
1277                    tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
1278                    state.set_auth_failed(&account.name);
1279                }
1280                Ok(r2) => {
1281                    if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
1282                        state.update_rate_limits(&account.name, info);
1283                    }
1284                }
1285                Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
1286            }
1287        } else {
1288            tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
1289            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1290                state.update_rate_limits(&account.name, info);
1291            }
1292        }
1293    }
1294}
1295
1296/// Build and send a prefetch request for the given provider + token.
1297async fn prefetch_send(
1298    client: &reqwest::Client,
1299    url: &str,
1300    provider: &crate::provider::Provider,
1301    token: &str,
1302    body: &serde_json::Value,
1303) -> anyhow::Result<reqwest::Response> {
1304    let mut headers = reqwest::header::HeaderMap::new();
1305    provider.inject_auth_headers(&mut headers, token)?;
1306    for (name, value) in provider.prefetch_extra_headers() {
1307        headers.insert(
1308            reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
1309            reqwest::header::HeaderValue::from_static(value),
1310        );
1311    }
1312    Ok(client.post(url).headers(headers).json(body).send().await?)
1313}
1314
1315/// GET a cheap endpoint to verify credentials are still valid for providers that
1316/// don't expose rate-limit headers (e.g. OpenAI). On 401, attempts a token refresh;
1317/// marks the account as `reauth_required` if the refresh also fails.
1318async fn auth_probe_get(
1319    client: &reqwest::Client,
1320    path: &str,
1321    account: &crate::config::AccountConfig,
1322    state: &StateStore,
1323) {
1324    let cred = match account.credential.clone() {
1325        Some(c) => c,
1326        None => return,
1327    };
1328    let upstream = account.upstream_url.as_deref()
1329        .unwrap_or_else(|| account.provider.default_upstream_url());
1330    let url = format!("{}{}", upstream, path);
1331
1332    let do_get = |token: &str| -> reqwest::RequestBuilder {
1333        let mut headers = reqwest::header::HeaderMap::new();
1334        let _ = account.provider.inject_auth_headers(&mut headers, token);
1335        client.get(&url).headers(headers)
1336    };
1337
1338    let resp = match do_get(cred.bearer_token()).send().await {
1339        Ok(r) => r,
1340        Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
1341    };
1342
1343    if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
1344        tracing::info!(account = %account.name, "auth probe: token rejected, refreshing");
1345        let Some(oauth_cred) = cred.as_oauth() else {
1346            // API-key account — key is invalid; no refresh possible.
1347            tracing::error!(account = %account.name, "auth probe 401 — API key rejected");
1348            state.set_auth_failed(&account.name);
1349            return;
1350        };
1351        let fresh = match account.provider.refresh_token(oauth_cred).await {
1352            Ok(f) => f,
1353            Err(e) => {
1354                tracing::warn!(account = %account.name, "token refresh failed: {e}");
1355                state.set_auth_failed(&account.name);
1356                return;
1357            }
1358        };
1359        let mut store = crate::config::CredentialsStore::load();
1360        store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1361        store.save().ok();
1362        if fresh.id_token.is_some() {
1363            crate::oauth::write_codex_auth_file(&fresh);
1364        }
1365
1366        let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
1367        match do_get(fresh_token).send().await {
1368            Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
1369                tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
1370                state.set_auth_failed(&account.name);
1371            }
1372            Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
1373            Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
1374        }
1375    } else {
1376        tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
1377        // Access token is valid. Do NOT refresh here — rotating the refresh_token races
1378        // with codex CLI, which also tries to refresh at startup using the same token.
1379        // Proactive refreshing is handled solely by openai_token_refresh_loop.
1380    }
1381}
1382
1383// ---------------------------------------------------------------------------
1384// Proactive OpenAI token refresh loop
1385// ---------------------------------------------------------------------------
1386
1387/// Returns true if the access_token inside `cred` has fewer than `threshold_mins`
1388/// minutes remaining. Falls back to the stored `expires_at` if the JWT cannot be decoded.
1389fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
1390    let now_ms = std::time::SystemTime::now()
1391        .duration_since(std::time::UNIX_EPOCH)
1392        .unwrap_or_default()
1393        .as_millis() as u64;
1394    let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
1395        .unwrap_or(cred.expires_at);
1396    exp_ms < now_ms + threshold_mins * 60 * 1_000
1397}
1398
1399/// Sync live_creds from auth.json if auth.json has a newer token.
1400///
1401/// Codex CLI refreshes its own token and writes auth.json. Before we refresh,
1402/// we pull that in so we don't use a stale refresh_token that codex already rotated.
1403async fn sync_live_creds_from_auth_json(
1404    account_name: &str,
1405    live_creds: &LiveCredentials,
1406) {
1407    let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
1408    let current_exp = live_creds.read().await
1409        .get(account_name)
1410        .and_then(|c| c.as_oauth())
1411        .map(|c| c.expires_at)
1412        .unwrap_or(0);
1413    if from_file.expires_at > current_exp {
1414        tracing::info!(account = %account_name, "synced fresher token from auth.json");
1415        live_creds.write().await.insert(account_name.to_owned(), Credential::Oauth(from_file));
1416    }
1417}
1418
1419/// Perform a single proactive refresh for one account and persist the result.
1420async fn do_proactive_refresh(
1421    account: &crate::config::AccountConfig,
1422    creds: &crate::oauth::OAuthCredential,
1423    live_creds: &LiveCredentials,
1424    state: &StateStore,
1425) {
1426    tracing::info!(account = %account.name, "proactive OpenAI token refresh");
1427    match account.provider.refresh_token(creds).await {
1428        Ok(fresh) => {
1429            tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
1430            {
1431                let mut map = live_creds.write().await;
1432                map.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1433            }
1434            let mut store = crate::config::CredentialsStore::load();
1435            store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1436            store.save().ok();
1437            if fresh.id_token.is_some() {
1438                crate::oauth::write_codex_auth_file(&fresh);
1439            }
1440            state.clear_auth_failed(&account.name);
1441        }
1442        Err(e) => {
1443            tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
1444            state.set_auth_failed(&account.name);
1445        }
1446    }
1447}
1448
1449
1450/// Keeps shunt's live credentials in sync with Codex CLI's auth.json.
1451///
1452/// Strategy: never proactively rotate the refresh_token — that races with
1453/// Codex CLI's own refresh logic and causes "invalid_grant" errors. Instead,
1454/// just periodically sync from auth.json so shunt picks up whatever Codex wrote.
1455/// On-demand refresh (401 handler) covers the case where Codex isn't running
1456/// and the token has actually expired.
1457pub async fn openai_token_refresh_loop(
1458    config: Arc<Config>,
1459    state: StateStore,
1460    live_creds: LiveCredentials,
1461) {
1462    // Startup: sync from auth.json first (Codex may have refreshed since shunt last ran).
1463    for account in config.accounts.iter()
1464        .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1465    {
1466        if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
1467            continue;
1468        }
1469        sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1470
1471        let creds = {
1472            let map = live_creds.read().await;
1473            map.get(&account.name).cloned().or_else(|| account.credential.clone())
1474        };
1475        if let Some(creds) = creds {
1476            if let Some(oauth) = creds.as_oauth() {
1477                if access_token_expires_soon(oauth, 30) {
1478                    // access_token is nearly expired — refresh now so shunt can serve requests immediately.
1479                    do_proactive_refresh(account, oauth, &live_creds, &state).await;
1480                } else {
1481                    tracing::info!(account = %account.name, "access_token fresh at startup");
1482                }
1483            }
1484        }
1485    }
1486
1487    // Periodic sync every 5 minutes — picks up any token Codex CLI has written.
1488    // No proactive refresh: Codex owns the refresh lifecycle; shunt uses what Codex produces.
1489    loop {
1490        tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
1491        for account in config.accounts.iter()
1492            .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1493        {
1494            sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1495        }
1496    }
1497}
1498
1499// ---------------------------------------------------------------------------
1500// Error type
1501// ---------------------------------------------------------------------------
1502
1503enum ProxyError {
1504    BodyRead,
1505    Upstream,
1506    AllAccountsUnavailable,
1507    Unauthorized,
1508    RateLimited,
1509}
1510
1511impl IntoResponse for ProxyError {
1512    fn into_response(self) -> Response {
1513        match self {
1514            ProxyError::RateLimited => {
1515                let mut resp = (
1516                    StatusCode::TOO_MANY_REQUESTS,
1517                    axum::Json(json!({
1518                        "type": "error",
1519                        "error": {"type": "rate_limit_error", "message": "too many requests — slow down"}
1520                    })),
1521                ).into_response();
1522                resp.headers_mut().insert(
1523                    axum::http::header::RETRY_AFTER,
1524                    axum::http::HeaderValue::from_static("60"),
1525                );
1526                resp
1527            }
1528            other => {
1529                let (status, msg) = match other {
1530                    ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
1531                    ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
1532                    ProxyError::AllAccountsUnavailable => {
1533                        (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
1534                    }
1535                    ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
1536                    ProxyError::RateLimited => unreachable!(),
1537                };
1538                (status, axum::Json(json!({
1539                    "type": "error",
1540                    "error": {"type": "api_error", "message": msg}
1541                }))).into_response()
1542            }
1543        }
1544    }
1545}
1546
1547// ---------------------------------------------------------------------------
1548// Recovery watcher — periodically retries token refresh for auth_failed accounts
1549// ---------------------------------------------------------------------------
1550
1551/// Runs as a background task. Every 2 minutes, tries to refresh tokens for any
1552/// auth_failed account. If refresh succeeds the account is brought back online
1553/// without a process restart. If all accounts remain unrecoverable, fires a
1554/// macOS notification (at most once per hour).
1555pub async fn recovery_watcher(
1556    config: Arc<Config>,
1557    state: StateStore,
1558    credentials: LiveCredentials,
1559) {
1560    use std::time::{Duration, Instant};
1561    const CHECK_INTERVAL: Duration = Duration::from_secs(120);
1562    const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
1563
1564    let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
1565    let mut last_notified: Option<Instant> = None;
1566
1567    loop {
1568        tokio::time::sleep(CHECK_INTERVAL).await;
1569
1570        let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
1571        let failed = state.auth_failed_accounts(&name_refs);
1572        if failed.is_empty() {
1573            last_notified = None;
1574            continue;
1575        }
1576
1577        tracing::warn!(
1578            accounts = ?failed,
1579            "recovery: {} account(s) auth_failed, attempting token refresh",
1580            failed.len()
1581        );
1582
1583        let mut any_recovered = false;
1584
1585        for name in &failed {
1586            let cred = {
1587                let map = credentials.read().await;
1588                map.get(*name).cloned()
1589            };
1590            let Some(cred) = cred else { continue };
1591            if !cred.has_refresh_token() { continue; }
1592            let Some(oauth_cred) = cred.as_oauth().cloned() else { continue };
1593
1594            let provider = config.accounts.iter()
1595                .find(|a| a.name == *name)
1596                .map(|a| a.provider.clone())
1597                .unwrap_or_default();
1598
1599            let result = tokio::time::timeout(
1600                Duration::from_secs(20),
1601                provider.refresh_token(&oauth_cred),
1602            ).await;
1603
1604            match result {
1605                Ok(Ok(fresh)) => {
1606                    tracing::info!(account = %name, "recovery: token refreshed — account back online");
1607                    {
1608                        let mut map = credentials.write().await;
1609                        map.insert(name.to_string(), Credential::Oauth(fresh.clone()));
1610                    }
1611                    let name_owned = name.to_string();
1612                    let fresh_owned = fresh.clone();
1613                    tokio::task::spawn_blocking(move || {
1614                        let mut store = crate::config::CredentialsStore::load();
1615                        store.accounts.insert(name_owned, Credential::Oauth(fresh_owned.clone()));
1616                        store.save().ok();
1617                        if fresh_owned.id_token.is_some() {
1618                            crate::oauth::write_codex_auth_file(&fresh_owned);
1619                        }
1620                    });
1621                    state.clear_auth_failed(name);
1622                    any_recovered = true;
1623                }
1624                Ok(Err(e)) => {
1625                    tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
1626                    if !state.get_alerts_muted() {
1627                        notify(
1628                            "shunt: Reauth Required",
1629                            &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
1630                            "Basso",
1631                        );
1632                    }
1633                }
1634                Err(_) => {
1635                    tracing::error!(account = %name, "recovery: token refresh timed out");
1636                    if !state.get_alerts_muted() {
1637                        notify(
1638                            "shunt: Reauth Required",
1639                            &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
1640                            "Basso",
1641                        );
1642                    }
1643                }
1644            }
1645        }
1646
1647        if any_recovered {
1648            tracing::info!("recovery: at least one account is back online");
1649            continue;
1650        }
1651
1652        // All accounts still auth_failed after refresh attempts — notify.
1653        let still_failed = state.auth_failed_accounts(&name_refs);
1654        if still_failed.len() == account_names.len() {
1655            let should_notify = last_notified
1656                .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
1657                .unwrap_or(true);
1658            if should_notify {
1659                error!(
1660                    "ALL accounts are offline (auth failed). \
1661                     Run `shunt add-account` to re-authorize."
1662                );
1663                if !state.get_alerts_muted() {
1664                    notify(
1665                        "shunt: All Accounts Offline",
1666                        "All accounts need re-authorization. Run `shunt add-account`.",
1667                        "Basso",
1668                    );
1669                }
1670                last_notified = Some(Instant::now());
1671            }
1672        }
1673    }
1674}
1675
1676/// Sends a single lightweight prefetch request for `account` immediately after its
1677/// cooldown expires, so the router has fresh rate-limit headers before the next
1678/// real request arrives.
1679async fn post_cooldown_prefetch(
1680    client: &reqwest::Client,
1681    account: &crate::config::AccountConfig,
1682    token: &str,
1683    state: &StateStore,
1684    upstream_url: &str,
1685) {
1686    let Some((path, body)) = account.provider.prefetch_request() else {
1687        if let Some(probe_path) = account.provider.auth_probe_get_path() {
1688            auth_probe_get(client, probe_path, account, state).await;
1689        }
1690        return;
1691    };
1692    let url = format!("{upstream_url}{path}");
1693    match prefetch_send(client, &url, &account.provider, token, &body).await {
1694        Ok(r) => {
1695            // If the prefetch itself gets 429'd, the upstream rate limit hasn't
1696            // actually reset yet — re-cool the account to prevent waiting
1697            // requests from immediately hitting the same 429.
1698            if r.status().as_u16() == 429 {
1699                let retry_after_ms = r.headers()
1700                    .get("retry-after")
1701                    .and_then(|v| v.to_str().ok())
1702                    .and_then(|s| s.parse::<u64>().ok())
1703                    .map(|secs| secs.saturating_mul(1_000).max(500))
1704                    .unwrap_or(30_000);
1705                tracing::warn!(account = %account.name, retry_after_ms, "post-cooldown prefetch got 429 — re-cooling");
1706                state.set_cooldown_staggered(&account.name, retry_after_ms);
1707                return;
1708            }
1709            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1710                state.update_rate_limits(&account.name, info);
1711                tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1712            }
1713        }
1714        Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1715    }
1716}
1717
1718/// Periodic health-check loop: probes every account at a configurable interval
1719/// to detect dead/invalid accounts before real traffic hits them.
1720///
1721/// Uses exponential backoff (base_interval * 2^min(failures, 3)) per account,
1722/// capped at ~40 min. Marks accounts as `health_check_failed` after 2 consecutive
1723/// failures (tolerates one transient blip). On 401, delegates to `set_auth_failed`.
1724pub async fn health_check_loop(
1725    config: Arc<Config>,
1726    state: StateStore,
1727    live_creds: LiveCredentials,
1728) {
1729    if !config.server.health_check_enabled {
1730        return;
1731    }
1732
1733    // Let prefetch_rate_limits finish first.
1734    tokio::time::sleep(std::time::Duration::from_secs(15)).await;
1735
1736    let base_interval_ms = config.server.health_check_interval_secs * 1000;
1737    let timeout = std::time::Duration::from_secs(config.server.health_check_timeout_secs);
1738    let client = reqwest::Client::builder()
1739        .timeout(timeout)
1740        .build()
1741        .unwrap_or_default();
1742
1743    const FAILURE_THRESHOLD: u32 = 2;
1744    const MAX_BACKOFF_EXP: u32 = 3; // 2^3 = 8x → 40 min at 5-min base
1745
1746    loop {
1747        for account in &config.accounts {
1748            // Skip accounts already handled by recovery_watcher.
1749            {
1750                let states = state.account_states();
1751                if let Some(acc_state) = states.get(&account.name) {
1752                    if acc_state.disabled || acc_state.auth_failed {
1753                        continue;
1754                    }
1755                }
1756            }
1757
1758            // Exponential backoff based on consecutive failure count.
1759            let (last_check_ms, failures) = state.health_check_info(&account.name);
1760            let backoff_factor = 1u64 << failures.min(MAX_BACKOFF_EXP);
1761            let effective_interval_ms = base_interval_ms.saturating_mul(backoff_factor);
1762            let now = crate::state::now_ms_pub();
1763            if last_check_ms > 0 && now.saturating_sub(last_check_ms) < effective_interval_ms {
1764                continue;
1765            }
1766
1767            state.update_last_health_check(&account.name);
1768
1769            // Resolve current credential from live_creds (may have been refreshed).
1770            let cred = {
1771                let creds = live_creds.read().await;
1772                creds.get(&account.name).cloned()
1773            }.or_else(|| account.credential.clone());
1774
1775            let cred = match cred {
1776                Some(c) => c,
1777                None => {
1778                    // Local providers have no cred — probe reachability via GET /v1/models.
1779                    if let Some(probe_path) = account.provider.auth_probe_get_path() {
1780                        let upstream = account.upstream_url.as_deref()
1781                            .unwrap_or_else(|| account.provider.default_upstream_url());
1782                        let url = format!("{upstream}{probe_path}");
1783                        match client.get(&url).send().await {
1784                            Ok(r) if r.status().is_success() => {
1785                                if state.is_health_check_failed(&account.name) {
1786                                    tracing::info!(account = %account.name, "health check recovered");
1787                                }
1788                                state.clear_health_check_failed(&account.name);
1789                            }
1790                            Ok(r) => {
1791                                let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1792                                tracing::warn!(account = %account.name, status = %r.status(),
1793                                    failures = count, "health check failed");
1794                            }
1795                            Err(e) => {
1796                                let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1797                                tracing::warn!(account = %account.name, failures = count,
1798                                    "health check unreachable: {e}");
1799                            }
1800                        }
1801                    }
1802                    continue;
1803                }
1804            };
1805
1806            let token = cred.bearer_token();
1807            let upstream = account.upstream_url.as_deref()
1808                .unwrap_or(&config.server.upstream_url);
1809
1810            // Try POST prefetch (Anthropic) or GET auth probe (other providers).
1811            if let Some((path, body)) = account.provider.prefetch_request() {
1812                let url = format!("{upstream}{path}");
1813                match prefetch_send(&client, &url, &account.provider, token, &body).await {
1814                    Ok(r) => {
1815                        let status = r.status();
1816                        if status == reqwest::StatusCode::UNAUTHORIZED {
1817                            // Attempt refresh for OAuth accounts.
1818                            if let Some(oauth_cred) = cred.as_oauth() {
1819                                match account.provider.refresh_token(oauth_cred).await {
1820                                    Ok(fresh) => {
1821                                        let mut store = crate::config::CredentialsStore::load();
1822                                        store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1823                                        store.save().ok();
1824                                        live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh));
1825                                        state.clear_auth_failed(&account.name);
1826                                        if state.is_health_check_failed(&account.name) {
1827                                            state.clear_health_check_failed(&account.name);
1828                                        }
1829                                        tracing::info!(account = %account.name, "health check: token refreshed");
1830                                    }
1831                                    Err(e) => {
1832                                        tracing::error!(account = %account.name, "health check: refresh failed: {e}");
1833                                        state.set_auth_failed(&account.name);
1834                                    }
1835                                }
1836                            } else {
1837                                tracing::error!(account = %account.name, "health check: 401 — API key rejected");
1838                                state.set_auth_failed(&account.name);
1839                            }
1840                        } else if status.is_server_error() {
1841                            let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1842                            tracing::warn!(account = %account.name, status = %status,
1843                                failures = count, "health check: server error");
1844                        } else {
1845                            // Success — update rate limits if available.
1846                            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1847                                state.update_rate_limits(&account.name, info);
1848                            }
1849                            if state.is_health_check_failed(&account.name) {
1850                                tracing::info!(account = %account.name, "health check recovered");
1851                            }
1852                            state.clear_health_check_failed(&account.name);
1853                        }
1854                    }
1855                    Err(e) => {
1856                        let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1857                        tracing::warn!(account = %account.name, failures = count,
1858                            "health check probe failed: {e}");
1859                    }
1860                }
1861            } else if let Some(probe_path) = account.provider.auth_probe_get_path() {
1862                let probe_upstream = account.upstream_url.as_deref()
1863                    .unwrap_or_else(|| account.provider.default_upstream_url());
1864                let url = format!("{probe_upstream}{probe_path}");
1865                let mut headers = reqwest::header::HeaderMap::new();
1866                let _ = account.provider.inject_auth_headers(&mut headers, token);
1867                match client.get(&url).headers(headers).send().await {
1868                    Ok(r) => {
1869                        let status = r.status();
1870                        if status == reqwest::StatusCode::UNAUTHORIZED {
1871                            if let Some(oauth_cred) = cred.as_oauth() {
1872                                match account.provider.refresh_token(oauth_cred).await {
1873                                    Ok(fresh) => {
1874                                        let mut store = crate::config::CredentialsStore::load();
1875                                        store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1876                                        store.save().ok();
1877                                        live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh));
1878                                        state.clear_auth_failed(&account.name);
1879                                        state.clear_health_check_failed(&account.name);
1880                                        tracing::info!(account = %account.name, "health check: token refreshed (GET probe)");
1881                                    }
1882                                    Err(e) => {
1883                                        tracing::error!(account = %account.name, "health check: refresh failed: {e}");
1884                                        state.set_auth_failed(&account.name);
1885                                    }
1886                                }
1887                            } else {
1888                                tracing::error!(account = %account.name, "health check: 401 — API key rejected");
1889                                state.set_auth_failed(&account.name);
1890                            }
1891                        } else if status.is_server_error() {
1892                            let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1893                            tracing::warn!(account = %account.name, status = %status,
1894                                failures = count, "health check: server error (GET probe)");
1895                        } else {
1896                            if state.is_health_check_failed(&account.name) {
1897                                tracing::info!(account = %account.name, "health check recovered");
1898                            }
1899                            state.clear_health_check_failed(&account.name);
1900                        }
1901                    }
1902                    Err(e) => {
1903                        let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1904                        tracing::warn!(account = %account.name, failures = count,
1905                            "health check probe failed: {e}");
1906                    }
1907                }
1908            }
1909        }
1910
1911        tokio::time::sleep(std::time::Duration::from_secs(config.server.health_check_interval_secs)).await;
1912    }
1913}
1914
1915/// Watches for account cooldowns expiring and triggers a post-cooldown prefetch
1916/// so each account re-enters rotation with fresh rate-limit metrics.
1917///
1918/// Analogous to `recovery_watcher` (which handles `auth_failed` accounts), but
1919/// for timed cooldowns (429 / 529 / 401 / 403 backoffs). Sleeps precisely until
1920/// the next cooldown deadline rather than polling at a fixed interval.
1921///
1922/// Also handles stale rate-limit data: if an account's rate-limit snapshot is
1923/// older than STALE_RL_MS and the account is available, a lightweight prefetch
1924/// is triggered so the router always has fresh utilization metrics.
1925pub async fn cooldown_watcher(
1926    config: Arc<Config>,
1927    state: StateStore,
1928    credentials: LiveCredentials,
1929) {
1930    /// Re-fetch rate-limit headers if data is older than 1 hour.
1931    const STALE_RL_MS: u64 = 60 * 60_000;
1932
1933    let client = reqwest::Client::builder()
1934        .timeout(std::time::Duration::from_secs(20))
1935        .build()
1936        .unwrap_or_default();
1937
1938    // In-memory: the cooldown_until_ms value we already ran a post-resume for.
1939    // Prevents re-triggering on every poll after expiry.
1940    let mut last_resumed: HashMap<String, u64> = HashMap::new();
1941    // Accounts whose cooldown was long enough (≥5 min) to deserve a "back online" notification.
1942    let mut notify_on_resume: HashSet<String> = HashSet::new();
1943    // Epoch-ms of the last successful stale-prefetch per account.
1944    let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1945
1946    loop {
1947        let states = state.account_states();
1948        let rl_snapshot = state.rate_limit_snapshot();
1949        let now = now_ms();
1950        let mut next_wake_ms: Option<u64> = None;
1951
1952        for account in &config.accounts {
1953            let Some(st) = states.get(&account.name) else { continue };
1954            if st.disabled { continue; } // auth_failed or permanently disabled
1955            let cdl = st.cooldown_until_ms;
1956
1957            if cdl > 0 && cdl <= now {
1958                // Cooldown expired — skip if we already handled this exact deadline
1959                let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1960                if !handled {
1961                    tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1962                    let token = {
1963                        let creds = credentials.read().await;
1964                        creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1965                    };
1966                    if let Some(token) = token {
1967                        post_cooldown_prefetch(
1968                            &client, account, &token, &state,
1969                            &config.server.upstream_url,
1970                        ).await;
1971                    }
1972                    if notify_on_resume.remove(&account.name) && !state.get_alerts_muted() {
1973                        notify(
1974                            "shunt: Account Resumed",
1975                            &format!("Account '{}' is back online.", account.name),
1976                            "Glass",
1977                        );
1978                    }
1979                    last_resumed.insert(account.name.clone(), cdl);
1980                    last_stale_prefetch.insert(account.name.clone(), now);
1981                }
1982            } else if cdl > now {
1983                // Still cooling — schedule wake at expiry; flag for notification if long
1984                let remaining = cdl - now;
1985                if remaining >= 5 * 60_000 {
1986                    notify_on_resume.insert(account.name.clone());
1987                }
1988                next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1989            } else {
1990                // Not in cooldown — check for stale rate-limit data
1991                let rl_age = rl_snapshot
1992                    .get(&account.name)
1993                    .map(|r| now.saturating_sub(r.updated_ms))
1994                    .unwrap_or(u64::MAX); // no data → treat as infinitely stale
1995                let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1996                let fetched_ago = now.saturating_sub(last_fetched);
1997
1998                if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1999                    tracing::debug!(
2000                        account = %account.name,
2001                        age_min = rl_age / 60_000,
2002                        "rate-limit data stale — refreshing"
2003                    );
2004                    let token = {
2005                        let creds = credentials.read().await;
2006                        creds.get(&account.name).map(|c| c.bearer_token().to_owned())
2007                    };
2008                    if let Some(token) = token {
2009                        post_cooldown_prefetch(
2010                            &client, account, &token, &state,
2011                            &config.server.upstream_url,
2012                        ).await;
2013                    }
2014                    last_stale_prefetch.insert(account.name.clone(), now);
2015                }
2016            }
2017        }
2018
2019        // Sleep exactly until the next cooldown expires; fall back to 30s poll
2020        let sleep_ms = next_wake_ms
2021            .map(|wake| wake.saturating_sub(now_ms()).max(50))
2022            .unwrap_or(30_000);
2023        tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
2024    }
2025}
2026
2027use crate::notify::notify;
2028use crate::translate::{
2029    translate_to_anthropic,
2030    translate_from_anthropic,
2031    uuid_v4,
2032    translate_anthropic_stream,
2033    translate_anthropic_req_to_chatgpt,
2034    translate_response_chatgpt_to_anthropic,
2035    translate_anthropic_req_to_openai,
2036    translate_response_openai_to_anthropic,
2037    translate_response_anthropic_to_openai,
2038};
2039
2040// ---------------------------------------------------------------------------
2041// OpenAI-compatible API (translates to Anthropic Claude)
2042// ---------------------------------------------------------------------------
2043//
2044// When the OpenAI proxy receives a request at /v1/chat/completions, if an
2045// anthropic_base_url is configured, it translates the request to Anthropic
2046// Messages format and forwards it to the Anthropic proxy (which handles
2047// account selection, token management, and rate limiting).
2048// The response is translated back to OpenAI Chat Completions format.
2049
2050
2051
2052
2053/// GET /v1/models — return Claude models in OpenAI format.
2054async fn openai_models_handler() -> impl IntoResponse {
2055    axum::Json(json!({
2056        "object": "list",
2057        "data": [
2058            { "id": "claude-opus-4-6",           "object": "model", "owned_by": "anthropic" },
2059            { "id": "claude-sonnet-4-6",          "object": "model", "owned_by": "anthropic" },
2060            { "id": "claude-haiku-4-5-20251001",  "object": "model", "owned_by": "anthropic" },
2061        ]
2062    }))
2063}
2064
2065/// POST /v1/chat/completions — translate OpenAI request to Anthropic, proxy through Claude pool.
2066async fn openai_compat_handler(
2067    State(s): State<AppState>,
2068    req: Request,
2069) -> Result<Response, ProxyError> {
2070    let Some(ref anthropic_url) = s.anthropic_base_url else {
2071        // No Anthropic proxy configured — fall back to normal forwarding
2072        return proxy_handler(State(s), req).await;
2073    };
2074
2075    let body_bytes = axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY)
2076        .await
2077        .map_err(|_| ProxyError::BodyRead)?;
2078
2079    let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
2080        .unwrap_or(json!({}));
2081
2082    let stream = openai_body["stream"].as_bool().unwrap_or(false);
2083    let anthropic_body = translate_to_anthropic(openai_body);
2084
2085    let client = reqwest::Client::builder()
2086        .timeout(std::time::Duration::from_secs(300))
2087        .build()
2088        .map_err(|_| ProxyError::Upstream)?;
2089
2090    let mut req_builder = client
2091        .post(format!("{anthropic_url}/v1/messages"))
2092        .header("content-type", "application/json")
2093        .header("anthropic-version", "2023-06-01")
2094        .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
2095        .header("x-shunt-compat", "openai");
2096    if let Some(ref key) = s.config.server.remote_key {
2097        req_builder = req_builder.header("x-api-key", key.as_str());
2098    }
2099    let resp = req_builder
2100        .json(&anthropic_body)
2101        .send()
2102        .await
2103        .map_err(|_| ProxyError::Upstream)?;
2104
2105    if !resp.status().is_success() {
2106        let status = resp.status();
2107        let body = resp.text().await.unwrap_or_default();
2108        let code = status.as_u16();
2109        return Ok(axum::response::Response::builder()
2110            .status(code)
2111            .header("content-type", "application/json")
2112            .body(axum::body::Body::from(body))
2113            .unwrap());
2114    }
2115
2116    if stream {
2117        // Translate Anthropic SSE stream → OpenAI SSE stream
2118        let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
2119        let stream = translate_anthropic_stream(resp, chat_id);
2120        Ok(axum::response::Response::builder()
2121            .status(200)
2122            .header("content-type", "text/event-stream")
2123            .header("cache-control", "no-cache")
2124            .body(axum::body::Body::from_stream(stream))
2125            .unwrap())
2126    } else {
2127        let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
2128        let openai_resp = translate_from_anthropic(anthropic_resp);
2129        Ok(axum::Json(openai_resp).into_response())
2130    }
2131}
2132
2133// ---------------------------------------------------------------------------
2134// ChatGPT backend API translation (chatgpt.com /backend-api/conversation)
2135// ---------------------------------------------------------------------------
2136
2137/// Fetch the sentinel token required by chatgpt.com's backend API.
2138/// Returns None if the request fails or proof-of-work is required.
2139async fn fetch_sentinel_token(client: &reqwest::Client, upstream: &str, token: &str) -> Option<String> {
2140    let url = format!("{}/backend-api/sentinel/chat-requirements", upstream);
2141    let resp = client
2142        .get(&url)
2143        .header("Authorization", format!("Bearer {}", token))
2144        .send()
2145        .await
2146        .ok()?;
2147    if !resp.status().is_success() {
2148        return None;
2149    }
2150    let json: serde_json::Value = resp.json().await.ok()?;
2151    if json["proofofwork"]["required"].as_bool() == Some(true) {
2152        return None;
2153    }
2154    json["token"].as_str().map(ToOwned::to_owned)
2155}
2156
2157
2158/// Returns true if the model lacks support for extended thinking / effort.
2159/// These params must be stripped before forwarding.
2160fn is_simple_model(model: &str) -> bool {
2161    model.contains("haiku")
2162}
2163
2164/// Auto-detect a fallback model based on the current model name.
2165/// opus → sonnet, sonnet → haiku, anything else → None.
2166fn auto_fallback_model(model: &str) -> Option<&'static str> {
2167    if model.contains("opus") {
2168        Some("claude-sonnet-4-6")
2169    } else if model.contains("sonnet") {
2170        Some("claude-haiku-4-5-20251001")
2171    } else {
2172        None
2173    }
2174}
2175
2176/// Resolve the target model name for a non-Anthropic account.
2177///
2178/// Priority: per-account `model` pin → global `model_mapping` → provider `default_model()`.
2179/// If the provider is `Local` (default_model = ""), the incoming model name is passed through.
2180fn resolve_model(
2181    incoming: &str,
2182    account: &crate::config::AccountConfig,
2183    mapping: &std::collections::HashMap<String, String>,
2184) -> String {
2185    // 1. Per-account pin (highest priority).
2186    if let Some(m) = &account.model {
2187        return m.clone();
2188    }
2189    // 2. Global mapping for this specific incoming model name.
2190    if let Some(m) = mapping.get(incoming) {
2191        return m.clone();
2192    }
2193    // 3. Provider default.
2194    let default = account.provider.default_model();
2195    if !default.is_empty() {
2196        return default.to_owned();
2197    }
2198    // 4. Pass through (Local provider — model name is server-defined).
2199    incoming.to_owned()
2200}
2201