1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::forwarder::Forwarder;
16use crate::oauth::OAuthCredential;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24 config: Arc<Config>,
25 forwarder: Arc<Forwarder>,
26 state: StateStore,
27 credentials: Arc<RwLock<HashMap<String, OAuthCredential>>>,
29 refresh_locks: Arc<std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
37 started_ms: u64,
39 anthropic_base_url: Option<String>,
42}
43
44pub fn create_app(config: Config) -> anyhow::Result<Router> {
45 let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
46 Ok(app)
47}
48
49pub type LiveCredentials = Arc<RwLock<HashMap<String, OAuthCredential>>>;
51
52pub fn create_app_with_state(
53 config: Config,
54 state: StateStore,
55 anthropic_base_url: Option<String>,
56) -> anyhow::Result<(Router, LiveCredentials)> {
57 let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
58
59 for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
62 state.set_auth_failed(&a.name);
63 }
64
65 let credentials: LiveCredentials = Arc::new(RwLock::new(
66 config.accounts.iter()
67 .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
68 .collect::<HashMap<_, _>>(),
69 ));
70
71 let app_state = AppState {
72 config: Arc::new(config),
73 forwarder: Arc::new(forwarder),
74 state,
75 credentials: Arc::clone(&credentials),
76 refresh_locks: Arc::new(std::sync::Mutex::new(HashMap::new())),
77 started_ms: now_ms(),
78 anthropic_base_url,
79 };
80
81 let provider = app_state.config.accounts.first()
86 .map(|a| &a.provider)
87 .cloned()
88 .unwrap_or_default();
89
90 let proxy_routes = match provider {
91 Provider::Anthropic => Router::new()
92 .route("/v1/messages", post(proxy_handler))
93 .route("/v1/messages/count_tokens", post(proxy_handler)),
94 Provider::OpenAI => Router::new()
95 .route("/v1/chat/completions", post(openai_compat_handler))
96 .route("/v1/models", get(openai_models_handler))
97 .fallback(proxy_handler),
98 };
99
100 let app = Router::new()
101 .route("/health", get(health))
102 .route("/status", get(status_handler))
103 .route("/use", post(use_handler))
104 .merge(proxy_routes)
105 .with_state(app_state);
106
107 Ok((app, credentials))
108}
109
110async fn health() -> impl IntoResponse {
111 axum::Json(json!({"status": "ok"}))
112}
113
114async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
115 let account_states = s.state.account_states();
116 let quotas = s.state.quota_snapshot();
117 let rate_limits = s.state.rate_limit_snapshot();
118
119 let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
120 let st = account_states.get(&a.name);
121 let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
122 "reauth_required"
123 } else if st.map(|s| s.disabled).unwrap_or(false) {
124 "disabled"
125 } else if s.state.is_available(&a.name) {
126 "available"
127 } else {
128 "cooling"
129 };
130
131 let quota = quotas.get(&a.name);
132 let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
133 let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
134 let tokens_used = quota.map(|q| json!({
135 "input": q.input_tokens,
136 "output": q.output_tokens,
137 "total": q.total_tokens(),
138 }));
139
140 let rl = rate_limits.get(&a.name);
141 let rate_limit = rl.map(|r| json!({
142 "utilization_5h": r.utilization_5h,
143 "reset_5h": r.reset_5h,
144 "status_5h": r.status_5h,
145 "utilization_7d": r.utilization_7d,
146 "reset_7d": r.reset_7d,
147 "status_7d": r.status_7d,
148 "representative_claim": r.representative_claim,
149 "updated_ms": r.updated_ms,
150 }));
151
152 let acc_state = account_states.get(&a.name);
153 let email = a.credential.as_ref().and_then(|c| c.email.as_deref()).map(|e| e.to_owned());
154 let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
155 let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
156 let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
157 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
158 let reset_5h = rl.and_then(|r| r.reset_5h);
159 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
160 let reset_7d = rl.and_then(|r| r.reset_7d);
161 let available = s.state.is_available(&a.name);
162
163 json!({
164 "name": a.name,
165 "email": email,
166 "plan_type": a.plan_type,
167 "status": avail_status,
168 "available": available,
169 "disabled": disabled,
170 "auth_failed": auth_failed,
171 "cooldown_until_ms": cooldown_until_ms,
172 "utilization_5h": utilization_5h,
173 "reset_5h": reset_5h,
174 "utilization_7d": utilization_7d,
175 "reset_7d": reset_7d,
176 "window_expires_ms": window_expires_ms,
177 "tokens_used": tokens_used,
178 "rate_limit": rate_limit,
179 })
180 }).collect();
181
182 let recent_requests = s.state.recent_requests_snapshot();
183 let savings = s.state.savings_snapshot();
184
185 axum::Json(json!({
186 "version": env!("CARGO_PKG_VERSION"),
187 "started_ms": s.started_ms,
188 "accounts": accounts,
189 "pinned_account": s.state.get_pinned(),
190 "last_used_account": s.state.get_last_used(),
191 "recent_requests": recent_requests,
192 "savings": savings,
193 }))
194}
195
196async fn use_handler(
197 State(s): State<AppState>,
198 axum::Json(body): axum::Json<serde_json::Value>,
199) -> impl IntoResponse {
200 let account = body["account"].as_str().map(|s| s.to_owned());
201 if let Some(ref name) = account {
203 if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
204 return axum::Json(json!({
205 "error": format!("unknown account '{name}'")
206 }));
207 }
208 let pinned = if name == "auto" { None } else { Some(name.clone()) };
209 s.state.set_pinned(pinned);
210 axum::Json(json!({ "pinned": name }))
211 } else {
212 s.state.set_pinned(None);
213 axum::Json(json!({ "pinned": null }))
214 }
215}
216
217fn now_ms() -> u64 {
218 use std::time::{SystemTime, UNIX_EPOCH};
219 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
220}
221
222async fn proxy_handler(
223 State(s): State<AppState>,
224 req: Request,
225) -> Result<Response, ProxyError> {
226 if let Some(ref expected) = s.config.server.remote_key {
228 let provided = req.headers()
229 .get("x-api-key")
230 .and_then(|v| v.to_str().ok())
231 .unwrap_or("");
232 if provided != expected {
233 return Err(ProxyError::Unauthorized);
234 }
235 }
236
237 let method = req.method().as_str().to_owned();
238 let path = req.uri().path().to_owned();
239 let headers = req.headers().clone();
240
241 let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
242 .await
243 .map_err(|_| ProxyError::BodyRead)?;
244
245 let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
246 .ok()
247 .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
248 .unwrap_or_default();
249 let req_start_ms = now_ms();
250
251 let fp = router::fingerprint(&body_bytes);
252 let fp_ref = fp.as_deref();
253
254 let mut tried: HashSet<String> = HashSet::new();
255 let mut refreshed: HashSet<String> = HashSet::new();
257 let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
259
260 loop {
261 let account = match router::pick_account(
262 &s.config.accounts, &s.state, fp_ref, &tried,
263 s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
264 ) {
265 Some(a) => a,
266 None => {
267 let account_states = s.state.account_states();
271 let now = now_ms();
272 let soonest_ms = s.config.accounts.iter()
273 .filter_map(|a| {
274 let st = account_states.get(&a.name)?;
275 if st.disabled { return None; } if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
277 })
278 .min();
279
280 match soonest_ms {
281 Some(wake_ms) if wake_ms <= wait_deadline_ms => {
282 let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; warn!(wait_ms, "all accounts cooling — waiting for next available account");
284 tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
285 tried.clear(); }
287 _ => return Err(ProxyError::AllAccountsUnavailable),
288 }
289 continue;
290 }
291 };
292
293 let account_name = account.name.clone();
294
295 let token = {
299 let creds = s.credentials.read().await;
300 let cred = creds.get(&account_name)
301 .cloned()
302 .or_else(|| account.credential.clone());
303 match cred {
304 Some(c) => c.access_token,
305 None => String::new(),
306 }
307 };
308
309 let response = s.forwarder
310 .forward(&method, &path, body_bytes.clone(), &headers, account, &token)
311 .await
312 .map_err(|e| {
313 error!("Forward error: {:#}", e);
314 ProxyError::Upstream
315 })?;
316
317 match response.status().as_u16() {
318 200..=299 => {
319 s.state.set_last_used(&account_name);
320 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
321 s.state.update_rate_limits(&account_name, info);
322 }
323 return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
324 }
325 429 => {
326 let info = account.provider.parse_rate_limits(response.headers());
327 let cooldown_ms = info.as_ref()
330 .and_then(|i| i.reset_5h.or(i.reset_7d))
331 .map(|reset_secs| {
332 let reset_ms = reset_secs.saturating_mul(1_000);
333 reset_ms.saturating_sub(now_ms()).saturating_add(500) })
335 .unwrap_or(60_000);
336 warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
337 if let Some(info) = info {
338 s.state.update_rate_limits(&account_name, info);
339 }
340 s.state.set_cooldown(&account_name, cooldown_ms);
341 if cooldown_ms >= 5 * 60_000 {
342 let mins = cooldown_ms / 60_000;
343 notify(
344 "shunt: Rate Limited",
345 &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
346 "Ping",
347 );
348 }
349 tried.insert(account_name);
350 }
351 529 => {
352 warn!(account = %account_name, "529 overloaded — cooling 30s");
353 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
354 s.state.update_rate_limits(&account_name, info);
355 }
356 s.state.set_cooldown(&account_name, 30_000);
357 tried.insert(account_name);
358 }
359 401 => {
360 if !refreshed.contains(&account_name) {
361 let account_lock = {
369 let mut locks = s.refresh_locks.lock().unwrap();
370 locks.entry(account_name.clone())
371 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
372 .clone()
373 };
374 let _guard = account_lock.lock().await;
375
376 let cred_before = {
379 let creds = s.credentials.read().await;
380 creds.get(&account_name).cloned()
381 .or_else(|| account.credential.clone())
382 };
383 let Some(cred) = cred_before else {
384 tried.insert(account_name);
385 continue;
386 };
387
388 let token_before = cred.access_token.clone();
390 let already_refreshed = {
391 let creds = s.credentials.read().await;
392 creds.get(&account_name)
393 .map(|c| c.access_token != token_before)
394 .unwrap_or(false)
395 };
396
397 if already_refreshed {
398 warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
400 refreshed.insert(account_name);
401 } else {
402 match tokio::time::timeout(
403 std::time::Duration::from_secs(10),
404 account.provider.refresh_token(&cred),
405 ).await {
406 Ok(Ok(fresh)) => {
407 warn!(account = %account_name, "401 — token refreshed, retrying");
408 {
409 let mut creds = s.credentials.write().await;
410 creds.insert(account_name.clone(), fresh.clone());
411 }
412 let name = account_name.clone();
414 let fresh = fresh.clone();
415 tokio::task::spawn_blocking(move || {
416 let mut store = CredentialsStore::load();
417 store.accounts.insert(name, fresh.clone());
418 store.save().ok();
419 if fresh.id_token.is_some() {
420 crate::oauth::write_codex_auth_file(&fresh);
421 }
422 });
423 refreshed.insert(account_name);
425 }
426 _ => {
427 error!(account = %account_name, "401 — token refresh failed, cooling 5min");
429 s.state.set_cooldown(&account_name, 5 * 60_000);
430 tried.insert(account_name);
431 }
432 }
433 }
434 } else {
435 error!(account = %account_name, "401 after refresh — cooling 5min");
437 s.state.set_cooldown(&account_name, 5 * 60_000);
438 tried.insert(account_name);
439 }
440 }
441 403 => {
442 error!(account = %account_name, "403 forbidden — cooling 30min");
444 s.state.set_cooldown(&account_name, 30 * 60_000);
445 notify(
446 "shunt: Account Forbidden",
447 &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
448 "Basso",
449 );
450 tried.insert(account_name);
451 }
452 _ => {
453 return Ok(response);
455 }
456 }
457 }
458}
459
460async fn tap_usage(
469 resp: Response,
470 state: &StateStore,
471 account: &str,
472 model: &str,
473 req_start_ms: u64,
474) -> Response {
475 use axum::body::Body;
476 use crate::state::RequestLog;
477
478 if quota::is_streaming_response(&resp) {
479 let state = state.clone();
480 let account = account.to_owned();
481 let model = model.to_owned();
482 let on_complete = Arc::new(move |input: u64, output: u64| {
483 state.record_usage(&account, input, output);
484 state.record_global(&model, input, output);
485 state.record_request(RequestLog {
486 ts_ms: req_start_ms,
487 account: account.clone(),
488 model: model.clone(),
489 status: 200,
490 input_tokens: input,
491 output_tokens: output,
492 duration_ms: now_ms().saturating_sub(req_start_ms),
493 });
494 });
495 let (parts, body) = resp.into_parts();
496 let wrapped = quota::wrap_streaming_body(body, on_complete);
497 return Response::from_parts(parts, wrapped);
498 }
499
500 let (parts, body) = resp.into_parts();
502 let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
503 Ok(b) => b,
504 Err(_) => return Response::from_parts(parts, Body::empty()),
505 };
506 let (input, output) = quota::extract_usage_from_json(&bytes);
507 state.record_usage(account, input, output);
508 state.record_global(model, input, output);
509 state.record_request(RequestLog {
510 ts_ms: req_start_ms,
511 account: account.to_owned(),
512 model: model.to_owned(),
513 status: 200,
514 input_tokens: input,
515 output_tokens: output,
516 duration_ms: now_ms().saturating_sub(req_start_ms),
517 });
518 Response::from_parts(parts, Body::from(bytes))
519}
520
521
522pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
530 let client = reqwest::Client::builder()
531 .timeout(std::time::Duration::from_secs(20))
532 .build()
533 .unwrap_or_default();
534
535 for account in &config.accounts {
536 let rl = state.rate_limit_snapshot();
538 if let Some(r) = rl.get(&account.name) {
539 if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
540 continue;
541 }
542 }
543
544 let creds = match account.credential.clone() {
546 Some(c) => c,
547 None => continue,
548 };
549
550 let Some((path, body)) = account.provider.prefetch_request() else {
551 if let Some(probe_path) = account.provider.auth_probe_get_path() {
553 auth_probe_get(&client, probe_path, account, &state).await;
554 }
555 continue;
556 };
557 let url = format!("{}{}", config.server.upstream_url, path);
558
559 let resp = prefetch_send(&client, &url, &account.provider, &creds.access_token, &body).await;
560
561 let r = match resp {
562 Ok(r) => r,
563 Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
564 };
565
566 if r.status() == reqwest::StatusCode::UNAUTHORIZED {
567 tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
568 let fresh = match account.provider.refresh_token(&creds).await {
569 Ok(f) => f,
570 Err(e) => {
571 tracing::warn!(account = %account.name, "token refresh failed: {e}");
572 state.set_auth_failed(&account.name);
573 continue;
574 }
575 };
576 let mut store = crate::config::CredentialsStore::load();
577 store.accounts.insert(account.name.clone(), fresh.clone());
578 store.save().ok();
579 if fresh.id_token.is_some() {
580 crate::oauth::write_codex_auth_file(&fresh);
581 }
582 live_creds.write().await.insert(account.name.clone(), fresh.clone());
584
585 match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
586 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
587 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
588 state.set_auth_failed(&account.name);
589 }
590 Ok(r2) => {
591 if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
592 state.update_rate_limits(&account.name, info);
593 }
594 }
595 Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
596 }
597 } else {
598 tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
599 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
600 state.update_rate_limits(&account.name, info);
601 }
602 }
603 }
604}
605
606async fn prefetch_send(
608 client: &reqwest::Client,
609 url: &str,
610 provider: &crate::provider::Provider,
611 token: &str,
612 body: &serde_json::Value,
613) -> anyhow::Result<reqwest::Response> {
614 let mut headers = reqwest::header::HeaderMap::new();
615 provider.inject_auth_headers(&mut headers, token)?;
616 for (name, value) in provider.prefetch_extra_headers() {
617 headers.insert(
618 reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
619 reqwest::header::HeaderValue::from_static(value),
620 );
621 }
622 Ok(client.post(url).headers(headers).json(body).send().await?)
623}
624
625async fn auth_probe_get(
629 client: &reqwest::Client,
630 path: &str,
631 account: &crate::config::AccountConfig,
632 state: &StateStore,
633) {
634 let creds = match account.credential.clone() {
635 Some(c) => c,
636 None => return,
637 };
638 let upstream = match account.provider {
639 crate::provider::Provider::OpenAI => "https://chatgpt.com",
640 crate::provider::Provider::Anthropic => "https://api.anthropic.com",
641 };
642 let url = format!("{}{}", upstream, path);
643
644 let do_get = |token: &str| -> reqwest::RequestBuilder {
645 let mut headers = reqwest::header::HeaderMap::new();
646 let _ = account.provider.inject_auth_headers(&mut headers, token);
647 client.get(&url).headers(headers)
648 };
649
650 let probe_token = &creds.access_token;
651 let resp = match do_get(probe_token).send().await {
652 Ok(r) => r,
653 Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
654 };
655
656 if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
657 tracing::info!(account = %account.name, "auth probe: access token rejected, refreshing");
658 let fresh = match account.provider.refresh_token(&creds).await {
659 Ok(f) => f,
660 Err(e) => {
661 tracing::warn!(account = %account.name, "token refresh failed: {e}");
662 state.set_auth_failed(&account.name);
663 return;
664 }
665 };
666 let mut store = crate::config::CredentialsStore::load();
667 store.accounts.insert(account.name.clone(), fresh.clone());
668 store.save().ok();
669 if fresh.id_token.is_some() {
670 crate::oauth::write_codex_auth_file(&fresh);
671 }
672
673 let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
674 match do_get(fresh_token).send().await {
675 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
676 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
677 state.set_auth_failed(&account.name);
678 }
679 Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
680 Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
681 }
682 } else {
683 tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
684 }
688}
689
690fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
697 let now_ms = std::time::SystemTime::now()
698 .duration_since(std::time::UNIX_EPOCH)
699 .unwrap_or_default()
700 .as_millis() as u64;
701 let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
702 .unwrap_or(cred.expires_at);
703 exp_ms < now_ms + threshold_mins * 60 * 1_000
704}
705
706async fn sync_live_creds_from_auth_json(
711 account_name: &str,
712 live_creds: &LiveCredentials,
713) {
714 let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
715 let current_exp = live_creds.read().await
716 .get(account_name)
717 .map(|c| c.expires_at)
718 .unwrap_or(0);
719 if from_file.expires_at > current_exp {
720 tracing::info!(account = %account_name, "synced fresher token from auth.json");
721 live_creds.write().await.insert(account_name.to_owned(), from_file);
722 }
723}
724
725async fn do_proactive_refresh(
727 account: &crate::config::AccountConfig,
728 creds: &crate::oauth::OAuthCredential,
729 live_creds: &LiveCredentials,
730 state: &StateStore,
731) {
732 tracing::info!(account = %account.name, "proactive OpenAI token refresh");
733 match account.provider.refresh_token(creds).await {
734 Ok(fresh) => {
735 tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
736 {
737 let mut map = live_creds.write().await;
738 map.insert(account.name.clone(), fresh.clone());
739 }
740 let mut store = crate::config::CredentialsStore::load();
741 store.accounts.insert(account.name.clone(), fresh.clone());
742 store.save().ok();
743 if fresh.id_token.is_some() {
744 crate::oauth::write_codex_auth_file(&fresh);
745 }
746 state.clear_auth_failed(&account.name);
747 }
748 Err(e) => {
749 tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
750 state.set_auth_failed(&account.name);
751 }
752 }
753}
754
755
756pub async fn openai_token_refresh_loop(
764 config: Arc<Config>,
765 state: StateStore,
766 live_creds: LiveCredentials,
767) {
768 for account in config.accounts.iter()
770 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
771 {
772 if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
773 continue;
774 }
775 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
776
777 let creds = {
778 let map = live_creds.read().await;
779 map.get(&account.name).cloned().or_else(|| account.credential.clone())
780 };
781 if let Some(creds) = creds {
782 if access_token_expires_soon(&creds, 30) {
783 do_proactive_refresh(account, &creds, &live_creds, &state).await;
785 } else {
786 tracing::info!(account = %account.name, "access_token fresh at startup");
787 }
788 }
789 }
790
791 loop {
794 tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
795 for account in config.accounts.iter()
796 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
797 {
798 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
799 }
800 }
801}
802
803enum ProxyError {
808 BodyRead,
809 Upstream,
810 AllAccountsUnavailable,
811 Unauthorized,
812}
813
814impl IntoResponse for ProxyError {
815 fn into_response(self) -> Response {
816 let (status, msg) = match self {
817 ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
818 ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
819 ProxyError::AllAccountsUnavailable => {
820 (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
821 }
822 ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
823 };
824
825 (status, axum::Json(json!({
826 "type": "error",
827 "error": {"type": "api_error", "message": msg}
828 }))).into_response()
829 }
830}
831
832pub async fn recovery_watcher(
841 config: Arc<Config>,
842 state: StateStore,
843 credentials: LiveCredentials,
844) {
845 use std::time::{Duration, Instant};
846 const CHECK_INTERVAL: Duration = Duration::from_secs(120);
847 const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
848
849 let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
850 let mut last_notified: Option<Instant> = None;
851
852 loop {
853 tokio::time::sleep(CHECK_INTERVAL).await;
854
855 let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
856 let failed = state.auth_failed_accounts(&name_refs);
857 if failed.is_empty() {
858 last_notified = None;
859 continue;
860 }
861
862 tracing::warn!(
863 accounts = ?failed,
864 "recovery: {} account(s) auth_failed, attempting token refresh",
865 failed.len()
866 );
867
868 let mut any_recovered = false;
869
870 for name in &failed {
871 let cred = {
872 let map = credentials.read().await;
873 map.get(*name).cloned()
874 };
875 let Some(cred) = cred else { continue };
876 if cred.refresh_token.is_empty() { continue; }
877
878 let provider = config.accounts.iter()
879 .find(|a| a.name == *name)
880 .map(|a| a.provider.clone())
881 .unwrap_or_default();
882
883 let result = tokio::time::timeout(
884 Duration::from_secs(20),
885 provider.refresh_token(&cred),
886 ).await;
887
888 match result {
889 Ok(Ok(fresh)) => {
890 tracing::info!(account = %name, "recovery: token refreshed — account back online");
891 {
892 let mut map = credentials.write().await;
893 map.insert(name.to_string(), fresh.clone());
894 }
895 let name_owned = name.to_string();
896 let fresh_owned = fresh.clone();
897 tokio::task::spawn_blocking(move || {
898 let mut store = crate::config::CredentialsStore::load();
899 store.accounts.insert(name_owned, fresh_owned.clone());
900 store.save().ok();
901 if fresh_owned.id_token.is_some() {
902 crate::oauth::write_codex_auth_file(&fresh_owned);
903 }
904 });
905 state.clear_auth_failed(name);
906 any_recovered = true;
907 }
908 Ok(Err(e)) => {
909 tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
910 notify(
911 "shunt: Reauth Required",
912 &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
913 "Basso",
914 );
915 }
916 Err(_) => {
917 tracing::error!(account = %name, "recovery: token refresh timed out");
918 notify(
919 "shunt: Reauth Required",
920 &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
921 "Basso",
922 );
923 }
924 }
925 }
926
927 if any_recovered {
928 tracing::info!("recovery: at least one account is back online");
929 continue;
930 }
931
932 let still_failed = state.auth_failed_accounts(&name_refs);
934 if still_failed.len() == account_names.len() {
935 let should_notify = last_notified
936 .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
937 .unwrap_or(true);
938 if should_notify {
939 error!(
940 "ALL accounts are offline (auth failed). \
941 Run `shunt add-account` to re-authorize."
942 );
943 notify(
944 "shunt: All Accounts Offline",
945 "All accounts need re-authorization. Run `shunt add-account`.",
946 "Basso",
947 );
948 last_notified = Some(Instant::now());
949 }
950 }
951 }
952}
953
954async fn post_cooldown_prefetch(
958 client: &reqwest::Client,
959 account: &crate::config::AccountConfig,
960 token: &str,
961 state: &StateStore,
962 upstream_url: &str,
963) {
964 let Some((path, body)) = account.provider.prefetch_request() else {
965 if let Some(probe_path) = account.provider.auth_probe_get_path() {
966 auth_probe_get(client, probe_path, account, state).await;
967 }
968 return;
969 };
970 let url = format!("{upstream_url}{path}");
971 match prefetch_send(client, &url, &account.provider, token, &body).await {
972 Ok(r) => {
973 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
974 state.update_rate_limits(&account.name, info);
975 tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
976 }
977 }
978 Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
979 }
980}
981
982pub async fn cooldown_watcher(
993 config: Arc<Config>,
994 state: StateStore,
995 credentials: LiveCredentials,
996) {
997 const STALE_RL_MS: u64 = 60 * 60_000;
999
1000 let client = reqwest::Client::builder()
1001 .timeout(std::time::Duration::from_secs(20))
1002 .build()
1003 .unwrap_or_default();
1004
1005 let mut last_resumed: HashMap<String, u64> = HashMap::new();
1008 let mut notify_on_resume: HashSet<String> = HashSet::new();
1010 let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1012
1013 loop {
1014 let states = state.account_states();
1015 let rl_snapshot = state.rate_limit_snapshot();
1016 let now = now_ms();
1017 let mut next_wake_ms: Option<u64> = None;
1018
1019 for account in &config.accounts {
1020 let Some(st) = states.get(&account.name) else { continue };
1021 if st.disabled { continue; } let cdl = st.cooldown_until_ms;
1023
1024 if cdl > 0 && cdl <= now {
1025 let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1027 if !handled {
1028 tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1029 let token = {
1030 let creds = credentials.read().await;
1031 creds.get(&account.name).map(|c| c.access_token.clone())
1032 };
1033 if let Some(token) = token {
1034 post_cooldown_prefetch(
1035 &client, account, &token, &state,
1036 &config.server.upstream_url,
1037 ).await;
1038 }
1039 if notify_on_resume.remove(&account.name) {
1040 notify(
1041 "shunt: Account Resumed",
1042 &format!("Account '{}' is back online.", account.name),
1043 "Glass",
1044 );
1045 }
1046 last_resumed.insert(account.name.clone(), cdl);
1047 last_stale_prefetch.insert(account.name.clone(), now);
1048 }
1049 } else if cdl > now {
1050 let remaining = cdl - now;
1052 if remaining >= 5 * 60_000 {
1053 notify_on_resume.insert(account.name.clone());
1054 }
1055 next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1056 } else {
1057 let rl_age = rl_snapshot
1059 .get(&account.name)
1060 .map(|r| now.saturating_sub(r.updated_ms))
1061 .unwrap_or(u64::MAX); let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1063 let fetched_ago = now.saturating_sub(last_fetched);
1064
1065 if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1066 tracing::debug!(
1067 account = %account.name,
1068 age_min = rl_age / 60_000,
1069 "rate-limit data stale — refreshing"
1070 );
1071 let token = {
1072 let creds = credentials.read().await;
1073 creds.get(&account.name).map(|c| c.access_token.clone())
1074 };
1075 if let Some(token) = token {
1076 post_cooldown_prefetch(
1077 &client, account, &token, &state,
1078 &config.server.upstream_url,
1079 ).await;
1080 }
1081 last_stale_prefetch.insert(account.name.clone(), now);
1082 }
1083 }
1084 }
1085
1086 let sleep_ms = next_wake_ms
1088 .map(|wake| wake.saturating_sub(now_ms()).max(50))
1089 .unwrap_or(30_000);
1090 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1091 }
1092}
1093
1094use crate::notify::notify;
1095
1096fn map_model(openai_model: &str) -> String {
1109 if openai_model.starts_with("claude-") {
1110 return openai_model.to_owned();
1111 }
1112 match openai_model {
1113 "gpt-4o" | "gpt-4.5" | "o1" | "o1-pro" | "o3" | "o3-pro" | "gpt-5" | "gpt-5.5" => {
1114 "claude-opus-4-6"
1115 }
1116 "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" | "o1-mini" | "o3-mini" => {
1117 "claude-haiku-4-5-20251001"
1118 }
1119 _ => "claude-sonnet-4-6",
1120 }.to_owned()
1121}
1122
1123fn translate_to_anthropic(body: serde_json::Value) -> serde_json::Value {
1125 let model = body["model"].as_str().unwrap_or("gpt-4o");
1126 let claude_model = map_model(model);
1127
1128 let mut system: Option<String> = None;
1130 let mut messages = Vec::new();
1131 if let Some(arr) = body["messages"].as_array() {
1132 for msg in arr {
1133 let role = msg["role"].as_str().unwrap_or("");
1134 if role == "system" {
1135 let content = msg["content"].as_str()
1137 .map(|s| s.to_owned())
1138 .unwrap_or_else(|| serde_json::to_string(&msg["content"]).unwrap_or_default());
1139 system = Some(content);
1140 } else if role == "tool" {
1141 let tool_use_id = msg["tool_call_id"].as_str().unwrap_or("").to_owned();
1143 let content = msg["content"].as_str().unwrap_or("").to_owned();
1144 messages.push(json!({
1145 "role": "user",
1146 "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}]
1147 }));
1148 } else {
1149 if let Some(tool_calls) = msg["tool_calls"].as_array() {
1151 let mut content_blocks: Vec<serde_json::Value> = Vec::new();
1152 if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
1153 content_blocks.push(json!({"type": "text", "text": text}));
1154 }
1155 for tc in tool_calls {
1156 content_blocks.push(json!({
1157 "type": "tool_use",
1158 "id": tc["id"].as_str().unwrap_or(""),
1159 "name": tc["function"]["name"].as_str().unwrap_or(""),
1160 "input": serde_json::from_str::<serde_json::Value>(
1161 tc["function"]["arguments"].as_str().unwrap_or("{}")
1162 ).unwrap_or(json!({})),
1163 }));
1164 }
1165 messages.push(json!({"role": "assistant", "content": content_blocks}));
1166 } else {
1167 let content = msg["content"].as_str().unwrap_or("").to_owned();
1168 messages.push(json!({ "role": role, "content": content }));
1169 }
1170 }
1171 }
1172 }
1173
1174 let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1175 let stream = body["stream"].as_bool().unwrap_or(false);
1176
1177 let mut req = json!({
1178 "model": claude_model,
1179 "messages": messages,
1180 "max_tokens": max_tokens,
1181 "stream": stream,
1182 });
1183
1184 if let Some(sys) = system {
1185 req["system"] = json!(sys);
1186 }
1187 if let Some(temp) = body.get("temperature") {
1188 req["temperature"] = temp.clone();
1189 }
1190 if let Some(sp) = body.get("stop") {
1191 req["stop_sequences"] = sp.clone();
1192 }
1193
1194 if let Some(tools) = body["tools"].as_array() {
1196 let claude_tools: Vec<serde_json::Value> = tools.iter().filter_map(|t| {
1197 let func = &t["function"];
1198 Some(json!({
1199 "name": func["name"].as_str()?,
1200 "description": func["description"].as_str().unwrap_or(""),
1201 "input_schema": func.get("parameters").cloned().unwrap_or(json!({"type": "object", "properties": {}})),
1202 }))
1203 }).collect();
1204 if !claude_tools.is_empty() {
1205 req["tools"] = json!(claude_tools);
1206 }
1207 }
1208
1209 req
1210}
1211
1212fn translate_from_anthropic(body: serde_json::Value) -> serde_json::Value {
1214 let id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1215 let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6").to_owned();
1216
1217 let mut text_content = String::new();
1219 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1220 if let Some(blocks) = body["content"].as_array() {
1221 for (idx, block) in blocks.iter().enumerate() {
1222 match block["type"].as_str() {
1223 Some("text") => {
1224 text_content.push_str(block["text"].as_str().unwrap_or(""));
1225 }
1226 Some("tool_use") => {
1227 let args = match &block["input"] {
1228 serde_json::Value::String(s) => s.clone(),
1229 v => serde_json::to_string(v).unwrap_or_default(),
1230 };
1231 tool_calls.push(json!({
1232 "id": block["id"].as_str().unwrap_or(""),
1233 "type": "function",
1234 "index": idx,
1235 "function": {
1236 "name": block["name"].as_str().unwrap_or(""),
1237 "arguments": args,
1238 }
1239 }));
1240 }
1241 _ => {}
1242 }
1243 }
1244 }
1245
1246 let stop_reason = body["stop_reason"].as_str().unwrap_or("end_turn");
1247 let finish_reason = match stop_reason {
1248 "end_turn" => "stop",
1249 "tool_use" => "tool_calls",
1250 "max_tokens" => "length",
1251 other => other,
1252 };
1253
1254 let input_tokens = body["usage"]["input_tokens"].as_u64().unwrap_or(0);
1255 let output_tokens = body["usage"]["output_tokens"].as_u64().unwrap_or(0);
1256
1257 let mut message = json!({"role": "assistant", "content": text_content});
1258 if !tool_calls.is_empty() {
1259 message["tool_calls"] = json!(tool_calls);
1260 }
1261
1262 json!({
1263 "id": id,
1264 "object": "chat.completion",
1265 "model": model,
1266 "choices": [{
1267 "index": 0,
1268 "message": message,
1269 "finish_reason": finish_reason,
1270 }],
1271 "usage": {
1272 "prompt_tokens": input_tokens,
1273 "completion_tokens": output_tokens,
1274 "total_tokens": input_tokens + output_tokens,
1275 }
1276 })
1277}
1278
1279fn uuid_v4() -> String {
1280 use crate::oauth::rand_bytes;
1281 let b: [u8; 16] = rand_bytes();
1282 format!("{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
1283 u32::from_be_bytes(b[0..4].try_into().unwrap()),
1284 u16::from_be_bytes(b[4..6].try_into().unwrap()),
1285 u16::from_be_bytes(b[6..8].try_into().unwrap()),
1286 u16::from_be_bytes(b[8..10].try_into().unwrap()),
1287 {
1288 let mut v = 0u64;
1289 for &x in &b[10..16] { v = (v << 8) | x as u64; }
1290 v
1291 }
1292 )
1293}
1294
1295async fn openai_models_handler() -> impl IntoResponse {
1297 axum::Json(json!({
1298 "object": "list",
1299 "data": [
1300 { "id": "claude-opus-4-6", "object": "model", "owned_by": "anthropic" },
1301 { "id": "claude-sonnet-4-6", "object": "model", "owned_by": "anthropic" },
1302 { "id": "claude-haiku-4-5-20251001", "object": "model", "owned_by": "anthropic" },
1303 ]
1304 }))
1305}
1306
1307async fn openai_compat_handler(
1309 State(s): State<AppState>,
1310 req: Request,
1311) -> Result<Response, ProxyError> {
1312 let Some(ref anthropic_url) = s.anthropic_base_url else {
1313 return proxy_handler(State(s), req).await;
1315 };
1316
1317 let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1318 .await
1319 .map_err(|_| ProxyError::BodyRead)?;
1320
1321 let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1322 .unwrap_or(json!({}));
1323
1324 let stream = openai_body["stream"].as_bool().unwrap_or(false);
1325 let anthropic_body = translate_to_anthropic(openai_body);
1326
1327 let client = reqwest::Client::builder()
1328 .timeout(std::time::Duration::from_secs(300))
1329 .build()
1330 .map_err(|_| ProxyError::Upstream)?;
1331
1332 let resp = client
1333 .post(format!("{anthropic_url}/v1/messages"))
1334 .header("content-type", "application/json")
1335 .header("anthropic-version", "2023-06-01")
1336 .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1337 .header("x-shunt-compat", "openai")
1338 .json(&anthropic_body)
1339 .send()
1340 .await
1341 .map_err(|_| ProxyError::Upstream)?;
1342
1343 if !resp.status().is_success() {
1344 let status = resp.status();
1345 let body = resp.text().await.unwrap_or_default();
1346 let code = status.as_u16();
1347 return Ok(axum::response::Response::builder()
1348 .status(code)
1349 .header("content-type", "application/json")
1350 .body(axum::body::Body::from(body))
1351 .unwrap());
1352 }
1353
1354 if stream {
1355 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1357 let stream = translate_anthropic_stream(resp, chat_id);
1358 Ok(axum::response::Response::builder()
1359 .status(200)
1360 .header("content-type", "text/event-stream")
1361 .header("cache-control", "no-cache")
1362 .body(axum::body::Body::from_stream(stream))
1363 .unwrap())
1364 } else {
1365 let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1366 let openai_resp = translate_from_anthropic(anthropic_resp);
1367 Ok(axum::Json(openai_resp).into_response())
1368 }
1369}
1370
1371fn translate_anthropic_stream(
1374 resp: reqwest::Response,
1375 chat_id: String,
1376) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1377 use futures_util::StreamExt;
1378
1379 let id = chat_id;
1380 let byte_stream = resp.bytes_stream();
1381
1382 async_stream::stream! {
1383 let mut buf = String::new();
1384 let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1386 let mut tool_call_count: usize = 0;
1387 futures_util::pin_mut!(byte_stream);
1388
1389 let init = format!(
1391 "data: {}\n\n",
1392 serde_json::to_string(&json!({
1393 "id": id,
1394 "object": "chat.completion.chunk",
1395 "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1396 })).unwrap()
1397 );
1398 yield Ok(bytes::Bytes::from(init));
1399
1400 while let Some(chunk) = byte_stream.next().await {
1401 let chunk = match chunk {
1402 Ok(c) => c,
1403 Err(_) => break,
1404 };
1405 buf.push_str(&String::from_utf8_lossy(&chunk));
1406
1407 while let Some(nl) = buf.find('\n') {
1409 let line = buf[..nl].trim_end_matches('\r').to_owned();
1410 buf = buf[nl + 1..].to_owned();
1411
1412 if !line.starts_with("data: ") { continue; }
1413 let data = &line["data: ".len()..];
1414 if data == "[DONE]" { continue; }
1415
1416 let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1417 let event_type = event["type"].as_str().unwrap_or("");
1418
1419 let maybe_chunk = match event_type {
1420 "content_block_start" => {
1421 let block_idx = event["index"].as_u64().unwrap_or(0);
1422 let cb = &event["content_block"];
1423 if cb["type"].as_str() == Some("tool_use") {
1424 let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
1425 let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
1426 let oai_idx = tool_call_count;
1427 tool_call_count += 1;
1428 tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
1429 Some(json!({
1430 "id": id,
1431 "object": "chat.completion.chunk",
1432 "choices": [{"index": 0, "delta": {
1433 "tool_calls": [{
1434 "index": oai_idx,
1435 "id": tool_id,
1436 "type": "function",
1437 "function": {"name": tool_name, "arguments": ""}
1438 }]
1439 }, "finish_reason": null}]
1440 }))
1441 } else {
1442 None
1443 }
1444 }
1445 "content_block_delta" => {
1446 let block_idx = event["index"].as_u64().unwrap_or(0);
1447 let delta = &event["delta"];
1448 match delta["type"].as_str() {
1449 Some("text_delta") => {
1450 let text = delta["text"].as_str().unwrap_or("");
1451 if text.is_empty() { continue; }
1452 Some(json!({
1453 "id": id,
1454 "object": "chat.completion.chunk",
1455 "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1456 }))
1457 }
1458 Some("input_json_delta") => {
1459 let args = delta["partial_json"].as_str().unwrap_or("");
1460 if let Some((oai_idx, _, _)) = tool_blocks.get(&block_idx) {
1461 Some(json!({
1462 "id": id,
1463 "object": "chat.completion.chunk",
1464 "choices": [{"index": 0, "delta": {
1465 "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
1466 }, "finish_reason": null}]
1467 }))
1468 } else {
1469 None
1470 }
1471 }
1472 _ => None,
1473 }
1474 }
1475 "message_delta" => {
1476 let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1477 let finish = match stop_reason {
1478 "end_turn" => "stop",
1479 "tool_use" => "tool_calls",
1480 "max_tokens" => "length",
1481 other => other,
1482 };
1483 Some(json!({
1484 "id": id,
1485 "object": "chat.completion.chunk",
1486 "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1487 }))
1488 }
1489 _ => None,
1490 };
1491
1492 if let Some(c) = maybe_chunk {
1493 let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
1494 yield Ok(bytes::Bytes::from(out));
1495 }
1496 }
1497 }
1498
1499 yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
1500 }
1501}