1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::credential::Credential;
16use crate::forwarder::Forwarder;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24 config: Arc<Config>,
25 forwarder: Arc<Forwarder>,
26 state: StateStore,
27 credentials: Arc<RwLock<HashMap<String, Credential>>>,
29 refresh_locks: Arc<std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
37 started_ms: u64,
39 anthropic_base_url: Option<String>,
42}
43
44pub fn create_app(config: Config) -> anyhow::Result<Router> {
45 let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
46 Ok(app)
47}
48
49pub type LiveCredentials = Arc<RwLock<HashMap<String, Credential>>>;
51
52fn build_app_state(
56 config: Config,
57 state: StateStore,
58 anthropic_base_url: Option<String>,
59) -> anyhow::Result<(AppState, LiveCredentials)> {
60 let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
61
62 for a in &config.accounts {
63 if a.provider.auth_kind() == crate::provider::AuthKind::None {
64 state.clear_auth_failed(&a.name);
66 } else if a.credential.is_none() {
67 state.set_auth_failed(&a.name);
68 }
69 }
70
71 let credentials: LiveCredentials = Arc::new(RwLock::new(
72 config.accounts.iter()
73 .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
74 .collect::<HashMap<_, _>>(),
75 ));
76
77 let app_state = AppState {
78 config: Arc::new(config),
79 forwarder: Arc::new(forwarder),
80 state,
81 credentials: Arc::clone(&credentials),
82 refresh_locks: Arc::new(std::sync::Mutex::new(HashMap::new())),
83 started_ms: now_ms(),
84 anthropic_base_url,
85 };
86
87 Ok((app_state, credentials))
88}
89
90pub fn create_proxy_app(
91 config: Config,
92 state: StateStore,
93 anthropic_base_url: Option<String>,
94) -> anyhow::Result<(Router, LiveCredentials)> {
95 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
96
97 let app = Router::new()
98 .route("/v1/messages", post(proxy_handler))
99 .route("/v1/messages/count_tokens", post(proxy_handler))
100 .route("/v1/chat/completions", post(openai_compat_handler))
101 .route("/v1/models", get(openai_models_handler))
102 .fallback(proxy_handler)
103 .with_state(app_state);
104
105 Ok((app, credentials))
106}
107
108pub fn create_control_app(
111 config: Config,
112 state: StateStore,
113) -> anyhow::Result<Router> {
114 let (app_state, _) = build_app_state(config, state, None)?;
115
116 let app = Router::new()
117 .route("/health", get(health))
118 .route("/status", get(status_handler))
119 .route("/use", post(use_handler))
120 .with_state(app_state);
121
122 Ok(app)
123}
124
125pub fn create_app_with_state(
129 config: Config,
130 state: StateStore,
131 anthropic_base_url: Option<String>,
132) -> anyhow::Result<(Router, LiveCredentials)> {
133 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
134
135 let app = Router::new()
136 .route("/health", get(health))
138 .route("/status", get(status_handler))
139 .route("/use", post(use_handler))
140 .route("/v1/messages", post(proxy_handler))
142 .route("/v1/messages/count_tokens", post(proxy_handler))
143 .route("/v1/chat/completions", post(openai_compat_handler))
144 .route("/v1/models", get(openai_models_handler))
145 .fallback(proxy_handler)
146 .with_state(app_state);
147
148 Ok((app, credentials))
149}
150
151async fn health() -> impl IntoResponse {
152 axum::Json(json!({"status": "ok"}))
153}
154
155async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
156 let account_states = s.state.account_states();
157 let quotas = s.state.quota_snapshot();
158 let rate_limits = s.state.rate_limit_snapshot();
159
160 let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
161 let st = account_states.get(&a.name);
162 let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
163 "reauth_required"
164 } else if st.map(|s| s.disabled).unwrap_or(false) {
165 "disabled"
166 } else if s.state.is_available(&a.name) {
167 "available"
168 } else {
169 "cooling"
170 };
171
172 let quota = quotas.get(&a.name);
173 let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
174 let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
175 let tokens_used = quota.map(|q| json!({
176 "input": q.input_tokens,
177 "output": q.output_tokens,
178 "total": q.total_tokens(),
179 }));
180
181 let rl = rate_limits.get(&a.name);
182 let rate_limit = rl.map(|r| json!({
183 "utilization_5h": r.utilization_5h,
184 "reset_5h": r.reset_5h,
185 "status_5h": r.status_5h,
186 "utilization_7d": r.utilization_7d,
187 "reset_7d": r.reset_7d,
188 "status_7d": r.status_7d,
189 "representative_claim": r.representative_claim,
190 "updated_ms": r.updated_ms,
191 }));
192
193 let acc_state = account_states.get(&a.name);
194 let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
195 let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
196 let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
197 let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
198 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
199 let reset_5h = rl.and_then(|r| r.reset_5h);
200 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
201 let reset_7d = rl.and_then(|r| r.reset_7d);
202 let available = s.state.is_available(&a.name);
203
204 json!({
205 "name": a.name,
206 "email": email,
207 "plan_type": a.plan_type,
208 "provider": a.provider.to_string(),
209 "status": avail_status,
210 "available": available,
211 "disabled": disabled,
212 "auth_failed": auth_failed,
213 "cooldown_until_ms": cooldown_until_ms,
214 "utilization_5h": utilization_5h,
215 "reset_5h": reset_5h,
216 "utilization_7d": utilization_7d,
217 "reset_7d": reset_7d,
218 "window_expires_ms": window_expires_ms,
219 "tokens_used": tokens_used,
220 "rate_limit": rate_limit,
221 })
222 }).collect();
223
224 let recent_requests = s.state.recent_requests_snapshot();
225 let savings = s.state.savings_snapshot();
226
227 axum::Json(json!({
228 "version": env!("CARGO_PKG_VERSION"),
229 "started_ms": s.started_ms,
230 "accounts": accounts,
231 "pinned_account": s.state.get_pinned(),
232 "last_used_account": s.state.get_last_used(),
233 "recent_requests": recent_requests,
234 "savings": savings,
235 }))
236}
237
238async fn use_handler(
239 State(s): State<AppState>,
240 axum::Json(body): axum::Json<serde_json::Value>,
241) -> Response {
242 let account = body["account"].as_str().map(|s| s.to_owned());
243 if let Some(ref name) = account {
245 if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
246 return (StatusCode::BAD_REQUEST, axum::Json(json!({
247 "error": format!("unknown account '{name}'")
248 }))).into_response();
249 }
250 let pinned = if name == "auto" { None } else { Some(name.clone()) };
251 s.state.set_pinned(pinned);
252 axum::Json(json!({ "pinned": name })).into_response()
253 } else {
254 s.state.set_pinned(None);
255 axum::Json(json!({ "pinned": null })).into_response()
256 }
257}
258
259fn now_ms() -> u64 {
260 use std::time::{SystemTime, UNIX_EPOCH};
261 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
262}
263
264async fn proxy_handler(
265 State(s): State<AppState>,
266 req: Request,
267) -> Result<Response, ProxyError> {
268 if let Some(ref expected) = s.config.server.remote_key {
270 let provided = req.headers()
271 .get("x-api-key")
272 .and_then(|v| v.to_str().ok())
273 .unwrap_or("");
274 if provided != expected {
275 return Err(ProxyError::Unauthorized);
276 }
277 }
278
279 let method = req.method().as_str().to_owned();
280 let path = req.uri().path().to_owned();
281 let headers = req.headers().clone();
282
283 let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
284 .await
285 .map_err(|_| ProxyError::BodyRead)?;
286
287 let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
288 .ok()
289 .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
290 .unwrap_or_default();
291 let req_start_ms = now_ms();
292
293 let fp = router::fingerprint(&body_bytes);
294 let fp_ref = fp.as_deref();
295
296 let mut tried: HashSet<String> = HashSet::new();
297 let mut refreshed: HashSet<String> = HashSet::new();
299 let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
301
302 loop {
303 let account = match router::pick_account(
304 &s.config.accounts, &s.state, fp_ref, &tried,
305 s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
306 ) {
307 Some(a) => a,
308 None => {
309 let account_states = s.state.account_states();
313 let now = now_ms();
314 let soonest_ms = s.config.accounts.iter()
315 .filter_map(|a| {
316 let st = account_states.get(&a.name)?;
317 if st.disabled { return None; } if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
319 })
320 .min();
321
322 match soonest_ms {
323 Some(wake_ms) if wake_ms <= wait_deadline_ms => {
324 let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; warn!(wait_ms, "all accounts cooling — waiting for next available account");
326 tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
327 tried.clear(); }
329 _ => return Err(ProxyError::AllAccountsUnavailable),
330 }
331 continue;
332 }
333 };
334
335 let account_name = account.name.clone();
336
337 let token = {
342 let creds = s.credentials.read().await;
343 let cred = creds.get(&account_name)
344 .cloned()
345 .or_else(|| account.credential.clone());
346 match cred {
347 Some(c) => c.bearer_token().to_owned(),
348 None => String::new(),
349 }
350 };
351
352 let req_is_anthropic = path.starts_with("/v1/messages");
356 let acct_is_anthropic = account.provider.wire_protocol()
357 == crate::provider::WireProtocol::Anthropic;
358 let acct_is_chatgpt = matches!(account.provider, Provider::OpenAI);
361
362 let mut log_model = model.clone();
365
366 let (fwd_path, fwd_body, mut fwd_headers) = if req_is_anthropic == acct_is_anthropic {
367 (path.clone(), body_bytes.clone(), headers.clone())
369 } else if req_is_anthropic && acct_is_chatgpt {
370 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
372 let translated = translate_anthropic_req_to_chatgpt(&val);
373 let mut h = headers.clone();
374 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
375 h.remove(*name);
376 }
377 (
378 "/backend-api/conversation".to_owned(),
379 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
380 h,
381 )
382 } else if req_is_anthropic {
383 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
385 let target_model = resolve_model(&model, account, &s.config.model_mapping);
387 log_model = target_model.clone();
388 let translated = translate_anthropic_req_to_openai(val, &target_model);
389 let mut h = headers.clone();
390 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
391 h.remove(*name);
392 }
393 (
394 "/v1/chat/completions".to_owned(),
395 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
396 h,
397 )
398 } else {
399 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
401 let translated = translate_to_anthropic(val);
402 (
403 "/v1/messages".to_owned(),
404 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
405 headers.clone(),
406 )
407 };
408
409 let upstream = account.upstream_url.as_deref()
412 .unwrap_or(&s.config.server.upstream_url);
413
414 if req_is_anthropic && acct_is_chatgpt {
417 tracing::info!(account = %account_name, upstream = %upstream, "routing to chatgpt.com — fetching sentinel");
418 let sentinel_client = reqwest::Client::builder()
419 .timeout(std::time::Duration::from_secs(3))
420 .build()
421 .unwrap_or_default();
422 let sentinel_opt = tokio::time::timeout(
423 std::time::Duration::from_secs(3),
424 fetch_sentinel_token(&sentinel_client, upstream, &token),
425 ).await.ok().flatten();
426 if let Some(sentinel) = sentinel_opt {
427 if let Ok(name) = axum::http::header::HeaderName::from_bytes(
428 b"openai-sentinel-chat-requirements-token",
429 ) {
430 if let Ok(val) = axum::http::HeaderValue::from_str(&sentinel) {
431 fwd_headers.insert(name, val);
432 }
433 }
434 }
435 }
436
437 let response = if acct_is_chatgpt {
440 tracing::info!(account = %account_name, path = %fwd_path, "forwarding to chatgpt.com (15s cap)");
441 match tokio::time::timeout(
442 std::time::Duration::from_secs(15),
443 s.forwarder.forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token),
444 ).await {
445 Ok(Ok(r)) => r,
446 Ok(Err(e)) => {
447 error!(account = %account_name, "chatgpt.com forward error: {:#}", e);
448 s.state.set_cooldown(&account_name, 5 * 60_000);
449 tried.insert(account_name);
450 continue;
451 }
452 Err(_) => {
453 warn!(account = %account_name, "chatgpt.com request timed out (Cloudflare) — cooling 5min");
454 s.state.set_cooldown(&account_name, 5 * 60_000);
455 tried.insert(account_name);
456 continue;
457 }
458 }
459 } else {
460 s.forwarder
461 .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
462 .await
463 .map_err(|e| {
464 error!("Forward error: {:#}", e);
465 ProxyError::Upstream
466 })?
467 };
468
469 match response.status().as_u16() {
470 200..=299 => {
471 s.state.set_last_used(&account_name);
472 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
473 s.state.update_rate_limits(&account_name, info);
474 }
475 let response = if req_is_anthropic == acct_is_anthropic {
477 response
478 } else if req_is_anthropic && acct_is_chatgpt {
479 translate_response_chatgpt_to_anthropic(response, &model).await
481 } else if req_is_anthropic {
482 translate_response_openai_to_anthropic(response, &model).await
484 } else {
485 translate_response_anthropic_to_openai(response).await
487 };
488 return Ok(tap_usage(response, &s.state, &account_name, &log_model, req_start_ms).await);
489 }
490 429 => {
491 let info = account.provider.parse_rate_limits(response.headers());
492 let cooldown_ms = info.as_ref()
495 .and_then(|i| i.reset_5h.or(i.reset_7d))
496 .map(|reset_secs| {
497 let reset_ms = reset_secs.saturating_mul(1_000);
498 reset_ms.saturating_sub(now_ms()).saturating_add(500) })
500 .unwrap_or(60_000);
501 warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
502 if let Some(info) = info {
503 s.state.update_rate_limits(&account_name, info);
504 }
505 s.state.set_cooldown(&account_name, cooldown_ms);
506 if cooldown_ms >= 5 * 60_000 {
507 let mins = cooldown_ms / 60_000;
508 notify(
509 "shunt: Rate Limited",
510 &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
511 "Ping",
512 );
513 }
514 tried.insert(account_name);
515 }
516 529 => {
517 warn!(account = %account_name, "529 overloaded — cooling 30s");
518 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
519 s.state.update_rate_limits(&account_name, info);
520 }
521 s.state.set_cooldown(&account_name, 30_000);
522 tried.insert(account_name);
523 }
524 401 => {
525 if !refreshed.contains(&account_name) {
526 let account_lock = {
534 let mut locks = s.refresh_locks.lock().unwrap();
535 locks.entry(account_name.clone())
536 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
537 .clone()
538 };
539 let _guard = account_lock.lock().await;
540
541 let cred_before = {
544 let creds = s.credentials.read().await;
545 creds.get(&account_name).cloned()
546 .or_else(|| account.credential.clone())
547 };
548 let Some(cred) = cred_before else {
549 tried.insert(account_name);
550 continue;
551 };
552
553 let token_before = cred.access_token().to_owned();
555 let already_refreshed = {
556 let creds = s.credentials.read().await;
557 creds.get(&account_name)
558 .map(|c| c.access_token() != token_before)
559 .unwrap_or(false)
560 };
561
562 if already_refreshed {
563 warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
565 refreshed.insert(account_name);
566 } else if let Some(oauth_cred) = cred.as_oauth() {
567 match tokio::time::timeout(
569 std::time::Duration::from_secs(10),
570 account.provider.refresh_token(oauth_cred),
571 ).await {
572 Ok(Ok(fresh)) => {
573 warn!(account = %account_name, "401 — token refreshed, retrying");
574 {
575 let mut creds = s.credentials.write().await;
576 creds.insert(account_name.clone(), Credential::Oauth(fresh.clone()));
577 }
578 let name = account_name.clone();
580 let fresh = fresh.clone();
581 tokio::task::spawn_blocking(move || {
582 let mut store = CredentialsStore::load();
583 store.accounts.insert(name, Credential::Oauth(fresh.clone()));
584 store.save().ok();
585 if fresh.id_token.is_some() {
586 crate::oauth::write_codex_auth_file(&fresh);
587 }
588 });
589 refreshed.insert(account_name);
591 }
592 _ => {
593 error!(account = %account_name, "401 — token refresh failed, cooling 5min");
595 s.state.set_cooldown(&account_name, 5 * 60_000);
596 tried.insert(account_name);
597 }
598 }
599 } else {
600 error!(account = %account_name, "401 — API key rejected, cooling 5min");
602 s.state.set_cooldown(&account_name, 5 * 60_000);
603 tried.insert(account_name);
604 }
605 } else {
606 error!(account = %account_name, "401 after refresh — cooling 5min");
608 s.state.set_cooldown(&account_name, 5 * 60_000);
609 tried.insert(account_name);
610 }
611 }
612 403 => {
613 if acct_is_anthropic {
617 error!(account = %account_name, "403 forbidden — cooling 30min");
618 s.state.set_cooldown(&account_name, 30 * 60_000);
619 notify(
620 "shunt: Account Forbidden",
621 &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
622 "Basso",
623 );
624 } else {
625 warn!(account = %account_name, "403 from chatgpt.com (Cloudflare) — cooling 5min");
626 s.state.set_cooldown(&account_name, 5 * 60_000);
627 }
628 tried.insert(account_name);
629 }
630 _ => {
631 return Ok(response);
633 }
634 }
635 }
636}
637
638async fn tap_usage(
647 resp: Response,
648 state: &StateStore,
649 account: &str,
650 model: &str,
651 req_start_ms: u64,
652) -> Response {
653 use axum::body::Body;
654 use crate::state::RequestLog;
655
656 if quota::is_streaming_response(&resp) {
657 let state = state.clone();
658 let account = account.to_owned();
659 let model = model.to_owned();
660 let on_complete = Arc::new(move |input: u64, output: u64| {
661 state.record_usage(&account, input, output);
662 state.record_global(&model, input, output);
663 state.record_request(RequestLog {
664 ts_ms: req_start_ms,
665 account: account.clone(),
666 model: model.clone(),
667 status: 200,
668 input_tokens: input,
669 output_tokens: output,
670 duration_ms: now_ms().saturating_sub(req_start_ms),
671 });
672 });
673 let (parts, body) = resp.into_parts();
674 let wrapped = quota::wrap_streaming_body(body, on_complete);
675 return Response::from_parts(parts, wrapped);
676 }
677
678 let (parts, body) = resp.into_parts();
680 let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
681 Ok(b) => b,
682 Err(_) => return Response::from_parts(parts, Body::empty()),
683 };
684 let (input, output) = quota::extract_usage_from_json(&bytes);
685 state.record_usage(account, input, output);
686 state.record_global(model, input, output);
687 state.record_request(RequestLog {
688 ts_ms: req_start_ms,
689 account: account.to_owned(),
690 model: model.to_owned(),
691 status: 200,
692 input_tokens: input,
693 output_tokens: output,
694 duration_ms: now_ms().saturating_sub(req_start_ms),
695 });
696 Response::from_parts(parts, Body::from(bytes))
697}
698
699
700pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
708 let client = reqwest::Client::builder()
709 .timeout(std::time::Duration::from_secs(20))
710 .build()
711 .unwrap_or_default();
712
713 for account in &config.accounts {
714 let rl = state.rate_limit_snapshot();
716 if let Some(r) = rl.get(&account.name) {
717 if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
718 continue;
719 }
720 }
721
722 let cred = match account.credential.clone() {
724 Some(c) => c,
725 None => continue,
726 };
727
728 let Some((path, body)) = account.provider.prefetch_request() else {
729 if let Some(probe_path) = account.provider.auth_probe_get_path() {
731 auth_probe_get(&client, probe_path, account, &state).await;
732 }
733 continue;
734 };
735 let url = format!("{}{}", config.server.upstream_url, path);
736
737 let resp = prefetch_send(&client, &url, &account.provider, cred.bearer_token(), &body).await;
738
739 let r = match resp {
740 Ok(r) => r,
741 Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
742 };
743
744 if r.status() == reqwest::StatusCode::UNAUTHORIZED {
745 tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
746 let Some(oauth_cred) = cred.as_oauth() else {
747 tracing::error!(account = %account.name, "prefetch 401 — API key rejected");
749 state.set_auth_failed(&account.name);
750 continue;
751 };
752 let fresh = match account.provider.refresh_token(oauth_cred).await {
753 Ok(f) => f,
754 Err(e) => {
755 tracing::warn!(account = %account.name, "token refresh failed: {e}");
756 state.set_auth_failed(&account.name);
757 continue;
758 }
759 };
760 let mut store = crate::config::CredentialsStore::load();
761 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
762 store.save().ok();
763 if fresh.id_token.is_some() {
764 crate::oauth::write_codex_auth_file(&fresh);
765 }
766 live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
768
769 match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
770 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
771 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
772 state.set_auth_failed(&account.name);
773 }
774 Ok(r2) => {
775 if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
776 state.update_rate_limits(&account.name, info);
777 }
778 }
779 Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
780 }
781 } else {
782 tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
783 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
784 state.update_rate_limits(&account.name, info);
785 }
786 }
787 }
788}
789
790async fn prefetch_send(
792 client: &reqwest::Client,
793 url: &str,
794 provider: &crate::provider::Provider,
795 token: &str,
796 body: &serde_json::Value,
797) -> anyhow::Result<reqwest::Response> {
798 let mut headers = reqwest::header::HeaderMap::new();
799 provider.inject_auth_headers(&mut headers, token)?;
800 for (name, value) in provider.prefetch_extra_headers() {
801 headers.insert(
802 reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
803 reqwest::header::HeaderValue::from_static(value),
804 );
805 }
806 Ok(client.post(url).headers(headers).json(body).send().await?)
807}
808
809async fn auth_probe_get(
813 client: &reqwest::Client,
814 path: &str,
815 account: &crate::config::AccountConfig,
816 state: &StateStore,
817) {
818 let cred = match account.credential.clone() {
819 Some(c) => c,
820 None => return,
821 };
822 let upstream = account.upstream_url.as_deref()
823 .unwrap_or_else(|| account.provider.default_upstream_url());
824 let url = format!("{}{}", upstream, path);
825
826 let do_get = |token: &str| -> reqwest::RequestBuilder {
827 let mut headers = reqwest::header::HeaderMap::new();
828 let _ = account.provider.inject_auth_headers(&mut headers, token);
829 client.get(&url).headers(headers)
830 };
831
832 let resp = match do_get(cred.bearer_token()).send().await {
833 Ok(r) => r,
834 Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
835 };
836
837 if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
838 tracing::info!(account = %account.name, "auth probe: token rejected, refreshing");
839 let Some(oauth_cred) = cred.as_oauth() else {
840 tracing::error!(account = %account.name, "auth probe 401 — API key rejected");
842 state.set_auth_failed(&account.name);
843 return;
844 };
845 let fresh = match account.provider.refresh_token(oauth_cred).await {
846 Ok(f) => f,
847 Err(e) => {
848 tracing::warn!(account = %account.name, "token refresh failed: {e}");
849 state.set_auth_failed(&account.name);
850 return;
851 }
852 };
853 let mut store = crate::config::CredentialsStore::load();
854 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
855 store.save().ok();
856 if fresh.id_token.is_some() {
857 crate::oauth::write_codex_auth_file(&fresh);
858 }
859
860 let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
861 match do_get(fresh_token).send().await {
862 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
863 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
864 state.set_auth_failed(&account.name);
865 }
866 Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
867 Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
868 }
869 } else {
870 tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
871 }
875}
876
877fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
884 let now_ms = std::time::SystemTime::now()
885 .duration_since(std::time::UNIX_EPOCH)
886 .unwrap_or_default()
887 .as_millis() as u64;
888 let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
889 .unwrap_or(cred.expires_at);
890 exp_ms < now_ms + threshold_mins * 60 * 1_000
891}
892
893async fn sync_live_creds_from_auth_json(
898 account_name: &str,
899 live_creds: &LiveCredentials,
900) {
901 let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
902 let current_exp = live_creds.read().await
903 .get(account_name)
904 .and_then(|c| c.as_oauth())
905 .map(|c| c.expires_at)
906 .unwrap_or(0);
907 if from_file.expires_at > current_exp {
908 tracing::info!(account = %account_name, "synced fresher token from auth.json");
909 live_creds.write().await.insert(account_name.to_owned(), Credential::Oauth(from_file));
910 }
911}
912
913async fn do_proactive_refresh(
915 account: &crate::config::AccountConfig,
916 creds: &crate::oauth::OAuthCredential,
917 live_creds: &LiveCredentials,
918 state: &StateStore,
919) {
920 tracing::info!(account = %account.name, "proactive OpenAI token refresh");
921 match account.provider.refresh_token(creds).await {
922 Ok(fresh) => {
923 tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
924 {
925 let mut map = live_creds.write().await;
926 map.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
927 }
928 let mut store = crate::config::CredentialsStore::load();
929 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
930 store.save().ok();
931 if fresh.id_token.is_some() {
932 crate::oauth::write_codex_auth_file(&fresh);
933 }
934 state.clear_auth_failed(&account.name);
935 }
936 Err(e) => {
937 tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
938 state.set_auth_failed(&account.name);
939 }
940 }
941}
942
943
944pub async fn openai_token_refresh_loop(
952 config: Arc<Config>,
953 state: StateStore,
954 live_creds: LiveCredentials,
955) {
956 for account in config.accounts.iter()
958 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
959 {
960 if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
961 continue;
962 }
963 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
964
965 let creds = {
966 let map = live_creds.read().await;
967 map.get(&account.name).cloned().or_else(|| account.credential.clone())
968 };
969 if let Some(creds) = creds {
970 if let Some(oauth) = creds.as_oauth() {
971 if access_token_expires_soon(oauth, 30) {
972 do_proactive_refresh(account, oauth, &live_creds, &state).await;
974 } else {
975 tracing::info!(account = %account.name, "access_token fresh at startup");
976 }
977 }
978 }
979 }
980
981 loop {
984 tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
985 for account in config.accounts.iter()
986 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
987 {
988 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
989 }
990 }
991}
992
993enum ProxyError {
998 BodyRead,
999 Upstream,
1000 AllAccountsUnavailable,
1001 Unauthorized,
1002}
1003
1004impl IntoResponse for ProxyError {
1005 fn into_response(self) -> Response {
1006 let (status, msg) = match self {
1007 ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
1008 ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
1009 ProxyError::AllAccountsUnavailable => {
1010 (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
1011 }
1012 ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
1013 };
1014
1015 (status, axum::Json(json!({
1016 "type": "error",
1017 "error": {"type": "api_error", "message": msg}
1018 }))).into_response()
1019 }
1020}
1021
1022pub async fn recovery_watcher(
1031 config: Arc<Config>,
1032 state: StateStore,
1033 credentials: LiveCredentials,
1034) {
1035 use std::time::{Duration, Instant};
1036 const CHECK_INTERVAL: Duration = Duration::from_secs(120);
1037 const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
1038
1039 let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
1040 let mut last_notified: Option<Instant> = None;
1041
1042 loop {
1043 tokio::time::sleep(CHECK_INTERVAL).await;
1044
1045 let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
1046 let failed = state.auth_failed_accounts(&name_refs);
1047 if failed.is_empty() {
1048 last_notified = None;
1049 continue;
1050 }
1051
1052 tracing::warn!(
1053 accounts = ?failed,
1054 "recovery: {} account(s) auth_failed, attempting token refresh",
1055 failed.len()
1056 );
1057
1058 let mut any_recovered = false;
1059
1060 for name in &failed {
1061 let cred = {
1062 let map = credentials.read().await;
1063 map.get(*name).cloned()
1064 };
1065 let Some(cred) = cred else { continue };
1066 if !cred.has_refresh_token() { continue; }
1067 let Some(oauth_cred) = cred.as_oauth().cloned() else { continue };
1068
1069 let provider = config.accounts.iter()
1070 .find(|a| a.name == *name)
1071 .map(|a| a.provider.clone())
1072 .unwrap_or_default();
1073
1074 let result = tokio::time::timeout(
1075 Duration::from_secs(20),
1076 provider.refresh_token(&oauth_cred),
1077 ).await;
1078
1079 match result {
1080 Ok(Ok(fresh)) => {
1081 tracing::info!(account = %name, "recovery: token refreshed — account back online");
1082 {
1083 let mut map = credentials.write().await;
1084 map.insert(name.to_string(), Credential::Oauth(fresh.clone()));
1085 }
1086 let name_owned = name.to_string();
1087 let fresh_owned = fresh.clone();
1088 tokio::task::spawn_blocking(move || {
1089 let mut store = crate::config::CredentialsStore::load();
1090 store.accounts.insert(name_owned, Credential::Oauth(fresh_owned.clone()));
1091 store.save().ok();
1092 if fresh_owned.id_token.is_some() {
1093 crate::oauth::write_codex_auth_file(&fresh_owned);
1094 }
1095 });
1096 state.clear_auth_failed(name);
1097 any_recovered = true;
1098 }
1099 Ok(Err(e)) => {
1100 tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
1101 notify(
1102 "shunt: Reauth Required",
1103 &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
1104 "Basso",
1105 );
1106 }
1107 Err(_) => {
1108 tracing::error!(account = %name, "recovery: token refresh timed out");
1109 notify(
1110 "shunt: Reauth Required",
1111 &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
1112 "Basso",
1113 );
1114 }
1115 }
1116 }
1117
1118 if any_recovered {
1119 tracing::info!("recovery: at least one account is back online");
1120 continue;
1121 }
1122
1123 let still_failed = state.auth_failed_accounts(&name_refs);
1125 if still_failed.len() == account_names.len() {
1126 let should_notify = last_notified
1127 .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
1128 .unwrap_or(true);
1129 if should_notify {
1130 error!(
1131 "ALL accounts are offline (auth failed). \
1132 Run `shunt add-account` to re-authorize."
1133 );
1134 notify(
1135 "shunt: All Accounts Offline",
1136 "All accounts need re-authorization. Run `shunt add-account`.",
1137 "Basso",
1138 );
1139 last_notified = Some(Instant::now());
1140 }
1141 }
1142 }
1143}
1144
1145async fn post_cooldown_prefetch(
1149 client: &reqwest::Client,
1150 account: &crate::config::AccountConfig,
1151 token: &str,
1152 state: &StateStore,
1153 upstream_url: &str,
1154) {
1155 let Some((path, body)) = account.provider.prefetch_request() else {
1156 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1157 auth_probe_get(client, probe_path, account, state).await;
1158 }
1159 return;
1160 };
1161 let url = format!("{upstream_url}{path}");
1162 match prefetch_send(client, &url, &account.provider, token, &body).await {
1163 Ok(r) => {
1164 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1165 state.update_rate_limits(&account.name, info);
1166 tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1167 }
1168 }
1169 Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1170 }
1171}
1172
1173pub async fn cooldown_watcher(
1184 config: Arc<Config>,
1185 state: StateStore,
1186 credentials: LiveCredentials,
1187) {
1188 const STALE_RL_MS: u64 = 60 * 60_000;
1190
1191 let client = reqwest::Client::builder()
1192 .timeout(std::time::Duration::from_secs(20))
1193 .build()
1194 .unwrap_or_default();
1195
1196 let mut last_resumed: HashMap<String, u64> = HashMap::new();
1199 let mut notify_on_resume: HashSet<String> = HashSet::new();
1201 let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1203
1204 loop {
1205 let states = state.account_states();
1206 let rl_snapshot = state.rate_limit_snapshot();
1207 let now = now_ms();
1208 let mut next_wake_ms: Option<u64> = None;
1209
1210 for account in &config.accounts {
1211 let Some(st) = states.get(&account.name) else { continue };
1212 if st.disabled { continue; } let cdl = st.cooldown_until_ms;
1214
1215 if cdl > 0 && cdl <= now {
1216 let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1218 if !handled {
1219 tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1220 let token = {
1221 let creds = credentials.read().await;
1222 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1223 };
1224 if let Some(token) = token {
1225 post_cooldown_prefetch(
1226 &client, account, &token, &state,
1227 &config.server.upstream_url,
1228 ).await;
1229 }
1230 if notify_on_resume.remove(&account.name) {
1231 notify(
1232 "shunt: Account Resumed",
1233 &format!("Account '{}' is back online.", account.name),
1234 "Glass",
1235 );
1236 }
1237 last_resumed.insert(account.name.clone(), cdl);
1238 last_stale_prefetch.insert(account.name.clone(), now);
1239 }
1240 } else if cdl > now {
1241 let remaining = cdl - now;
1243 if remaining >= 5 * 60_000 {
1244 notify_on_resume.insert(account.name.clone());
1245 }
1246 next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1247 } else {
1248 let rl_age = rl_snapshot
1250 .get(&account.name)
1251 .map(|r| now.saturating_sub(r.updated_ms))
1252 .unwrap_or(u64::MAX); let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1254 let fetched_ago = now.saturating_sub(last_fetched);
1255
1256 if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1257 tracing::debug!(
1258 account = %account.name,
1259 age_min = rl_age / 60_000,
1260 "rate-limit data stale — refreshing"
1261 );
1262 let token = {
1263 let creds = credentials.read().await;
1264 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1265 };
1266 if let Some(token) = token {
1267 post_cooldown_prefetch(
1268 &client, account, &token, &state,
1269 &config.server.upstream_url,
1270 ).await;
1271 }
1272 last_stale_prefetch.insert(account.name.clone(), now);
1273 }
1274 }
1275 }
1276
1277 let sleep_ms = next_wake_ms
1279 .map(|wake| wake.saturating_sub(now_ms()).max(50))
1280 .unwrap_or(30_000);
1281 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1282 }
1283}
1284
1285use crate::notify::notify;
1286use crate::translate::{
1287 translate_to_anthropic,
1288 translate_from_anthropic,
1289 uuid_v4,
1290 translate_anthropic_stream,
1291 translate_anthropic_req_to_chatgpt,
1292 translate_response_chatgpt_to_anthropic,
1293 translate_anthropic_req_to_openai,
1294 translate_response_openai_to_anthropic,
1295 translate_response_anthropic_to_openai,
1296};
1297
1298async fn openai_models_handler() -> impl IntoResponse {
1313 axum::Json(json!({
1314 "object": "list",
1315 "data": [
1316 { "id": "claude-opus-4-6", "object": "model", "owned_by": "anthropic" },
1317 { "id": "claude-sonnet-4-6", "object": "model", "owned_by": "anthropic" },
1318 { "id": "claude-haiku-4-5-20251001", "object": "model", "owned_by": "anthropic" },
1319 ]
1320 }))
1321}
1322
1323async fn openai_compat_handler(
1325 State(s): State<AppState>,
1326 req: Request,
1327) -> Result<Response, ProxyError> {
1328 let Some(ref anthropic_url) = s.anthropic_base_url else {
1329 return proxy_handler(State(s), req).await;
1331 };
1332
1333 let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1334 .await
1335 .map_err(|_| ProxyError::BodyRead)?;
1336
1337 let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1338 .unwrap_or(json!({}));
1339
1340 let stream = openai_body["stream"].as_bool().unwrap_or(false);
1341 let anthropic_body = translate_to_anthropic(openai_body);
1342
1343 let client = reqwest::Client::builder()
1344 .timeout(std::time::Duration::from_secs(300))
1345 .build()
1346 .map_err(|_| ProxyError::Upstream)?;
1347
1348 let resp = client
1349 .post(format!("{anthropic_url}/v1/messages"))
1350 .header("content-type", "application/json")
1351 .header("anthropic-version", "2023-06-01")
1352 .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1353 .header("x-shunt-compat", "openai")
1354 .json(&anthropic_body)
1355 .send()
1356 .await
1357 .map_err(|_| ProxyError::Upstream)?;
1358
1359 if !resp.status().is_success() {
1360 let status = resp.status();
1361 let body = resp.text().await.unwrap_or_default();
1362 let code = status.as_u16();
1363 return Ok(axum::response::Response::builder()
1364 .status(code)
1365 .header("content-type", "application/json")
1366 .body(axum::body::Body::from(body))
1367 .unwrap());
1368 }
1369
1370 if stream {
1371 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1373 let stream = translate_anthropic_stream(resp, chat_id);
1374 Ok(axum::response::Response::builder()
1375 .status(200)
1376 .header("content-type", "text/event-stream")
1377 .header("cache-control", "no-cache")
1378 .body(axum::body::Body::from_stream(stream))
1379 .unwrap())
1380 } else {
1381 let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1382 let openai_resp = translate_from_anthropic(anthropic_resp);
1383 Ok(axum::Json(openai_resp).into_response())
1384 }
1385}
1386
1387async fn fetch_sentinel_token(client: &reqwest::Client, upstream: &str, token: &str) -> Option<String> {
1394 let url = format!("{}/backend-api/sentinel/chat-requirements", upstream);
1395 let resp = client
1396 .get(&url)
1397 .header("Authorization", format!("Bearer {}", token))
1398 .send()
1399 .await
1400 .ok()?;
1401 if !resp.status().is_success() {
1402 return None;
1403 }
1404 let json: serde_json::Value = resp.json().await.ok()?;
1405 if json["proofofwork"]["required"].as_bool() == Some(true) {
1406 return None;
1407 }
1408 json["token"].as_str().map(ToOwned::to_owned)
1409}
1410
1411
1412fn resolve_model(
1417 incoming: &str,
1418 account: &crate::config::AccountConfig,
1419 mapping: &std::collections::HashMap<String, String>,
1420) -> String {
1421 if let Some(m) = &account.model {
1423 return m.clone();
1424 }
1425 if let Some(m) = mapping.get(incoming) {
1427 return m.clone();
1428 }
1429 let default = account.provider.default_model();
1431 if !default.is_empty() {
1432 return default.to_owned();
1433 }
1434 incoming.to_owned()
1436}
1437