Skip to main content

shunt/
proxy.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::forwarder::Forwarder;
16use crate::oauth::{refresh_token, OAuthCredential};
17use crate::quota;
18use crate::router;
19use crate::state::{RateLimitInfo, StateStore};
20
21#[derive(Clone)]
22struct AppState {
23    config: Arc<Config>,
24    forwarder: Arc<Forwarder>,
25    state: StateStore,
26    /// Live credentials — can be refreshed at runtime without restarting.
27    credentials: Arc<RwLock<HashMap<String, OAuthCredential>>>,
28    /// Epoch-ms when this proxy instance started.
29    started_ms: u64,
30}
31
32pub fn create_app(config: Config) -> anyhow::Result<Router> {
33    create_app_with_state(config, StateStore::load(&state_path()))
34}
35
36pub fn create_app_with_state(config: Config, state: StateStore) -> anyhow::Result<Router> {
37    let forwarder = Forwarder::new(&config.server.upstream_url)?;
38
39    // Accounts with no credential are shown in status but skipped during routing.
40    // Mark them disabled immediately so the router ignores them.
41    for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
42        state.set_auth_failed(&a.name);
43    }
44
45    let credentials = Arc::new(RwLock::new(
46        config.accounts.iter()
47            .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
48            .collect::<HashMap<_, _>>(),
49    ));
50
51    let app_state = AppState {
52        config: Arc::new(config),
53        forwarder: Arc::new(forwarder),
54        state,
55        credentials,
56        started_ms: now_ms(),
57    };
58
59    let app = Router::new()
60        .route("/health", get(health))
61        .route("/status", get(status_handler))
62        .route("/use", post(use_handler))
63        .route("/v1/messages", post(proxy_handler))
64        .route("/v1/messages/count_tokens", post(proxy_handler))
65        .with_state(app_state);
66
67    Ok(app)
68}
69
70async fn health() -> impl IntoResponse {
71    axum::Json(json!({"status": "ok"}))
72}
73
74async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
75    let account_states = s.state.account_states();
76    let quotas = s.state.quota_snapshot();
77    let rate_limits = s.state.rate_limit_snapshot();
78
79    let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
80        let st = account_states.get(&a.name);
81        let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
82            "reauth_required"
83        } else if st.map(|s| s.disabled).unwrap_or(false) {
84            "disabled"
85        } else if s.state.is_available(&a.name) {
86            "available"
87        } else {
88            "cooling"
89        };
90
91        let quota = quotas.get(&a.name);
92        let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
93        let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
94        let tokens_used = quota.map(|q| json!({
95            "input": q.input_tokens,
96            "output": q.output_tokens,
97            "total": q.total_tokens(),
98        }));
99
100        let rl = rate_limits.get(&a.name);
101        let rate_limit = rl.map(|r| json!({
102            "utilization_5h": r.utilization_5h,
103            "reset_5h": r.reset_5h,
104            "status_5h": r.status_5h,
105            "utilization_7d": r.utilization_7d,
106            "reset_7d": r.reset_7d,
107            "status_7d": r.status_7d,
108            "representative_claim": r.representative_claim,
109            "updated_ms": r.updated_ms,
110        }));
111
112        let acc_state = account_states.get(&a.name);
113        let email = a.credential.as_ref().and_then(|c| c.email.as_deref()).map(|e| e.to_owned());
114        let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
115        let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
116        let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
117        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
118        let reset_5h = rl.and_then(|r| r.reset_5h);
119        let total_tokens = quota.map(|q| q.total_tokens()).unwrap_or(0);
120        let available = s.state.is_available(&a.name);
121
122        json!({
123            "name": a.name,
124            "email": email,
125            "plan": a.plan_type,
126            "plan_type": a.plan_type,
127            "status": avail_status,
128            "available": available,
129            "disabled": disabled,
130            "auth_failed": auth_failed,
131            "cooldown_until_ms": cooldown_until_ms,
132            "utilization_5h": utilization_5h,
133            "reset_5h": reset_5h,
134            "total_tokens": total_tokens,
135            "window_expires_ms": window_expires_ms,
136            "tokens_used": tokens_used,
137            "rate_limit": rate_limit,
138        })
139    }).collect();
140
141    let recent_requests = s.state.recent_requests_snapshot();
142
143    axum::Json(json!({
144        "version": env!("CARGO_PKG_VERSION"),
145        "started_ms": s.started_ms,
146        "accounts": accounts,
147        "pinned": s.state.get_pinned(),
148        "last_used": s.state.get_last_used(),
149        "pinned_account": s.state.get_pinned(),
150        "last_used_account": s.state.get_last_used(),
151        "recent_requests": recent_requests,
152    }))
153}
154
155async fn use_handler(
156    State(s): State<AppState>,
157    axum::Json(body): axum::Json<serde_json::Value>,
158) -> impl IntoResponse {
159    let account = body["account"].as_str().map(|s| s.to_owned());
160    // Validate the account name exists (unless clearing to auto)
161    if let Some(ref name) = account {
162        if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
163            return axum::Json(json!({
164                "error": format!("unknown account '{name}'")
165            }));
166        }
167        let pinned = if name == "auto" { None } else { Some(name.clone()) };
168        s.state.set_pinned(pinned);
169        axum::Json(json!({ "pinned": name }))
170    } else {
171        s.state.set_pinned(None);
172        axum::Json(json!({ "pinned": null }))
173    }
174}
175
176fn now_ms() -> u64 {
177    use std::time::{SystemTime, UNIX_EPOCH};
178    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
179}
180
181async fn proxy_handler(
182    State(s): State<AppState>,
183    req: Request,
184) -> Result<Response, ProxyError> {
185    // Remote auth: if a remote_key is configured, the client must supply it as x-api-key.
186    if let Some(ref expected) = s.config.server.remote_key {
187        let provided = req.headers()
188            .get("x-api-key")
189            .and_then(|v| v.to_str().ok())
190            .unwrap_or("");
191        if provided != expected {
192            return Err(ProxyError::Unauthorized);
193        }
194    }
195
196    let method = req.method().as_str().to_owned();
197    let path = req.uri().path().to_owned();
198    let headers = req.headers().clone();
199
200    let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
201        .await
202        .map_err(|_| ProxyError::BodyRead)?;
203
204    let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
205        .ok()
206        .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
207        .unwrap_or_default();
208    let req_start_ms = now_ms();
209
210    let fp = router::fingerprint(&body_bytes);
211    let fp_ref = fp.as_deref();
212
213    let mut tried: HashSet<String> = HashSet::new();
214    // Track accounts we've already attempted a token refresh for this request.
215    let mut refreshed: HashSet<String> = HashSet::new();
216
217    loop {
218        let account = match router::pick_account(&s.config.accounts, &s.state, fp_ref, &tried) {
219            Some(a) => a,
220            None => return Err(ProxyError::AllAccountsUnavailable),
221        };
222
223        let account_name = account.name.clone();
224
225        // Use the live (possibly refreshed) token rather than the one baked into config.
226        let token = {
227            let creds = s.credentials.read().await;
228            creds.get(&account_name)
229                .map(|c| c.access_token.clone())
230                .or_else(|| account.credential.as_ref().map(|c| c.access_token.clone()))
231                .unwrap_or_default()
232        };
233
234        let response = s.forwarder
235            .forward(&method, &path, body_bytes.clone(), &headers, account, &token)
236            .await
237            .map_err(|e| {
238                error!("Forward error: {:#}", e);
239                ProxyError::Upstream
240            })?;
241
242        match response.status().as_u16() {
243            200..=299 => {
244                s.state.set_last_used(&account_name);
245                return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
246            }
247            429 => {
248                warn!(account = %account_name, "429 rate-limited — cooling 60s");
249                capture_rate_limit_headers(response.headers(), &s.state, &account_name);
250                s.state.set_cooldown(&account_name, 60_000);
251                tried.insert(account_name);
252            }
253            529 => {
254                warn!(account = %account_name, "529 overloaded — cooling 30s");
255                capture_rate_limit_headers(response.headers(), &s.state, &account_name);
256                s.state.set_cooldown(&account_name, 30_000);
257                tried.insert(account_name);
258            }
259            401 => {
260                if !refreshed.contains(&account_name) {
261                    // Access token invalidated (e.g. user logged out) — try refresh.
262                    let cred = {
263                        let creds = s.credentials.read().await;
264                        creds.get(&account_name).cloned()
265                            .or_else(|| account.credential.clone())
266                    };
267                    let Some(cred) = cred else {
268                        tried.insert(account_name);
269                        continue;
270                    };
271                    match tokio::time::timeout(
272                        std::time::Duration::from_secs(10),
273                        refresh_token(&cred),
274                    ).await {
275                        Ok(Ok(fresh)) => {
276                            warn!(account = %account_name, "401 — token refreshed, retrying");
277                            {
278                                let mut creds = s.credentials.write().await;
279                                creds.insert(account_name.clone(), fresh.clone());
280                            }
281                            // Persist to disk so the refreshed token survives a restart.
282                            let name = account_name.clone();
283                            let fresh = fresh.clone();
284                            tokio::task::spawn_blocking(move || {
285                                let mut store = CredentialsStore::load();
286                                store.accounts.insert(name, fresh);
287                                store.save().ok();
288                            });
289                            // Mark as refreshed but don't add to tried — retry this account.
290                            refreshed.insert(account_name);
291                        }
292                        _ => {
293                            // Refresh failed/timed out — cool down, don't permanently disable.
294                            error!(account = %account_name, "401 — token refresh failed, cooling 5min");
295                            s.state.set_cooldown(&account_name, 5 * 60_000);
296                            tried.insert(account_name);
297                        }
298                    }
299                } else {
300                    // Already refreshed once and still 401 — cool down this account.
301                    error!(account = %account_name, "401 after refresh — cooling 5min");
302                    s.state.set_cooldown(&account_name, 5 * 60_000);
303                    tried.insert(account_name);
304                }
305            }
306            403 => {
307                // Forbidden — subscription lapsed or org restriction; refreshing won't help.
308                error!(account = %account_name, "403 forbidden — cooling 30min");
309                s.state.set_cooldown(&account_name, 30 * 60_000);
310                tried.insert(account_name);
311            }
312            _ => {
313                // 400, 404, 500, etc. — return as-is, no retry
314                return Ok(response);
315            }
316        }
317    }
318}
319
320// ---------------------------------------------------------------------------
321// Usage extraction
322// ---------------------------------------------------------------------------
323
324/// Intercept a successful response to record token usage, then pass it through.
325///
326/// - Streaming: wraps the body stream with an SSE scanner (zero latency).
327/// - Non-streaming: buffers the body, parses usage, rebuilds the response.
328async fn tap_usage(
329    resp: Response,
330    state: &StateStore,
331    account: &str,
332    model: &str,
333    req_start_ms: u64,
334) -> Response {
335    use axum::body::Body;
336    use crate::state::RequestLog;
337
338    // Capture rate-limit headers before the response is consumed
339    capture_rate_limit_headers(resp.headers(), state, account);
340
341    if quota::is_streaming_response(&resp) {
342        let state = state.clone();
343        let account = account.to_owned();
344        let model = model.to_owned();
345        let on_complete = Arc::new(move |input: u64, output: u64| {
346            state.record_usage(&account, input, output);
347            state.record_request(RequestLog {
348                ts_ms: req_start_ms,
349                account: account.clone(),
350                model: model.clone(),
351                status: 200,
352                input_tokens: input,
353                output_tokens: output,
354                duration_ms: now_ms().saturating_sub(req_start_ms),
355            });
356        });
357        let (parts, body) = resp.into_parts();
358        let wrapped = quota::wrap_streaming_body(body, on_complete);
359        return Response::from_parts(parts, wrapped);
360    }
361
362    // Non-streaming: buffer, extract, rebuild
363    let (parts, body) = resp.into_parts();
364    let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
365        Ok(b) => b,
366        Err(_) => return Response::from_parts(parts, Body::empty()),
367    };
368    let (input, output) = quota::extract_usage_from_json(&bytes);
369    state.record_usage(account, input, output);
370    state.record_request(RequestLog {
371        ts_ms: req_start_ms,
372        account: account.to_owned(),
373        model: model.to_owned(),
374        status: 200,
375        input_tokens: input,
376        output_tokens: output,
377        duration_ms: now_ms().saturating_sub(req_start_ms),
378    });
379    Response::from_parts(parts, Body::from(bytes))
380}
381
382fn capture_rate_limit_headers(headers: &axum::http::HeaderMap, state: &StateStore, account: &str) {
383    fn hdr_u64(headers: &axum::http::HeaderMap, name: &str) -> Option<u64> {
384        headers.get(name)?.to_str().ok()?.parse().ok()
385    }
386    fn hdr_f64(headers: &axum::http::HeaderMap, name: &str) -> Option<f64> {
387        headers.get(name)?.to_str().ok()?.parse().ok()
388    }
389    fn hdr_str(headers: &axum::http::HeaderMap, name: &str) -> Option<String> {
390        Some(headers.get(name)?.to_str().ok()?.to_owned())
391    }
392
393    // Claude Code OAuth uses anthropic-ratelimit-unified-* headers
394    let utilization_5h  = hdr_f64(headers, "anthropic-ratelimit-unified-5h-utilization");
395    let reset_5h        = hdr_u64(headers, "anthropic-ratelimit-unified-5h-reset");
396    let status_5h       = hdr_str(headers, "anthropic-ratelimit-unified-5h-status");
397    let utilization_7d  = hdr_f64(headers, "anthropic-ratelimit-unified-7d-utilization");
398    let reset_7d        = hdr_u64(headers, "anthropic-ratelimit-unified-7d-reset");
399    let status_7d       = hdr_str(headers, "anthropic-ratelimit-unified-7d-status");
400    let overage_status          = hdr_str(headers, "anthropic-ratelimit-unified-overage-status");
401    let overage_disabled_reason = hdr_str(headers, "anthropic-ratelimit-unified-overage-disabled-reason");
402    let representative_claim    = hdr_str(headers, "anthropic-ratelimit-unified-representative-claim");
403
404    if utilization_5h.is_some() || utilization_7d.is_some() {
405        state.update_rate_limits(account, RateLimitInfo {
406            utilization_5h,
407            reset_5h,
408            status_5h,
409            utilization_7d,
410            reset_7d,
411            status_7d,
412            overage_status,
413            overage_disabled_reason,
414            representative_claim,
415            updated_ms: now_ms(),
416        });
417    }
418}
419
420// ---------------------------------------------------------------------------
421// Rate limit prefetch
422// ---------------------------------------------------------------------------
423
424/// For any account with no rate-limit data yet, make a cheap count_tokens
425/// call directly to the upstream API so we populate metrics without waiting
426/// for a real user request. Runs as a background task after startup.
427pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore) {
428    let upstream = &config.server.upstream_url;
429    let url = format!("{upstream}/v1/messages");
430    let client = reqwest::Client::builder()
431        .timeout(std::time::Duration::from_secs(20))
432        .build()
433        .unwrap_or_default();
434
435    // Minimal 1-token message — cheapest way to get the unified rate limit headers
436    let body = json!({
437        "model": "claude-haiku-4-5-20251001",
438        "max_tokens": 1,
439        "messages": [{"role": "user", "content": "hi"}]
440    });
441
442    for account in &config.accounts {
443        // Skip if we already have data for this account
444        let rl = state.rate_limit_snapshot();
445        if let Some(r) = rl.get(&account.name) {
446            if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
447                continue;
448            }
449        }
450
451        let creds = match account.credential.clone() {
452            Some(c) => c,
453            None => continue, // no credential — skip prefetch
454        };
455        let resp = client
456            .post(&url)
457            .header("authorization", format!("Bearer {}", creds.access_token))
458            .header("anthropic-version", "2023-06-01")
459            .header("anthropic-dangerous-direct-browser-access", "true")
460            .json(&body)
461            .send()
462            .await;
463
464        let r = match resp {
465            Ok(r) => r,
466            Err(e) => { tracing::warn!(account = %account.name, "prefetch request failed: {e}"); continue; }
467        };
468
469        if r.status() == reqwest::StatusCode::UNAUTHORIZED {
470            // Token expired — try to refresh and retry once
471            tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
472            let fresh = match crate::oauth::refresh_token(&creds).await {
473                Ok(f) => f,
474                Err(e) => {
475                    tracing::warn!(account = %account.name, "token refresh failed: {e}");
476                    state.set_auth_failed(&account.name);
477                    continue;
478                }
479            };
480            // Persist updated token
481            let mut store = crate::config::CredentialsStore::load();
482            store.accounts.insert(account.name.clone(), fresh.clone());
483            store.save().ok();
484
485            let retry = client
486                .post(&url)
487                .header("authorization", format!("Bearer {}", fresh.access_token))
488                .header("anthropic-version", "2023-06-01")
489                .header("anthropic-dangerous-direct-browser-access", "true")
490                .json(&body)
491                .send()
492                .await;
493            match retry {
494                Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
495                    tracing::error!(account = %account.name, "401 after refresh — credentials need re-authorization");
496                    state.set_auth_failed(&account.name);
497                }
498                Ok(r2) => {
499                    capture_rate_limit_headers(r2.headers(), &state, &account.name);
500                }
501                Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
502            }
503        } else {
504            tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
505            capture_rate_limit_headers(r.headers(), &state, &account.name);
506        }
507    }
508}
509
510// ---------------------------------------------------------------------------
511// Error type
512// ---------------------------------------------------------------------------
513
514enum ProxyError {
515    BodyRead,
516    Upstream,
517    AllAccountsUnavailable,
518    Unauthorized,
519}
520
521impl IntoResponse for ProxyError {
522    fn into_response(self) -> Response {
523        let (status, msg) = match self {
524            ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
525            ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
526            ProxyError::AllAccountsUnavailable => {
527                (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
528            }
529            ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
530        };
531
532        (status, axum::Json(json!({
533            "type": "error",
534            "error": {"type": "api_error", "message": msg}
535        }))).into_response()
536    }
537}