1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::credential::Credential;
16use crate::forwarder::Forwarder;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21use crate::telemetry::TelemetryClient;
22
23#[derive(Clone)]
24struct AppState {
25 config: Arc<Config>,
26 forwarder: Arc<Forwarder>,
27 state: StateStore,
28 credentials: Arc<RwLock<HashMap<String, Credential>>>,
30 refresh_locks: Arc<std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
38 started_ms: u64,
40 anthropic_base_url: Option<String>,
43 telemetry: Option<TelemetryClient>,
45}
46
47pub fn create_app(config: Config) -> anyhow::Result<Router> {
48 let (app, _, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
49 Ok(app)
50}
51
52pub type LiveCredentials = Arc<RwLock<HashMap<String, Credential>>>;
54
55fn build_app_state(
59 config: Config,
60 state: StateStore,
61 anthropic_base_url: Option<String>,
62) -> anyhow::Result<(AppState, LiveCredentials)> {
63 let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
64
65 for a in &config.accounts {
66 if a.provider.auth_kind() == crate::provider::AuthKind::None {
67 state.clear_auth_failed(&a.name);
69 } else if a.credential.is_none() {
70 state.set_auth_failed(&a.name);
71 }
72 }
73
74 let credentials: LiveCredentials = Arc::new(RwLock::new(
75 config.accounts.iter()
76 .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
77 .collect::<HashMap<_, _>>(),
78 ));
79
80 let telemetry = config.server.telemetry_url.as_deref().map(|url| {
81 TelemetryClient::new(url, config.server.telemetry_token.clone(), config.server.instance_name.clone())
82 });
83
84 let app_state = AppState {
85 config: Arc::new(config),
86 forwarder: Arc::new(forwarder),
87 state,
88 credentials: Arc::clone(&credentials),
89 refresh_locks: Arc::new(std::sync::Mutex::new(HashMap::new())),
90 started_ms: now_ms(),
91 anthropic_base_url,
92 telemetry,
93 };
94
95 Ok((app_state, credentials))
96}
97
98pub fn create_proxy_app(
99 config: Config,
100 state: StateStore,
101 anthropic_base_url: Option<String>,
102) -> anyhow::Result<(Router, LiveCredentials)> {
103 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
104
105 let app = Router::new()
106 .route("/v1/messages", post(proxy_handler))
107 .route("/v1/messages/count_tokens", post(proxy_handler))
108 .route("/v1/chat/completions", post(openai_compat_handler))
109 .route("/v1/models", get(openai_models_handler))
110 .fallback(proxy_handler)
111 .with_state(app_state);
112
113 Ok((app, credentials))
114}
115
116pub fn create_control_app(
119 config: Config,
120 state: StateStore,
121) -> anyhow::Result<Router> {
122 let (app_state, _) = build_app_state(config, state, None)?;
123
124 let app = Router::new()
125 .route("/health", get(health))
126 .route("/status", get(status_handler))
127 .route("/use", post(use_handler))
128 .with_state(app_state);
129
130 Ok(app)
131}
132
133pub fn create_app_with_state(
137 config: Config,
138 state: StateStore,
139 anthropic_base_url: Option<String>,
140) -> anyhow::Result<(Router, LiveCredentials, Option<TelemetryClient>)> {
141 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
142 let telemetry = app_state.telemetry.clone();
143
144 let app = Router::new()
145 .route("/health", get(health))
147 .route("/status", get(status_handler))
148 .route("/use", post(use_handler))
149 .route("/v1/messages", post(proxy_handler))
151 .route("/v1/messages/count_tokens", post(proxy_handler))
152 .route("/v1/chat/completions", post(openai_compat_handler))
153 .route("/v1/models", get(openai_models_handler))
154 .fallback(proxy_handler)
155 .with_state(app_state);
156
157 Ok((app, credentials, telemetry))
158}
159
160pub fn build_status_snapshot(config: &Config, state: &StateStore, started_ms: u64) -> serde_json::Value {
162 let account_states = state.account_states();
163 let quotas = state.quota_snapshot();
164 let rate_limits = state.rate_limit_snapshot();
165
166 let accounts: Vec<_> = config.accounts.iter().map(|a| {
167 let st = account_states.get(&a.name);
168 let rl = rate_limits.get(&a.name);
169 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
170 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
171 let reset_5h = rl.and_then(|r| r.reset_5h);
172 let reset_7d = rl.and_then(|r| r.reset_7d);
173 let disabled = st.map(|s| s.disabled).unwrap_or(false);
174 let auth_failed = st.map(|s| s.auth_failed).unwrap_or(false);
175 let cooldown_until_ms = st.map(|s| s.cooldown_until_ms).unwrap_or(0);
176 let available = state.is_available(&a.name);
177 let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
178
179 json!({
180 "name": a.name,
181 "email": email,
182 "provider": a.provider.to_string(),
183 "available": available,
184 "disabled": disabled,
185 "auth_failed": auth_failed,
186 "cooldown_until_ms": cooldown_until_ms,
187 "utilization_5h": utilization_5h,
188 "reset_5h": reset_5h,
189 "utilization_7d": utilization_7d,
190 "reset_7d": reset_7d,
191 })
192 }).collect();
193
194 json!({
195 "started_ms": started_ms,
196 "accounts": accounts,
197 "pinned_account": state.get_pinned(),
198 "last_used_account": state.get_last_used(),
199 })
200}
201
202async fn health() -> impl IntoResponse {
203 axum::Json(json!({"status": "ok"}))
204}
205
206async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
207 let account_states = s.state.account_states();
208 let quotas = s.state.quota_snapshot();
209 let rate_limits = s.state.rate_limit_snapshot();
210
211 let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
212 let st = account_states.get(&a.name);
213 let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
214 "reauth_required"
215 } else if st.map(|s| s.disabled).unwrap_or(false) {
216 "disabled"
217 } else if s.state.is_available(&a.name) {
218 "available"
219 } else {
220 "cooling"
221 };
222
223 let quota = quotas.get(&a.name);
224 let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
225 let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
226 let tokens_used = quota.map(|q| json!({
227 "input": q.input_tokens,
228 "output": q.output_tokens,
229 "total": q.total_tokens(),
230 }));
231
232 let rl = rate_limits.get(&a.name);
233 let rate_limit = rl.map(|r| json!({
234 "utilization_5h": r.utilization_5h,
235 "reset_5h": r.reset_5h,
236 "status_5h": r.status_5h,
237 "utilization_7d": r.utilization_7d,
238 "reset_7d": r.reset_7d,
239 "status_7d": r.status_7d,
240 "representative_claim": r.representative_claim,
241 "updated_ms": r.updated_ms,
242 }));
243
244 let acc_state = account_states.get(&a.name);
245 let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
246 let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
247 let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
248 let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
249 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
250 let reset_5h = rl.and_then(|r| r.reset_5h);
251 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
252 let reset_7d = rl.and_then(|r| r.reset_7d);
253 let available = s.state.is_available(&a.name);
254
255 json!({
256 "name": a.name,
257 "email": email,
258 "plan_type": a.plan_type,
259 "provider": a.provider.to_string(),
260 "status": avail_status,
261 "available": available,
262 "disabled": disabled,
263 "auth_failed": auth_failed,
264 "cooldown_until_ms": cooldown_until_ms,
265 "utilization_5h": utilization_5h,
266 "reset_5h": reset_5h,
267 "utilization_7d": utilization_7d,
268 "reset_7d": reset_7d,
269 "window_expires_ms": window_expires_ms,
270 "tokens_used": tokens_used,
271 "rate_limit": rate_limit,
272 })
273 }).collect();
274
275 let recent_requests = s.state.recent_requests_snapshot();
276 let savings = s.state.savings_snapshot();
277
278 axum::Json(json!({
279 "version": env!("CARGO_PKG_VERSION"),
280 "started_ms": s.started_ms,
281 "accounts": accounts,
282 "pinned_account": s.state.get_pinned(),
283 "last_used_account": s.state.get_last_used(),
284 "recent_requests": recent_requests,
285 "savings": savings,
286 }))
287}
288
289async fn use_handler(
290 State(s): State<AppState>,
291 axum::Json(body): axum::Json<serde_json::Value>,
292) -> Response {
293 let account = body["account"].as_str().map(|s| s.to_owned());
294 if let Some(ref name) = account {
296 if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
297 return (StatusCode::BAD_REQUEST, axum::Json(json!({
298 "error": format!("unknown account '{name}'")
299 }))).into_response();
300 }
301 let pinned = if name == "auto" { None } else { Some(name.clone()) };
302 s.state.set_pinned(pinned);
303 axum::Json(json!({ "pinned": name })).into_response()
304 } else {
305 s.state.set_pinned(None);
306 axum::Json(json!({ "pinned": null })).into_response()
307 }
308}
309
310fn now_ms() -> u64 {
311 use std::time::{SystemTime, UNIX_EPOCH};
312 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
313}
314
315async fn proxy_handler(
316 State(s): State<AppState>,
317 req: Request,
318) -> Result<Response, ProxyError> {
319 if let Some(ref expected) = s.config.server.remote_key {
321 let provided = req.headers()
322 .get("x-api-key")
323 .and_then(|v| v.to_str().ok())
324 .unwrap_or("");
325 if provided != expected {
326 return Err(ProxyError::Unauthorized);
327 }
328 }
329
330 let method = req.method().as_str().to_owned();
331 let path = req.uri().path().to_owned();
332 let headers = req.headers().clone();
333
334 let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
335 .await
336 .map_err(|_| ProxyError::BodyRead)?;
337
338 let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
339 .ok()
340 .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
341 .unwrap_or_default();
342 let req_start_ms = now_ms();
343
344 let fp = router::fingerprint(&body_bytes);
345 let fp_ref = fp.as_deref();
346
347 let mut tried: HashSet<String> = HashSet::new();
348 let mut refreshed: HashSet<String> = HashSet::new();
350 let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
352
353 loop {
354 let account = match router::pick_account(
355 &s.config.accounts, &s.state, fp_ref, &tried,
356 s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
357 ) {
358 Some(a) => a,
359 None => {
360 let account_states = s.state.account_states();
364 let now = now_ms();
365 let soonest_ms = s.config.accounts.iter()
366 .filter_map(|a| {
367 let st = account_states.get(&a.name)?;
368 if st.disabled { return None; } if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
370 })
371 .min();
372
373 match soonest_ms {
374 Some(wake_ms) if wake_ms <= wait_deadline_ms => {
375 let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; warn!(wait_ms, "all accounts cooling — waiting for next available account");
377 tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
378 tried.clear(); }
380 _ => return Err(ProxyError::AllAccountsUnavailable),
381 }
382 continue;
383 }
384 };
385
386 let account_name = account.name.clone();
387
388 let token = {
393 let creds = s.credentials.read().await;
394 let cred = creds.get(&account_name)
395 .cloned()
396 .or_else(|| account.credential.clone());
397 match cred {
398 Some(c) => c.bearer_token().to_owned(),
399 None => String::new(),
400 }
401 };
402
403 let req_is_anthropic = path.starts_with("/v1/messages");
407 let acct_is_anthropic = account.provider.wire_protocol()
408 == crate::provider::WireProtocol::Anthropic;
409 let acct_is_chatgpt = matches!(account.provider, Provider::OpenAI);
412
413 let mut log_model = model.clone();
416
417 let (fwd_path, fwd_body, mut fwd_headers) = if req_is_anthropic == acct_is_anthropic {
418 (path.clone(), body_bytes.clone(), headers.clone())
420 } else if req_is_anthropic && acct_is_chatgpt {
421 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
423 let translated = translate_anthropic_req_to_chatgpt(&val);
424 let mut h = headers.clone();
425 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
426 h.remove(*name);
427 }
428 (
429 "/backend-api/conversation".to_owned(),
430 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
431 h,
432 )
433 } else if req_is_anthropic {
434 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
436 let target_model = resolve_model(&model, account, &s.config.model_mapping);
438 log_model = target_model.clone();
439 let translated = translate_anthropic_req_to_openai(val, &target_model);
440 let mut h = headers.clone();
441 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
442 h.remove(*name);
443 }
444 (
445 "/v1/chat/completions".to_owned(),
446 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
447 h,
448 )
449 } else {
450 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
452 let translated = translate_to_anthropic(val);
453 (
454 "/v1/messages".to_owned(),
455 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
456 headers.clone(),
457 )
458 };
459
460 let upstream = account.upstream_url.as_deref()
463 .unwrap_or(&s.config.server.upstream_url);
464
465 if req_is_anthropic && acct_is_chatgpt {
468 tracing::info!(account = %account_name, upstream = %upstream, "routing to chatgpt.com — fetching sentinel");
469 let sentinel_client = reqwest::Client::builder()
470 .timeout(std::time::Duration::from_secs(3))
471 .build()
472 .unwrap_or_default();
473 let sentinel_opt = tokio::time::timeout(
474 std::time::Duration::from_secs(3),
475 fetch_sentinel_token(&sentinel_client, upstream, &token),
476 ).await.ok().flatten();
477 if let Some(sentinel) = sentinel_opt {
478 if let Ok(name) = axum::http::header::HeaderName::from_bytes(
479 b"openai-sentinel-chat-requirements-token",
480 ) {
481 if let Ok(val) = axum::http::HeaderValue::from_str(&sentinel) {
482 fwd_headers.insert(name, val);
483 }
484 }
485 }
486 }
487
488 let response = if acct_is_chatgpt {
491 tracing::info!(account = %account_name, path = %fwd_path, "forwarding to chatgpt.com (15s cap)");
492 match tokio::time::timeout(
493 std::time::Duration::from_secs(15),
494 s.forwarder.forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token),
495 ).await {
496 Ok(Ok(r)) => r,
497 Ok(Err(e)) => {
498 error!(account = %account_name, "chatgpt.com forward error: {:#}", e);
499 s.state.set_cooldown(&account_name, 5 * 60_000);
500 tried.insert(account_name);
501 continue;
502 }
503 Err(_) => {
504 warn!(account = %account_name, "chatgpt.com request timed out (Cloudflare) — cooling 5min");
505 s.state.set_cooldown(&account_name, 5 * 60_000);
506 tried.insert(account_name);
507 continue;
508 }
509 }
510 } else {
511 s.forwarder
512 .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
513 .await
514 .map_err(|e| {
515 error!("Forward error: {:#}", e);
516 ProxyError::Upstream
517 })?
518 };
519
520 match response.status().as_u16() {
521 200..=299 => {
522 s.state.set_last_used(&account_name);
523 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
524 s.state.update_rate_limits(&account_name, info);
525 }
526 let response = if req_is_anthropic == acct_is_anthropic {
528 response
529 } else if req_is_anthropic && acct_is_chatgpt {
530 translate_response_chatgpt_to_anthropic(response, &model).await
532 } else if req_is_anthropic {
533 translate_response_openai_to_anthropic(response, &model).await
535 } else {
536 translate_response_anthropic_to_openai(response).await
538 };
539 return Ok(tap_usage(response, &s.state, s.telemetry.as_ref(), &account_name, &log_model, req_start_ms).await);
540 }
541 429 => {
542 let info = account.provider.parse_rate_limits(response.headers());
543 let cooldown_ms = info.as_ref()
546 .and_then(|i| i.reset_5h.or(i.reset_7d))
547 .map(|reset_secs| {
548 let reset_ms = reset_secs.saturating_mul(1_000);
549 reset_ms.saturating_sub(now_ms()).saturating_add(500) })
551 .unwrap_or(60_000);
552 warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
553 if let Some(info) = info {
554 s.state.update_rate_limits(&account_name, info);
555 }
556 s.state.set_cooldown(&account_name, cooldown_ms);
557 if cooldown_ms >= 5 * 60_000 {
558 let mins = cooldown_ms / 60_000;
559 notify(
560 "shunt: Rate Limited",
561 &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
562 "Ping",
563 );
564 }
565 tried.insert(account_name);
566 }
567 529 => {
568 warn!(account = %account_name, "529 overloaded — cooling 30s");
569 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
570 s.state.update_rate_limits(&account_name, info);
571 }
572 s.state.set_cooldown(&account_name, 30_000);
573 tried.insert(account_name);
574 }
575 401 => {
576 if !refreshed.contains(&account_name) {
577 let account_lock = {
585 let mut locks = s.refresh_locks.lock().unwrap();
586 locks.entry(account_name.clone())
587 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
588 .clone()
589 };
590 let _guard = account_lock.lock().await;
591
592 let cred_before = {
595 let creds = s.credentials.read().await;
596 creds.get(&account_name).cloned()
597 .or_else(|| account.credential.clone())
598 };
599 let Some(cred) = cred_before else {
600 tried.insert(account_name);
601 continue;
602 };
603
604 let token_before = cred.access_token().to_owned();
606 let already_refreshed = {
607 let creds = s.credentials.read().await;
608 creds.get(&account_name)
609 .map(|c| c.access_token() != token_before)
610 .unwrap_or(false)
611 };
612
613 if already_refreshed {
614 warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
616 refreshed.insert(account_name);
617 } else if let Some(oauth_cred) = cred.as_oauth() {
618 match tokio::time::timeout(
620 std::time::Duration::from_secs(10),
621 account.provider.refresh_token(oauth_cred),
622 ).await {
623 Ok(Ok(fresh)) => {
624 warn!(account = %account_name, "401 — token refreshed, retrying");
625 {
626 let mut creds = s.credentials.write().await;
627 creds.insert(account_name.clone(), Credential::Oauth(fresh.clone()));
628 }
629 let name = account_name.clone();
631 let fresh = fresh.clone();
632 tokio::task::spawn_blocking(move || {
633 let mut store = CredentialsStore::load();
634 store.accounts.insert(name, Credential::Oauth(fresh.clone()));
635 store.save().ok();
636 if fresh.id_token.is_some() {
637 crate::oauth::write_codex_auth_file(&fresh);
638 }
639 });
640 refreshed.insert(account_name);
642 }
643 _ => {
644 error!(account = %account_name, "401 — token refresh failed, cooling 5min");
646 s.state.set_cooldown(&account_name, 5 * 60_000);
647 tried.insert(account_name);
648 }
649 }
650 } else {
651 error!(account = %account_name, "401 — API key rejected, cooling 5min");
653 s.state.set_cooldown(&account_name, 5 * 60_000);
654 tried.insert(account_name);
655 }
656 } else {
657 error!(account = %account_name, "401 after refresh — cooling 5min");
659 s.state.set_cooldown(&account_name, 5 * 60_000);
660 tried.insert(account_name);
661 }
662 }
663 403 => {
664 if acct_is_anthropic {
668 error!(account = %account_name, "403 forbidden — cooling 30min");
669 s.state.set_cooldown(&account_name, 30 * 60_000);
670 notify(
671 "shunt: Account Forbidden",
672 &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
673 "Basso",
674 );
675 } else {
676 warn!(account = %account_name, "403 from chatgpt.com (Cloudflare) — cooling 5min");
677 s.state.set_cooldown(&account_name, 5 * 60_000);
678 }
679 tried.insert(account_name);
680 }
681 _ => {
682 return Ok(response);
684 }
685 }
686 }
687}
688
689async fn tap_usage(
698 resp: Response,
699 state: &StateStore,
700 telemetry: Option<&TelemetryClient>,
701 account: &str,
702 model: &str,
703 req_start_ms: u64,
704) -> Response {
705 use axum::body::Body;
706 use crate::state::RequestLog;
707
708 if quota::is_streaming_response(&resp) {
709 let state = state.clone();
710 let telem = telemetry.cloned();
711 let account = account.to_owned();
712 let model = model.to_owned();
713 let on_complete = Arc::new(move |input: u64, output: u64| {
714 let log = RequestLog {
715 ts_ms: req_start_ms,
716 account: account.clone(),
717 model: model.clone(),
718 status: 200,
719 input_tokens: input,
720 output_tokens: output,
721 duration_ms: now_ms().saturating_sub(req_start_ms),
722 };
723 state.record_usage(&account, input, output);
724 state.record_global(&model, input, output);
725 if let Some(ref t) = telem { t.push_event(&log); }
726 state.record_request(log);
727 });
728 let (parts, body) = resp.into_parts();
729 let wrapped = quota::wrap_streaming_body(body, on_complete);
730 return Response::from_parts(parts, wrapped);
731 }
732
733 let (parts, body) = resp.into_parts();
735 let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
736 Ok(b) => b,
737 Err(_) => return Response::from_parts(parts, Body::empty()),
738 };
739 let (input, output) = quota::extract_usage_from_json(&bytes);
740 let log = RequestLog {
741 ts_ms: req_start_ms,
742 account: account.to_owned(),
743 model: model.to_owned(),
744 status: 200,
745 input_tokens: input,
746 output_tokens: output,
747 duration_ms: now_ms().saturating_sub(req_start_ms),
748 };
749 state.record_usage(account, input, output);
750 state.record_global(model, input, output);
751 if let Some(t) = telemetry { t.push_event(&log); }
752 state.record_request(log);
753 Response::from_parts(parts, Body::from(bytes))
754}
755
756
757pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
765 let client = reqwest::Client::builder()
766 .timeout(std::time::Duration::from_secs(20))
767 .build()
768 .unwrap_or_default();
769
770 for account in &config.accounts {
771 let rl = state.rate_limit_snapshot();
773 if let Some(r) = rl.get(&account.name) {
774 if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
775 continue;
776 }
777 }
778
779 let cred = match account.credential.clone() {
781 Some(c) => c,
782 None => continue,
783 };
784
785 let Some((path, body)) = account.provider.prefetch_request() else {
786 if let Some(probe_path) = account.provider.auth_probe_get_path() {
788 auth_probe_get(&client, probe_path, account, &state).await;
789 }
790 continue;
791 };
792 let url = format!("{}{}", config.server.upstream_url, path);
793
794 let resp = prefetch_send(&client, &url, &account.provider, cred.bearer_token(), &body).await;
795
796 let r = match resp {
797 Ok(r) => r,
798 Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
799 };
800
801 if r.status() == reqwest::StatusCode::UNAUTHORIZED {
802 tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
803 let Some(oauth_cred) = cred.as_oauth() else {
804 tracing::error!(account = %account.name, "prefetch 401 — API key rejected");
806 state.set_auth_failed(&account.name);
807 continue;
808 };
809 let fresh = match account.provider.refresh_token(oauth_cred).await {
810 Ok(f) => f,
811 Err(e) => {
812 tracing::warn!(account = %account.name, "token refresh failed: {e}");
813 state.set_auth_failed(&account.name);
814 continue;
815 }
816 };
817 let mut store = crate::config::CredentialsStore::load();
818 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
819 store.save().ok();
820 if fresh.id_token.is_some() {
821 crate::oauth::write_codex_auth_file(&fresh);
822 }
823 live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
825
826 match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
827 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
828 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
829 state.set_auth_failed(&account.name);
830 }
831 Ok(r2) => {
832 if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
833 state.update_rate_limits(&account.name, info);
834 }
835 }
836 Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
837 }
838 } else {
839 tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
840 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
841 state.update_rate_limits(&account.name, info);
842 }
843 }
844 }
845}
846
847async fn prefetch_send(
849 client: &reqwest::Client,
850 url: &str,
851 provider: &crate::provider::Provider,
852 token: &str,
853 body: &serde_json::Value,
854) -> anyhow::Result<reqwest::Response> {
855 let mut headers = reqwest::header::HeaderMap::new();
856 provider.inject_auth_headers(&mut headers, token)?;
857 for (name, value) in provider.prefetch_extra_headers() {
858 headers.insert(
859 reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
860 reqwest::header::HeaderValue::from_static(value),
861 );
862 }
863 Ok(client.post(url).headers(headers).json(body).send().await?)
864}
865
866async fn auth_probe_get(
870 client: &reqwest::Client,
871 path: &str,
872 account: &crate::config::AccountConfig,
873 state: &StateStore,
874) {
875 let cred = match account.credential.clone() {
876 Some(c) => c,
877 None => return,
878 };
879 let upstream = account.upstream_url.as_deref()
880 .unwrap_or_else(|| account.provider.default_upstream_url());
881 let url = format!("{}{}", upstream, path);
882
883 let do_get = |token: &str| -> reqwest::RequestBuilder {
884 let mut headers = reqwest::header::HeaderMap::new();
885 let _ = account.provider.inject_auth_headers(&mut headers, token);
886 client.get(&url).headers(headers)
887 };
888
889 let resp = match do_get(cred.bearer_token()).send().await {
890 Ok(r) => r,
891 Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
892 };
893
894 if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
895 tracing::info!(account = %account.name, "auth probe: token rejected, refreshing");
896 let Some(oauth_cred) = cred.as_oauth() else {
897 tracing::error!(account = %account.name, "auth probe 401 — API key rejected");
899 state.set_auth_failed(&account.name);
900 return;
901 };
902 let fresh = match account.provider.refresh_token(oauth_cred).await {
903 Ok(f) => f,
904 Err(e) => {
905 tracing::warn!(account = %account.name, "token refresh failed: {e}");
906 state.set_auth_failed(&account.name);
907 return;
908 }
909 };
910 let mut store = crate::config::CredentialsStore::load();
911 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
912 store.save().ok();
913 if fresh.id_token.is_some() {
914 crate::oauth::write_codex_auth_file(&fresh);
915 }
916
917 let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
918 match do_get(fresh_token).send().await {
919 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
920 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
921 state.set_auth_failed(&account.name);
922 }
923 Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
924 Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
925 }
926 } else {
927 tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
928 }
932}
933
934fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
941 let now_ms = std::time::SystemTime::now()
942 .duration_since(std::time::UNIX_EPOCH)
943 .unwrap_or_default()
944 .as_millis() as u64;
945 let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
946 .unwrap_or(cred.expires_at);
947 exp_ms < now_ms + threshold_mins * 60 * 1_000
948}
949
950async fn sync_live_creds_from_auth_json(
955 account_name: &str,
956 live_creds: &LiveCredentials,
957) {
958 let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
959 let current_exp = live_creds.read().await
960 .get(account_name)
961 .and_then(|c| c.as_oauth())
962 .map(|c| c.expires_at)
963 .unwrap_or(0);
964 if from_file.expires_at > current_exp {
965 tracing::info!(account = %account_name, "synced fresher token from auth.json");
966 live_creds.write().await.insert(account_name.to_owned(), Credential::Oauth(from_file));
967 }
968}
969
970async fn do_proactive_refresh(
972 account: &crate::config::AccountConfig,
973 creds: &crate::oauth::OAuthCredential,
974 live_creds: &LiveCredentials,
975 state: &StateStore,
976) {
977 tracing::info!(account = %account.name, "proactive OpenAI token refresh");
978 match account.provider.refresh_token(creds).await {
979 Ok(fresh) => {
980 tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
981 {
982 let mut map = live_creds.write().await;
983 map.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
984 }
985 let mut store = crate::config::CredentialsStore::load();
986 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
987 store.save().ok();
988 if fresh.id_token.is_some() {
989 crate::oauth::write_codex_auth_file(&fresh);
990 }
991 state.clear_auth_failed(&account.name);
992 }
993 Err(e) => {
994 tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
995 state.set_auth_failed(&account.name);
996 }
997 }
998}
999
1000
1001pub async fn openai_token_refresh_loop(
1009 config: Arc<Config>,
1010 state: StateStore,
1011 live_creds: LiveCredentials,
1012) {
1013 for account in config.accounts.iter()
1015 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1016 {
1017 if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
1018 continue;
1019 }
1020 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1021
1022 let creds = {
1023 let map = live_creds.read().await;
1024 map.get(&account.name).cloned().or_else(|| account.credential.clone())
1025 };
1026 if let Some(creds) = creds {
1027 if let Some(oauth) = creds.as_oauth() {
1028 if access_token_expires_soon(oauth, 30) {
1029 do_proactive_refresh(account, oauth, &live_creds, &state).await;
1031 } else {
1032 tracing::info!(account = %account.name, "access_token fresh at startup");
1033 }
1034 }
1035 }
1036 }
1037
1038 loop {
1041 tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
1042 for account in config.accounts.iter()
1043 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1044 {
1045 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1046 }
1047 }
1048}
1049
1050enum ProxyError {
1055 BodyRead,
1056 Upstream,
1057 AllAccountsUnavailable,
1058 Unauthorized,
1059}
1060
1061impl IntoResponse for ProxyError {
1062 fn into_response(self) -> Response {
1063 let (status, msg) = match self {
1064 ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
1065 ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
1066 ProxyError::AllAccountsUnavailable => {
1067 (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
1068 }
1069 ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
1070 };
1071
1072 (status, axum::Json(json!({
1073 "type": "error",
1074 "error": {"type": "api_error", "message": msg}
1075 }))).into_response()
1076 }
1077}
1078
1079pub async fn recovery_watcher(
1088 config: Arc<Config>,
1089 state: StateStore,
1090 credentials: LiveCredentials,
1091) {
1092 use std::time::{Duration, Instant};
1093 const CHECK_INTERVAL: Duration = Duration::from_secs(120);
1094 const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
1095
1096 let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
1097 let mut last_notified: Option<Instant> = None;
1098
1099 loop {
1100 tokio::time::sleep(CHECK_INTERVAL).await;
1101
1102 let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
1103 let failed = state.auth_failed_accounts(&name_refs);
1104 if failed.is_empty() {
1105 last_notified = None;
1106 continue;
1107 }
1108
1109 tracing::warn!(
1110 accounts = ?failed,
1111 "recovery: {} account(s) auth_failed, attempting token refresh",
1112 failed.len()
1113 );
1114
1115 let mut any_recovered = false;
1116
1117 for name in &failed {
1118 let cred = {
1119 let map = credentials.read().await;
1120 map.get(*name).cloned()
1121 };
1122 let Some(cred) = cred else { continue };
1123 if !cred.has_refresh_token() { continue; }
1124 let Some(oauth_cred) = cred.as_oauth().cloned() else { continue };
1125
1126 let provider = config.accounts.iter()
1127 .find(|a| a.name == *name)
1128 .map(|a| a.provider.clone())
1129 .unwrap_or_default();
1130
1131 let result = tokio::time::timeout(
1132 Duration::from_secs(20),
1133 provider.refresh_token(&oauth_cred),
1134 ).await;
1135
1136 match result {
1137 Ok(Ok(fresh)) => {
1138 tracing::info!(account = %name, "recovery: token refreshed — account back online");
1139 {
1140 let mut map = credentials.write().await;
1141 map.insert(name.to_string(), Credential::Oauth(fresh.clone()));
1142 }
1143 let name_owned = name.to_string();
1144 let fresh_owned = fresh.clone();
1145 tokio::task::spawn_blocking(move || {
1146 let mut store = crate::config::CredentialsStore::load();
1147 store.accounts.insert(name_owned, Credential::Oauth(fresh_owned.clone()));
1148 store.save().ok();
1149 if fresh_owned.id_token.is_some() {
1150 crate::oauth::write_codex_auth_file(&fresh_owned);
1151 }
1152 });
1153 state.clear_auth_failed(name);
1154 any_recovered = true;
1155 }
1156 Ok(Err(e)) => {
1157 tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
1158 notify(
1159 "shunt: Reauth Required",
1160 &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
1161 "Basso",
1162 );
1163 }
1164 Err(_) => {
1165 tracing::error!(account = %name, "recovery: token refresh timed out");
1166 notify(
1167 "shunt: Reauth Required",
1168 &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
1169 "Basso",
1170 );
1171 }
1172 }
1173 }
1174
1175 if any_recovered {
1176 tracing::info!("recovery: at least one account is back online");
1177 continue;
1178 }
1179
1180 let still_failed = state.auth_failed_accounts(&name_refs);
1182 if still_failed.len() == account_names.len() {
1183 let should_notify = last_notified
1184 .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
1185 .unwrap_or(true);
1186 if should_notify {
1187 error!(
1188 "ALL accounts are offline (auth failed). \
1189 Run `shunt add-account` to re-authorize."
1190 );
1191 notify(
1192 "shunt: All Accounts Offline",
1193 "All accounts need re-authorization. Run `shunt add-account`.",
1194 "Basso",
1195 );
1196 last_notified = Some(Instant::now());
1197 }
1198 }
1199 }
1200}
1201
1202async fn post_cooldown_prefetch(
1206 client: &reqwest::Client,
1207 account: &crate::config::AccountConfig,
1208 token: &str,
1209 state: &StateStore,
1210 upstream_url: &str,
1211) {
1212 let Some((path, body)) = account.provider.prefetch_request() else {
1213 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1214 auth_probe_get(client, probe_path, account, state).await;
1215 }
1216 return;
1217 };
1218 let url = format!("{upstream_url}{path}");
1219 match prefetch_send(client, &url, &account.provider, token, &body).await {
1220 Ok(r) => {
1221 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1222 state.update_rate_limits(&account.name, info);
1223 tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1224 }
1225 }
1226 Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1227 }
1228}
1229
1230pub async fn cooldown_watcher(
1241 config: Arc<Config>,
1242 state: StateStore,
1243 credentials: LiveCredentials,
1244) {
1245 const STALE_RL_MS: u64 = 60 * 60_000;
1247
1248 let client = reqwest::Client::builder()
1249 .timeout(std::time::Duration::from_secs(20))
1250 .build()
1251 .unwrap_or_default();
1252
1253 let mut last_resumed: HashMap<String, u64> = HashMap::new();
1256 let mut notify_on_resume: HashSet<String> = HashSet::new();
1258 let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1260
1261 loop {
1262 let states = state.account_states();
1263 let rl_snapshot = state.rate_limit_snapshot();
1264 let now = now_ms();
1265 let mut next_wake_ms: Option<u64> = None;
1266
1267 for account in &config.accounts {
1268 let Some(st) = states.get(&account.name) else { continue };
1269 if st.disabled { continue; } let cdl = st.cooldown_until_ms;
1271
1272 if cdl > 0 && cdl <= now {
1273 let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1275 if !handled {
1276 tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1277 let token = {
1278 let creds = credentials.read().await;
1279 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1280 };
1281 if let Some(token) = token {
1282 post_cooldown_prefetch(
1283 &client, account, &token, &state,
1284 &config.server.upstream_url,
1285 ).await;
1286 }
1287 if notify_on_resume.remove(&account.name) {
1288 notify(
1289 "shunt: Account Resumed",
1290 &format!("Account '{}' is back online.", account.name),
1291 "Glass",
1292 );
1293 }
1294 last_resumed.insert(account.name.clone(), cdl);
1295 last_stale_prefetch.insert(account.name.clone(), now);
1296 }
1297 } else if cdl > now {
1298 let remaining = cdl - now;
1300 if remaining >= 5 * 60_000 {
1301 notify_on_resume.insert(account.name.clone());
1302 }
1303 next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1304 } else {
1305 let rl_age = rl_snapshot
1307 .get(&account.name)
1308 .map(|r| now.saturating_sub(r.updated_ms))
1309 .unwrap_or(u64::MAX); let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1311 let fetched_ago = now.saturating_sub(last_fetched);
1312
1313 if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1314 tracing::debug!(
1315 account = %account.name,
1316 age_min = rl_age / 60_000,
1317 "rate-limit data stale — refreshing"
1318 );
1319 let token = {
1320 let creds = credentials.read().await;
1321 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1322 };
1323 if let Some(token) = token {
1324 post_cooldown_prefetch(
1325 &client, account, &token, &state,
1326 &config.server.upstream_url,
1327 ).await;
1328 }
1329 last_stale_prefetch.insert(account.name.clone(), now);
1330 }
1331 }
1332 }
1333
1334 let sleep_ms = next_wake_ms
1336 .map(|wake| wake.saturating_sub(now_ms()).max(50))
1337 .unwrap_or(30_000);
1338 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1339 }
1340}
1341
1342use crate::notify::notify;
1343use crate::translate::{
1344 translate_to_anthropic,
1345 translate_from_anthropic,
1346 uuid_v4,
1347 translate_anthropic_stream,
1348 translate_anthropic_req_to_chatgpt,
1349 translate_response_chatgpt_to_anthropic,
1350 translate_anthropic_req_to_openai,
1351 translate_response_openai_to_anthropic,
1352 translate_response_anthropic_to_openai,
1353};
1354
1355async fn openai_models_handler() -> impl IntoResponse {
1370 axum::Json(json!({
1371 "object": "list",
1372 "data": [
1373 { "id": "claude-opus-4-6", "object": "model", "owned_by": "anthropic" },
1374 { "id": "claude-sonnet-4-6", "object": "model", "owned_by": "anthropic" },
1375 { "id": "claude-haiku-4-5-20251001", "object": "model", "owned_by": "anthropic" },
1376 ]
1377 }))
1378}
1379
1380async fn openai_compat_handler(
1382 State(s): State<AppState>,
1383 req: Request,
1384) -> Result<Response, ProxyError> {
1385 let Some(ref anthropic_url) = s.anthropic_base_url else {
1386 return proxy_handler(State(s), req).await;
1388 };
1389
1390 let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1391 .await
1392 .map_err(|_| ProxyError::BodyRead)?;
1393
1394 let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1395 .unwrap_or(json!({}));
1396
1397 let stream = openai_body["stream"].as_bool().unwrap_or(false);
1398 let anthropic_body = translate_to_anthropic(openai_body);
1399
1400 let client = reqwest::Client::builder()
1401 .timeout(std::time::Duration::from_secs(300))
1402 .build()
1403 .map_err(|_| ProxyError::Upstream)?;
1404
1405 let resp = client
1406 .post(format!("{anthropic_url}/v1/messages"))
1407 .header("content-type", "application/json")
1408 .header("anthropic-version", "2023-06-01")
1409 .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1410 .header("x-shunt-compat", "openai")
1411 .json(&anthropic_body)
1412 .send()
1413 .await
1414 .map_err(|_| ProxyError::Upstream)?;
1415
1416 if !resp.status().is_success() {
1417 let status = resp.status();
1418 let body = resp.text().await.unwrap_or_default();
1419 let code = status.as_u16();
1420 return Ok(axum::response::Response::builder()
1421 .status(code)
1422 .header("content-type", "application/json")
1423 .body(axum::body::Body::from(body))
1424 .unwrap());
1425 }
1426
1427 if stream {
1428 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1430 let stream = translate_anthropic_stream(resp, chat_id);
1431 Ok(axum::response::Response::builder()
1432 .status(200)
1433 .header("content-type", "text/event-stream")
1434 .header("cache-control", "no-cache")
1435 .body(axum::body::Body::from_stream(stream))
1436 .unwrap())
1437 } else {
1438 let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1439 let openai_resp = translate_from_anthropic(anthropic_resp);
1440 Ok(axum::Json(openai_resp).into_response())
1441 }
1442}
1443
1444async fn fetch_sentinel_token(client: &reqwest::Client, upstream: &str, token: &str) -> Option<String> {
1451 let url = format!("{}/backend-api/sentinel/chat-requirements", upstream);
1452 let resp = client
1453 .get(&url)
1454 .header("Authorization", format!("Bearer {}", token))
1455 .send()
1456 .await
1457 .ok()?;
1458 if !resp.status().is_success() {
1459 return None;
1460 }
1461 let json: serde_json::Value = resp.json().await.ok()?;
1462 if json["proofofwork"]["required"].as_bool() == Some(true) {
1463 return None;
1464 }
1465 json["token"].as_str().map(ToOwned::to_owned)
1466}
1467
1468
1469fn resolve_model(
1474 incoming: &str,
1475 account: &crate::config::AccountConfig,
1476 mapping: &std::collections::HashMap<String, String>,
1477) -> String {
1478 if let Some(m) = &account.model {
1480 return m.clone();
1481 }
1482 if let Some(m) = mapping.get(incoming) {
1484 return m.clone();
1485 }
1486 let default = account.provider.default_model();
1488 if !default.is_empty() {
1489 return default.to_owned();
1490 }
1491 incoming.to_owned()
1493}
1494