1use std::collections::{HashMap, HashSet};
2use std::net::IpAddr;
3use std::sync::Arc;
4use std::time::Instant;
5
6use parking_lot::Mutex as ParkingMutex;
7
8use axum::extract::{Request, State};
9use axum::http::StatusCode;
10use axum::response::{IntoResponse, Response};
11use axum::routing::{get, post};
12use axum::Router;
13use bytes::Bytes;
14use serde_json::json;
15use tokio::sync::RwLock;
16use tracing::{error, info, warn};
17
18use crate::config::{state_path, Config, CredentialsStore};
19use crate::credential::Credential;
20use crate::forwarder::Forwarder;
21use crate::provider::Provider;
22use crate::quota;
23use crate::router;
24use crate::state::StateStore;
25use crate::telemetry::TelemetryClient;
26
27const MAX_REQUEST_BODY: usize = 100 * 1024 * 1024;
29
30#[derive(Clone)]
31struct AppState {
32 config: Arc<Config>,
33 forwarder: Arc<Forwarder>,
34 state: StateStore,
35 credentials: Arc<RwLock<HashMap<String, Credential>>>,
37 refresh_locks: Arc<ParkingMutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
45 started_ms: u64,
47 anthropic_base_url: Option<String>,
50 telemetry: Option<TelemetryClient>,
52 rate_limiter: Option<Arc<ParkingMutex<HashMap<IpAddr, TokenBucket>>>>,
54}
55
56struct TokenBucket {
58 tokens: f64,
59 last_refill: Instant,
60}
61
62impl TokenBucket {
63 fn new(capacity: f64) -> Self {
64 Self { tokens: capacity, last_refill: Instant::now() }
65 }
66
67 fn check_and_consume(&mut self, rpm: f64) -> bool {
70 let elapsed = self.last_refill.elapsed().as_secs_f64();
71 self.last_refill = Instant::now();
72 let burst = (rpm / 6.0).max(10.0);
74 self.tokens = (self.tokens + elapsed * rpm / 60.0).min(burst);
75 if self.tokens >= 1.0 {
76 self.tokens -= 1.0;
77 true
78 } else {
79 false
80 }
81 }
82}
83
84pub fn create_app(config: Config) -> anyhow::Result<Router> {
85 let (app, _, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
86 Ok(app)
87}
88
89pub type LiveCredentials = Arc<RwLock<HashMap<String, Credential>>>;
91
92fn build_app_state(
96 config: Config,
97 state: StateStore,
98 anthropic_base_url: Option<String>,
99) -> anyhow::Result<(AppState, LiveCredentials)> {
100 let forwarder = Forwarder::new(config.server.request_timeout_secs)?;
101
102 for a in &config.accounts {
103 if a.provider.auth_kind() == crate::provider::AuthKind::None {
104 state.clear_auth_failed(&a.name);
106 } else if a.credential.is_none() {
107 state.set_auth_failed(&a.name);
108 }
109 }
110
111 let credentials: LiveCredentials = Arc::new(RwLock::new(
112 config.accounts.iter()
113 .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
114 .collect::<HashMap<_, _>>(),
115 ));
116
117 let telemetry = config.server.telemetry_url.as_deref().map(|url| {
118 TelemetryClient::new(url, config.server.telemetry_token.clone(), config.server.instance_name.clone())
119 });
120
121 let rate_limiter = if config.server.rate_limit_rpm > 0 {
122 Some(Arc::new(ParkingMutex::new(HashMap::<IpAddr, TokenBucket>::new())))
123 } else {
124 None
125 };
126
127 let app_state = AppState {
128 config: Arc::new(config),
129 forwarder: Arc::new(forwarder),
130 state,
131 credentials: Arc::clone(&credentials),
132 refresh_locks: Arc::new(ParkingMutex::new(HashMap::new())),
133 started_ms: now_ms(),
134 anthropic_base_url,
135 telemetry,
136 rate_limiter,
137 };
138
139 Ok((app_state, credentials))
140}
141
142pub fn create_proxy_app(
143 config: Config,
144 state: StateStore,
145 anthropic_base_url: Option<String>,
146) -> anyhow::Result<(Router, LiveCredentials)> {
147 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
148
149 let app = Router::new()
150 .route("/v1/messages", post(proxy_handler))
151 .route("/v1/messages/count_tokens", post(proxy_handler))
152 .route("/v1/chat/completions", post(openai_compat_handler))
153 .route("/v1/models", get(openai_models_handler))
154 .fallback(proxy_handler)
155 .with_state(app_state);
156
157 Ok((app, credentials))
158}
159
160pub fn create_control_app(
163 config: Config,
164 state: StateStore,
165) -> anyhow::Result<Router> {
166 let (app_state, _) = build_app_state(config, state, None)?;
167
168 let app = Router::new()
169 .route("/health", get(health))
170 .route("/status", get(status_handler))
171 .route("/use", post(use_handler))
172 .route("/model", get(model_get_handler).post(model_set_handler).delete(model_clear_handler))
173 .route("/strategy", get(strategy_get_handler).post(strategy_set_handler).delete(strategy_clear_handler))
174 .route("/burst-limit", get(burst_limit_get_handler).post(burst_limit_set_handler).delete(burst_limit_clear_handler))
175 .route("/fallback", get(fallback_get_handler).post(fallback_set_handler).delete(fallback_clear_handler))
176 .route("/effort", get(effort_get_handler).post(effort_set_handler).delete(effort_clear_handler))
177 .route("/thinking", get(thinking_get_handler).post(thinking_set_handler).delete(thinking_clear_handler))
178 .route("/alerts", get(alerts_get_handler).post(alerts_set_handler))
179 .with_state(app_state);
180
181 Ok(app)
182}
183
184pub fn create_app_with_state(
188 config: Config,
189 state: StateStore,
190 anthropic_base_url: Option<String>,
191) -> anyhow::Result<(Router, LiveCredentials, Option<TelemetryClient>)> {
192 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
193 let telemetry = app_state.telemetry.clone();
194
195 let app = Router::new()
196 .route("/health", get(health))
198 .route("/status", get(status_handler))
199 .route("/use", post(use_handler))
200 .route("/model", get(model_get_handler).post(model_set_handler).delete(model_clear_handler))
201 .route("/strategy", get(strategy_get_handler).post(strategy_set_handler).delete(strategy_clear_handler))
202 .route("/burst-limit", get(burst_limit_get_handler).post(burst_limit_set_handler).delete(burst_limit_clear_handler))
203 .route("/fallback", get(fallback_get_handler).post(fallback_set_handler).delete(fallback_clear_handler))
204 .route("/effort", get(effort_get_handler).post(effort_set_handler).delete(effort_clear_handler))
205 .route("/thinking", get(thinking_get_handler).post(thinking_set_handler).delete(thinking_clear_handler))
206 .route("/alerts", get(alerts_get_handler).post(alerts_set_handler))
207 .route("/v1/messages", post(proxy_handler))
209 .route("/v1/messages/count_tokens", post(proxy_handler))
210 .route("/v1/chat/completions", post(openai_compat_handler))
211 .route("/v1/models", get(openai_models_handler))
212 .fallback(proxy_handler)
213 .with_state(app_state);
214
215 Ok((app, credentials, telemetry))
216}
217
218pub fn build_status_snapshot(config: &Config, state: &StateStore, started_ms: u64) -> serde_json::Value {
220 let account_states = state.account_states();
221 let rate_limits = state.rate_limit_snapshot();
222
223 let accounts: Vec<_> = config.accounts.iter().map(|a| {
224 let st = account_states.get(&a.name);
225 let rl = rate_limits.get(&a.name);
226 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
227 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
228 let reset_5h = rl.and_then(|r| r.reset_5h);
229 let reset_7d = rl.and_then(|r| r.reset_7d);
230 let disabled = st.map(|s| s.disabled).unwrap_or(false);
231 let auth_failed = st.map(|s| s.auth_failed).unwrap_or(false);
232 let health_check_failed = st.map(|s| s.health_check_failed).unwrap_or(false);
233 let cooldown_until_ms = st.map(|s| s.cooldown_until_ms).unwrap_or(0);
234 let available = state.is_available(&a.name);
235 let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
236
237 json!({
238 "name": a.name,
239 "email": email,
240 "provider": a.provider.to_string(),
241 "available": available,
242 "disabled": disabled,
243 "auth_failed": auth_failed,
244 "health_check_failed": health_check_failed,
245 "cooldown_until_ms": cooldown_until_ms,
246 "utilization_5h": utilization_5h,
247 "reset_5h": reset_5h,
248 "utilization_7d": utilization_7d,
249 "reset_7d": reset_7d,
250 })
251 }).collect();
252
253 json!({
254 "started_ms": started_ms,
255 "accounts": accounts,
256 "pinned_account": state.get_pinned(),
257 "last_used_account": state.get_last_used(),
258 })
259}
260
261async fn health() -> impl IntoResponse {
262 axum::Json(json!({ "status": "ok", "version": env!("CARGO_PKG_VERSION") }))
263}
264
265async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
266 let account_states = s.state.account_states();
267 let quotas = s.state.quota_snapshot();
268 let rate_limits = s.state.rate_limit_snapshot();
269
270 let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
271 let st = account_states.get(&a.name);
272 let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
273 "reauth_required"
274 } else if st.map(|s| s.disabled).unwrap_or(false) {
275 "disabled"
276 } else if st.map(|s| s.health_check_failed).unwrap_or(false) {
277 "unhealthy"
278 } else if s.state.is_available(&a.name) {
279 "available"
280 } else {
281 "cooling"
282 };
283
284 let quota = quotas.get(&a.name);
285 let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
286 let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
287 let tokens_used = quota.map(|q| json!({
288 "input": q.input_tokens,
289 "output": q.output_tokens,
290 "total": q.total_tokens(),
291 }));
292
293 let rl = rate_limits.get(&a.name);
294 let rate_limit = rl.map(|r| json!({
295 "utilization_5h": r.utilization_5h,
296 "reset_5h": r.reset_5h,
297 "status_5h": r.status_5h,
298 "utilization_7d": r.utilization_7d,
299 "reset_7d": r.reset_7d,
300 "status_7d": r.status_7d,
301 "representative_claim": r.representative_claim,
302 "updated_ms": r.updated_ms,
303 }));
304
305 let acc_state = account_states.get(&a.name);
306 let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
307 let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
308 let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
309 let health_check_failed = acc_state.map(|s| s.health_check_failed).unwrap_or(false);
310 let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
311 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
312 let reset_5h = rl.and_then(|r| r.reset_5h);
313 let status_5h = rl.and_then(|r| r.status_5h.clone());
314 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
315 let reset_7d = rl.and_then(|r| r.reset_7d);
316 let status_7d = rl.and_then(|r| r.status_7d.clone());
317 let available = s.state.is_available(&a.name);
318
319 json!({
320 "name": a.name,
321 "email": email,
322 "plan_type": a.plan_type,
323 "provider": a.provider.to_string(),
324 "status": avail_status,
325 "available": available,
326 "disabled": disabled,
327 "auth_failed": auth_failed,
328 "health_check_failed": health_check_failed,
329 "cooldown_until_ms": cooldown_until_ms,
330 "utilization_5h": utilization_5h,
331 "reset_5h": reset_5h,
332 "status_5h": status_5h,
333 "utilization_7d": utilization_7d,
334 "reset_7d": reset_7d,
335 "status_7d": status_7d,
336 "window_expires_ms": window_expires_ms,
337 "tokens_used": tokens_used,
338 "rate_limit": rate_limit,
339 })
340 }).collect();
341
342 let recent_requests = s.state.recent_requests_snapshot();
343 let savings = s.state.savings_snapshot();
344
345 axum::Json(json!({
346 "version": env!("CARGO_PKG_VERSION"),
347 "started_ms": s.started_ms,
348 "accounts": accounts,
349 "pinned_account": s.state.get_pinned(),
350 "last_used_account": s.state.get_last_used(),
351 "recent_requests": recent_requests,
352 "savings": savings,
353 }))
354}
355
356async fn use_handler(
357 State(s): State<AppState>,
358 axum::Json(body): axum::Json<serde_json::Value>,
359) -> Response {
360 let account = body["account"].as_str().map(|s| s.to_owned());
361 if let Some(ref name) = account {
363 if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
364 return (StatusCode::BAD_REQUEST, axum::Json(json!({
365 "error": format!("unknown account '{name}'")
366 }))).into_response();
367 }
368 let pinned = if name == "auto" { None } else { Some(name.clone()) };
369 s.state.set_pinned(pinned);
370 axum::Json(json!({ "pinned": name })).into_response()
371 } else {
372 s.state.set_pinned(None);
373 axum::Json(json!({ "pinned": null })).into_response()
374 }
375}
376
377async fn model_get_handler(State(s): State<AppState>) -> impl IntoResponse {
378 let model = s.state.get_model_override();
379 axum::Json(json!({ "model": model }))
380}
381
382async fn model_set_handler(
383 State(s): State<AppState>,
384 axum::Json(body): axum::Json<serde_json::Value>,
385) -> Response {
386 let Some(model) = body["model"].as_str() else {
387 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing model field" }))).into_response();
388 };
389 s.state.set_model_override(model.to_owned());
390 info!(model, "model override set");
391 axum::Json(json!({ "model": model })).into_response()
392}
393
394async fn model_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
395 s.state.clear_model_override();
396 info!("model override cleared");
397 axum::Json(json!({ "model": null }))
398}
399
400async fn strategy_get_handler(State(s): State<AppState>) -> impl IntoResponse {
401 let (strategy_str, source) = match s.state.get_routing_strategy() {
402 Some(st) => (st.as_str(), "override"),
403 None => (s.config.server.routing_strategy.as_str(), "config"),
404 };
405 axum::Json(json!({ "strategy": strategy_str, "source": source }))
406}
407
408async fn strategy_set_handler(
409 State(s): State<AppState>,
410 axum::Json(body): axum::Json<serde_json::Value>,
411) -> Response {
412 let Some(name) = body["strategy"].as_str() else {
413 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing strategy field" }))).into_response();
414 };
415 let Some(strategy) = crate::config::RoutingStrategy::from_str(name) else {
416 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": format!("unknown strategy '{name}'") }))).into_response();
417 };
418 s.state.set_routing_strategy(strategy);
419 info!(strategy = name, "routing strategy override set");
420 axum::Json(json!({ "strategy": strategy.as_str(), "source": "override" })).into_response()
421}
422
423async fn strategy_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
424 s.state.clear_routing_strategy();
425 info!("routing strategy override cleared");
426 let strategy_str = s.config.server.routing_strategy.as_str();
427 axum::Json(json!({ "strategy": strategy_str, "source": "config" }))
428}
429
430async fn burst_limit_get_handler(State(s): State<AppState>) -> impl IntoResponse {
433 let (limit, source) = match s.state.get_burst_rpm_limit_override() {
434 Some(l) => (l, "override"),
435 None => (s.config.server.burst_rpm_limit, if s.config.server.burst_rpm_limit == 10 { "default" } else { "config" }),
436 };
437 axum::Json(json!({ "burst_rpm_limit": limit, "source": source }))
438}
439
440async fn burst_limit_set_handler(
441 State(s): State<AppState>,
442 axum::Json(body): axum::Json<serde_json::Value>,
443) -> Response {
444 let Some(limit) = body["burst_rpm_limit"].as_u64() else {
445 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing burst_rpm_limit field (integer)" }))).into_response();
446 };
447 let limit = limit as u32;
448 s.state.set_burst_rpm_limit_override(limit);
449 info!(limit, "burst RPM limit override set");
450 axum::Json(json!({ "burst_rpm_limit": limit, "source": "override" })).into_response()
451}
452
453async fn burst_limit_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
454 s.state.clear_burst_rpm_limit_override();
455 info!("burst RPM limit override cleared");
456 let limit = s.config.server.burst_rpm_limit;
457 axum::Json(json!({ "burst_rpm_limit": limit, "source": if limit == 10 { "default" } else { "config" } }))
458}
459
460async fn fallback_get_handler(State(s): State<AppState>) -> impl IntoResponse {
463 match s.state.get_fallback_model_override() {
464 Some(Some(model)) => axum::Json(json!({ "fallback_model": model, "source": "override" })),
465 Some(None) => axum::Json(json!({ "fallback_model": null, "source": "override", "disabled": true })),
466 None => match &s.config.server.fallback_model {
467 Some(model) => axum::Json(json!({ "fallback_model": model, "source": "config" })),
468 None => axum::Json(json!({ "fallback_model": "auto", "source": "auto" })),
469 },
470 }
471}
472
473async fn fallback_set_handler(
474 State(s): State<AppState>,
475 axum::Json(body): axum::Json<serde_json::Value>,
476) -> Response {
477 if body["fallback_model"].is_null() || body.get("disabled").and_then(|v| v.as_bool()) == Some(true) {
478 s.state.set_fallback_model_override(None);
479 info!("fallback model explicitly disabled");
480 return axum::Json(json!({ "fallback_model": null, "source": "override", "disabled": true })).into_response();
481 }
482 let Some(model) = body["fallback_model"].as_str() else {
483 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing fallback_model field" }))).into_response();
484 };
485 let model = model.to_owned();
486 s.state.set_fallback_model_override(Some(model.clone()));
487 info!(model = %model, "fallback model override set");
488 axum::Json(json!({ "fallback_model": model, "source": "override" })).into_response()
489}
490
491async fn fallback_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
492 s.state.clear_fallback_model_override();
493 info!("fallback model override cleared");
494 match &s.config.server.fallback_model {
495 Some(model) => axum::Json(json!({ "fallback_model": model, "source": "config" })),
496 None => axum::Json(json!({ "fallback_model": "auto", "source": "auto" })),
497 }
498}
499
500async fn effort_get_handler(State(s): State<AppState>) -> impl IntoResponse {
501 match s.state.get_effort_override() {
502 Some(effort) => axum::Json(json!({ "effort": effort, "source": "override" })),
503 None => axum::Json(json!({ "effort": null, "source": "passthrough" })),
504 }
505}
506
507async fn effort_set_handler(
508 State(s): State<AppState>,
509 axum::Json(body): axum::Json<serde_json::Value>,
510) -> Response {
511 let Some(effort) = body["effort"].as_str() else {
512 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing effort string field" }))).into_response();
513 };
514 let valid = ["low", "medium", "high", "xhigh", "max"];
515 if !valid.contains(&effort) {
516 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "effort must be one of: low, medium, high, max" }))).into_response();
517 }
518 s.state.set_effort_override(effort.to_owned());
519 info!(effort, "effort override set");
520 axum::Json(json!({ "effort": effort, "source": "override" })).into_response()
521}
522
523async fn effort_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
524 s.state.clear_effort_override();
525 info!("effort override cleared");
526 axum::Json(json!({ "effort": null, "source": "passthrough" }))
527}
528
529async fn thinking_get_handler(State(s): State<AppState>) -> impl IntoResponse {
530 match s.state.get_thinking_override() {
531 Some(mode) => axum::Json(json!({ "thinking": mode, "source": "override" })),
532 None => axum::Json(json!({ "thinking": null, "source": "passthrough" })),
533 }
534}
535
536async fn thinking_set_handler(
537 State(s): State<AppState>,
538 axum::Json(body): axum::Json<serde_json::Value>,
539) -> Response {
540 let Some(mode) = body["thinking"].as_str() else {
541 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing thinking string field" }))).into_response();
542 };
543 let valid = ["adaptive", "disabled"];
544 if !valid.contains(&mode) {
545 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "thinking must be one of: adaptive, disabled" }))).into_response();
546 }
547 s.state.set_thinking_override(mode.to_owned());
548 info!(mode, "thinking override set");
549 axum::Json(json!({ "thinking": mode, "source": "override" })).into_response()
550}
551
552async fn thinking_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
553 s.state.clear_thinking_override();
554 info!("thinking override cleared");
555 axum::Json(json!({ "thinking": null, "source": "passthrough" }))
556}
557
558async fn alerts_get_handler(State(s): State<AppState>) -> impl IntoResponse {
559 let muted = s.state.get_alerts_muted();
560 axum::Json(json!({ "muted": muted }))
561}
562
563async fn alerts_set_handler(
564 State(s): State<AppState>,
565 axum::Json(body): axum::Json<serde_json::Value>,
566) -> Response {
567 let Some(muted) = body["muted"].as_bool() else {
568 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing muted bool field" }))).into_response();
569 };
570 s.state.set_alerts_muted(muted);
571 info!(muted, "alerts mute state changed");
572 axum::Json(json!({ "muted": muted })).into_response()
573}
574
575use crate::state::now_ms_pub as now_ms;
576
577fn extract_client_ip(req: &Request, trust_proxy_headers: bool) -> IpAddr {
584 if trust_proxy_headers {
585 if let Some(ip) = req.headers()
586 .get("x-real-ip")
587 .and_then(|v| v.to_str().ok())
588 .and_then(|s| s.parse().ok())
589 {
590 return ip;
591 }
592 }
593 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
594}
595
596async fn proxy_handler(
597 State(s): State<AppState>,
598 req: Request,
599) -> Result<Response, ProxyError> {
600 if let Some(ref expected) = s.config.server.remote_key {
602 let provided = req.headers()
603 .get("x-api-key")
604 .and_then(|v| v.to_str().ok())
605 .unwrap_or("");
606 if provided != expected {
607 return Err(ProxyError::Unauthorized);
608 }
609 }
610
611 if let Some(ref rl) = s.rate_limiter {
613 let ip = extract_client_ip(&req, s.config.server.trust_proxy_headers);
614 let rpm = s.config.server.rate_limit_rpm as f64;
615 let allowed = rl.lock().entry(ip).or_insert_with(|| TokenBucket::new(rpm)).check_and_consume(rpm);
616 if !allowed {
617 return Err(ProxyError::RateLimited);
618 }
619 }
620
621 let method = req.method().as_str().to_owned();
622 let path = req.uri().path().to_owned();
623 let headers = req.headers().clone();
624
625 let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY)
626 .await
627 .map_err(|_| ProxyError::BodyRead)?;
628
629 let (mut body_bytes, model) = if let Ok(mut val) = serde_json::from_slice::<serde_json::Value>(&body_bytes) {
633 let mut changed = false;
634 if let Some(override_model) = s.state.get_model_override() {
635 if val.get("model").is_some() {
636 val["model"] = serde_json::Value::String(override_model);
637 changed = true;
638 }
639 }
640 if let Some(effort) = s.state.get_effort_override() {
642 if val.get("output_config").is_none() {
643 val["output_config"] = serde_json::json!({});
644 }
645 val["output_config"]["effort"] = serde_json::Value::String(effort);
646 changed = true;
647 }
648 if let Some(thinking_mode) = s.state.get_thinking_override() {
650 val["thinking"] = serde_json::json!({ "type": thinking_mode });
651 changed = true;
652 }
653 let resolved_model = val["model"].as_str().unwrap_or("").to_owned();
654 if is_simple_model(&resolved_model) {
655 if let Some(obj) = val.as_object_mut() {
656 for key in &["thinking", "effort", "reasoning_effort"] {
658 if obj.remove(*key).is_some() { changed = true; }
659 }
660 if let Some(serde_json::Value::Object(oc)) = obj.get_mut("output_config") {
662 if oc.remove("effort").is_some() { changed = true; }
663 if oc.is_empty() { obj.remove("output_config"); }
665 }
666 if obj.remove("context_management").is_some() { changed = true; }
668 if let Some(serde_json::Value::Array(betas)) = obj.get_mut("betas") {
670 let before = betas.len();
671 betas.retain(|b| b.as_str() != Some("interleaved-thinking-2025-05-14"));
672 if betas.len() != before { changed = true; }
673 }
674 }
675 }
676 let model = val["model"].as_str().unwrap_or("").to_owned();
677 let bytes = if changed {
678 Bytes::from(serde_json::to_vec(&val).unwrap_or_else(|_| body_bytes.to_vec()))
679 } else {
680 body_bytes
681 };
682 (bytes, model)
683 } else {
684 (body_bytes, String::new())
685 };
686
687 let mut headers = headers;
689 if is_simple_model(&model) {
690 if let Some(beta_val) = headers.get("anthropic-beta").and_then(|v| v.to_str().ok().map(|s| s.to_owned())) {
691 let filtered: Vec<&str> = beta_val.split(',')
692 .map(|s| s.trim())
693 .filter(|b| !b.contains("thinking") && !b.contains("effort"))
694 .collect();
695 let new_beta = filtered.join(",");
696 if filtered.is_empty() {
697 headers.remove("anthropic-beta");
698 } else if let Ok(v) = axum::http::HeaderValue::from_str(&new_beta) {
699 headers.insert("anthropic-beta", v);
700 }
701 }
702 }
703
704 let req_start_ms = now_ms();
705 let request_id = uuid::Uuid::new_v4().to_string()[..8].to_owned();
706
707 let fp = router::fingerprint(&body_bytes);
708 let fp_ref = fp.as_deref();
709
710 let mut tried: HashSet<String> = HashSet::new();
711 let mut refreshed: HashSet<String> = HashSet::new();
713 let mut fell_back = false;
715 let wait_deadline_ms = now_ms() + s.config.server.request_timeout_secs.saturating_mul(1_000);
718
719 loop {
720 let effective_strategy = s.state.get_routing_strategy()
721 .unwrap_or(s.config.server.routing_strategy);
722 let snap = s.state.routing_snapshot();
723 let effective_burst_rpm = s.state.get_burst_rpm_limit_override()
724 .unwrap_or(s.config.server.burst_rpm_limit);
725 let account = match router::pick_account(
726 &s.config.accounts, &s.state, &snap, fp_ref, &tried,
727 s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
728 effective_strategy, effective_burst_rpm,
729 ) {
730 Some(a) => a,
731 None => {
732 if !fell_back && !model.is_empty() {
736 let fallback: Option<String> = match s.state.get_fallback_model_override() {
737 Some(Some(m)) => Some(m), Some(None) => None, None => s.config.server.fallback_model.clone()
740 .or_else(|| auto_fallback_model(&model).map(|s| s.to_owned())),
741 };
742 if let Some(ref fb) = fallback {
743 if model != *fb {
744 if let Ok(mut val) = serde_json::from_slice::<serde_json::Value>(&body_bytes) {
745 warn!(from = %model, to = %fb, "all accounts cooling — falling back to cheaper model");
746 val["model"] = serde_json::Value::String(fb.clone());
747 body_bytes = Bytes::from(serde_json::to_vec(&val).unwrap_or_else(|_| body_bytes.to_vec()));
748 fell_back = true;
749 tried.clear();
750 continue;
751 }
752 }
753 }
754 }
755
756 let account_states = s.state.account_states();
760 let now = now_ms();
761 let soonest_ms = s.config.accounts.iter()
762 .filter_map(|a| {
763 let st = account_states.get(&a.name)?;
764 if st.disabled { return None; } if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
766 })
767 .min();
768
769 match soonest_ms {
770 Some(wake_ms) if wake_ms <= wait_deadline_ms => {
771 let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; warn!(wait_ms, "all accounts cooling — waiting for next available account");
773 tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
774 tried.clear(); }
776 _ => return Err(ProxyError::AllAccountsUnavailable),
777 }
778 continue;
779 }
780 };
781
782 let account_name = account.name.clone();
783 s.state.record_request_burst(&account_name);
784
785 let token = {
790 let creds = s.credentials.read().await;
791 let cred = creds.get(&account_name)
792 .cloned()
793 .or_else(|| account.credential.clone());
794 match cred {
795 Some(c) => c.bearer_token().to_owned(),
796 None => String::new(),
797 }
798 };
799
800 let req_is_anthropic = path.starts_with("/v1/messages");
804 let acct_is_anthropic = account.provider.wire_protocol()
805 == crate::provider::WireProtocol::Anthropic;
806 let acct_is_chatgpt = matches!(account.provider, Provider::OpenAI);
809
810 let mut log_model = model.clone();
813
814 let (fwd_path, fwd_body, mut fwd_headers) = if req_is_anthropic == acct_is_anthropic {
815 (path.clone(), body_bytes.clone(), headers.clone())
817 } else if req_is_anthropic && acct_is_chatgpt {
818 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
820 let translated = translate_anthropic_req_to_chatgpt(&val);
821 let mut h = headers.clone();
822 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
823 h.remove(*name);
824 }
825 (
826 "/backend-api/conversation".to_owned(),
827 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
828 h,
829 )
830 } else if req_is_anthropic {
831 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
833 let target_model = resolve_model(&model, account, &s.config.model_mapping);
835 log_model = target_model.clone();
836 let translated = translate_anthropic_req_to_openai(val, &target_model);
837 let mut h = headers.clone();
838 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
839 h.remove(*name);
840 }
841 (
842 "/v1/chat/completions".to_owned(),
843 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
844 h,
845 )
846 } else {
847 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
849 let translated = translate_to_anthropic(val);
850 (
851 "/v1/messages".to_owned(),
852 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
853 headers.clone(),
854 )
855 };
856
857 let upstream = account.upstream_url.as_deref()
860 .unwrap_or(&s.config.server.upstream_url);
861
862 if req_is_anthropic && acct_is_chatgpt {
865 tracing::info!(account = %account_name, upstream = %upstream, "routing to chatgpt.com — fetching sentinel");
866 let sentinel_client = reqwest::Client::builder()
867 .timeout(std::time::Duration::from_secs(3))
868 .build()
869 .unwrap_or_default();
870 let sentinel_opt = tokio::time::timeout(
871 std::time::Duration::from_secs(3),
872 fetch_sentinel_token(&sentinel_client, upstream, &token),
873 ).await.ok().flatten();
874 if let Some(sentinel) = sentinel_opt {
875 if let Ok(name) = axum::http::header::HeaderName::from_bytes(
876 b"openai-sentinel-chat-requirements-token",
877 ) {
878 if let Ok(val) = axum::http::HeaderValue::from_str(&sentinel) {
879 fwd_headers.insert(name, val);
880 }
881 }
882 }
883 }
884
885 let response = if acct_is_chatgpt {
888 tracing::info!(account = %account_name, path = %fwd_path, "forwarding to chatgpt.com (15s cap)");
889 match tokio::time::timeout(
890 std::time::Duration::from_secs(15),
891 s.forwarder.forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token),
892 ).await {
893 Ok(Ok(r)) => r,
894 Ok(Err(e)) => {
895 error!(account = %account_name, "chatgpt.com forward error: {:#}", e);
896 s.state.set_cooldown(&account_name, 5 * 60_000);
897 tried.insert(account_name);
898 continue;
899 }
900 Err(_) => {
901 warn!(account = %account_name, "chatgpt.com request timed out (Cloudflare) — cooling 5min");
902 s.state.set_cooldown(&account_name, 5 * 60_000);
903 tried.insert(account_name);
904 continue;
905 }
906 }
907 } else {
908 s.forwarder
909 .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
910 .await
911 .map_err(|e| {
912 error!("Forward error: {:#}", e);
913 ProxyError::Upstream
914 })?
915 };
916
917 match response.status().as_u16() {
918 200..=299 => {
919 s.state.set_last_used(&account_name);
920 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
921 s.state.update_rate_limits(&account_name, info);
922 }
923 let response = if req_is_anthropic == acct_is_anthropic {
925 response
926 } else if req_is_anthropic && acct_is_chatgpt {
927 translate_response_chatgpt_to_anthropic(response, &model).await
929 } else if req_is_anthropic {
930 translate_response_openai_to_anthropic(response, &model).await
932 } else {
933 translate_response_anthropic_to_openai(response).await
935 };
936 return Ok(tap_usage(response, &s.state, s.telemetry.as_ref(), &account_name, &log_model, req_start_ms, &request_id, &path, tried.len()).await);
937 }
938 429 => {
939 let info = account.provider.parse_rate_limits(response.headers());
940 let retry_after_ms = response.headers()
950 .get("retry-after")
951 .and_then(|v| v.to_str().ok())
952 .and_then(|s| s.parse::<u64>().ok())
953 .map(|secs| secs.saturating_mul(1_000).max(500).saturating_add(30_000));
954 const MAX_429_COOLDOWN_MS: u64 = 5 * 60_000;
955 let cooldown_ms = info.as_ref()
956 .and_then(|i| i.reset_5h.or(i.reset_7d))
957 .map(|reset_secs| {
958 let reset_ms = reset_secs.saturating_mul(1_000);
959 reset_ms.saturating_sub(now_ms()).saturating_add(500) })
961 .or(retry_after_ms)
962 .unwrap_or(90_000) .min(MAX_429_COOLDOWN_MS);
964 warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling");
965 if let Some(info) = info {
966 s.state.update_rate_limits(&account_name, info);
967 }
968 s.state.set_cooldown_staggered(&account_name, cooldown_ms);
969 if cooldown_ms >= 5 * 60_000 && !s.state.get_alerts_muted() {
970 let mins = cooldown_ms / 60_000;
971 notify(
972 "shunt: Rate Limited",
973 &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
974 "Ping",
975 );
976 }
977 tried.insert(account_name);
978 }
979 529 => {
980 warn!(account = %account_name, "529 overloaded — cooling 30s");
981 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
982 s.state.update_rate_limits(&account_name, info);
983 }
984 s.state.set_cooldown(&account_name, 30_000);
985 tried.insert(account_name);
986 }
987 401 => {
988 if !refreshed.contains(&account_name) {
989 let account_lock = {
997 let mut locks = s.refresh_locks.lock();
998 locks.entry(account_name.clone())
999 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
1000 .clone()
1001 };
1002 let _guard = account_lock.lock().await;
1003
1004 let cred_before = {
1007 let creds = s.credentials.read().await;
1008 creds.get(&account_name).cloned()
1009 .or_else(|| account.credential.clone())
1010 };
1011 let Some(cred) = cred_before else {
1012 tried.insert(account_name);
1013 continue;
1014 };
1015
1016 let token_before = cred.access_token().to_owned();
1018 let already_refreshed = {
1019 let creds = s.credentials.read().await;
1020 creds.get(&account_name)
1021 .map(|c| c.access_token() != token_before)
1022 .unwrap_or(false)
1023 };
1024
1025 if already_refreshed {
1026 warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
1028 refreshed.insert(account_name);
1029 } else if let Some(oauth_cred) = cred.as_oauth() {
1030 match tokio::time::timeout(
1032 std::time::Duration::from_secs(10),
1033 account.provider.refresh_token(oauth_cred),
1034 ).await {
1035 Ok(Ok(fresh)) => {
1036 warn!(account = %account_name, "401 — token refreshed, retrying");
1037 {
1038 let mut creds = s.credentials.write().await;
1039 creds.insert(account_name.clone(), Credential::Oauth(fresh.clone()));
1040 }
1041 let name = account_name.clone();
1043 let fresh = fresh.clone();
1044 tokio::task::spawn_blocking(move || {
1045 let mut store = CredentialsStore::load();
1046 store.accounts.insert(name, Credential::Oauth(fresh.clone()));
1047 store.save().ok();
1048 if fresh.id_token.is_some() {
1049 crate::oauth::write_codex_auth_file(&fresh);
1050 }
1051 });
1052 refreshed.insert(account_name);
1054 }
1055 _ => {
1056 error!(account = %account_name, "401 — token refresh failed, cooling 5min");
1058 s.state.set_cooldown(&account_name, 5 * 60_000);
1059 tried.insert(account_name);
1060 }
1061 }
1062 } else {
1063 error!(account = %account_name, "401 — API key rejected, cooling 5min");
1065 s.state.set_cooldown(&account_name, 5 * 60_000);
1066 tried.insert(account_name);
1067 }
1068 } else {
1069 error!(account = %account_name, "401 after refresh — cooling 5min");
1071 s.state.set_cooldown(&account_name, 5 * 60_000);
1072 tried.insert(account_name);
1073 }
1074 }
1075 403 => {
1076 if acct_is_anthropic {
1080 error!(account = %account_name, "403 forbidden — cooling 30min");
1081 s.state.set_cooldown(&account_name, 30 * 60_000);
1082 if !s.state.get_alerts_muted() {
1083 notify(
1084 "shunt: Account Forbidden",
1085 &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
1086 "Basso",
1087 );
1088 }
1089 } else {
1090 warn!(account = %account_name, "403 from chatgpt.com (Cloudflare) — cooling 5min");
1091 s.state.set_cooldown(&account_name, 5 * 60_000);
1092 }
1093 tried.insert(account_name);
1094 }
1095 _ => {
1096 return Ok(response);
1098 }
1099 }
1100 }
1101}
1102
1103async fn tap_usage(
1112 resp: Response,
1113 state: &StateStore,
1114 telemetry: Option<&TelemetryClient>,
1115 account: &str,
1116 model: &str,
1117 req_start_ms: u64,
1118 request_id: &str,
1119 path: &str,
1120 retries: usize,
1121) -> Response {
1122 use axum::body::Body;
1123 use crate::state::RequestLog;
1124
1125 let streaming = quota::is_streaming_response(&resp);
1126
1127 if streaming {
1128 let state = state.clone();
1129 let telem = telemetry.cloned();
1130 let account = account.to_owned();
1131 let model = model.to_owned();
1132 let request_id = request_id.to_owned();
1133 let path = path.to_owned();
1134 let on_complete = Arc::new(move |input: u64, output: u64| {
1135 let duration_ms = now_ms().saturating_sub(req_start_ms);
1136 info!(
1137 request_id = %request_id,
1138 account = %account,
1139 model = %model,
1140 status = 200,
1141 latency_ms = duration_ms,
1142 path = %path,
1143 stream = true,
1144 input_tokens = input,
1145 output_tokens = output,
1146 retries = retries,
1147 "request complete"
1148 );
1149 let log = RequestLog {
1150 ts_ms: req_start_ms,
1151 account: account.clone(),
1152 model: model.clone(),
1153 status: 200,
1154 input_tokens: input,
1155 output_tokens: output,
1156 duration_ms,
1157 };
1158 state.record_usage(&account, input, output);
1159 state.record_global(&model, input, output);
1160 if let Some(ref t) = telem { t.push_event(&log); }
1161 state.record_request(log);
1162 });
1163 let (parts, body) = resp.into_parts();
1164 let wrapped = quota::wrap_streaming_body(body, on_complete);
1165 return Response::from_parts(parts, wrapped);
1166 }
1167
1168 let (parts, body) = resp.into_parts();
1170 let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
1171 Ok(b) => b,
1172 Err(_) => return Response::from_parts(parts, Body::empty()),
1173 };
1174 let (input, output) = quota::extract_usage_from_json(&bytes);
1175 let duration_ms = now_ms().saturating_sub(req_start_ms);
1176 info!(
1177 request_id = %request_id,
1178 account = %account,
1179 model = %model,
1180 status = 200,
1181 latency_ms = duration_ms,
1182 path = %path,
1183 stream = false,
1184 input_tokens = input,
1185 output_tokens = output,
1186 retries = retries,
1187 "request complete"
1188 );
1189 let log = RequestLog {
1190 ts_ms: req_start_ms,
1191 account: account.to_owned(),
1192 model: model.to_owned(),
1193 status: 200,
1194 input_tokens: input,
1195 output_tokens: output,
1196 duration_ms,
1197 };
1198 state.record_usage(account, input, output);
1199 state.record_global(model, input, output);
1200 if let Some(t) = telemetry { t.push_event(&log); }
1201 state.record_request(log);
1202 Response::from_parts(parts, Body::from(bytes))
1203}
1204
1205
1206pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
1214 let client = reqwest::Client::builder()
1215 .timeout(std::time::Duration::from_secs(20))
1216 .build()
1217 .unwrap_or_default();
1218
1219 let existing_rl = state.rate_limit_snapshot();
1220 for account in &config.accounts {
1221 if let Some(r) = existing_rl.get(&account.name) {
1223 if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
1224 continue;
1225 }
1226 }
1227
1228 let cred = match account.credential.clone() {
1230 Some(c) => c,
1231 None => continue,
1232 };
1233
1234 let Some((path, body)) = account.provider.prefetch_request() else {
1235 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1237 auth_probe_get(&client, probe_path, account, &state).await;
1238 }
1239 continue;
1240 };
1241 let url = format!("{}{}", config.server.upstream_url, path);
1242
1243 let resp = prefetch_send(&client, &url, &account.provider, cred.bearer_token(), &body).await;
1244
1245 let r = match resp {
1246 Ok(r) => r,
1247 Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
1248 };
1249
1250 if r.status() == reqwest::StatusCode::UNAUTHORIZED {
1251 tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
1252 let Some(oauth_cred) = cred.as_oauth() else {
1253 tracing::error!(account = %account.name, "prefetch 401 — API key rejected");
1255 state.set_auth_failed(&account.name);
1256 continue;
1257 };
1258 let fresh = match account.provider.refresh_token(oauth_cred).await {
1259 Ok(f) => f,
1260 Err(e) => {
1261 tracing::warn!(account = %account.name, "token refresh failed: {e}");
1262 state.set_auth_failed(&account.name);
1263 continue;
1264 }
1265 };
1266 let mut store = crate::config::CredentialsStore::load();
1267 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1268 store.save().ok();
1269 if fresh.id_token.is_some() {
1270 crate::oauth::write_codex_auth_file(&fresh);
1271 }
1272 live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1274
1275 match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
1276 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
1277 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
1278 state.set_auth_failed(&account.name);
1279 }
1280 Ok(r2) => {
1281 if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
1282 state.update_rate_limits(&account.name, info);
1283 }
1284 }
1285 Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
1286 }
1287 } else {
1288 tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
1289 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1290 state.update_rate_limits(&account.name, info);
1291 }
1292 }
1293 }
1294}
1295
1296async fn prefetch_send(
1298 client: &reqwest::Client,
1299 url: &str,
1300 provider: &crate::provider::Provider,
1301 token: &str,
1302 body: &serde_json::Value,
1303) -> anyhow::Result<reqwest::Response> {
1304 let mut headers = reqwest::header::HeaderMap::new();
1305 provider.inject_auth_headers(&mut headers, token)?;
1306 for (name, value) in provider.prefetch_extra_headers() {
1307 headers.insert(
1308 reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
1309 reqwest::header::HeaderValue::from_static(value),
1310 );
1311 }
1312 Ok(client.post(url).headers(headers).json(body).send().await?)
1313}
1314
1315async fn auth_probe_get(
1319 client: &reqwest::Client,
1320 path: &str,
1321 account: &crate::config::AccountConfig,
1322 state: &StateStore,
1323) {
1324 let cred = match account.credential.clone() {
1325 Some(c) => c,
1326 None => return,
1327 };
1328 let upstream = account.upstream_url.as_deref()
1329 .unwrap_or_else(|| account.provider.default_upstream_url());
1330 let url = format!("{}{}", upstream, path);
1331
1332 let do_get = |token: &str| -> reqwest::RequestBuilder {
1333 let mut headers = reqwest::header::HeaderMap::new();
1334 let _ = account.provider.inject_auth_headers(&mut headers, token);
1335 client.get(&url).headers(headers)
1336 };
1337
1338 let resp = match do_get(cred.bearer_token()).send().await {
1339 Ok(r) => r,
1340 Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
1341 };
1342
1343 if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
1344 tracing::info!(account = %account.name, "auth probe: token rejected, refreshing");
1345 let Some(oauth_cred) = cred.as_oauth() else {
1346 tracing::error!(account = %account.name, "auth probe 401 — API key rejected");
1348 state.set_auth_failed(&account.name);
1349 return;
1350 };
1351 let fresh = match account.provider.refresh_token(oauth_cred).await {
1352 Ok(f) => f,
1353 Err(e) => {
1354 tracing::warn!(account = %account.name, "token refresh failed: {e}");
1355 state.set_auth_failed(&account.name);
1356 return;
1357 }
1358 };
1359 let mut store = crate::config::CredentialsStore::load();
1360 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1361 store.save().ok();
1362 if fresh.id_token.is_some() {
1363 crate::oauth::write_codex_auth_file(&fresh);
1364 }
1365
1366 let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
1367 match do_get(fresh_token).send().await {
1368 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
1369 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
1370 state.set_auth_failed(&account.name);
1371 }
1372 Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
1373 Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
1374 }
1375 } else {
1376 tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
1377 }
1381}
1382
1383fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
1390 let now_ms = std::time::SystemTime::now()
1391 .duration_since(std::time::UNIX_EPOCH)
1392 .unwrap_or_default()
1393 .as_millis() as u64;
1394 let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
1395 .unwrap_or(cred.expires_at);
1396 exp_ms < now_ms + threshold_mins * 60 * 1_000
1397}
1398
1399async fn sync_live_creds_from_auth_json(
1404 account_name: &str,
1405 live_creds: &LiveCredentials,
1406) {
1407 let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
1408 let current_exp = live_creds.read().await
1409 .get(account_name)
1410 .and_then(|c| c.as_oauth())
1411 .map(|c| c.expires_at)
1412 .unwrap_or(0);
1413 if from_file.expires_at > current_exp {
1414 tracing::info!(account = %account_name, "synced fresher token from auth.json");
1415 live_creds.write().await.insert(account_name.to_owned(), Credential::Oauth(from_file));
1416 }
1417}
1418
1419async fn do_proactive_refresh(
1421 account: &crate::config::AccountConfig,
1422 creds: &crate::oauth::OAuthCredential,
1423 live_creds: &LiveCredentials,
1424 state: &StateStore,
1425) {
1426 tracing::info!(account = %account.name, "proactive OpenAI token refresh");
1427 match account.provider.refresh_token(creds).await {
1428 Ok(fresh) => {
1429 tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
1430 {
1431 let mut map = live_creds.write().await;
1432 map.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1433 }
1434 let mut store = crate::config::CredentialsStore::load();
1435 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1436 store.save().ok();
1437 if fresh.id_token.is_some() {
1438 crate::oauth::write_codex_auth_file(&fresh);
1439 }
1440 state.clear_auth_failed(&account.name);
1441 }
1442 Err(e) => {
1443 tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
1444 state.set_auth_failed(&account.name);
1445 }
1446 }
1447}
1448
1449
1450pub async fn openai_token_refresh_loop(
1458 config: Arc<Config>,
1459 state: StateStore,
1460 live_creds: LiveCredentials,
1461) {
1462 for account in config.accounts.iter()
1464 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1465 {
1466 if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
1467 continue;
1468 }
1469 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1470
1471 let creds = {
1472 let map = live_creds.read().await;
1473 map.get(&account.name).cloned().or_else(|| account.credential.clone())
1474 };
1475 if let Some(creds) = creds {
1476 if let Some(oauth) = creds.as_oauth() {
1477 if access_token_expires_soon(oauth, 30) {
1478 do_proactive_refresh(account, oauth, &live_creds, &state).await;
1480 } else {
1481 tracing::info!(account = %account.name, "access_token fresh at startup");
1482 }
1483 }
1484 }
1485 }
1486
1487 loop {
1490 tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
1491 for account in config.accounts.iter()
1492 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1493 {
1494 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1495 }
1496 }
1497}
1498
1499enum ProxyError {
1504 BodyRead,
1505 Upstream,
1506 AllAccountsUnavailable,
1507 Unauthorized,
1508 RateLimited,
1509}
1510
1511impl IntoResponse for ProxyError {
1512 fn into_response(self) -> Response {
1513 match self {
1514 ProxyError::RateLimited => {
1515 let mut resp = (
1516 StatusCode::TOO_MANY_REQUESTS,
1517 axum::Json(json!({
1518 "type": "error",
1519 "error": {"type": "rate_limit_error", "message": "too many requests — slow down"}
1520 })),
1521 ).into_response();
1522 resp.headers_mut().insert(
1523 axum::http::header::RETRY_AFTER,
1524 axum::http::HeaderValue::from_static("60"),
1525 );
1526 resp
1527 }
1528 other => {
1529 let (status, msg) = match other {
1530 ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
1531 ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
1532 ProxyError::AllAccountsUnavailable => {
1533 (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
1534 }
1535 ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
1536 ProxyError::RateLimited => unreachable!(),
1537 };
1538 (status, axum::Json(json!({
1539 "type": "error",
1540 "error": {"type": "api_error", "message": msg}
1541 }))).into_response()
1542 }
1543 }
1544 }
1545}
1546
1547pub async fn recovery_watcher(
1556 config: Arc<Config>,
1557 state: StateStore,
1558 credentials: LiveCredentials,
1559) {
1560 use std::time::{Duration, Instant};
1561 const CHECK_INTERVAL: Duration = Duration::from_secs(120);
1562 const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
1563
1564 let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
1565 let mut last_notified: Option<Instant> = None;
1566
1567 loop {
1568 tokio::time::sleep(CHECK_INTERVAL).await;
1569
1570 let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
1571 let failed = state.auth_failed_accounts(&name_refs);
1572 if failed.is_empty() {
1573 last_notified = None;
1574 continue;
1575 }
1576
1577 tracing::warn!(
1578 accounts = ?failed,
1579 "recovery: {} account(s) auth_failed, attempting token refresh",
1580 failed.len()
1581 );
1582
1583 let mut any_recovered = false;
1584
1585 for name in &failed {
1586 let cred = {
1587 let map = credentials.read().await;
1588 map.get(*name).cloned()
1589 };
1590 let Some(cred) = cred else { continue };
1591 if !cred.has_refresh_token() { continue; }
1592 let Some(oauth_cred) = cred.as_oauth().cloned() else { continue };
1593
1594 let provider = config.accounts.iter()
1595 .find(|a| a.name == *name)
1596 .map(|a| a.provider.clone())
1597 .unwrap_or_default();
1598
1599 let result = tokio::time::timeout(
1600 Duration::from_secs(20),
1601 provider.refresh_token(&oauth_cred),
1602 ).await;
1603
1604 match result {
1605 Ok(Ok(fresh)) => {
1606 tracing::info!(account = %name, "recovery: token refreshed — account back online");
1607 {
1608 let mut map = credentials.write().await;
1609 map.insert(name.to_string(), Credential::Oauth(fresh.clone()));
1610 }
1611 let name_owned = name.to_string();
1612 let fresh_owned = fresh.clone();
1613 tokio::task::spawn_blocking(move || {
1614 let mut store = crate::config::CredentialsStore::load();
1615 store.accounts.insert(name_owned, Credential::Oauth(fresh_owned.clone()));
1616 store.save().ok();
1617 if fresh_owned.id_token.is_some() {
1618 crate::oauth::write_codex_auth_file(&fresh_owned);
1619 }
1620 });
1621 state.clear_auth_failed(name);
1622 any_recovered = true;
1623 }
1624 Ok(Err(e)) => {
1625 tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
1626 if !state.get_alerts_muted() {
1627 notify(
1628 "shunt: Reauth Required",
1629 &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
1630 "Basso",
1631 );
1632 }
1633 }
1634 Err(_) => {
1635 tracing::error!(account = %name, "recovery: token refresh timed out");
1636 if !state.get_alerts_muted() {
1637 notify(
1638 "shunt: Reauth Required",
1639 &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
1640 "Basso",
1641 );
1642 }
1643 }
1644 }
1645 }
1646
1647 if any_recovered {
1648 tracing::info!("recovery: at least one account is back online");
1649 continue;
1650 }
1651
1652 let still_failed = state.auth_failed_accounts(&name_refs);
1654 if still_failed.len() == account_names.len() {
1655 let should_notify = last_notified
1656 .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
1657 .unwrap_or(true);
1658 if should_notify {
1659 error!(
1660 "ALL accounts are offline (auth failed). \
1661 Run `shunt add-account` to re-authorize."
1662 );
1663 if !state.get_alerts_muted() {
1664 notify(
1665 "shunt: All Accounts Offline",
1666 "All accounts need re-authorization. Run `shunt add-account`.",
1667 "Basso",
1668 );
1669 }
1670 last_notified = Some(Instant::now());
1671 }
1672 }
1673 }
1674}
1675
1676async fn post_cooldown_prefetch(
1680 client: &reqwest::Client,
1681 account: &crate::config::AccountConfig,
1682 token: &str,
1683 state: &StateStore,
1684 upstream_url: &str,
1685) {
1686 let Some((path, body)) = account.provider.prefetch_request() else {
1687 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1688 auth_probe_get(client, probe_path, account, state).await;
1689 }
1690 return;
1691 };
1692 let url = format!("{upstream_url}{path}");
1693 match prefetch_send(client, &url, &account.provider, token, &body).await {
1694 Ok(r) => {
1695 if r.status().as_u16() == 429 {
1699 let retry_after_ms = r.headers()
1700 .get("retry-after")
1701 .and_then(|v| v.to_str().ok())
1702 .and_then(|s| s.parse::<u64>().ok())
1703 .map(|secs| secs.saturating_mul(1_000).max(500))
1704 .unwrap_or(30_000);
1705 tracing::warn!(account = %account.name, retry_after_ms, "post-cooldown prefetch got 429 — re-cooling");
1706 state.set_cooldown_staggered(&account.name, retry_after_ms);
1707 return;
1708 }
1709 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1710 state.update_rate_limits(&account.name, info);
1711 tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1712 }
1713 }
1714 Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1715 }
1716}
1717
1718pub async fn health_check_loop(
1725 config: Arc<Config>,
1726 state: StateStore,
1727 live_creds: LiveCredentials,
1728) {
1729 if !config.server.health_check_enabled {
1730 return;
1731 }
1732
1733 tokio::time::sleep(std::time::Duration::from_secs(15)).await;
1735
1736 let base_interval_ms = config.server.health_check_interval_secs * 1000;
1737 let timeout = std::time::Duration::from_secs(config.server.health_check_timeout_secs);
1738 let client = reqwest::Client::builder()
1739 .timeout(timeout)
1740 .build()
1741 .unwrap_or_default();
1742
1743 const FAILURE_THRESHOLD: u32 = 2;
1744 const MAX_BACKOFF_EXP: u32 = 3; loop {
1747 for account in &config.accounts {
1748 {
1750 let states = state.account_states();
1751 if let Some(acc_state) = states.get(&account.name) {
1752 if acc_state.disabled || acc_state.auth_failed {
1753 continue;
1754 }
1755 }
1756 }
1757
1758 let (last_check_ms, failures) = state.health_check_info(&account.name);
1760 let backoff_factor = 1u64 << failures.min(MAX_BACKOFF_EXP);
1761 let effective_interval_ms = base_interval_ms.saturating_mul(backoff_factor);
1762 let now = crate::state::now_ms_pub();
1763 if last_check_ms > 0 && now.saturating_sub(last_check_ms) < effective_interval_ms {
1764 continue;
1765 }
1766
1767 state.update_last_health_check(&account.name);
1768
1769 let cred = {
1771 let creds = live_creds.read().await;
1772 creds.get(&account.name).cloned()
1773 }.or_else(|| account.credential.clone());
1774
1775 let cred = match cred {
1776 Some(c) => c,
1777 None => {
1778 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1780 let upstream = account.upstream_url.as_deref()
1781 .unwrap_or_else(|| account.provider.default_upstream_url());
1782 let url = format!("{upstream}{probe_path}");
1783 match client.get(&url).send().await {
1784 Ok(r) if r.status().is_success() => {
1785 if state.is_health_check_failed(&account.name) {
1786 tracing::info!(account = %account.name, "health check recovered");
1787 }
1788 state.clear_health_check_failed(&account.name);
1789 }
1790 Ok(r) => {
1791 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1792 tracing::warn!(account = %account.name, status = %r.status(),
1793 failures = count, "health check failed");
1794 }
1795 Err(e) => {
1796 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1797 tracing::warn!(account = %account.name, failures = count,
1798 "health check unreachable: {e}");
1799 }
1800 }
1801 }
1802 continue;
1803 }
1804 };
1805
1806 let token = cred.bearer_token();
1807 let upstream = account.upstream_url.as_deref()
1808 .unwrap_or(&config.server.upstream_url);
1809
1810 if let Some((path, body)) = account.provider.prefetch_request() {
1812 let url = format!("{upstream}{path}");
1813 match prefetch_send(&client, &url, &account.provider, token, &body).await {
1814 Ok(r) => {
1815 let status = r.status();
1816 if status == reqwest::StatusCode::UNAUTHORIZED {
1817 if let Some(oauth_cred) = cred.as_oauth() {
1819 match account.provider.refresh_token(oauth_cred).await {
1820 Ok(fresh) => {
1821 let mut store = crate::config::CredentialsStore::load();
1822 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1823 store.save().ok();
1824 live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh));
1825 state.clear_auth_failed(&account.name);
1826 if state.is_health_check_failed(&account.name) {
1827 state.clear_health_check_failed(&account.name);
1828 }
1829 tracing::info!(account = %account.name, "health check: token refreshed");
1830 }
1831 Err(e) => {
1832 tracing::error!(account = %account.name, "health check: refresh failed: {e}");
1833 state.set_auth_failed(&account.name);
1834 }
1835 }
1836 } else {
1837 tracing::error!(account = %account.name, "health check: 401 — API key rejected");
1838 state.set_auth_failed(&account.name);
1839 }
1840 } else if status.is_server_error() {
1841 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1842 tracing::warn!(account = %account.name, status = %status,
1843 failures = count, "health check: server error");
1844 } else {
1845 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1847 state.update_rate_limits(&account.name, info);
1848 }
1849 if state.is_health_check_failed(&account.name) {
1850 tracing::info!(account = %account.name, "health check recovered");
1851 }
1852 state.clear_health_check_failed(&account.name);
1853 }
1854 }
1855 Err(e) => {
1856 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1857 tracing::warn!(account = %account.name, failures = count,
1858 "health check probe failed: {e}");
1859 }
1860 }
1861 } else if let Some(probe_path) = account.provider.auth_probe_get_path() {
1862 let probe_upstream = account.upstream_url.as_deref()
1863 .unwrap_or_else(|| account.provider.default_upstream_url());
1864 let url = format!("{probe_upstream}{probe_path}");
1865 let mut headers = reqwest::header::HeaderMap::new();
1866 let _ = account.provider.inject_auth_headers(&mut headers, token);
1867 match client.get(&url).headers(headers).send().await {
1868 Ok(r) => {
1869 let status = r.status();
1870 if status == reqwest::StatusCode::UNAUTHORIZED {
1871 if let Some(oauth_cred) = cred.as_oauth() {
1872 match account.provider.refresh_token(oauth_cred).await {
1873 Ok(fresh) => {
1874 let mut store = crate::config::CredentialsStore::load();
1875 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1876 store.save().ok();
1877 live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh));
1878 state.clear_auth_failed(&account.name);
1879 state.clear_health_check_failed(&account.name);
1880 tracing::info!(account = %account.name, "health check: token refreshed (GET probe)");
1881 }
1882 Err(e) => {
1883 tracing::error!(account = %account.name, "health check: refresh failed: {e}");
1884 state.set_auth_failed(&account.name);
1885 }
1886 }
1887 } else {
1888 tracing::error!(account = %account.name, "health check: 401 — API key rejected");
1889 state.set_auth_failed(&account.name);
1890 }
1891 } else if status.is_server_error() {
1892 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1893 tracing::warn!(account = %account.name, status = %status,
1894 failures = count, "health check: server error (GET probe)");
1895 } else {
1896 if state.is_health_check_failed(&account.name) {
1897 tracing::info!(account = %account.name, "health check recovered");
1898 }
1899 state.clear_health_check_failed(&account.name);
1900 }
1901 }
1902 Err(e) => {
1903 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1904 tracing::warn!(account = %account.name, failures = count,
1905 "health check probe failed: {e}");
1906 }
1907 }
1908 }
1909 }
1910
1911 tokio::time::sleep(std::time::Duration::from_secs(config.server.health_check_interval_secs)).await;
1912 }
1913}
1914
1915pub async fn cooldown_watcher(
1926 config: Arc<Config>,
1927 state: StateStore,
1928 credentials: LiveCredentials,
1929) {
1930 const STALE_RL_MS: u64 = 60 * 60_000;
1932
1933 let client = reqwest::Client::builder()
1934 .timeout(std::time::Duration::from_secs(20))
1935 .build()
1936 .unwrap_or_default();
1937
1938 let mut last_resumed: HashMap<String, u64> = HashMap::new();
1941 let mut notify_on_resume: HashSet<String> = HashSet::new();
1943 let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1945
1946 loop {
1947 let states = state.account_states();
1948 let rl_snapshot = state.rate_limit_snapshot();
1949 let now = now_ms();
1950 let mut next_wake_ms: Option<u64> = None;
1951
1952 for account in &config.accounts {
1953 let Some(st) = states.get(&account.name) else { continue };
1954 if st.disabled { continue; } let cdl = st.cooldown_until_ms;
1956
1957 if cdl > 0 && cdl <= now {
1958 let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1960 if !handled {
1961 tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1962 let token = {
1963 let creds = credentials.read().await;
1964 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1965 };
1966 if let Some(token) = token {
1967 post_cooldown_prefetch(
1968 &client, account, &token, &state,
1969 &config.server.upstream_url,
1970 ).await;
1971 }
1972 if notify_on_resume.remove(&account.name) && !state.get_alerts_muted() {
1973 notify(
1974 "shunt: Account Resumed",
1975 &format!("Account '{}' is back online.", account.name),
1976 "Glass",
1977 );
1978 }
1979 last_resumed.insert(account.name.clone(), cdl);
1980 last_stale_prefetch.insert(account.name.clone(), now);
1981 }
1982 } else if cdl > now {
1983 let remaining = cdl - now;
1985 if remaining >= 5 * 60_000 {
1986 notify_on_resume.insert(account.name.clone());
1987 }
1988 next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1989 } else {
1990 let rl_age = rl_snapshot
1992 .get(&account.name)
1993 .map(|r| now.saturating_sub(r.updated_ms))
1994 .unwrap_or(u64::MAX); let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1996 let fetched_ago = now.saturating_sub(last_fetched);
1997
1998 if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1999 tracing::debug!(
2000 account = %account.name,
2001 age_min = rl_age / 60_000,
2002 "rate-limit data stale — refreshing"
2003 );
2004 let token = {
2005 let creds = credentials.read().await;
2006 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
2007 };
2008 if let Some(token) = token {
2009 post_cooldown_prefetch(
2010 &client, account, &token, &state,
2011 &config.server.upstream_url,
2012 ).await;
2013 }
2014 last_stale_prefetch.insert(account.name.clone(), now);
2015 }
2016 }
2017 }
2018
2019 let sleep_ms = next_wake_ms
2021 .map(|wake| wake.saturating_sub(now_ms()).max(50))
2022 .unwrap_or(30_000);
2023 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
2024 }
2025}
2026
2027use crate::notify::notify;
2028use crate::translate::{
2029 translate_to_anthropic,
2030 translate_from_anthropic,
2031 uuid_v4,
2032 translate_anthropic_stream,
2033 translate_anthropic_req_to_chatgpt,
2034 translate_response_chatgpt_to_anthropic,
2035 translate_anthropic_req_to_openai,
2036 translate_response_openai_to_anthropic,
2037 translate_response_anthropic_to_openai,
2038};
2039
2040async fn openai_models_handler() -> impl IntoResponse {
2055 axum::Json(json!({
2056 "object": "list",
2057 "data": [
2058 { "id": "claude-opus-4-6", "object": "model", "owned_by": "anthropic" },
2059 { "id": "claude-sonnet-4-6", "object": "model", "owned_by": "anthropic" },
2060 { "id": "claude-haiku-4-5-20251001", "object": "model", "owned_by": "anthropic" },
2061 ]
2062 }))
2063}
2064
2065async fn openai_compat_handler(
2067 State(s): State<AppState>,
2068 req: Request,
2069) -> Result<Response, ProxyError> {
2070 let Some(ref anthropic_url) = s.anthropic_base_url else {
2071 return proxy_handler(State(s), req).await;
2073 };
2074
2075 let body_bytes = axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY)
2076 .await
2077 .map_err(|_| ProxyError::BodyRead)?;
2078
2079 let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
2080 .unwrap_or(json!({}));
2081
2082 let stream = openai_body["stream"].as_bool().unwrap_or(false);
2083 let anthropic_body = translate_to_anthropic(openai_body);
2084
2085 let client = reqwest::Client::builder()
2086 .timeout(std::time::Duration::from_secs(300))
2087 .build()
2088 .map_err(|_| ProxyError::Upstream)?;
2089
2090 let mut req_builder = client
2091 .post(format!("{anthropic_url}/v1/messages"))
2092 .header("content-type", "application/json")
2093 .header("anthropic-version", "2023-06-01")
2094 .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
2095 .header("x-shunt-compat", "openai");
2096 if let Some(ref key) = s.config.server.remote_key {
2097 req_builder = req_builder.header("x-api-key", key.as_str());
2098 }
2099 let resp = req_builder
2100 .json(&anthropic_body)
2101 .send()
2102 .await
2103 .map_err(|_| ProxyError::Upstream)?;
2104
2105 if !resp.status().is_success() {
2106 let status = resp.status();
2107 let body = resp.text().await.unwrap_or_default();
2108 let code = status.as_u16();
2109 return Ok(axum::response::Response::builder()
2110 .status(code)
2111 .header("content-type", "application/json")
2112 .body(axum::body::Body::from(body))
2113 .unwrap());
2114 }
2115
2116 if stream {
2117 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
2119 let stream = translate_anthropic_stream(resp, chat_id);
2120 Ok(axum::response::Response::builder()
2121 .status(200)
2122 .header("content-type", "text/event-stream")
2123 .header("cache-control", "no-cache")
2124 .body(axum::body::Body::from_stream(stream))
2125 .unwrap())
2126 } else {
2127 let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
2128 let openai_resp = translate_from_anthropic(anthropic_resp);
2129 Ok(axum::Json(openai_resp).into_response())
2130 }
2131}
2132
2133async fn fetch_sentinel_token(client: &reqwest::Client, upstream: &str, token: &str) -> Option<String> {
2140 let url = format!("{}/backend-api/sentinel/chat-requirements", upstream);
2141 let resp = client
2142 .get(&url)
2143 .header("Authorization", format!("Bearer {}", token))
2144 .send()
2145 .await
2146 .ok()?;
2147 if !resp.status().is_success() {
2148 return None;
2149 }
2150 let json: serde_json::Value = resp.json().await.ok()?;
2151 if json["proofofwork"]["required"].as_bool() == Some(true) {
2152 return None;
2153 }
2154 json["token"].as_str().map(ToOwned::to_owned)
2155}
2156
2157
2158fn is_simple_model(model: &str) -> bool {
2161 model.contains("haiku")
2162}
2163
2164fn auto_fallback_model(model: &str) -> Option<&'static str> {
2167 if model.contains("opus") {
2168 Some("claude-sonnet-4-6")
2169 } else if model.contains("sonnet") {
2170 Some("claude-haiku-4-5-20251001")
2171 } else {
2172 None
2173 }
2174}
2175
2176fn resolve_model(
2181 incoming: &str,
2182 account: &crate::config::AccountConfig,
2183 mapping: &std::collections::HashMap<String, String>,
2184) -> String {
2185 if let Some(m) = &account.model {
2187 return m.clone();
2188 }
2189 if let Some(m) = mapping.get(incoming) {
2191 return m.clone();
2192 }
2193 let default = account.provider.default_model();
2195 if !default.is_empty() {
2196 return default.to_owned();
2197 }
2198 incoming.to_owned()
2200}
2201