1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::forwarder::Forwarder;
16use crate::oauth::OAuthCredential;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24 config: Arc<Config>,
25 forwarder: Arc<Forwarder>,
26 state: StateStore,
27 credentials: Arc<RwLock<HashMap<String, OAuthCredential>>>,
29 started_ms: u64,
31 anthropic_base_url: Option<String>,
34}
35
36pub fn create_app(config: Config) -> anyhow::Result<Router> {
37 let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
38 Ok(app)
39}
40
41pub type LiveCredentials = Arc<RwLock<HashMap<String, OAuthCredential>>>;
43
44pub fn create_app_with_state(
45 config: Config,
46 state: StateStore,
47 anthropic_base_url: Option<String>,
48) -> anyhow::Result<(Router, LiveCredentials)> {
49 let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
50
51 for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
54 state.set_auth_failed(&a.name);
55 }
56
57 let credentials: LiveCredentials = Arc::new(RwLock::new(
58 config.accounts.iter()
59 .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
60 .collect::<HashMap<_, _>>(),
61 ));
62
63 let app_state = AppState {
64 config: Arc::new(config),
65 forwarder: Arc::new(forwarder),
66 state,
67 credentials: Arc::clone(&credentials),
68 started_ms: now_ms(),
69 anthropic_base_url,
70 };
71
72 let provider = app_state.config.accounts.first()
77 .map(|a| &a.provider)
78 .cloned()
79 .unwrap_or_default();
80
81 let proxy_routes = match provider {
82 Provider::Anthropic => Router::new()
83 .route("/v1/messages", post(proxy_handler))
84 .route("/v1/messages/count_tokens", post(proxy_handler)),
85 Provider::OpenAI => Router::new()
86 .route("/v1/chat/completions", post(openai_compat_handler))
87 .route("/v1/models", get(openai_models_handler))
88 .fallback(proxy_handler),
89 };
90
91 let app = Router::new()
92 .route("/health", get(health))
93 .route("/status", get(status_handler))
94 .route("/use", post(use_handler))
95 .merge(proxy_routes)
96 .with_state(app_state);
97
98 Ok((app, credentials))
99}
100
101async fn health() -> impl IntoResponse {
102 axum::Json(json!({"status": "ok"}))
103}
104
105async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
106 let account_states = s.state.account_states();
107 let quotas = s.state.quota_snapshot();
108 let rate_limits = s.state.rate_limit_snapshot();
109
110 let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
111 let st = account_states.get(&a.name);
112 let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
113 "reauth_required"
114 } else if st.map(|s| s.disabled).unwrap_or(false) {
115 "disabled"
116 } else if s.state.is_available(&a.name) {
117 "available"
118 } else {
119 "cooling"
120 };
121
122 let quota = quotas.get(&a.name);
123 let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
124 let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
125 let tokens_used = quota.map(|q| json!({
126 "input": q.input_tokens,
127 "output": q.output_tokens,
128 "total": q.total_tokens(),
129 }));
130
131 let rl = rate_limits.get(&a.name);
132 let rate_limit = rl.map(|r| json!({
133 "utilization_5h": r.utilization_5h,
134 "reset_5h": r.reset_5h,
135 "status_5h": r.status_5h,
136 "utilization_7d": r.utilization_7d,
137 "reset_7d": r.reset_7d,
138 "status_7d": r.status_7d,
139 "representative_claim": r.representative_claim,
140 "updated_ms": r.updated_ms,
141 }));
142
143 let acc_state = account_states.get(&a.name);
144 let email = a.credential.as_ref().and_then(|c| c.email.as_deref()).map(|e| e.to_owned());
145 let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
146 let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
147 let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
148 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
149 let reset_5h = rl.and_then(|r| r.reset_5h);
150 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
151 let reset_7d = rl.and_then(|r| r.reset_7d);
152 let available = s.state.is_available(&a.name);
153
154 json!({
155 "name": a.name,
156 "email": email,
157 "plan_type": a.plan_type,
158 "status": avail_status,
159 "available": available,
160 "disabled": disabled,
161 "auth_failed": auth_failed,
162 "cooldown_until_ms": cooldown_until_ms,
163 "utilization_5h": utilization_5h,
164 "reset_5h": reset_5h,
165 "utilization_7d": utilization_7d,
166 "reset_7d": reset_7d,
167 "window_expires_ms": window_expires_ms,
168 "tokens_used": tokens_used,
169 "rate_limit": rate_limit,
170 })
171 }).collect();
172
173 let recent_requests = s.state.recent_requests_snapshot();
174 let savings = s.state.savings_snapshot();
175
176 axum::Json(json!({
177 "version": env!("CARGO_PKG_VERSION"),
178 "started_ms": s.started_ms,
179 "accounts": accounts,
180 "pinned_account": s.state.get_pinned(),
181 "last_used_account": s.state.get_last_used(),
182 "recent_requests": recent_requests,
183 "savings": savings,
184 }))
185}
186
187async fn use_handler(
188 State(s): State<AppState>,
189 axum::Json(body): axum::Json<serde_json::Value>,
190) -> impl IntoResponse {
191 let account = body["account"].as_str().map(|s| s.to_owned());
192 if let Some(ref name) = account {
194 if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
195 return axum::Json(json!({
196 "error": format!("unknown account '{name}'")
197 }));
198 }
199 let pinned = if name == "auto" { None } else { Some(name.clone()) };
200 s.state.set_pinned(pinned);
201 axum::Json(json!({ "pinned": name }))
202 } else {
203 s.state.set_pinned(None);
204 axum::Json(json!({ "pinned": null }))
205 }
206}
207
208fn now_ms() -> u64 {
209 use std::time::{SystemTime, UNIX_EPOCH};
210 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
211}
212
213async fn proxy_handler(
214 State(s): State<AppState>,
215 req: Request,
216) -> Result<Response, ProxyError> {
217 if let Some(ref expected) = s.config.server.remote_key {
219 let provided = req.headers()
220 .get("x-api-key")
221 .and_then(|v| v.to_str().ok())
222 .unwrap_or("");
223 if provided != expected {
224 return Err(ProxyError::Unauthorized);
225 }
226 }
227
228 let method = req.method().as_str().to_owned();
229 let path = req.uri().path().to_owned();
230 let headers = req.headers().clone();
231
232 let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
233 .await
234 .map_err(|_| ProxyError::BodyRead)?;
235
236 let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
237 .ok()
238 .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
239 .unwrap_or_default();
240 let req_start_ms = now_ms();
241
242 let fp = router::fingerprint(&body_bytes);
243 let fp_ref = fp.as_deref();
244
245 let mut tried: HashSet<String> = HashSet::new();
246 let mut refreshed: HashSet<String> = HashSet::new();
248 let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
250
251 loop {
252 let account = match router::pick_account(
253 &s.config.accounts, &s.state, fp_ref, &tried,
254 s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
255 ) {
256 Some(a) => a,
257 None => {
258 let account_states = s.state.account_states();
262 let now = now_ms();
263 let soonest_ms = s.config.accounts.iter()
264 .filter_map(|a| {
265 let st = account_states.get(&a.name)?;
266 if st.disabled { return None; } if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
268 })
269 .min();
270
271 match soonest_ms {
272 Some(wake_ms) if wake_ms <= wait_deadline_ms => {
273 let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; warn!(wait_ms, "all accounts cooling — waiting for next available account");
275 tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
276 tried.clear(); }
278 _ => return Err(ProxyError::AllAccountsUnavailable),
279 }
280 continue;
281 }
282 };
283
284 let account_name = account.name.clone();
285
286 let token = {
290 let creds = s.credentials.read().await;
291 let cred = creds.get(&account_name)
292 .cloned()
293 .or_else(|| account.credential.clone());
294 match cred {
295 Some(c) => c.access_token,
296 None => String::new(),
297 }
298 };
299
300 let response = s.forwarder
301 .forward(&method, &path, body_bytes.clone(), &headers, account, &token)
302 .await
303 .map_err(|e| {
304 error!("Forward error: {:#}", e);
305 ProxyError::Upstream
306 })?;
307
308 match response.status().as_u16() {
309 200..=299 => {
310 s.state.set_last_used(&account_name);
311 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
312 s.state.update_rate_limits(&account_name, info);
313 }
314 return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
315 }
316 429 => {
317 let info = account.provider.parse_rate_limits(response.headers());
318 let cooldown_ms = info.as_ref()
321 .and_then(|i| i.reset_5h.or(i.reset_7d))
322 .map(|reset_secs| {
323 let reset_ms = reset_secs.saturating_mul(1_000);
324 reset_ms.saturating_sub(now_ms()).saturating_add(500) })
326 .unwrap_or(60_000);
327 warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
328 if let Some(info) = info {
329 s.state.update_rate_limits(&account_name, info);
330 }
331 s.state.set_cooldown(&account_name, cooldown_ms);
332 if cooldown_ms >= 5 * 60_000 {
333 let mins = cooldown_ms / 60_000;
334 notify(
335 "shunt: Rate Limited",
336 &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
337 "Ping",
338 );
339 }
340 tried.insert(account_name);
341 }
342 529 => {
343 warn!(account = %account_name, "529 overloaded — cooling 30s");
344 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
345 s.state.update_rate_limits(&account_name, info);
346 }
347 s.state.set_cooldown(&account_name, 30_000);
348 tried.insert(account_name);
349 }
350 401 => {
351 if !refreshed.contains(&account_name) {
352 let cred = {
354 let creds = s.credentials.read().await;
355 creds.get(&account_name).cloned()
356 .or_else(|| account.credential.clone())
357 };
358 let Some(cred) = cred else {
359 tried.insert(account_name);
360 continue;
361 };
362 match tokio::time::timeout(
363 std::time::Duration::from_secs(10),
364 account.provider.refresh_token(&cred),
365 ).await {
366 Ok(Ok(fresh)) => {
367 warn!(account = %account_name, "401 — token refreshed, retrying");
368 {
369 let mut creds = s.credentials.write().await;
370 creds.insert(account_name.clone(), fresh.clone());
371 }
372 let name = account_name.clone();
374 let fresh = fresh.clone();
375 tokio::task::spawn_blocking(move || {
376 let mut store = CredentialsStore::load();
377 store.accounts.insert(name, fresh.clone());
378 store.save().ok();
379 if fresh.id_token.is_some() {
380 crate::oauth::write_codex_auth_file(&fresh);
381 }
382 });
383 refreshed.insert(account_name);
385 }
386 _ => {
387 error!(account = %account_name, "401 — token refresh failed, cooling 5min");
389 s.state.set_cooldown(&account_name, 5 * 60_000);
390 tried.insert(account_name);
391 }
392 }
393 } else {
394 error!(account = %account_name, "401 after refresh — cooling 5min");
396 s.state.set_cooldown(&account_name, 5 * 60_000);
397 tried.insert(account_name);
398 }
399 }
400 403 => {
401 error!(account = %account_name, "403 forbidden — cooling 30min");
403 s.state.set_cooldown(&account_name, 30 * 60_000);
404 notify(
405 "shunt: Account Forbidden",
406 &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
407 "Basso",
408 );
409 tried.insert(account_name);
410 }
411 _ => {
412 return Ok(response);
414 }
415 }
416 }
417}
418
419async fn tap_usage(
428 resp: Response,
429 state: &StateStore,
430 account: &str,
431 model: &str,
432 req_start_ms: u64,
433) -> Response {
434 use axum::body::Body;
435 use crate::state::RequestLog;
436
437 if quota::is_streaming_response(&resp) {
438 let state = state.clone();
439 let account = account.to_owned();
440 let model = model.to_owned();
441 let on_complete = Arc::new(move |input: u64, output: u64| {
442 state.record_usage(&account, input, output);
443 state.record_global(&model, input, output);
444 state.record_request(RequestLog {
445 ts_ms: req_start_ms,
446 account: account.clone(),
447 model: model.clone(),
448 status: 200,
449 input_tokens: input,
450 output_tokens: output,
451 duration_ms: now_ms().saturating_sub(req_start_ms),
452 });
453 });
454 let (parts, body) = resp.into_parts();
455 let wrapped = quota::wrap_streaming_body(body, on_complete);
456 return Response::from_parts(parts, wrapped);
457 }
458
459 let (parts, body) = resp.into_parts();
461 let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
462 Ok(b) => b,
463 Err(_) => return Response::from_parts(parts, Body::empty()),
464 };
465 let (input, output) = quota::extract_usage_from_json(&bytes);
466 state.record_usage(account, input, output);
467 state.record_global(model, input, output);
468 state.record_request(RequestLog {
469 ts_ms: req_start_ms,
470 account: account.to_owned(),
471 model: model.to_owned(),
472 status: 200,
473 input_tokens: input,
474 output_tokens: output,
475 duration_ms: now_ms().saturating_sub(req_start_ms),
476 });
477 Response::from_parts(parts, Body::from(bytes))
478}
479
480
481pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore) {
489 let client = reqwest::Client::builder()
490 .timeout(std::time::Duration::from_secs(20))
491 .build()
492 .unwrap_or_default();
493
494 for account in &config.accounts {
495 let rl = state.rate_limit_snapshot();
497 if let Some(r) = rl.get(&account.name) {
498 if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
499 continue;
500 }
501 }
502
503 let creds = match account.credential.clone() {
505 Some(c) => c,
506 None => continue,
507 };
508
509 let Some((path, body)) = account.provider.prefetch_request() else {
510 if let Some(probe_path) = account.provider.auth_probe_get_path() {
512 auth_probe_get(&client, probe_path, account, &state).await;
513 }
514 continue;
515 };
516 let url = format!("{}{}", config.server.upstream_url, path);
517
518 let resp = prefetch_send(&client, &url, &account.provider, &creds.access_token, &body).await;
519
520 let r = match resp {
521 Ok(r) => r,
522 Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
523 };
524
525 if r.status() == reqwest::StatusCode::UNAUTHORIZED {
526 tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
527 let fresh = match account.provider.refresh_token(&creds).await {
528 Ok(f) => f,
529 Err(e) => {
530 tracing::warn!(account = %account.name, "token refresh failed: {e}");
531 state.set_auth_failed(&account.name);
532 continue;
533 }
534 };
535 let mut store = crate::config::CredentialsStore::load();
536 store.accounts.insert(account.name.clone(), fresh.clone());
537 store.save().ok();
538 if fresh.id_token.is_some() {
539 crate::oauth::write_codex_auth_file(&fresh);
540 }
541
542 match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
543 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
544 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
545 state.set_auth_failed(&account.name);
546 }
547 Ok(r2) => {
548 if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
549 state.update_rate_limits(&account.name, info);
550 }
551 }
552 Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
553 }
554 } else {
555 tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
556 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
557 state.update_rate_limits(&account.name, info);
558 }
559 }
560 }
561}
562
563async fn prefetch_send(
565 client: &reqwest::Client,
566 url: &str,
567 provider: &crate::provider::Provider,
568 token: &str,
569 body: &serde_json::Value,
570) -> anyhow::Result<reqwest::Response> {
571 let mut headers = reqwest::header::HeaderMap::new();
572 provider.inject_auth_headers(&mut headers, token)?;
573 for (name, value) in provider.prefetch_extra_headers() {
574 headers.insert(
575 reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
576 reqwest::header::HeaderValue::from_static(value),
577 );
578 }
579 Ok(client.post(url).headers(headers).json(body).send().await?)
580}
581
582async fn auth_probe_get(
586 client: &reqwest::Client,
587 path: &str,
588 account: &crate::config::AccountConfig,
589 state: &StateStore,
590) {
591 let creds = match account.credential.clone() {
592 Some(c) => c,
593 None => return,
594 };
595 let upstream = match account.provider {
596 crate::provider::Provider::OpenAI => "https://chatgpt.com",
597 crate::provider::Provider::Anthropic => "https://api.anthropic.com",
598 };
599 let url = format!("{}{}", upstream, path);
600
601 let do_get = |token: &str| -> reqwest::RequestBuilder {
602 let mut headers = reqwest::header::HeaderMap::new();
603 let _ = account.provider.inject_auth_headers(&mut headers, token);
604 client.get(&url).headers(headers)
605 };
606
607 let probe_token = &creds.access_token;
608 let resp = match do_get(probe_token).send().await {
609 Ok(r) => r,
610 Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
611 };
612
613 if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
614 tracing::info!(account = %account.name, "auth probe: access token rejected, refreshing");
615 let fresh = match account.provider.refresh_token(&creds).await {
616 Ok(f) => f,
617 Err(e) => {
618 tracing::warn!(account = %account.name, "token refresh failed: {e}");
619 state.set_auth_failed(&account.name);
620 return;
621 }
622 };
623 let mut store = crate::config::CredentialsStore::load();
624 store.accounts.insert(account.name.clone(), fresh.clone());
625 store.save().ok();
626 if fresh.id_token.is_some() {
627 crate::oauth::write_codex_auth_file(&fresh);
628 }
629
630 let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
631 match do_get(fresh_token).send().await {
632 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
633 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
634 state.set_auth_failed(&account.name);
635 }
636 Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
637 Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
638 }
639 } else {
640 tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
641 }
645}
646
647fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
654 let now_ms = std::time::SystemTime::now()
655 .duration_since(std::time::UNIX_EPOCH)
656 .unwrap_or_default()
657 .as_millis() as u64;
658 let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
659 .unwrap_or(cred.expires_at);
660 exp_ms < now_ms + threshold_mins * 60 * 1_000
661}
662
663async fn sync_live_creds_from_auth_json(
668 account_name: &str,
669 live_creds: &LiveCredentials,
670) {
671 let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
672 let current_exp = live_creds.read().await
673 .get(account_name)
674 .map(|c| c.expires_at)
675 .unwrap_or(0);
676 if from_file.expires_at > current_exp {
677 tracing::info!(account = %account_name, "synced fresher token from auth.json");
678 live_creds.write().await.insert(account_name.to_owned(), from_file);
679 }
680}
681
682async fn do_proactive_refresh(
684 account: &crate::config::AccountConfig,
685 creds: &crate::oauth::OAuthCredential,
686 live_creds: &LiveCredentials,
687 state: &StateStore,
688) {
689 tracing::info!(account = %account.name, "proactive OpenAI token refresh");
690 match account.provider.refresh_token(creds).await {
691 Ok(fresh) => {
692 tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
693 {
694 let mut map = live_creds.write().await;
695 map.insert(account.name.clone(), fresh.clone());
696 }
697 let mut store = crate::config::CredentialsStore::load();
698 store.accounts.insert(account.name.clone(), fresh.clone());
699 store.save().ok();
700 if fresh.id_token.is_some() {
701 crate::oauth::write_codex_auth_file(&fresh);
702 }
703 state.clear_auth_failed(&account.name);
704 }
705 Err(e) => {
706 tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
707 state.set_auth_failed(&account.name);
708 }
709 }
710}
711
712
713pub async fn openai_token_refresh_loop(
721 config: Arc<Config>,
722 state: StateStore,
723 live_creds: LiveCredentials,
724) {
725 for account in config.accounts.iter()
727 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
728 {
729 if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
730 continue;
731 }
732 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
733
734 let creds = {
735 let map = live_creds.read().await;
736 map.get(&account.name).cloned().or_else(|| account.credential.clone())
737 };
738 if let Some(creds) = creds {
739 if access_token_expires_soon(&creds, 30) {
740 do_proactive_refresh(account, &creds, &live_creds, &state).await;
742 } else {
743 tracing::info!(account = %account.name, "access_token fresh at startup");
744 }
745 }
746 }
747
748 loop {
751 tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
752 for account in config.accounts.iter()
753 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
754 {
755 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
756 }
757 }
758}
759
760enum ProxyError {
765 BodyRead,
766 Upstream,
767 AllAccountsUnavailable,
768 Unauthorized,
769}
770
771impl IntoResponse for ProxyError {
772 fn into_response(self) -> Response {
773 let (status, msg) = match self {
774 ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
775 ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
776 ProxyError::AllAccountsUnavailable => {
777 (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
778 }
779 ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
780 };
781
782 (status, axum::Json(json!({
783 "type": "error",
784 "error": {"type": "api_error", "message": msg}
785 }))).into_response()
786 }
787}
788
789pub async fn recovery_watcher(
798 config: Arc<Config>,
799 state: StateStore,
800 credentials: LiveCredentials,
801) {
802 use std::time::{Duration, Instant};
803 const CHECK_INTERVAL: Duration = Duration::from_secs(120);
804 const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
805
806 let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
807 let mut last_notified: Option<Instant> = None;
808
809 loop {
810 tokio::time::sleep(CHECK_INTERVAL).await;
811
812 let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
813 let failed = state.auth_failed_accounts(&name_refs);
814 if failed.is_empty() {
815 last_notified = None;
816 continue;
817 }
818
819 tracing::warn!(
820 accounts = ?failed,
821 "recovery: {} account(s) auth_failed, attempting token refresh",
822 failed.len()
823 );
824
825 let mut any_recovered = false;
826
827 for name in &failed {
828 let cred = {
829 let map = credentials.read().await;
830 map.get(*name).cloned()
831 };
832 let Some(cred) = cred else { continue };
833 if cred.refresh_token.is_empty() { continue; }
834
835 let provider = config.accounts.iter()
836 .find(|a| a.name == *name)
837 .map(|a| a.provider.clone())
838 .unwrap_or_default();
839
840 let result = tokio::time::timeout(
841 Duration::from_secs(20),
842 provider.refresh_token(&cred),
843 ).await;
844
845 match result {
846 Ok(Ok(fresh)) => {
847 tracing::info!(account = %name, "recovery: token refreshed — account back online");
848 {
849 let mut map = credentials.write().await;
850 map.insert(name.to_string(), fresh.clone());
851 }
852 let name_owned = name.to_string();
853 let fresh_owned = fresh.clone();
854 tokio::task::spawn_blocking(move || {
855 let mut store = crate::config::CredentialsStore::load();
856 store.accounts.insert(name_owned, fresh_owned.clone());
857 store.save().ok();
858 if fresh_owned.id_token.is_some() {
859 crate::oauth::write_codex_auth_file(&fresh_owned);
860 }
861 });
862 state.clear_auth_failed(name);
863 any_recovered = true;
864 }
865 Ok(Err(e)) => {
866 tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
867 notify(
868 "shunt: Reauth Required",
869 &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
870 "Basso",
871 );
872 }
873 Err(_) => {
874 tracing::error!(account = %name, "recovery: token refresh timed out");
875 notify(
876 "shunt: Reauth Required",
877 &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
878 "Basso",
879 );
880 }
881 }
882 }
883
884 if any_recovered {
885 tracing::info!("recovery: at least one account is back online");
886 continue;
887 }
888
889 let still_failed = state.auth_failed_accounts(&name_refs);
891 if still_failed.len() == account_names.len() {
892 let should_notify = last_notified
893 .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
894 .unwrap_or(true);
895 if should_notify {
896 error!(
897 "ALL accounts are offline (auth failed). \
898 Run `shunt add-account` to re-authorize."
899 );
900 notify(
901 "shunt: All Accounts Offline",
902 "All accounts need re-authorization. Run `shunt add-account`.",
903 "Basso",
904 );
905 last_notified = Some(Instant::now());
906 }
907 }
908 }
909}
910
911async fn post_cooldown_prefetch(
915 client: &reqwest::Client,
916 account: &crate::config::AccountConfig,
917 token: &str,
918 state: &StateStore,
919 upstream_url: &str,
920) {
921 let Some((path, body)) = account.provider.prefetch_request() else {
922 if let Some(probe_path) = account.provider.auth_probe_get_path() {
923 auth_probe_get(client, probe_path, account, state).await;
924 }
925 return;
926 };
927 let url = format!("{upstream_url}{path}");
928 match prefetch_send(client, &url, &account.provider, token, &body).await {
929 Ok(r) => {
930 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
931 state.update_rate_limits(&account.name, info);
932 tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
933 }
934 }
935 Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
936 }
937}
938
939pub async fn cooldown_watcher(
946 config: Arc<Config>,
947 state: StateStore,
948 credentials: LiveCredentials,
949) {
950 let client = reqwest::Client::builder()
951 .timeout(std::time::Duration::from_secs(20))
952 .build()
953 .unwrap_or_default();
954
955 let mut last_resumed: HashMap<String, u64> = HashMap::new();
958 let mut notify_on_resume: HashSet<String> = HashSet::new();
960
961 loop {
962 let states = state.account_states();
963 let now = now_ms();
964 let mut next_wake_ms: Option<u64> = None;
965
966 for account in &config.accounts {
967 let Some(st) = states.get(&account.name) else { continue };
968 if st.disabled { continue; } let cdl = st.cooldown_until_ms;
970 if cdl == 0 { continue; } if cdl <= now {
973 let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
975 if !handled {
976 tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
977 let token = {
978 let creds = credentials.read().await;
979 creds.get(&account.name).map(|c| c.access_token.clone())
980 };
981 if let Some(token) = token {
982 post_cooldown_prefetch(
983 &client, account, &token, &state,
984 &config.server.upstream_url,
985 ).await;
986 }
987 if notify_on_resume.remove(&account.name) {
988 notify(
989 "shunt: Account Resumed",
990 &format!("Account '{}' is back online.", account.name),
991 "Glass",
992 );
993 }
994 last_resumed.insert(account.name.clone(), cdl);
995 }
996 } else {
997 let remaining = cdl - now;
999 if remaining >= 5 * 60_000 {
1000 notify_on_resume.insert(account.name.clone());
1001 }
1002 next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1003 }
1004 }
1005
1006 let sleep_ms = next_wake_ms
1008 .map(|wake| wake.saturating_sub(now_ms()).max(50))
1009 .unwrap_or(30_000);
1010 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1011 }
1012}
1013
1014use crate::notify::notify;
1015
1016fn map_model(openai_model: &str) -> &'static str {
1028 match openai_model {
1029 m if m.starts_with("claude-") => {
1030 if m.contains("opus") { "claude-opus-4-6" }
1033 else if m.contains("haiku") { "claude-haiku-4-5-20251001" }
1034 else { "claude-sonnet-4-6" }
1035 }
1036 "gpt-4o" | "gpt-4.5" | "o1" | "o1-pro" | "o3" | "o3-pro" | "gpt-5" | "gpt-5.5" => {
1037 "claude-opus-4-6"
1038 }
1039 "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" | "o1-mini" | "o3-mini" => {
1040 "claude-haiku-4-5-20251001"
1041 }
1042 _ => "claude-sonnet-4-6",
1043 }
1044}
1045
1046fn translate_to_anthropic(body: serde_json::Value) -> serde_json::Value {
1048 let model = body["model"].as_str().unwrap_or("gpt-4o");
1049 let claude_model = map_model(model).to_owned();
1050
1051 let mut system: Option<String> = None;
1053 let mut messages = Vec::new();
1054 if let Some(arr) = body["messages"].as_array() {
1055 for msg in arr {
1056 let role = msg["role"].as_str().unwrap_or("");
1057 let content = msg["content"].as_str().unwrap_or("").to_owned();
1058 if role == "system" {
1059 system = Some(content);
1060 } else {
1061 messages.push(json!({ "role": role, "content": content }));
1062 }
1063 }
1064 }
1065
1066 let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1067 let stream = body["stream"].as_bool().unwrap_or(false);
1068
1069 let mut req = json!({
1070 "model": claude_model,
1071 "messages": messages,
1072 "max_tokens": max_tokens,
1073 "stream": stream,
1074 });
1075
1076 if let Some(sys) = system {
1077 req["system"] = json!(sys);
1078 }
1079 if let Some(temp) = body.get("temperature") {
1080 req["temperature"] = temp.clone();
1081 }
1082 if let Some(sp) = body.get("stop") {
1083 req["stop_sequences"] = sp.clone();
1084 }
1085
1086 req
1087}
1088
1089fn translate_from_anthropic(body: serde_json::Value) -> serde_json::Value {
1091 let id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1092 let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6").to_owned();
1093 let content = body["content"]
1094 .as_array()
1095 .and_then(|arr| arr.iter().find_map(|b| b["text"].as_str()))
1096 .unwrap_or("")
1097 .to_owned();
1098 let stop_reason = body["stop_reason"].as_str().unwrap_or("end_turn");
1099 let finish_reason = if stop_reason == "end_turn" { "stop" } else { stop_reason };
1100 let input_tokens = body["usage"]["input_tokens"].as_u64().unwrap_or(0);
1101 let output_tokens = body["usage"]["output_tokens"].as_u64().unwrap_or(0);
1102
1103 json!({
1104 "id": id,
1105 "object": "chat.completion",
1106 "model": model,
1107 "choices": [{
1108 "index": 0,
1109 "message": { "role": "assistant", "content": content },
1110 "finish_reason": finish_reason,
1111 }],
1112 "usage": {
1113 "prompt_tokens": input_tokens,
1114 "completion_tokens": output_tokens,
1115 "total_tokens": input_tokens + output_tokens,
1116 }
1117 })
1118}
1119
1120fn uuid_v4() -> String {
1121 use crate::oauth::rand_bytes;
1122 let b: [u8; 16] = rand_bytes();
1123 format!("{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
1124 u32::from_be_bytes(b[0..4].try_into().unwrap()),
1125 u16::from_be_bytes(b[4..6].try_into().unwrap()),
1126 u16::from_be_bytes(b[6..8].try_into().unwrap()),
1127 u16::from_be_bytes(b[8..10].try_into().unwrap()),
1128 {
1129 let mut v = 0u64;
1130 for &x in &b[10..16] { v = (v << 8) | x as u64; }
1131 v
1132 }
1133 )
1134}
1135
1136async fn openai_models_handler() -> impl IntoResponse {
1138 axum::Json(json!({
1139 "object": "list",
1140 "data": [
1141 { "id": "claude-opus-4-6", "object": "model", "owned_by": "anthropic" },
1142 { "id": "claude-sonnet-4-6", "object": "model", "owned_by": "anthropic" },
1143 { "id": "claude-haiku-4-5-20251001", "object": "model", "owned_by": "anthropic" },
1144 ]
1145 }))
1146}
1147
1148async fn openai_compat_handler(
1150 State(s): State<AppState>,
1151 req: Request,
1152) -> Result<Response, ProxyError> {
1153 let Some(ref anthropic_url) = s.anthropic_base_url else {
1154 return proxy_handler(State(s), req).await;
1156 };
1157
1158 let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1159 .await
1160 .map_err(|_| ProxyError::BodyRead)?;
1161
1162 let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1163 .unwrap_or(json!({}));
1164
1165 let stream = openai_body["stream"].as_bool().unwrap_or(false);
1166 let anthropic_body = translate_to_anthropic(openai_body);
1167
1168 let client = reqwest::Client::builder()
1169 .timeout(std::time::Duration::from_secs(300))
1170 .build()
1171 .map_err(|_| ProxyError::Upstream)?;
1172
1173 let resp = client
1174 .post(format!("{anthropic_url}/v1/messages"))
1175 .header("content-type", "application/json")
1176 .header("anthropic-version", "2023-06-01")
1177 .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1178 .header("x-shunt-compat", "openai")
1179 .json(&anthropic_body)
1180 .send()
1181 .await
1182 .map_err(|_| ProxyError::Upstream)?;
1183
1184 if !resp.status().is_success() {
1185 let status = resp.status();
1186 let body = resp.text().await.unwrap_or_default();
1187 let code = status.as_u16();
1188 return Ok(axum::response::Response::builder()
1189 .status(code)
1190 .header("content-type", "application/json")
1191 .body(axum::body::Body::from(body))
1192 .unwrap());
1193 }
1194
1195 if stream {
1196 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1198 let stream = translate_anthropic_stream(resp, chat_id);
1199 Ok(axum::response::Response::builder()
1200 .status(200)
1201 .header("content-type", "text/event-stream")
1202 .header("cache-control", "no-cache")
1203 .body(axum::body::Body::from_stream(stream))
1204 .unwrap())
1205 } else {
1206 let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1207 let openai_resp = translate_from_anthropic(anthropic_resp);
1208 Ok(axum::Json(openai_resp).into_response())
1209 }
1210}
1211
1212fn translate_anthropic_stream(
1214 resp: reqwest::Response,
1215 chat_id: String,
1216) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1217 use futures_util::StreamExt;
1218
1219 let id = chat_id;
1220 let byte_stream = resp.bytes_stream();
1221
1222 async_stream::stream! {
1223 let mut buf = String::new();
1224 futures_util::pin_mut!(byte_stream);
1225
1226 let init = format!(
1228 "data: {}\n\n",
1229 serde_json::to_string(&json!({
1230 "id": id,
1231 "object": "chat.completion.chunk",
1232 "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1233 })).unwrap()
1234 );
1235 yield Ok(bytes::Bytes::from(init));
1236
1237 while let Some(chunk) = byte_stream.next().await {
1238 let chunk = match chunk {
1239 Ok(c) => c,
1240 Err(_) => break,
1241 };
1242 buf.push_str(&String::from_utf8_lossy(&chunk));
1243
1244 while let Some(nl) = buf.find('\n') {
1246 let line = buf[..nl].trim_end_matches('\r').to_owned();
1247 buf = buf[nl + 1..].to_owned();
1248
1249 if !line.starts_with("data: ") { continue; }
1250 let data = &line["data: ".len()..];
1251 if data == "[DONE]" { continue; }
1252
1253 let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1254 let event_type = event["type"].as_str().unwrap_or("");
1255
1256 let maybe_chunk = match event_type {
1257 "content_block_delta" => {
1258 let text = event["delta"]["text"].as_str().unwrap_or("");
1259 if text.is_empty() { continue; }
1260 Some(json!({
1261 "id": id,
1262 "object": "chat.completion.chunk",
1263 "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1264 }))
1265 }
1266 "message_delta" => {
1267 let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1268 let finish = if stop_reason == "end_turn" { "stop" } else { stop_reason };
1269 Some(json!({
1270 "id": id,
1271 "object": "chat.completion.chunk",
1272 "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1273 }))
1274 }
1275 _ => None,
1276 };
1277
1278 if let Some(c) = maybe_chunk {
1279 let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
1280 yield Ok(bytes::Bytes::from(out));
1281 }
1282 }
1283 }
1284
1285 yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
1286 }
1287}