1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::forwarder::Forwarder;
16use crate::oauth::OAuthCredential;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24 config: Arc<Config>,
25 forwarder: Arc<Forwarder>,
26 state: StateStore,
27 credentials: Arc<RwLock<HashMap<String, OAuthCredential>>>,
29 started_ms: u64,
31 anthropic_base_url: Option<String>,
34}
35
36pub fn create_app(config: Config) -> anyhow::Result<Router> {
37 let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
38 Ok(app)
39}
40
41pub type LiveCredentials = Arc<RwLock<HashMap<String, OAuthCredential>>>;
43
44pub fn create_app_with_state(
45 config: Config,
46 state: StateStore,
47 anthropic_base_url: Option<String>,
48) -> anyhow::Result<(Router, LiveCredentials)> {
49 let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
50
51 for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
54 state.set_auth_failed(&a.name);
55 }
56
57 let credentials: LiveCredentials = Arc::new(RwLock::new(
58 config.accounts.iter()
59 .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
60 .collect::<HashMap<_, _>>(),
61 ));
62
63 let app_state = AppState {
64 config: Arc::new(config),
65 forwarder: Arc::new(forwarder),
66 state,
67 credentials: Arc::clone(&credentials),
68 started_ms: now_ms(),
69 anthropic_base_url,
70 };
71
72 let provider = app_state.config.accounts.first()
77 .map(|a| &a.provider)
78 .cloned()
79 .unwrap_or_default();
80
81 let proxy_routes = match provider {
82 Provider::Anthropic => Router::new()
83 .route("/v1/messages", post(proxy_handler))
84 .route("/v1/messages/count_tokens", post(proxy_handler)),
85 Provider::OpenAI => Router::new()
86 .route("/v1/chat/completions", post(openai_compat_handler))
87 .route("/v1/models", get(openai_models_handler))
88 .fallback(proxy_handler),
89 };
90
91 let app = Router::new()
92 .route("/health", get(health))
93 .route("/status", get(status_handler))
94 .route("/use", post(use_handler))
95 .merge(proxy_routes)
96 .with_state(app_state);
97
98 Ok((app, credentials))
99}
100
101async fn health() -> impl IntoResponse {
102 axum::Json(json!({"status": "ok"}))
103}
104
105async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
106 let account_states = s.state.account_states();
107 let quotas = s.state.quota_snapshot();
108 let rate_limits = s.state.rate_limit_snapshot();
109
110 let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
111 let st = account_states.get(&a.name);
112 let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
113 "reauth_required"
114 } else if st.map(|s| s.disabled).unwrap_or(false) {
115 "disabled"
116 } else if s.state.is_available(&a.name) {
117 "available"
118 } else {
119 "cooling"
120 };
121
122 let quota = quotas.get(&a.name);
123 let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
124 let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
125 let tokens_used = quota.map(|q| json!({
126 "input": q.input_tokens,
127 "output": q.output_tokens,
128 "total": q.total_tokens(),
129 }));
130
131 let rl = rate_limits.get(&a.name);
132 let rate_limit = rl.map(|r| json!({
133 "utilization_5h": r.utilization_5h,
134 "reset_5h": r.reset_5h,
135 "status_5h": r.status_5h,
136 "utilization_7d": r.utilization_7d,
137 "reset_7d": r.reset_7d,
138 "status_7d": r.status_7d,
139 "representative_claim": r.representative_claim,
140 "updated_ms": r.updated_ms,
141 }));
142
143 let acc_state = account_states.get(&a.name);
144 let email = a.credential.as_ref().and_then(|c| c.email.as_deref()).map(|e| e.to_owned());
145 let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
146 let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
147 let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
148 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
149 let reset_5h = rl.and_then(|r| r.reset_5h);
150 let total_tokens = quota.map(|q| q.total_tokens()).unwrap_or(0);
151 let available = s.state.is_available(&a.name);
152
153 json!({
154 "name": a.name,
155 "email": email,
156 "plan_type": a.plan_type,
157 "status": avail_status,
158 "available": available,
159 "disabled": disabled,
160 "auth_failed": auth_failed,
161 "cooldown_until_ms": cooldown_until_ms,
162 "utilization_5h": utilization_5h,
163 "reset_5h": reset_5h,
164 "total_tokens": total_tokens,
165 "window_expires_ms": window_expires_ms,
166 "tokens_used": tokens_used,
167 "rate_limit": rate_limit,
168 })
169 }).collect();
170
171 let recent_requests = s.state.recent_requests_snapshot();
172 let savings = s.state.savings_snapshot();
173
174 axum::Json(json!({
175 "version": env!("CARGO_PKG_VERSION"),
176 "started_ms": s.started_ms,
177 "accounts": accounts,
178 "pinned_account": s.state.get_pinned(),
179 "last_used_account": s.state.get_last_used(),
180 "recent_requests": recent_requests,
181 "savings": savings,
182 }))
183}
184
185async fn use_handler(
186 State(s): State<AppState>,
187 axum::Json(body): axum::Json<serde_json::Value>,
188) -> impl IntoResponse {
189 let account = body["account"].as_str().map(|s| s.to_owned());
190 if let Some(ref name) = account {
192 if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
193 return axum::Json(json!({
194 "error": format!("unknown account '{name}'")
195 }));
196 }
197 let pinned = if name == "auto" { None } else { Some(name.clone()) };
198 s.state.set_pinned(pinned);
199 axum::Json(json!({ "pinned": name }))
200 } else {
201 s.state.set_pinned(None);
202 axum::Json(json!({ "pinned": null }))
203 }
204}
205
206fn now_ms() -> u64 {
207 use std::time::{SystemTime, UNIX_EPOCH};
208 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
209}
210
211async fn proxy_handler(
212 State(s): State<AppState>,
213 req: Request,
214) -> Result<Response, ProxyError> {
215 if let Some(ref expected) = s.config.server.remote_key {
217 let provided = req.headers()
218 .get("x-api-key")
219 .and_then(|v| v.to_str().ok())
220 .unwrap_or("");
221 if provided != expected {
222 return Err(ProxyError::Unauthorized);
223 }
224 }
225
226 let method = req.method().as_str().to_owned();
227 let path = req.uri().path().to_owned();
228 let headers = req.headers().clone();
229
230 let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
231 .await
232 .map_err(|_| ProxyError::BodyRead)?;
233
234 let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
235 .ok()
236 .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
237 .unwrap_or_default();
238 let req_start_ms = now_ms();
239
240 let fp = router::fingerprint(&body_bytes);
241 let fp_ref = fp.as_deref();
242
243 let mut tried: HashSet<String> = HashSet::new();
244 let mut refreshed: HashSet<String> = HashSet::new();
246
247 loop {
248 let account = match router::pick_account(
249 &s.config.accounts, &s.state, fp_ref, &tried,
250 s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
251 ) {
252 Some(a) => a,
253 None => return Err(ProxyError::AllAccountsUnavailable),
254 };
255
256 let account_name = account.name.clone();
257
258 let token = {
262 let creds = s.credentials.read().await;
263 let cred = creds.get(&account_name)
264 .cloned()
265 .or_else(|| account.credential.clone());
266 match cred {
267 Some(c) if account.provider == crate::provider::Provider::OpenAI => {
268 c.id_token.unwrap_or(c.access_token)
269 }
270 Some(c) => c.access_token,
271 None => String::new(),
272 }
273 };
274
275 let response = s.forwarder
276 .forward(&method, &path, body_bytes.clone(), &headers, account, &token)
277 .await
278 .map_err(|e| {
279 error!("Forward error: {:#}", e);
280 ProxyError::Upstream
281 })?;
282
283 match response.status().as_u16() {
284 200..=299 => {
285 s.state.set_last_used(&account_name);
286 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
287 s.state.update_rate_limits(&account_name, info);
288 }
289 return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
290 }
291 429 => {
292 warn!(account = %account_name, "429 rate-limited — cooling 60s");
293 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
294 s.state.update_rate_limits(&account_name, info);
295 }
296 s.state.set_cooldown(&account_name, 60_000);
297 tried.insert(account_name);
298 }
299 529 => {
300 warn!(account = %account_name, "529 overloaded — cooling 30s");
301 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
302 s.state.update_rate_limits(&account_name, info);
303 }
304 s.state.set_cooldown(&account_name, 30_000);
305 tried.insert(account_name);
306 }
307 401 => {
308 if !refreshed.contains(&account_name) {
309 let cred = {
311 let creds = s.credentials.read().await;
312 creds.get(&account_name).cloned()
313 .or_else(|| account.credential.clone())
314 };
315 let Some(cred) = cred else {
316 tried.insert(account_name);
317 continue;
318 };
319 match tokio::time::timeout(
320 std::time::Duration::from_secs(10),
321 account.provider.refresh_token(&cred),
322 ).await {
323 Ok(Ok(fresh)) => {
324 warn!(account = %account_name, "401 — token refreshed, retrying");
325 {
326 let mut creds = s.credentials.write().await;
327 creds.insert(account_name.clone(), fresh.clone());
328 }
329 let name = account_name.clone();
331 let fresh = fresh.clone();
332 tokio::task::spawn_blocking(move || {
333 let mut store = CredentialsStore::load();
334 store.accounts.insert(name, fresh.clone());
335 store.save().ok();
336 if fresh.id_token.is_some() {
337 crate::oauth::write_codex_auth_file(&fresh);
338 }
339 });
340 refreshed.insert(account_name);
342 }
343 _ => {
344 error!(account = %account_name, "401 — token refresh failed, cooling 5min");
346 s.state.set_cooldown(&account_name, 5 * 60_000);
347 tried.insert(account_name);
348 }
349 }
350 } else {
351 error!(account = %account_name, "401 after refresh — cooling 5min");
353 s.state.set_cooldown(&account_name, 5 * 60_000);
354 tried.insert(account_name);
355 }
356 }
357 403 => {
358 error!(account = %account_name, "403 forbidden — cooling 30min");
360 s.state.set_cooldown(&account_name, 30 * 60_000);
361 tried.insert(account_name);
362 }
363 _ => {
364 return Ok(response);
366 }
367 }
368 }
369}
370
371async fn tap_usage(
380 resp: Response,
381 state: &StateStore,
382 account: &str,
383 model: &str,
384 req_start_ms: u64,
385) -> Response {
386 use axum::body::Body;
387 use crate::state::RequestLog;
388
389 if quota::is_streaming_response(&resp) {
390 let state = state.clone();
391 let account = account.to_owned();
392 let model = model.to_owned();
393 let on_complete = Arc::new(move |input: u64, output: u64| {
394 state.record_usage(&account, input, output);
395 state.record_global(&model, input, output);
396 state.record_request(RequestLog {
397 ts_ms: req_start_ms,
398 account: account.clone(),
399 model: model.clone(),
400 status: 200,
401 input_tokens: input,
402 output_tokens: output,
403 duration_ms: now_ms().saturating_sub(req_start_ms),
404 });
405 });
406 let (parts, body) = resp.into_parts();
407 let wrapped = quota::wrap_streaming_body(body, on_complete);
408 return Response::from_parts(parts, wrapped);
409 }
410
411 let (parts, body) = resp.into_parts();
413 let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
414 Ok(b) => b,
415 Err(_) => return Response::from_parts(parts, Body::empty()),
416 };
417 let (input, output) = quota::extract_usage_from_json(&bytes);
418 state.record_usage(account, input, output);
419 state.record_global(model, input, output);
420 state.record_request(RequestLog {
421 ts_ms: req_start_ms,
422 account: account.to_owned(),
423 model: model.to_owned(),
424 status: 200,
425 input_tokens: input,
426 output_tokens: output,
427 duration_ms: now_ms().saturating_sub(req_start_ms),
428 });
429 Response::from_parts(parts, Body::from(bytes))
430}
431
432
433pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore) {
441 let client = reqwest::Client::builder()
442 .timeout(std::time::Duration::from_secs(20))
443 .build()
444 .unwrap_or_default();
445
446 for account in &config.accounts {
447 let rl = state.rate_limit_snapshot();
449 if let Some(r) = rl.get(&account.name) {
450 if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
451 continue;
452 }
453 }
454
455 let creds = match account.credential.clone() {
457 Some(c) => c,
458 None => continue,
459 };
460
461 let Some((path, body)) = account.provider.prefetch_request() else {
462 if let Some(probe_path) = account.provider.auth_probe_get_path() {
464 auth_probe_get(&client, probe_path, account, &state).await;
465 }
466 continue;
467 };
468 let url = format!("{}{}", config.server.upstream_url, path);
469
470 let resp = prefetch_send(&client, &url, &account.provider, &creds.access_token, &body).await;
471
472 let r = match resp {
473 Ok(r) => r,
474 Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
475 };
476
477 if r.status() == reqwest::StatusCode::UNAUTHORIZED {
478 tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
479 let fresh = match account.provider.refresh_token(&creds).await {
480 Ok(f) => f,
481 Err(e) => {
482 tracing::warn!(account = %account.name, "token refresh failed: {e}");
483 state.set_auth_failed(&account.name);
484 continue;
485 }
486 };
487 let mut store = crate::config::CredentialsStore::load();
488 store.accounts.insert(account.name.clone(), fresh.clone());
489 store.save().ok();
490 if fresh.id_token.is_some() {
491 crate::oauth::write_codex_auth_file(&fresh);
492 }
493
494 match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
495 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
496 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
497 state.set_auth_failed(&account.name);
498 }
499 Ok(r2) => {
500 if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
501 state.update_rate_limits(&account.name, info);
502 }
503 }
504 Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
505 }
506 } else {
507 tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
508 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
509 state.update_rate_limits(&account.name, info);
510 }
511 }
512 }
513}
514
515async fn prefetch_send(
517 client: &reqwest::Client,
518 url: &str,
519 provider: &crate::provider::Provider,
520 token: &str,
521 body: &serde_json::Value,
522) -> anyhow::Result<reqwest::Response> {
523 let mut headers = reqwest::header::HeaderMap::new();
524 provider.inject_auth_headers(&mut headers, token)?;
525 for (name, value) in provider.prefetch_extra_headers() {
526 headers.insert(
527 reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
528 reqwest::header::HeaderValue::from_static(value),
529 );
530 }
531 Ok(client.post(url).headers(headers).json(body).send().await?)
532}
533
534async fn auth_probe_get(
538 client: &reqwest::Client,
539 path: &str,
540 account: &crate::config::AccountConfig,
541 state: &StateStore,
542) {
543 let creds = match account.credential.clone() {
544 Some(c) => c,
545 None => return,
546 };
547 let upstream = match account.provider {
548 crate::provider::Provider::OpenAI => "https://chatgpt.com",
549 crate::provider::Provider::Anthropic => "https://api.anthropic.com",
550 };
551 let url = format!("{}{}", upstream, path);
552
553 let do_get = |token: &str| -> reqwest::RequestBuilder {
554 let mut headers = reqwest::header::HeaderMap::new();
555 let _ = account.provider.inject_auth_headers(&mut headers, token);
556 client.get(&url).headers(headers)
557 };
558
559 let probe_token = creds.id_token.as_deref().unwrap_or(&creds.access_token);
561 let resp = match do_get(probe_token).send().await {
562 Ok(r) => r,
563 Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
564 };
565
566 if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
567 tracing::info!(account = %account.name, "auth probe: access token rejected, refreshing");
568 let fresh = match account.provider.refresh_token(&creds).await {
569 Ok(f) => f,
570 Err(e) => {
571 tracing::warn!(account = %account.name, "token refresh failed: {e}");
572 state.set_auth_failed(&account.name);
573 return;
574 }
575 };
576 let mut store = crate::config::CredentialsStore::load();
577 store.accounts.insert(account.name.clone(), fresh.clone());
578 store.save().ok();
579 if fresh.id_token.is_some() {
580 crate::oauth::write_codex_auth_file(&fresh);
581 }
582
583 let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
584 match do_get(fresh_token).send().await {
585 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
586 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
587 state.set_auth_failed(&account.name);
588 }
589 Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
590 Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
591 }
592 } else {
593 tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
594 }
598}
599
600fn id_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
607 let Some(ref id_tok) = cred.id_token else { return true };
608 let Some(exp_ms) = crate::oauth::jwt_exp_ms(id_tok) else { return true };
609 let now_ms = std::time::SystemTime::now()
610 .duration_since(std::time::UNIX_EPOCH)
611 .unwrap_or_default()
612 .as_millis() as u64;
613 exp_ms < now_ms + threshold_mins * 60 * 1_000
614}
615
616async fn sync_live_creds_from_auth_json(
621 account_name: &str,
622 live_creds: &LiveCredentials,
623) {
624 let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
625 let current_exp = live_creds.read().await
626 .get(account_name)
627 .map(|c| c.expires_at)
628 .unwrap_or(0);
629 if from_file.expires_at > current_exp {
630 tracing::info!(account = %account_name, "synced fresher token from auth.json");
631 live_creds.write().await.insert(account_name.to_owned(), from_file);
632 }
633}
634
635async fn do_proactive_refresh(
637 account: &crate::config::AccountConfig,
638 creds: &crate::oauth::OAuthCredential,
639 live_creds: &LiveCredentials,
640 state: &StateStore,
641) {
642 tracing::info!(account = %account.name, "proactive OpenAI token refresh");
643 match account.provider.refresh_token(creds).await {
644 Ok(fresh) => {
645 tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
646 {
647 let mut map = live_creds.write().await;
648 map.insert(account.name.clone(), fresh.clone());
649 }
650 let mut store = crate::config::CredentialsStore::load();
651 store.accounts.insert(account.name.clone(), fresh.clone());
652 store.save().ok();
653 if fresh.id_token.is_some() {
654 crate::oauth::write_codex_auth_file(&fresh);
655 }
656 state.clear_auth_failed(&account.name);
657 }
658 Err(e) => {
659 tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
660 state.set_auth_failed(&account.name);
661 }
662 }
663}
664
665async fn secs_until_next_refresh(config: &Config, live_creds: &LiveCredentials) -> u64 {
668 let now_ms = std::time::SystemTime::now()
669 .duration_since(std::time::UNIX_EPOCH)
670 .unwrap_or_default()
671 .as_millis() as u64;
672 const WAKE_BEFORE_MS: u64 = 15 * 60 * 1_000; const MAX_SLEEP_SECS: u64 = 45 * 60;
674 const MIN_SLEEP_SECS: u64 = 60;
675
676 let mut min_sleep_secs = MAX_SLEEP_SECS;
677
678 for account in config.accounts.iter()
679 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
680 {
681 let creds = live_creds.read().await.get(&account.name).cloned();
682 let Some(creds) = creds else { continue };
683 let Some(ref id_tok) = creds.id_token else { continue };
684 let Some(exp_ms) = crate::oauth::jwt_exp_ms(id_tok) else { continue };
685
686 let wake_ms = exp_ms.saturating_sub(WAKE_BEFORE_MS);
688 let sleep_ms = wake_ms.saturating_sub(now_ms);
689 let sleep_secs = (sleep_ms / 1_000).clamp(MIN_SLEEP_SECS, MAX_SLEEP_SECS);
690 min_sleep_secs = min_sleep_secs.min(sleep_secs);
691 }
692
693 min_sleep_secs
694}
695
696pub async fn openai_token_refresh_loop(
708 config: Arc<Config>,
709 state: StateStore,
710 live_creds: LiveCredentials,
711) {
712 for account in config.accounts.iter()
714 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
715 {
716 if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
717 continue;
718 }
719 let creds = {
720 let map = live_creds.read().await;
721 map.get(&account.name).cloned().or_else(|| account.credential.clone())
722 };
723 if let Some(creds) = creds {
724 if id_token_expires_soon(&creds, 2) {
725 do_proactive_refresh(account, &creds, &live_creds, &state).await;
727 } else {
728 tracing::info!(account = %account.name, "id_token fresh at startup — skipping immediate refresh");
729 }
730 }
731 }
732
733 loop {
734 let sleep_secs = secs_until_next_refresh(&config, &live_creds).await;
738 tracing::debug!("next OpenAI token refresh check in {}m {}s",
739 sleep_secs / 60, sleep_secs % 60);
740 tokio::time::sleep(std::time::Duration::from_secs(sleep_secs)).await;
741
742 for account in config.accounts.iter()
743 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
744 {
745 if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
746 continue;
747 }
748
749 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
751
752 let creds = {
753 let map = live_creds.read().await;
754 map.get(&account.name).cloned().or_else(|| account.credential.clone())
755 };
756 let Some(creds) = creds else { continue };
757
758 if !id_token_expires_soon(&creds, 15) {
760 tracing::debug!(account = %account.name, "id_token still fresh, skipping refresh");
761 continue;
762 }
763
764 do_proactive_refresh(account, &creds, &live_creds, &state).await;
765 }
766 }
767}
768
769enum ProxyError {
774 BodyRead,
775 Upstream,
776 AllAccountsUnavailable,
777 Unauthorized,
778}
779
780impl IntoResponse for ProxyError {
781 fn into_response(self) -> Response {
782 let (status, msg) = match self {
783 ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
784 ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
785 ProxyError::AllAccountsUnavailable => {
786 (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
787 }
788 ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
789 };
790
791 (status, axum::Json(json!({
792 "type": "error",
793 "error": {"type": "api_error", "message": msg}
794 }))).into_response()
795 }
796}
797
798pub async fn recovery_watcher(
807 config: Arc<Config>,
808 state: StateStore,
809 credentials: LiveCredentials,
810) {
811 use std::time::{Duration, Instant};
812 const CHECK_INTERVAL: Duration = Duration::from_secs(120);
813 const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
814
815 let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
816 let mut last_notified: Option<Instant> = None;
817
818 loop {
819 tokio::time::sleep(CHECK_INTERVAL).await;
820
821 let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
822 let failed = state.auth_failed_accounts(&name_refs);
823 if failed.is_empty() {
824 last_notified = None;
825 continue;
826 }
827
828 tracing::warn!(
829 accounts = ?failed,
830 "recovery: {} account(s) auth_failed, attempting token refresh",
831 failed.len()
832 );
833
834 let mut any_recovered = false;
835
836 for name in &failed {
837 let cred = {
838 let map = credentials.read().await;
839 map.get(*name).cloned()
840 };
841 let Some(cred) = cred else { continue };
842 if cred.refresh_token.is_empty() { continue; }
843
844 let provider = config.accounts.iter()
845 .find(|a| a.name == *name)
846 .map(|a| a.provider.clone())
847 .unwrap_or_default();
848
849 let result = tokio::time::timeout(
850 Duration::from_secs(20),
851 provider.refresh_token(&cred),
852 ).await;
853
854 match result {
855 Ok(Ok(fresh)) => {
856 tracing::info!(account = %name, "recovery: token refreshed — account back online");
857 {
858 let mut map = credentials.write().await;
859 map.insert(name.to_string(), fresh.clone());
860 }
861 let name_owned = name.to_string();
862 let fresh_owned = fresh.clone();
863 tokio::task::spawn_blocking(move || {
864 let mut store = crate::config::CredentialsStore::load();
865 store.accounts.insert(name_owned, fresh_owned.clone());
866 store.save().ok();
867 if fresh_owned.id_token.is_some() {
868 crate::oauth::write_codex_auth_file(&fresh_owned);
869 }
870 });
871 state.clear_auth_failed(name);
872 any_recovered = true;
873 }
874 Ok(Err(e)) => {
875 tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
876 }
877 Err(_) => {
878 tracing::error!(account = %name, "recovery: token refresh timed out");
879 }
880 }
881 }
882
883 if any_recovered {
884 tracing::info!("recovery: at least one account is back online");
885 continue;
886 }
887
888 let still_failed = state.auth_failed_accounts(&name_refs);
890 if still_failed.len() == account_names.len() {
891 let should_notify = last_notified
892 .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
893 .unwrap_or(true);
894 if should_notify {
895 error!(
896 "ALL accounts are offline (auth failed). \
897 Run `shunt add-account` to re-authorize."
898 );
899 notify_all_accounts_offline();
900 last_notified = Some(Instant::now());
901 }
902 }
903 }
904}
905
906fn notify_all_accounts_offline() {
907 #[cfg(target_os = "macos")]
908 {
909 let _ = std::process::Command::new("osascript")
910 .args(["-e", concat!(
911 r#"display notification "#,
912 r#""All accounts have lost authentication. Run `shunt add-account` to re-authorize." "#,
913 r#"with title "shunt: All Accounts Offline" sound name "Basso""#
914 )])
915 .status();
916 }
917}
918
919fn map_model(openai_model: &str) -> &'static str {
931 match openai_model {
932 m if m.starts_with("claude-") => {
933 if m.contains("opus") { "claude-opus-4-6" }
936 else if m.contains("haiku") { "claude-haiku-4-5-20251001" }
937 else { "claude-sonnet-4-6" }
938 }
939 "gpt-4o" | "gpt-4.5" | "o1" | "o1-pro" | "o3" | "o3-pro" | "gpt-5" | "gpt-5.5" => {
940 "claude-opus-4-6"
941 }
942 "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" | "o1-mini" | "o3-mini" => {
943 "claude-haiku-4-5-20251001"
944 }
945 _ => "claude-sonnet-4-6",
946 }
947}
948
949fn translate_to_anthropic(body: serde_json::Value) -> serde_json::Value {
951 let model = body["model"].as_str().unwrap_or("gpt-4o");
952 let claude_model = map_model(model).to_owned();
953
954 let mut system: Option<String> = None;
956 let mut messages = Vec::new();
957 if let Some(arr) = body["messages"].as_array() {
958 for msg in arr {
959 let role = msg["role"].as_str().unwrap_or("");
960 let content = msg["content"].as_str().unwrap_or("").to_owned();
961 if role == "system" {
962 system = Some(content);
963 } else {
964 messages.push(json!({ "role": role, "content": content }));
965 }
966 }
967 }
968
969 let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
970 let stream = body["stream"].as_bool().unwrap_or(false);
971
972 let mut req = json!({
973 "model": claude_model,
974 "messages": messages,
975 "max_tokens": max_tokens,
976 "stream": stream,
977 });
978
979 if let Some(sys) = system {
980 req["system"] = json!(sys);
981 }
982 if let Some(temp) = body.get("temperature") {
983 req["temperature"] = temp.clone();
984 }
985 if let Some(sp) = body.get("stop") {
986 req["stop_sequences"] = sp.clone();
987 }
988
989 req
990}
991
992fn translate_from_anthropic(body: serde_json::Value) -> serde_json::Value {
994 let id = format!("chatcmpl-{}", &uuid_v4()[..8]);
995 let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6").to_owned();
996 let content = body["content"]
997 .as_array()
998 .and_then(|arr| arr.iter().find_map(|b| b["text"].as_str()))
999 .unwrap_or("")
1000 .to_owned();
1001 let stop_reason = body["stop_reason"].as_str().unwrap_or("end_turn");
1002 let finish_reason = if stop_reason == "end_turn" { "stop" } else { stop_reason };
1003 let input_tokens = body["usage"]["input_tokens"].as_u64().unwrap_or(0);
1004 let output_tokens = body["usage"]["output_tokens"].as_u64().unwrap_or(0);
1005
1006 json!({
1007 "id": id,
1008 "object": "chat.completion",
1009 "model": model,
1010 "choices": [{
1011 "index": 0,
1012 "message": { "role": "assistant", "content": content },
1013 "finish_reason": finish_reason,
1014 }],
1015 "usage": {
1016 "prompt_tokens": input_tokens,
1017 "completion_tokens": output_tokens,
1018 "total_tokens": input_tokens + output_tokens,
1019 }
1020 })
1021}
1022
1023fn uuid_v4() -> String {
1024 use crate::oauth::rand_bytes;
1025 let b: [u8; 16] = rand_bytes();
1026 format!("{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
1027 u32::from_be_bytes(b[0..4].try_into().unwrap()),
1028 u16::from_be_bytes(b[4..6].try_into().unwrap()),
1029 u16::from_be_bytes(b[6..8].try_into().unwrap()),
1030 u16::from_be_bytes(b[8..10].try_into().unwrap()),
1031 {
1032 let mut v = 0u64;
1033 for &x in &b[10..16] { v = (v << 8) | x as u64; }
1034 v
1035 }
1036 )
1037}
1038
1039async fn openai_models_handler() -> impl IntoResponse {
1041 axum::Json(json!({
1042 "object": "list",
1043 "data": [
1044 { "id": "claude-opus-4-6", "object": "model", "owned_by": "anthropic" },
1045 { "id": "claude-sonnet-4-6", "object": "model", "owned_by": "anthropic" },
1046 { "id": "claude-haiku-4-5-20251001", "object": "model", "owned_by": "anthropic" },
1047 ]
1048 }))
1049}
1050
1051async fn openai_compat_handler(
1053 State(s): State<AppState>,
1054 req: Request,
1055) -> Result<Response, ProxyError> {
1056 let Some(ref anthropic_url) = s.anthropic_base_url else {
1057 return proxy_handler(State(s), req).await;
1059 };
1060
1061 let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1062 .await
1063 .map_err(|_| ProxyError::BodyRead)?;
1064
1065 let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1066 .unwrap_or(json!({}));
1067
1068 let stream = openai_body["stream"].as_bool().unwrap_or(false);
1069 let anthropic_body = translate_to_anthropic(openai_body);
1070
1071 let client = reqwest::Client::builder()
1072 .timeout(std::time::Duration::from_secs(300))
1073 .build()
1074 .map_err(|_| ProxyError::Upstream)?;
1075
1076 let resp = client
1077 .post(format!("{anthropic_url}/v1/messages"))
1078 .header("content-type", "application/json")
1079 .header("anthropic-version", "2023-06-01")
1080 .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1081 .header("x-shunt-compat", "openai")
1082 .json(&anthropic_body)
1083 .send()
1084 .await
1085 .map_err(|_| ProxyError::Upstream)?;
1086
1087 if !resp.status().is_success() {
1088 let status = resp.status();
1089 let body = resp.text().await.unwrap_or_default();
1090 let code = status.as_u16();
1091 return Ok(axum::response::Response::builder()
1092 .status(code)
1093 .header("content-type", "application/json")
1094 .body(axum::body::Body::from(body))
1095 .unwrap());
1096 }
1097
1098 if stream {
1099 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1101 let stream = translate_anthropic_stream(resp, chat_id);
1102 Ok(axum::response::Response::builder()
1103 .status(200)
1104 .header("content-type", "text/event-stream")
1105 .header("cache-control", "no-cache")
1106 .body(axum::body::Body::from_stream(stream))
1107 .unwrap())
1108 } else {
1109 let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1110 let openai_resp = translate_from_anthropic(anthropic_resp);
1111 Ok(axum::Json(openai_resp).into_response())
1112 }
1113}
1114
1115fn translate_anthropic_stream(
1117 resp: reqwest::Response,
1118 chat_id: String,
1119) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1120 use futures_util::StreamExt;
1121
1122 let id = chat_id;
1123 let byte_stream = resp.bytes_stream();
1124
1125 async_stream::stream! {
1126 let mut buf = String::new();
1127 futures_util::pin_mut!(byte_stream);
1128
1129 let init = format!(
1131 "data: {}\n\n",
1132 serde_json::to_string(&json!({
1133 "id": id,
1134 "object": "chat.completion.chunk",
1135 "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1136 })).unwrap()
1137 );
1138 yield Ok(bytes::Bytes::from(init));
1139
1140 while let Some(chunk) = byte_stream.next().await {
1141 let chunk = match chunk {
1142 Ok(c) => c,
1143 Err(_) => break,
1144 };
1145 buf.push_str(&String::from_utf8_lossy(&chunk));
1146
1147 while let Some(nl) = buf.find('\n') {
1149 let line = buf[..nl].trim_end_matches('\r').to_owned();
1150 buf = buf[nl + 1..].to_owned();
1151
1152 if !line.starts_with("data: ") { continue; }
1153 let data = &line["data: ".len()..];
1154 if data == "[DONE]" { continue; }
1155
1156 let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1157 let event_type = event["type"].as_str().unwrap_or("");
1158
1159 let maybe_chunk = match event_type {
1160 "content_block_delta" => {
1161 let text = event["delta"]["text"].as_str().unwrap_or("");
1162 if text.is_empty() { continue; }
1163 Some(json!({
1164 "id": id,
1165 "object": "chat.completion.chunk",
1166 "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1167 }))
1168 }
1169 "message_delta" => {
1170 let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1171 let finish = if stop_reason == "end_turn" { "stop" } else { stop_reason };
1172 Some(json!({
1173 "id": id,
1174 "object": "chat.completion.chunk",
1175 "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1176 }))
1177 }
1178 _ => None,
1179 };
1180
1181 if let Some(c) = maybe_chunk {
1182 let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
1183 yield Ok(bytes::Bytes::from(out));
1184 }
1185 }
1186 }
1187
1188 yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
1189 }
1190}