1use std::collections::{HashMap, HashSet};
2use std::net::IpAddr;
3use std::sync::Arc;
4use std::time::Instant;
5
6use parking_lot::Mutex as ParkingMutex;
7
8use axum::extract::{Request, State};
9use axum::http::StatusCode;
10use axum::response::{IntoResponse, Response};
11use axum::routing::{get, post};
12use axum::Router;
13use bytes::Bytes;
14use serde_json::json;
15use tokio::sync::RwLock;
16use tracing::{error, info, warn};
17
18use crate::config::{state_path, Config, CredentialsStore};
19use crate::credential::Credential;
20use crate::forwarder::Forwarder;
21use crate::provider::Provider;
22use crate::quota;
23use crate::router;
24use crate::state::StateStore;
25use crate::telemetry::TelemetryClient;
26
27const MAX_REQUEST_BODY: usize = 100 * 1024 * 1024;
29
30#[derive(Clone)]
31struct AppState {
32 config: Arc<Config>,
33 forwarder: Arc<Forwarder>,
34 state: StateStore,
35 credentials: Arc<RwLock<HashMap<String, Credential>>>,
37 refresh_locks: Arc<ParkingMutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
45 started_ms: u64,
47 anthropic_base_url: Option<String>,
50 telemetry: Option<TelemetryClient>,
52 rate_limiter: Option<Arc<ParkingMutex<HashMap<IpAddr, TokenBucket>>>>,
54}
55
56struct TokenBucket {
58 tokens: f64,
59 last_refill: Instant,
60}
61
62impl TokenBucket {
63 fn new(capacity: f64) -> Self {
64 Self { tokens: capacity, last_refill: Instant::now() }
65 }
66
67 fn check_and_consume(&mut self, rpm: f64) -> bool {
70 let elapsed = self.last_refill.elapsed().as_secs_f64();
71 self.last_refill = Instant::now();
72 let burst = (rpm / 6.0).max(10.0);
74 self.tokens = (self.tokens + elapsed * rpm / 60.0).min(burst);
75 if self.tokens >= 1.0 {
76 self.tokens -= 1.0;
77 true
78 } else {
79 false
80 }
81 }
82}
83
84pub fn create_app(config: Config) -> anyhow::Result<Router> {
85 let (app, _, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
86 Ok(app)
87}
88
89pub type LiveCredentials = Arc<RwLock<HashMap<String, Credential>>>;
91
92fn build_app_state(
96 config: Config,
97 state: StateStore,
98 anthropic_base_url: Option<String>,
99) -> anyhow::Result<(AppState, LiveCredentials)> {
100 let forwarder = Forwarder::new(config.server.request_timeout_secs)?;
101
102 for a in &config.accounts {
103 if a.provider.auth_kind() == crate::provider::AuthKind::None {
104 state.clear_auth_failed(&a.name);
106 } else if a.credential.is_none() {
107 state.set_auth_failed(&a.name);
108 }
109 }
110
111 let credentials: LiveCredentials = Arc::new(RwLock::new(
112 config.accounts.iter()
113 .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
114 .collect::<HashMap<_, _>>(),
115 ));
116
117 let telemetry = config.server.telemetry_url.as_deref().map(|url| {
118 TelemetryClient::new(url, config.server.telemetry_token.clone(), config.server.instance_name.clone())
119 });
120
121 let rate_limiter = if config.server.rate_limit_rpm > 0 {
122 Some(Arc::new(ParkingMutex::new(HashMap::<IpAddr, TokenBucket>::new())))
123 } else {
124 None
125 };
126
127 let app_state = AppState {
128 config: Arc::new(config),
129 forwarder: Arc::new(forwarder),
130 state,
131 credentials: Arc::clone(&credentials),
132 refresh_locks: Arc::new(ParkingMutex::new(HashMap::new())),
133 started_ms: now_ms(),
134 anthropic_base_url,
135 telemetry,
136 rate_limiter,
137 };
138
139 Ok((app_state, credentials))
140}
141
142pub fn create_proxy_app(
143 config: Config,
144 state: StateStore,
145 anthropic_base_url: Option<String>,
146) -> anyhow::Result<(Router, LiveCredentials)> {
147 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
148
149 let app = Router::new()
150 .route("/v1/messages", post(proxy_handler))
151 .route("/v1/messages/count_tokens", post(proxy_handler))
152 .route("/v1/chat/completions", post(openai_compat_handler))
153 .route("/v1/models", get(openai_models_handler))
154 .fallback(proxy_handler)
155 .with_state(app_state);
156
157 Ok((app, credentials))
158}
159
160pub fn create_control_app(
163 config: Config,
164 state: StateStore,
165) -> anyhow::Result<Router> {
166 let (app_state, _) = build_app_state(config, state, None)?;
167
168 let app = Router::new()
169 .route("/health", get(health))
170 .route("/status", get(status_handler))
171 .route("/use", post(use_handler))
172 .route("/model", get(model_get_handler).post(model_set_handler).delete(model_clear_handler))
173 .route("/strategy", get(strategy_get_handler).post(strategy_set_handler).delete(strategy_clear_handler))
174 .with_state(app_state);
175
176 Ok(app)
177}
178
179pub fn create_app_with_state(
183 config: Config,
184 state: StateStore,
185 anthropic_base_url: Option<String>,
186) -> anyhow::Result<(Router, LiveCredentials, Option<TelemetryClient>)> {
187 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
188 let telemetry = app_state.telemetry.clone();
189
190 let app = Router::new()
191 .route("/health", get(health))
193 .route("/status", get(status_handler))
194 .route("/use", post(use_handler))
195 .route("/model", get(model_get_handler).post(model_set_handler).delete(model_clear_handler))
196 .route("/strategy", get(strategy_get_handler).post(strategy_set_handler).delete(strategy_clear_handler))
197 .route("/v1/messages", post(proxy_handler))
199 .route("/v1/messages/count_tokens", post(proxy_handler))
200 .route("/v1/chat/completions", post(openai_compat_handler))
201 .route("/v1/models", get(openai_models_handler))
202 .fallback(proxy_handler)
203 .with_state(app_state);
204
205 Ok((app, credentials, telemetry))
206}
207
208pub fn build_status_snapshot(config: &Config, state: &StateStore, started_ms: u64) -> serde_json::Value {
210 let account_states = state.account_states();
211 let rate_limits = state.rate_limit_snapshot();
212
213 let accounts: Vec<_> = config.accounts.iter().map(|a| {
214 let st = account_states.get(&a.name);
215 let rl = rate_limits.get(&a.name);
216 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
217 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
218 let reset_5h = rl.and_then(|r| r.reset_5h);
219 let reset_7d = rl.and_then(|r| r.reset_7d);
220 let disabled = st.map(|s| s.disabled).unwrap_or(false);
221 let auth_failed = st.map(|s| s.auth_failed).unwrap_or(false);
222 let health_check_failed = st.map(|s| s.health_check_failed).unwrap_or(false);
223 let cooldown_until_ms = st.map(|s| s.cooldown_until_ms).unwrap_or(0);
224 let available = state.is_available(&a.name);
225 let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
226
227 json!({
228 "name": a.name,
229 "email": email,
230 "provider": a.provider.to_string(),
231 "available": available,
232 "disabled": disabled,
233 "auth_failed": auth_failed,
234 "health_check_failed": health_check_failed,
235 "cooldown_until_ms": cooldown_until_ms,
236 "utilization_5h": utilization_5h,
237 "reset_5h": reset_5h,
238 "utilization_7d": utilization_7d,
239 "reset_7d": reset_7d,
240 })
241 }).collect();
242
243 json!({
244 "started_ms": started_ms,
245 "accounts": accounts,
246 "pinned_account": state.get_pinned(),
247 "last_used_account": state.get_last_used(),
248 })
249}
250
251async fn health() -> impl IntoResponse {
252 axum::Json(json!({"status": "ok"}))
253}
254
255async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
256 let account_states = s.state.account_states();
257 let quotas = s.state.quota_snapshot();
258 let rate_limits = s.state.rate_limit_snapshot();
259
260 let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
261 let st = account_states.get(&a.name);
262 let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
263 "reauth_required"
264 } else if st.map(|s| s.disabled).unwrap_or(false) {
265 "disabled"
266 } else if st.map(|s| s.health_check_failed).unwrap_or(false) {
267 "unhealthy"
268 } else if s.state.is_available(&a.name) {
269 "available"
270 } else {
271 "cooling"
272 };
273
274 let quota = quotas.get(&a.name);
275 let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
276 let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
277 let tokens_used = quota.map(|q| json!({
278 "input": q.input_tokens,
279 "output": q.output_tokens,
280 "total": q.total_tokens(),
281 }));
282
283 let rl = rate_limits.get(&a.name);
284 let rate_limit = rl.map(|r| json!({
285 "utilization_5h": r.utilization_5h,
286 "reset_5h": r.reset_5h,
287 "status_5h": r.status_5h,
288 "utilization_7d": r.utilization_7d,
289 "reset_7d": r.reset_7d,
290 "status_7d": r.status_7d,
291 "representative_claim": r.representative_claim,
292 "updated_ms": r.updated_ms,
293 }));
294
295 let acc_state = account_states.get(&a.name);
296 let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
297 let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
298 let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
299 let health_check_failed = acc_state.map(|s| s.health_check_failed).unwrap_or(false);
300 let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
301 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
302 let reset_5h = rl.and_then(|r| r.reset_5h);
303 let status_5h = rl.and_then(|r| r.status_5h.clone());
304 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
305 let reset_7d = rl.and_then(|r| r.reset_7d);
306 let status_7d = rl.and_then(|r| r.status_7d.clone());
307 let available = s.state.is_available(&a.name);
308
309 json!({
310 "name": a.name,
311 "email": email,
312 "plan_type": a.plan_type,
313 "provider": a.provider.to_string(),
314 "status": avail_status,
315 "available": available,
316 "disabled": disabled,
317 "auth_failed": auth_failed,
318 "health_check_failed": health_check_failed,
319 "cooldown_until_ms": cooldown_until_ms,
320 "utilization_5h": utilization_5h,
321 "reset_5h": reset_5h,
322 "status_5h": status_5h,
323 "utilization_7d": utilization_7d,
324 "reset_7d": reset_7d,
325 "status_7d": status_7d,
326 "window_expires_ms": window_expires_ms,
327 "tokens_used": tokens_used,
328 "rate_limit": rate_limit,
329 })
330 }).collect();
331
332 let recent_requests = s.state.recent_requests_snapshot();
333 let savings = s.state.savings_snapshot();
334
335 axum::Json(json!({
336 "version": env!("CARGO_PKG_VERSION"),
337 "started_ms": s.started_ms,
338 "accounts": accounts,
339 "pinned_account": s.state.get_pinned(),
340 "last_used_account": s.state.get_last_used(),
341 "recent_requests": recent_requests,
342 "savings": savings,
343 }))
344}
345
346async fn use_handler(
347 State(s): State<AppState>,
348 axum::Json(body): axum::Json<serde_json::Value>,
349) -> Response {
350 let account = body["account"].as_str().map(|s| s.to_owned());
351 if let Some(ref name) = account {
353 if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
354 return (StatusCode::BAD_REQUEST, axum::Json(json!({
355 "error": format!("unknown account '{name}'")
356 }))).into_response();
357 }
358 let pinned = if name == "auto" { None } else { Some(name.clone()) };
359 s.state.set_pinned(pinned);
360 axum::Json(json!({ "pinned": name })).into_response()
361 } else {
362 s.state.set_pinned(None);
363 axum::Json(json!({ "pinned": null })).into_response()
364 }
365}
366
367async fn model_get_handler(State(s): State<AppState>) -> impl IntoResponse {
368 let model = s.state.get_model_override();
369 axum::Json(json!({ "model": model }))
370}
371
372async fn model_set_handler(
373 State(s): State<AppState>,
374 axum::Json(body): axum::Json<serde_json::Value>,
375) -> Response {
376 let Some(model) = body["model"].as_str() else {
377 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing model field" }))).into_response();
378 };
379 s.state.set_model_override(model.to_owned());
380 info!(model, "model override set");
381 axum::Json(json!({ "model": model })).into_response()
382}
383
384async fn model_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
385 s.state.clear_model_override();
386 info!("model override cleared");
387 axum::Json(json!({ "model": null }))
388}
389
390async fn strategy_get_handler(State(s): State<AppState>) -> impl IntoResponse {
391 let (strategy_str, source) = match s.state.get_routing_strategy() {
392 Some(st) => (st.as_str(), "override"),
393 None => (s.config.server.routing_strategy.as_str(), "config"),
394 };
395 axum::Json(json!({ "strategy": strategy_str, "source": source }))
396}
397
398async fn strategy_set_handler(
399 State(s): State<AppState>,
400 axum::Json(body): axum::Json<serde_json::Value>,
401) -> Response {
402 let Some(name) = body["strategy"].as_str() else {
403 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing strategy field" }))).into_response();
404 };
405 let Some(strategy) = crate::config::RoutingStrategy::from_str(name) else {
406 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": format!("unknown strategy '{name}'") }))).into_response();
407 };
408 s.state.set_routing_strategy(strategy);
409 info!(strategy = name, "routing strategy override set");
410 axum::Json(json!({ "strategy": strategy.as_str(), "source": "override" })).into_response()
411}
412
413async fn strategy_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
414 s.state.clear_routing_strategy();
415 info!("routing strategy override cleared");
416 let strategy_str = s.config.server.routing_strategy.as_str();
417 axum::Json(json!({ "strategy": strategy_str, "source": "config" }))
418}
419
420fn now_ms() -> u64 {
421 use std::time::{SystemTime, UNIX_EPOCH};
422 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
423}
424
425fn extract_client_ip(req: &Request, trust_proxy_headers: bool) -> IpAddr {
432 if trust_proxy_headers {
433 if let Some(ip) = req.headers()
434 .get("x-real-ip")
435 .and_then(|v| v.to_str().ok())
436 .and_then(|s| s.parse().ok())
437 {
438 return ip;
439 }
440 }
441 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
442}
443
444async fn proxy_handler(
445 State(s): State<AppState>,
446 req: Request,
447) -> Result<Response, ProxyError> {
448 if let Some(ref expected) = s.config.server.remote_key {
450 let provided = req.headers()
451 .get("x-api-key")
452 .and_then(|v| v.to_str().ok())
453 .unwrap_or("");
454 if provided != expected {
455 return Err(ProxyError::Unauthorized);
456 }
457 }
458
459 if let Some(ref rl) = s.rate_limiter {
461 let ip = extract_client_ip(&req, s.config.server.trust_proxy_headers);
462 let rpm = s.config.server.rate_limit_rpm as f64;
463 let allowed = rl.lock().entry(ip).or_insert_with(|| TokenBucket::new(rpm)).check_and_consume(rpm);
464 if !allowed {
465 return Err(ProxyError::RateLimited);
466 }
467 }
468
469 let method = req.method().as_str().to_owned();
470 let path = req.uri().path().to_owned();
471 let headers = req.headers().clone();
472
473 let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY)
474 .await
475 .map_err(|_| ProxyError::BodyRead)?;
476
477 let body_bytes = if let Ok(mut val) = serde_json::from_slice::<serde_json::Value>(&body_bytes) {
480 let mut changed = false;
481 if let Some(override_model) = s.state.get_model_override() {
482 if val.get("model").is_some() {
483 val["model"] = serde_json::Value::String(override_model);
484 changed = true;
485 }
486 }
487 let resolved_model = val["model"].as_str().unwrap_or("").to_owned();
488 if is_simple_model(&resolved_model) {
489 if let Some(obj) = val.as_object_mut() {
490 for key in &["thinking", "effort", "reasoning_effort"] {
492 if obj.remove(*key).is_some() { changed = true; }
493 }
494 if let Some(serde_json::Value::Object(oc)) = obj.get_mut("output_config") {
496 if oc.remove("effort").is_some() { changed = true; }
497 if oc.is_empty() { obj.remove("output_config"); }
499 }
500 if obj.remove("context_management").is_some() { changed = true; }
502 if let Some(serde_json::Value::Array(betas)) = obj.get_mut("betas") {
504 let before = betas.len();
505 betas.retain(|b| b.as_str() != Some("interleaved-thinking-2025-05-14"));
506 if betas.len() != before { changed = true; }
507 }
508 }
509 }
510 if changed {
511 Bytes::from(serde_json::to_vec(&val).unwrap_or_else(|_| body_bytes.to_vec()))
512 } else {
513 body_bytes
514 }
515 } else {
516 body_bytes
517 };
518
519 let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
520 .ok()
521 .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
522 .unwrap_or_default();
523
524 let mut headers = headers;
526 if is_simple_model(&model) {
527 if let Some(beta_val) = headers.get("anthropic-beta").and_then(|v| v.to_str().ok().map(|s| s.to_owned())) {
528 let filtered: Vec<&str> = beta_val.split(',')
529 .map(|s| s.trim())
530 .filter(|b| !b.contains("thinking") && !b.contains("effort"))
531 .collect();
532 let new_beta = filtered.join(",");
533 if filtered.is_empty() {
534 headers.remove("anthropic-beta");
535 } else if let Ok(v) = axum::http::HeaderValue::from_str(&new_beta) {
536 headers.insert("anthropic-beta", v);
537 }
538 }
539 }
540
541 let req_start_ms = now_ms();
542 let request_id = uuid::Uuid::new_v4().to_string()[..8].to_owned();
543
544 let fp = router::fingerprint(&body_bytes);
545 let fp_ref = fp.as_deref();
546
547 let mut tried: HashSet<String> = HashSet::new();
548 let mut refreshed: HashSet<String> = HashSet::new();
550 let wait_deadline_ms = now_ms() + s.config.server.request_timeout_secs.saturating_mul(1_000);
553
554 loop {
555 let effective_strategy = s.state.get_routing_strategy()
556 .unwrap_or(s.config.server.routing_strategy);
557 let account = match router::pick_account(
558 &s.config.accounts, &s.state, fp_ref, &tried,
559 s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
560 effective_strategy,
561 ) {
562 Some(a) => a,
563 None => {
564 let account_states = s.state.account_states();
568 let now = now_ms();
569 let soonest_ms = s.config.accounts.iter()
570 .filter_map(|a| {
571 let st = account_states.get(&a.name)?;
572 if st.disabled { return None; } if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
574 })
575 .min();
576
577 match soonest_ms {
578 Some(wake_ms) if wake_ms <= wait_deadline_ms => {
579 let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; warn!(wait_ms, "all accounts cooling — waiting for next available account");
581 tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
582 tried.clear(); }
584 _ => return Err(ProxyError::AllAccountsUnavailable),
585 }
586 continue;
587 }
588 };
589
590 let account_name = account.name.clone();
591
592 let token = {
597 let creds = s.credentials.read().await;
598 let cred = creds.get(&account_name)
599 .cloned()
600 .or_else(|| account.credential.clone());
601 match cred {
602 Some(c) => c.bearer_token().to_owned(),
603 None => String::new(),
604 }
605 };
606
607 let req_is_anthropic = path.starts_with("/v1/messages");
611 let acct_is_anthropic = account.provider.wire_protocol()
612 == crate::provider::WireProtocol::Anthropic;
613 let acct_is_chatgpt = matches!(account.provider, Provider::OpenAI);
616
617 let mut log_model = model.clone();
620
621 let (fwd_path, fwd_body, mut fwd_headers) = if req_is_anthropic == acct_is_anthropic {
622 (path.clone(), body_bytes.clone(), headers.clone())
624 } else if req_is_anthropic && acct_is_chatgpt {
625 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
627 let translated = translate_anthropic_req_to_chatgpt(&val);
628 let mut h = headers.clone();
629 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
630 h.remove(*name);
631 }
632 (
633 "/backend-api/conversation".to_owned(),
634 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
635 h,
636 )
637 } else if req_is_anthropic {
638 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
640 let target_model = resolve_model(&model, account, &s.config.model_mapping);
642 log_model = target_model.clone();
643 let translated = translate_anthropic_req_to_openai(val, &target_model);
644 let mut h = headers.clone();
645 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
646 h.remove(*name);
647 }
648 (
649 "/v1/chat/completions".to_owned(),
650 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
651 h,
652 )
653 } else {
654 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
656 let translated = translate_to_anthropic(val);
657 (
658 "/v1/messages".to_owned(),
659 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
660 headers.clone(),
661 )
662 };
663
664 let upstream = account.upstream_url.as_deref()
667 .unwrap_or(&s.config.server.upstream_url);
668
669 if req_is_anthropic && acct_is_chatgpt {
672 tracing::info!(account = %account_name, upstream = %upstream, "routing to chatgpt.com — fetching sentinel");
673 let sentinel_client = reqwest::Client::builder()
674 .timeout(std::time::Duration::from_secs(3))
675 .build()
676 .unwrap_or_default();
677 let sentinel_opt = tokio::time::timeout(
678 std::time::Duration::from_secs(3),
679 fetch_sentinel_token(&sentinel_client, upstream, &token),
680 ).await.ok().flatten();
681 if let Some(sentinel) = sentinel_opt {
682 if let Ok(name) = axum::http::header::HeaderName::from_bytes(
683 b"openai-sentinel-chat-requirements-token",
684 ) {
685 if let Ok(val) = axum::http::HeaderValue::from_str(&sentinel) {
686 fwd_headers.insert(name, val);
687 }
688 }
689 }
690 }
691
692 let response = if acct_is_chatgpt {
695 tracing::info!(account = %account_name, path = %fwd_path, "forwarding to chatgpt.com (15s cap)");
696 match tokio::time::timeout(
697 std::time::Duration::from_secs(15),
698 s.forwarder.forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token),
699 ).await {
700 Ok(Ok(r)) => r,
701 Ok(Err(e)) => {
702 error!(account = %account_name, "chatgpt.com forward error: {:#}", e);
703 s.state.set_cooldown(&account_name, 5 * 60_000);
704 tried.insert(account_name);
705 continue;
706 }
707 Err(_) => {
708 warn!(account = %account_name, "chatgpt.com request timed out (Cloudflare) — cooling 5min");
709 s.state.set_cooldown(&account_name, 5 * 60_000);
710 tried.insert(account_name);
711 continue;
712 }
713 }
714 } else {
715 s.forwarder
716 .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
717 .await
718 .map_err(|e| {
719 error!("Forward error: {:#}", e);
720 ProxyError::Upstream
721 })?
722 };
723
724 match response.status().as_u16() {
725 200..=299 => {
726 s.state.set_last_used(&account_name);
727 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
728 s.state.update_rate_limits(&account_name, info);
729 }
730 let response = if req_is_anthropic == acct_is_anthropic {
732 response
733 } else if req_is_anthropic && acct_is_chatgpt {
734 translate_response_chatgpt_to_anthropic(response, &model).await
736 } else if req_is_anthropic {
737 translate_response_openai_to_anthropic(response, &model).await
739 } else {
740 translate_response_anthropic_to_openai(response).await
742 };
743 return Ok(tap_usage(response, &s.state, s.telemetry.as_ref(), &account_name, &log_model, req_start_ms, &request_id, &path, tried.len()).await);
744 }
745 429 => {
746 let info = account.provider.parse_rate_limits(response.headers());
747 let retry_after_ms = response.headers()
750 .get("retry-after")
751 .and_then(|v| v.to_str().ok())
752 .and_then(|s| s.parse::<u64>().ok())
753 .map(|secs| secs.saturating_mul(1_000).max(500));
754 let cooldown_ms = info.as_ref()
755 .and_then(|i| i.reset_5h.or(i.reset_7d))
756 .map(|reset_secs| {
757 let reset_ms = reset_secs.saturating_mul(1_000);
758 reset_ms.saturating_sub(now_ms()).saturating_add(500) })
760 .or(retry_after_ms)
761 .unwrap_or(60_000);
762 warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
763 if let Some(info) = info {
764 s.state.update_rate_limits(&account_name, info);
765 }
766 s.state.set_cooldown(&account_name, cooldown_ms);
767 if cooldown_ms >= 5 * 60_000 {
768 let mins = cooldown_ms / 60_000;
769 notify(
770 "shunt: Rate Limited",
771 &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
772 "Ping",
773 );
774 }
775 tried.insert(account_name);
776 }
777 529 => {
778 warn!(account = %account_name, "529 overloaded — cooling 30s");
779 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
780 s.state.update_rate_limits(&account_name, info);
781 }
782 s.state.set_cooldown(&account_name, 30_000);
783 tried.insert(account_name);
784 }
785 401 => {
786 if !refreshed.contains(&account_name) {
787 let account_lock = {
795 let mut locks = s.refresh_locks.lock();
796 locks.entry(account_name.clone())
797 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
798 .clone()
799 };
800 let _guard = account_lock.lock().await;
801
802 let cred_before = {
805 let creds = s.credentials.read().await;
806 creds.get(&account_name).cloned()
807 .or_else(|| account.credential.clone())
808 };
809 let Some(cred) = cred_before else {
810 tried.insert(account_name);
811 continue;
812 };
813
814 let token_before = cred.access_token().to_owned();
816 let already_refreshed = {
817 let creds = s.credentials.read().await;
818 creds.get(&account_name)
819 .map(|c| c.access_token() != token_before)
820 .unwrap_or(false)
821 };
822
823 if already_refreshed {
824 warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
826 refreshed.insert(account_name);
827 } else if let Some(oauth_cred) = cred.as_oauth() {
828 match tokio::time::timeout(
830 std::time::Duration::from_secs(10),
831 account.provider.refresh_token(oauth_cred),
832 ).await {
833 Ok(Ok(fresh)) => {
834 warn!(account = %account_name, "401 — token refreshed, retrying");
835 {
836 let mut creds = s.credentials.write().await;
837 creds.insert(account_name.clone(), Credential::Oauth(fresh.clone()));
838 }
839 let name = account_name.clone();
841 let fresh = fresh.clone();
842 tokio::task::spawn_blocking(move || {
843 let mut store = CredentialsStore::load();
844 store.accounts.insert(name, Credential::Oauth(fresh.clone()));
845 store.save().ok();
846 if fresh.id_token.is_some() {
847 crate::oauth::write_codex_auth_file(&fresh);
848 }
849 });
850 refreshed.insert(account_name);
852 }
853 _ => {
854 error!(account = %account_name, "401 — token refresh failed, cooling 5min");
856 s.state.set_cooldown(&account_name, 5 * 60_000);
857 tried.insert(account_name);
858 }
859 }
860 } else {
861 error!(account = %account_name, "401 — API key rejected, cooling 5min");
863 s.state.set_cooldown(&account_name, 5 * 60_000);
864 tried.insert(account_name);
865 }
866 } else {
867 error!(account = %account_name, "401 after refresh — cooling 5min");
869 s.state.set_cooldown(&account_name, 5 * 60_000);
870 tried.insert(account_name);
871 }
872 }
873 403 => {
874 if acct_is_anthropic {
878 error!(account = %account_name, "403 forbidden — cooling 30min");
879 s.state.set_cooldown(&account_name, 30 * 60_000);
880 notify(
881 "shunt: Account Forbidden",
882 &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
883 "Basso",
884 );
885 } else {
886 warn!(account = %account_name, "403 from chatgpt.com (Cloudflare) — cooling 5min");
887 s.state.set_cooldown(&account_name, 5 * 60_000);
888 }
889 tried.insert(account_name);
890 }
891 _ => {
892 return Ok(response);
894 }
895 }
896 }
897}
898
899async fn tap_usage(
908 resp: Response,
909 state: &StateStore,
910 telemetry: Option<&TelemetryClient>,
911 account: &str,
912 model: &str,
913 req_start_ms: u64,
914 request_id: &str,
915 path: &str,
916 retries: usize,
917) -> Response {
918 use axum::body::Body;
919 use crate::state::RequestLog;
920
921 let streaming = quota::is_streaming_response(&resp);
922
923 if streaming {
924 let state = state.clone();
925 let telem = telemetry.cloned();
926 let account = account.to_owned();
927 let model = model.to_owned();
928 let request_id = request_id.to_owned();
929 let path = path.to_owned();
930 let on_complete = Arc::new(move |input: u64, output: u64| {
931 let duration_ms = now_ms().saturating_sub(req_start_ms);
932 info!(
933 request_id = %request_id,
934 account = %account,
935 model = %model,
936 status = 200,
937 latency_ms = duration_ms,
938 path = %path,
939 stream = true,
940 input_tokens = input,
941 output_tokens = output,
942 retries = retries,
943 "request complete"
944 );
945 let log = RequestLog {
946 ts_ms: req_start_ms,
947 account: account.clone(),
948 model: model.clone(),
949 status: 200,
950 input_tokens: input,
951 output_tokens: output,
952 duration_ms,
953 };
954 state.record_usage(&account, input, output);
955 state.record_global(&model, input, output);
956 if let Some(ref t) = telem { t.push_event(&log); }
957 state.record_request(log);
958 });
959 let (parts, body) = resp.into_parts();
960 let wrapped = quota::wrap_streaming_body(body, on_complete);
961 return Response::from_parts(parts, wrapped);
962 }
963
964 let (parts, body) = resp.into_parts();
966 let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
967 Ok(b) => b,
968 Err(_) => return Response::from_parts(parts, Body::empty()),
969 };
970 let (input, output) = quota::extract_usage_from_json(&bytes);
971 let duration_ms = now_ms().saturating_sub(req_start_ms);
972 info!(
973 request_id = %request_id,
974 account = %account,
975 model = %model,
976 status = 200,
977 latency_ms = duration_ms,
978 path = %path,
979 stream = false,
980 input_tokens = input,
981 output_tokens = output,
982 retries = retries,
983 "request complete"
984 );
985 let log = RequestLog {
986 ts_ms: req_start_ms,
987 account: account.to_owned(),
988 model: model.to_owned(),
989 status: 200,
990 input_tokens: input,
991 output_tokens: output,
992 duration_ms,
993 };
994 state.record_usage(account, input, output);
995 state.record_global(model, input, output);
996 if let Some(t) = telemetry { t.push_event(&log); }
997 state.record_request(log);
998 Response::from_parts(parts, Body::from(bytes))
999}
1000
1001
1002pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
1010 let client = reqwest::Client::builder()
1011 .timeout(std::time::Duration::from_secs(20))
1012 .build()
1013 .unwrap_or_default();
1014
1015 for account in &config.accounts {
1016 let rl = state.rate_limit_snapshot();
1018 if let Some(r) = rl.get(&account.name) {
1019 if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
1020 continue;
1021 }
1022 }
1023
1024 let cred = match account.credential.clone() {
1026 Some(c) => c,
1027 None => continue,
1028 };
1029
1030 let Some((path, body)) = account.provider.prefetch_request() else {
1031 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1033 auth_probe_get(&client, probe_path, account, &state).await;
1034 }
1035 continue;
1036 };
1037 let url = format!("{}{}", config.server.upstream_url, path);
1038
1039 let resp = prefetch_send(&client, &url, &account.provider, cred.bearer_token(), &body).await;
1040
1041 let r = match resp {
1042 Ok(r) => r,
1043 Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
1044 };
1045
1046 if r.status() == reqwest::StatusCode::UNAUTHORIZED {
1047 tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
1048 let Some(oauth_cred) = cred.as_oauth() else {
1049 tracing::error!(account = %account.name, "prefetch 401 — API key rejected");
1051 state.set_auth_failed(&account.name);
1052 continue;
1053 };
1054 let fresh = match account.provider.refresh_token(oauth_cred).await {
1055 Ok(f) => f,
1056 Err(e) => {
1057 tracing::warn!(account = %account.name, "token refresh failed: {e}");
1058 state.set_auth_failed(&account.name);
1059 continue;
1060 }
1061 };
1062 let mut store = crate::config::CredentialsStore::load();
1063 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1064 store.save().ok();
1065 if fresh.id_token.is_some() {
1066 crate::oauth::write_codex_auth_file(&fresh);
1067 }
1068 live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1070
1071 match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
1072 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
1073 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
1074 state.set_auth_failed(&account.name);
1075 }
1076 Ok(r2) => {
1077 if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
1078 state.update_rate_limits(&account.name, info);
1079 }
1080 }
1081 Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
1082 }
1083 } else {
1084 tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
1085 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1086 state.update_rate_limits(&account.name, info);
1087 }
1088 }
1089 }
1090}
1091
1092async fn prefetch_send(
1094 client: &reqwest::Client,
1095 url: &str,
1096 provider: &crate::provider::Provider,
1097 token: &str,
1098 body: &serde_json::Value,
1099) -> anyhow::Result<reqwest::Response> {
1100 let mut headers = reqwest::header::HeaderMap::new();
1101 provider.inject_auth_headers(&mut headers, token)?;
1102 for (name, value) in provider.prefetch_extra_headers() {
1103 headers.insert(
1104 reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
1105 reqwest::header::HeaderValue::from_static(value),
1106 );
1107 }
1108 Ok(client.post(url).headers(headers).json(body).send().await?)
1109}
1110
1111async fn auth_probe_get(
1115 client: &reqwest::Client,
1116 path: &str,
1117 account: &crate::config::AccountConfig,
1118 state: &StateStore,
1119) {
1120 let cred = match account.credential.clone() {
1121 Some(c) => c,
1122 None => return,
1123 };
1124 let upstream = account.upstream_url.as_deref()
1125 .unwrap_or_else(|| account.provider.default_upstream_url());
1126 let url = format!("{}{}", upstream, path);
1127
1128 let do_get = |token: &str| -> reqwest::RequestBuilder {
1129 let mut headers = reqwest::header::HeaderMap::new();
1130 let _ = account.provider.inject_auth_headers(&mut headers, token);
1131 client.get(&url).headers(headers)
1132 };
1133
1134 let resp = match do_get(cred.bearer_token()).send().await {
1135 Ok(r) => r,
1136 Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
1137 };
1138
1139 if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
1140 tracing::info!(account = %account.name, "auth probe: token rejected, refreshing");
1141 let Some(oauth_cred) = cred.as_oauth() else {
1142 tracing::error!(account = %account.name, "auth probe 401 — API key rejected");
1144 state.set_auth_failed(&account.name);
1145 return;
1146 };
1147 let fresh = match account.provider.refresh_token(oauth_cred).await {
1148 Ok(f) => f,
1149 Err(e) => {
1150 tracing::warn!(account = %account.name, "token refresh failed: {e}");
1151 state.set_auth_failed(&account.name);
1152 return;
1153 }
1154 };
1155 let mut store = crate::config::CredentialsStore::load();
1156 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1157 store.save().ok();
1158 if fresh.id_token.is_some() {
1159 crate::oauth::write_codex_auth_file(&fresh);
1160 }
1161
1162 let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
1163 match do_get(fresh_token).send().await {
1164 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
1165 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
1166 state.set_auth_failed(&account.name);
1167 }
1168 Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
1169 Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
1170 }
1171 } else {
1172 tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
1173 }
1177}
1178
1179fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
1186 let now_ms = std::time::SystemTime::now()
1187 .duration_since(std::time::UNIX_EPOCH)
1188 .unwrap_or_default()
1189 .as_millis() as u64;
1190 let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
1191 .unwrap_or(cred.expires_at);
1192 exp_ms < now_ms + threshold_mins * 60 * 1_000
1193}
1194
1195async fn sync_live_creds_from_auth_json(
1200 account_name: &str,
1201 live_creds: &LiveCredentials,
1202) {
1203 let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
1204 let current_exp = live_creds.read().await
1205 .get(account_name)
1206 .and_then(|c| c.as_oauth())
1207 .map(|c| c.expires_at)
1208 .unwrap_or(0);
1209 if from_file.expires_at > current_exp {
1210 tracing::info!(account = %account_name, "synced fresher token from auth.json");
1211 live_creds.write().await.insert(account_name.to_owned(), Credential::Oauth(from_file));
1212 }
1213}
1214
1215async fn do_proactive_refresh(
1217 account: &crate::config::AccountConfig,
1218 creds: &crate::oauth::OAuthCredential,
1219 live_creds: &LiveCredentials,
1220 state: &StateStore,
1221) {
1222 tracing::info!(account = %account.name, "proactive OpenAI token refresh");
1223 match account.provider.refresh_token(creds).await {
1224 Ok(fresh) => {
1225 tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
1226 {
1227 let mut map = live_creds.write().await;
1228 map.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1229 }
1230 let mut store = crate::config::CredentialsStore::load();
1231 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1232 store.save().ok();
1233 if fresh.id_token.is_some() {
1234 crate::oauth::write_codex_auth_file(&fresh);
1235 }
1236 state.clear_auth_failed(&account.name);
1237 }
1238 Err(e) => {
1239 tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
1240 state.set_auth_failed(&account.name);
1241 }
1242 }
1243}
1244
1245
1246pub async fn openai_token_refresh_loop(
1254 config: Arc<Config>,
1255 state: StateStore,
1256 live_creds: LiveCredentials,
1257) {
1258 for account in config.accounts.iter()
1260 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1261 {
1262 if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
1263 continue;
1264 }
1265 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1266
1267 let creds = {
1268 let map = live_creds.read().await;
1269 map.get(&account.name).cloned().or_else(|| account.credential.clone())
1270 };
1271 if let Some(creds) = creds {
1272 if let Some(oauth) = creds.as_oauth() {
1273 if access_token_expires_soon(oauth, 30) {
1274 do_proactive_refresh(account, oauth, &live_creds, &state).await;
1276 } else {
1277 tracing::info!(account = %account.name, "access_token fresh at startup");
1278 }
1279 }
1280 }
1281 }
1282
1283 loop {
1286 tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
1287 for account in config.accounts.iter()
1288 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1289 {
1290 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1291 }
1292 }
1293}
1294
1295enum ProxyError {
1300 BodyRead,
1301 Upstream,
1302 AllAccountsUnavailable,
1303 Unauthorized,
1304 RateLimited,
1305}
1306
1307impl IntoResponse for ProxyError {
1308 fn into_response(self) -> Response {
1309 match self {
1310 ProxyError::RateLimited => {
1311 let mut resp = (
1312 StatusCode::TOO_MANY_REQUESTS,
1313 axum::Json(json!({
1314 "type": "error",
1315 "error": {"type": "rate_limit_error", "message": "too many requests — slow down"}
1316 })),
1317 ).into_response();
1318 resp.headers_mut().insert(
1319 axum::http::header::RETRY_AFTER,
1320 axum::http::HeaderValue::from_static("60"),
1321 );
1322 resp
1323 }
1324 other => {
1325 let (status, msg) = match other {
1326 ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
1327 ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
1328 ProxyError::AllAccountsUnavailable => {
1329 (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
1330 }
1331 ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
1332 ProxyError::RateLimited => unreachable!(),
1333 };
1334 (status, axum::Json(json!({
1335 "type": "error",
1336 "error": {"type": "api_error", "message": msg}
1337 }))).into_response()
1338 }
1339 }
1340 }
1341}
1342
1343pub async fn recovery_watcher(
1352 config: Arc<Config>,
1353 state: StateStore,
1354 credentials: LiveCredentials,
1355) {
1356 use std::time::{Duration, Instant};
1357 const CHECK_INTERVAL: Duration = Duration::from_secs(120);
1358 const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
1359
1360 let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
1361 let mut last_notified: Option<Instant> = None;
1362
1363 loop {
1364 tokio::time::sleep(CHECK_INTERVAL).await;
1365
1366 let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
1367 let failed = state.auth_failed_accounts(&name_refs);
1368 if failed.is_empty() {
1369 last_notified = None;
1370 continue;
1371 }
1372
1373 tracing::warn!(
1374 accounts = ?failed,
1375 "recovery: {} account(s) auth_failed, attempting token refresh",
1376 failed.len()
1377 );
1378
1379 let mut any_recovered = false;
1380
1381 for name in &failed {
1382 let cred = {
1383 let map = credentials.read().await;
1384 map.get(*name).cloned()
1385 };
1386 let Some(cred) = cred else { continue };
1387 if !cred.has_refresh_token() { continue; }
1388 let Some(oauth_cred) = cred.as_oauth().cloned() else { continue };
1389
1390 let provider = config.accounts.iter()
1391 .find(|a| a.name == *name)
1392 .map(|a| a.provider.clone())
1393 .unwrap_or_default();
1394
1395 let result = tokio::time::timeout(
1396 Duration::from_secs(20),
1397 provider.refresh_token(&oauth_cred),
1398 ).await;
1399
1400 match result {
1401 Ok(Ok(fresh)) => {
1402 tracing::info!(account = %name, "recovery: token refreshed — account back online");
1403 {
1404 let mut map = credentials.write().await;
1405 map.insert(name.to_string(), Credential::Oauth(fresh.clone()));
1406 }
1407 let name_owned = name.to_string();
1408 let fresh_owned = fresh.clone();
1409 tokio::task::spawn_blocking(move || {
1410 let mut store = crate::config::CredentialsStore::load();
1411 store.accounts.insert(name_owned, Credential::Oauth(fresh_owned.clone()));
1412 store.save().ok();
1413 if fresh_owned.id_token.is_some() {
1414 crate::oauth::write_codex_auth_file(&fresh_owned);
1415 }
1416 });
1417 state.clear_auth_failed(name);
1418 any_recovered = true;
1419 }
1420 Ok(Err(e)) => {
1421 tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
1422 notify(
1423 "shunt: Reauth Required",
1424 &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
1425 "Basso",
1426 );
1427 }
1428 Err(_) => {
1429 tracing::error!(account = %name, "recovery: token refresh timed out");
1430 notify(
1431 "shunt: Reauth Required",
1432 &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
1433 "Basso",
1434 );
1435 }
1436 }
1437 }
1438
1439 if any_recovered {
1440 tracing::info!("recovery: at least one account is back online");
1441 continue;
1442 }
1443
1444 let still_failed = state.auth_failed_accounts(&name_refs);
1446 if still_failed.len() == account_names.len() {
1447 let should_notify = last_notified
1448 .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
1449 .unwrap_or(true);
1450 if should_notify {
1451 error!(
1452 "ALL accounts are offline (auth failed). \
1453 Run `shunt add-account` to re-authorize."
1454 );
1455 notify(
1456 "shunt: All Accounts Offline",
1457 "All accounts need re-authorization. Run `shunt add-account`.",
1458 "Basso",
1459 );
1460 last_notified = Some(Instant::now());
1461 }
1462 }
1463 }
1464}
1465
1466async fn post_cooldown_prefetch(
1470 client: &reqwest::Client,
1471 account: &crate::config::AccountConfig,
1472 token: &str,
1473 state: &StateStore,
1474 upstream_url: &str,
1475) {
1476 let Some((path, body)) = account.provider.prefetch_request() else {
1477 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1478 auth_probe_get(client, probe_path, account, state).await;
1479 }
1480 return;
1481 };
1482 let url = format!("{upstream_url}{path}");
1483 match prefetch_send(client, &url, &account.provider, token, &body).await {
1484 Ok(r) => {
1485 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1486 state.update_rate_limits(&account.name, info);
1487 tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1488 }
1489 }
1490 Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1491 }
1492}
1493
1494pub async fn health_check_loop(
1501 config: Arc<Config>,
1502 state: StateStore,
1503 live_creds: LiveCredentials,
1504) {
1505 if !config.server.health_check_enabled {
1506 return;
1507 }
1508
1509 tokio::time::sleep(std::time::Duration::from_secs(15)).await;
1511
1512 let base_interval_ms = config.server.health_check_interval_secs * 1000;
1513 let timeout = std::time::Duration::from_secs(config.server.health_check_timeout_secs);
1514 let client = reqwest::Client::builder()
1515 .timeout(timeout)
1516 .build()
1517 .unwrap_or_default();
1518
1519 const FAILURE_THRESHOLD: u32 = 2;
1520 const MAX_BACKOFF_EXP: u32 = 3; loop {
1523 for account in &config.accounts {
1524 {
1526 let states = state.account_states();
1527 if let Some(acc_state) = states.get(&account.name) {
1528 if acc_state.disabled || acc_state.auth_failed {
1529 continue;
1530 }
1531 }
1532 }
1533
1534 let (last_check_ms, failures) = state.health_check_info(&account.name);
1536 let backoff_factor = 1u64 << failures.min(MAX_BACKOFF_EXP);
1537 let effective_interval_ms = base_interval_ms.saturating_mul(backoff_factor);
1538 let now = crate::state::now_ms_pub();
1539 if last_check_ms > 0 && now.saturating_sub(last_check_ms) < effective_interval_ms {
1540 continue;
1541 }
1542
1543 state.update_last_health_check(&account.name);
1544
1545 let cred = {
1547 let creds = live_creds.read().await;
1548 creds.get(&account.name).cloned()
1549 }.or_else(|| account.credential.clone());
1550
1551 let cred = match cred {
1552 Some(c) => c,
1553 None => {
1554 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1556 let upstream = account.upstream_url.as_deref()
1557 .unwrap_or_else(|| account.provider.default_upstream_url());
1558 let url = format!("{upstream}{probe_path}");
1559 match client.get(&url).send().await {
1560 Ok(r) if r.status().is_success() => {
1561 if state.is_health_check_failed(&account.name) {
1562 tracing::info!(account = %account.name, "health check recovered");
1563 }
1564 state.clear_health_check_failed(&account.name);
1565 }
1566 Ok(r) => {
1567 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1568 tracing::warn!(account = %account.name, status = %r.status(),
1569 failures = count, "health check failed");
1570 }
1571 Err(e) => {
1572 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1573 tracing::warn!(account = %account.name, failures = count,
1574 "health check unreachable: {e}");
1575 }
1576 }
1577 }
1578 continue;
1579 }
1580 };
1581
1582 let token = cred.bearer_token();
1583 let upstream = account.upstream_url.as_deref()
1584 .unwrap_or(&config.server.upstream_url);
1585
1586 if let Some((path, body)) = account.provider.prefetch_request() {
1588 let url = format!("{upstream}{path}");
1589 match prefetch_send(&client, &url, &account.provider, token, &body).await {
1590 Ok(r) => {
1591 let status = r.status();
1592 if status == reqwest::StatusCode::UNAUTHORIZED {
1593 if let Some(oauth_cred) = cred.as_oauth() {
1595 match account.provider.refresh_token(oauth_cred).await {
1596 Ok(fresh) => {
1597 let mut store = crate::config::CredentialsStore::load();
1598 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1599 store.save().ok();
1600 live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh));
1601 state.clear_auth_failed(&account.name);
1602 if state.is_health_check_failed(&account.name) {
1603 state.clear_health_check_failed(&account.name);
1604 }
1605 tracing::info!(account = %account.name, "health check: token refreshed");
1606 }
1607 Err(e) => {
1608 tracing::error!(account = %account.name, "health check: refresh failed: {e}");
1609 state.set_auth_failed(&account.name);
1610 }
1611 }
1612 } else {
1613 tracing::error!(account = %account.name, "health check: 401 — API key rejected");
1614 state.set_auth_failed(&account.name);
1615 }
1616 } else if status.is_server_error() {
1617 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1618 tracing::warn!(account = %account.name, status = %status,
1619 failures = count, "health check: server error");
1620 } else {
1621 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1623 state.update_rate_limits(&account.name, info);
1624 }
1625 if state.is_health_check_failed(&account.name) {
1626 tracing::info!(account = %account.name, "health check recovered");
1627 }
1628 state.clear_health_check_failed(&account.name);
1629 }
1630 }
1631 Err(e) => {
1632 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1633 tracing::warn!(account = %account.name, failures = count,
1634 "health check probe failed: {e}");
1635 }
1636 }
1637 } else if let Some(probe_path) = account.provider.auth_probe_get_path() {
1638 let probe_upstream = account.upstream_url.as_deref()
1639 .unwrap_or_else(|| account.provider.default_upstream_url());
1640 let url = format!("{probe_upstream}{probe_path}");
1641 let mut headers = reqwest::header::HeaderMap::new();
1642 let _ = account.provider.inject_auth_headers(&mut headers, token);
1643 match client.get(&url).headers(headers).send().await {
1644 Ok(r) => {
1645 let status = r.status();
1646 if status == reqwest::StatusCode::UNAUTHORIZED {
1647 if let Some(oauth_cred) = cred.as_oauth() {
1648 match account.provider.refresh_token(oauth_cred).await {
1649 Ok(fresh) => {
1650 let mut store = crate::config::CredentialsStore::load();
1651 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1652 store.save().ok();
1653 live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh));
1654 state.clear_auth_failed(&account.name);
1655 state.clear_health_check_failed(&account.name);
1656 tracing::info!(account = %account.name, "health check: token refreshed (GET probe)");
1657 }
1658 Err(e) => {
1659 tracing::error!(account = %account.name, "health check: refresh failed: {e}");
1660 state.set_auth_failed(&account.name);
1661 }
1662 }
1663 } else {
1664 tracing::error!(account = %account.name, "health check: 401 — API key rejected");
1665 state.set_auth_failed(&account.name);
1666 }
1667 } else if status.is_server_error() {
1668 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1669 tracing::warn!(account = %account.name, status = %status,
1670 failures = count, "health check: server error (GET probe)");
1671 } else {
1672 if state.is_health_check_failed(&account.name) {
1673 tracing::info!(account = %account.name, "health check recovered");
1674 }
1675 state.clear_health_check_failed(&account.name);
1676 }
1677 }
1678 Err(e) => {
1679 let count = state.record_health_check_failure(&account.name, FAILURE_THRESHOLD);
1680 tracing::warn!(account = %account.name, failures = count,
1681 "health check probe failed: {e}");
1682 }
1683 }
1684 }
1685 }
1686
1687 tokio::time::sleep(std::time::Duration::from_secs(config.server.health_check_interval_secs)).await;
1688 }
1689}
1690
1691pub async fn cooldown_watcher(
1702 config: Arc<Config>,
1703 state: StateStore,
1704 credentials: LiveCredentials,
1705) {
1706 const STALE_RL_MS: u64 = 60 * 60_000;
1708
1709 let client = reqwest::Client::builder()
1710 .timeout(std::time::Duration::from_secs(20))
1711 .build()
1712 .unwrap_or_default();
1713
1714 let mut last_resumed: HashMap<String, u64> = HashMap::new();
1717 let mut notify_on_resume: HashSet<String> = HashSet::new();
1719 let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1721
1722 loop {
1723 let states = state.account_states();
1724 let rl_snapshot = state.rate_limit_snapshot();
1725 let now = now_ms();
1726 let mut next_wake_ms: Option<u64> = None;
1727
1728 for account in &config.accounts {
1729 let Some(st) = states.get(&account.name) else { continue };
1730 if st.disabled { continue; } let cdl = st.cooldown_until_ms;
1732
1733 if cdl > 0 && cdl <= now {
1734 let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1736 if !handled {
1737 tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1738 let token = {
1739 let creds = credentials.read().await;
1740 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1741 };
1742 if let Some(token) = token {
1743 post_cooldown_prefetch(
1744 &client, account, &token, &state,
1745 &config.server.upstream_url,
1746 ).await;
1747 }
1748 if notify_on_resume.remove(&account.name) {
1749 notify(
1750 "shunt: Account Resumed",
1751 &format!("Account '{}' is back online.", account.name),
1752 "Glass",
1753 );
1754 }
1755 last_resumed.insert(account.name.clone(), cdl);
1756 last_stale_prefetch.insert(account.name.clone(), now);
1757 }
1758 } else if cdl > now {
1759 let remaining = cdl - now;
1761 if remaining >= 5 * 60_000 {
1762 notify_on_resume.insert(account.name.clone());
1763 }
1764 next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1765 } else {
1766 let rl_age = rl_snapshot
1768 .get(&account.name)
1769 .map(|r| now.saturating_sub(r.updated_ms))
1770 .unwrap_or(u64::MAX); let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1772 let fetched_ago = now.saturating_sub(last_fetched);
1773
1774 if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1775 tracing::debug!(
1776 account = %account.name,
1777 age_min = rl_age / 60_000,
1778 "rate-limit data stale — refreshing"
1779 );
1780 let token = {
1781 let creds = credentials.read().await;
1782 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1783 };
1784 if let Some(token) = token {
1785 post_cooldown_prefetch(
1786 &client, account, &token, &state,
1787 &config.server.upstream_url,
1788 ).await;
1789 }
1790 last_stale_prefetch.insert(account.name.clone(), now);
1791 }
1792 }
1793 }
1794
1795 let sleep_ms = next_wake_ms
1797 .map(|wake| wake.saturating_sub(now_ms()).max(50))
1798 .unwrap_or(30_000);
1799 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1800 }
1801}
1802
1803use crate::notify::notify;
1804use crate::translate::{
1805 translate_to_anthropic,
1806 translate_from_anthropic,
1807 uuid_v4,
1808 translate_anthropic_stream,
1809 translate_anthropic_req_to_chatgpt,
1810 translate_response_chatgpt_to_anthropic,
1811 translate_anthropic_req_to_openai,
1812 translate_response_openai_to_anthropic,
1813 translate_response_anthropic_to_openai,
1814};
1815
1816async fn openai_models_handler() -> impl IntoResponse {
1831 axum::Json(json!({
1832 "object": "list",
1833 "data": [
1834 { "id": "claude-opus-4-6", "object": "model", "owned_by": "anthropic" },
1835 { "id": "claude-sonnet-4-6", "object": "model", "owned_by": "anthropic" },
1836 { "id": "claude-haiku-4-5-20251001", "object": "model", "owned_by": "anthropic" },
1837 ]
1838 }))
1839}
1840
1841async fn openai_compat_handler(
1843 State(s): State<AppState>,
1844 req: Request,
1845) -> Result<Response, ProxyError> {
1846 let Some(ref anthropic_url) = s.anthropic_base_url else {
1847 return proxy_handler(State(s), req).await;
1849 };
1850
1851 let body_bytes = axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY)
1852 .await
1853 .map_err(|_| ProxyError::BodyRead)?;
1854
1855 let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1856 .unwrap_or(json!({}));
1857
1858 let stream = openai_body["stream"].as_bool().unwrap_or(false);
1859 let anthropic_body = translate_to_anthropic(openai_body);
1860
1861 let client = reqwest::Client::builder()
1862 .timeout(std::time::Duration::from_secs(300))
1863 .build()
1864 .map_err(|_| ProxyError::Upstream)?;
1865
1866 let mut req_builder = client
1867 .post(format!("{anthropic_url}/v1/messages"))
1868 .header("content-type", "application/json")
1869 .header("anthropic-version", "2023-06-01")
1870 .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1871 .header("x-shunt-compat", "openai");
1872 if let Some(ref key) = s.config.server.remote_key {
1873 req_builder = req_builder.header("x-api-key", key.as_str());
1874 }
1875 let resp = req_builder
1876 .json(&anthropic_body)
1877 .send()
1878 .await
1879 .map_err(|_| ProxyError::Upstream)?;
1880
1881 if !resp.status().is_success() {
1882 let status = resp.status();
1883 let body = resp.text().await.unwrap_or_default();
1884 let code = status.as_u16();
1885 return Ok(axum::response::Response::builder()
1886 .status(code)
1887 .header("content-type", "application/json")
1888 .body(axum::body::Body::from(body))
1889 .unwrap());
1890 }
1891
1892 if stream {
1893 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1895 let stream = translate_anthropic_stream(resp, chat_id);
1896 Ok(axum::response::Response::builder()
1897 .status(200)
1898 .header("content-type", "text/event-stream")
1899 .header("cache-control", "no-cache")
1900 .body(axum::body::Body::from_stream(stream))
1901 .unwrap())
1902 } else {
1903 let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1904 let openai_resp = translate_from_anthropic(anthropic_resp);
1905 Ok(axum::Json(openai_resp).into_response())
1906 }
1907}
1908
1909async fn fetch_sentinel_token(client: &reqwest::Client, upstream: &str, token: &str) -> Option<String> {
1916 let url = format!("{}/backend-api/sentinel/chat-requirements", upstream);
1917 let resp = client
1918 .get(&url)
1919 .header("Authorization", format!("Bearer {}", token))
1920 .send()
1921 .await
1922 .ok()?;
1923 if !resp.status().is_success() {
1924 return None;
1925 }
1926 let json: serde_json::Value = resp.json().await.ok()?;
1927 if json["proofofwork"]["required"].as_bool() == Some(true) {
1928 return None;
1929 }
1930 json["token"].as_str().map(ToOwned::to_owned)
1931}
1932
1933
1934fn is_simple_model(model: &str) -> bool {
1937 model.contains("haiku")
1938}
1939
1940fn resolve_model(
1945 incoming: &str,
1946 account: &crate::config::AccountConfig,
1947 mapping: &std::collections::HashMap<String, String>,
1948) -> String {
1949 if let Some(m) = &account.model {
1951 return m.clone();
1952 }
1953 if let Some(m) = mapping.get(incoming) {
1955 return m.clone();
1956 }
1957 let default = account.provider.default_model();
1959 if !default.is_empty() {
1960 return default.to_owned();
1961 }
1962 incoming.to_owned()
1964}
1965