1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::forwarder::Forwarder;
16use crate::oauth::{refresh_token, OAuthCredential};
17use crate::quota;
18use crate::router;
19use crate::state::{RateLimitInfo, StateStore};
20
21#[derive(Clone)]
22struct AppState {
23 config: Arc<Config>,
24 forwarder: Arc<Forwarder>,
25 state: StateStore,
26 credentials: Arc<RwLock<HashMap<String, OAuthCredential>>>,
28 started_ms: u64,
30}
31
32pub fn create_app(config: Config) -> anyhow::Result<Router> {
33 create_app_with_state(config, StateStore::load(&state_path()))
34}
35
36pub fn create_app_with_state(config: Config, state: StateStore) -> anyhow::Result<Router> {
37 let forwarder = Forwarder::new(&config.server.upstream_url)?;
38
39 for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
42 state.set_auth_failed(&a.name);
43 }
44
45 let credentials = Arc::new(RwLock::new(
46 config.accounts.iter()
47 .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
48 .collect::<HashMap<_, _>>(),
49 ));
50
51 let app_state = AppState {
52 config: Arc::new(config),
53 forwarder: Arc::new(forwarder),
54 state,
55 credentials,
56 started_ms: now_ms(),
57 };
58
59 let app = Router::new()
60 .route("/health", get(health))
61 .route("/status", get(status_handler))
62 .route("/use", post(use_handler))
63 .route("/v1/messages", post(proxy_handler))
64 .route("/v1/messages/count_tokens", post(proxy_handler))
65 .with_state(app_state);
66
67 Ok(app)
68}
69
70async fn health() -> impl IntoResponse {
71 axum::Json(json!({"status": "ok"}))
72}
73
74async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
75 let account_states = s.state.account_states();
76 let quotas = s.state.quota_snapshot();
77 let rate_limits = s.state.rate_limit_snapshot();
78
79 let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
80 let st = account_states.get(&a.name);
81 let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
82 "reauth_required"
83 } else if st.map(|s| s.disabled).unwrap_or(false) {
84 "disabled"
85 } else if s.state.is_available(&a.name) {
86 "available"
87 } else {
88 "cooling"
89 };
90
91 let quota = quotas.get(&a.name);
92 let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
93 let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
94 let tokens_used = quota.map(|q| json!({
95 "input": q.input_tokens,
96 "output": q.output_tokens,
97 "total": q.total_tokens(),
98 }));
99
100 let rl = rate_limits.get(&a.name);
101 let rate_limit = rl.map(|r| json!({
102 "utilization_5h": r.utilization_5h,
103 "reset_5h": r.reset_5h,
104 "status_5h": r.status_5h,
105 "utilization_7d": r.utilization_7d,
106 "reset_7d": r.reset_7d,
107 "status_7d": r.status_7d,
108 "representative_claim": r.representative_claim,
109 "updated_ms": r.updated_ms,
110 }));
111
112 let acc_state = account_states.get(&a.name);
113 let email = a.credential.as_ref().and_then(|c| c.email.as_deref()).map(|e| e.to_owned());
114 let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
115 let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
116 let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
117 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
118 let reset_5h = rl.and_then(|r| r.reset_5h);
119 let total_tokens = quota.map(|q| q.total_tokens()).unwrap_or(0);
120 let available = s.state.is_available(&a.name);
121
122 json!({
123 "name": a.name,
124 "email": email,
125 "plan": a.plan_type,
126 "plan_type": a.plan_type,
127 "status": avail_status,
128 "available": available,
129 "disabled": disabled,
130 "auth_failed": auth_failed,
131 "cooldown_until_ms": cooldown_until_ms,
132 "utilization_5h": utilization_5h,
133 "reset_5h": reset_5h,
134 "total_tokens": total_tokens,
135 "window_expires_ms": window_expires_ms,
136 "tokens_used": tokens_used,
137 "rate_limit": rate_limit,
138 })
139 }).collect();
140
141 let recent_requests = s.state.recent_requests_snapshot();
142
143 axum::Json(json!({
144 "version": env!("CARGO_PKG_VERSION"),
145 "started_ms": s.started_ms,
146 "accounts": accounts,
147 "pinned": s.state.get_pinned(),
148 "last_used": s.state.get_last_used(),
149 "pinned_account": s.state.get_pinned(),
150 "last_used_account": s.state.get_last_used(),
151 "recent_requests": recent_requests,
152 }))
153}
154
155async fn use_handler(
156 State(s): State<AppState>,
157 axum::Json(body): axum::Json<serde_json::Value>,
158) -> impl IntoResponse {
159 let account = body["account"].as_str().map(|s| s.to_owned());
160 if let Some(ref name) = account {
162 if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
163 return axum::Json(json!({
164 "error": format!("unknown account '{name}'")
165 }));
166 }
167 let pinned = if name == "auto" { None } else { Some(name.clone()) };
168 s.state.set_pinned(pinned);
169 axum::Json(json!({ "pinned": name }))
170 } else {
171 s.state.set_pinned(None);
172 axum::Json(json!({ "pinned": null }))
173 }
174}
175
176fn now_ms() -> u64 {
177 use std::time::{SystemTime, UNIX_EPOCH};
178 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
179}
180
181async fn proxy_handler(
182 State(s): State<AppState>,
183 req: Request,
184) -> Result<Response, ProxyError> {
185 if let Some(ref expected) = s.config.server.remote_key {
187 let provided = req.headers()
188 .get("x-api-key")
189 .and_then(|v| v.to_str().ok())
190 .unwrap_or("");
191 if provided != expected {
192 return Err(ProxyError::Unauthorized);
193 }
194 }
195
196 let method = req.method().as_str().to_owned();
197 let path = req.uri().path().to_owned();
198 let headers = req.headers().clone();
199
200 let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
201 .await
202 .map_err(|_| ProxyError::BodyRead)?;
203
204 let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
205 .ok()
206 .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
207 .unwrap_or_default();
208 let req_start_ms = now_ms();
209
210 let fp = router::fingerprint(&body_bytes);
211 let fp_ref = fp.as_deref();
212
213 let mut tried: HashSet<String> = HashSet::new();
214 let mut refreshed: HashSet<String> = HashSet::new();
216
217 loop {
218 let account = match router::pick_account(&s.config.accounts, &s.state, fp_ref, &tried) {
219 Some(a) => a,
220 None => return Err(ProxyError::AllAccountsUnavailable),
221 };
222
223 let account_name = account.name.clone();
224
225 let token = {
227 let creds = s.credentials.read().await;
228 creds.get(&account_name)
229 .map(|c| c.access_token.clone())
230 .or_else(|| account.credential.as_ref().map(|c| c.access_token.clone()))
231 .unwrap_or_default()
232 };
233
234 let response = s.forwarder
235 .forward(&method, &path, body_bytes.clone(), &headers, account, &token)
236 .await
237 .map_err(|e| {
238 error!("Forward error: {:#}", e);
239 ProxyError::Upstream
240 })?;
241
242 match response.status().as_u16() {
243 200..=299 => {
244 s.state.set_last_used(&account_name);
245 return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
246 }
247 429 => {
248 warn!(account = %account_name, "429 rate-limited — cooling 60s");
249 capture_rate_limit_headers(response.headers(), &s.state, &account_name);
250 s.state.set_cooldown(&account_name, 60_000);
251 tried.insert(account_name);
252 }
253 529 => {
254 warn!(account = %account_name, "529 overloaded — cooling 30s");
255 capture_rate_limit_headers(response.headers(), &s.state, &account_name);
256 s.state.set_cooldown(&account_name, 30_000);
257 tried.insert(account_name);
258 }
259 401 => {
260 if !refreshed.contains(&account_name) {
261 let cred = {
263 let creds = s.credentials.read().await;
264 creds.get(&account_name).cloned()
265 .or_else(|| account.credential.clone())
266 };
267 let Some(cred) = cred else {
268 tried.insert(account_name);
269 continue;
270 };
271 match tokio::time::timeout(
272 std::time::Duration::from_secs(10),
273 refresh_token(&cred),
274 ).await {
275 Ok(Ok(fresh)) => {
276 warn!(account = %account_name, "401 — token refreshed, retrying");
277 {
278 let mut creds = s.credentials.write().await;
279 creds.insert(account_name.clone(), fresh.clone());
280 }
281 let name = account_name.clone();
283 let fresh = fresh.clone();
284 tokio::task::spawn_blocking(move || {
285 let mut store = CredentialsStore::load();
286 store.accounts.insert(name, fresh);
287 store.save().ok();
288 });
289 refreshed.insert(account_name);
291 }
292 _ => {
293 error!(account = %account_name, "401 — token refresh failed, cooling 5min");
295 s.state.set_cooldown(&account_name, 5 * 60_000);
296 tried.insert(account_name);
297 }
298 }
299 } else {
300 error!(account = %account_name, "401 after refresh — cooling 5min");
302 s.state.set_cooldown(&account_name, 5 * 60_000);
303 tried.insert(account_name);
304 }
305 }
306 403 => {
307 error!(account = %account_name, "403 forbidden — cooling 30min");
309 s.state.set_cooldown(&account_name, 30 * 60_000);
310 tried.insert(account_name);
311 }
312 _ => {
313 return Ok(response);
315 }
316 }
317 }
318}
319
320async fn tap_usage(
329 resp: Response,
330 state: &StateStore,
331 account: &str,
332 model: &str,
333 req_start_ms: u64,
334) -> Response {
335 use axum::body::Body;
336 use crate::state::RequestLog;
337
338 capture_rate_limit_headers(resp.headers(), state, account);
340
341 if quota::is_streaming_response(&resp) {
342 let state = state.clone();
343 let account = account.to_owned();
344 let model = model.to_owned();
345 let on_complete = Arc::new(move |input: u64, output: u64| {
346 state.record_usage(&account, input, output);
347 state.record_request(RequestLog {
348 ts_ms: req_start_ms,
349 account: account.clone(),
350 model: model.clone(),
351 status: 200,
352 input_tokens: input,
353 output_tokens: output,
354 duration_ms: now_ms().saturating_sub(req_start_ms),
355 });
356 });
357 let (parts, body) = resp.into_parts();
358 let wrapped = quota::wrap_streaming_body(body, on_complete);
359 return Response::from_parts(parts, wrapped);
360 }
361
362 let (parts, body) = resp.into_parts();
364 let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
365 Ok(b) => b,
366 Err(_) => return Response::from_parts(parts, Body::empty()),
367 };
368 let (input, output) = quota::extract_usage_from_json(&bytes);
369 state.record_usage(account, input, output);
370 state.record_request(RequestLog {
371 ts_ms: req_start_ms,
372 account: account.to_owned(),
373 model: model.to_owned(),
374 status: 200,
375 input_tokens: input,
376 output_tokens: output,
377 duration_ms: now_ms().saturating_sub(req_start_ms),
378 });
379 Response::from_parts(parts, Body::from(bytes))
380}
381
382fn capture_rate_limit_headers(headers: &axum::http::HeaderMap, state: &StateStore, account: &str) {
383 fn hdr_u64(headers: &axum::http::HeaderMap, name: &str) -> Option<u64> {
384 headers.get(name)?.to_str().ok()?.parse().ok()
385 }
386 fn hdr_f64(headers: &axum::http::HeaderMap, name: &str) -> Option<f64> {
387 headers.get(name)?.to_str().ok()?.parse().ok()
388 }
389 fn hdr_str(headers: &axum::http::HeaderMap, name: &str) -> Option<String> {
390 Some(headers.get(name)?.to_str().ok()?.to_owned())
391 }
392
393 let utilization_5h = hdr_f64(headers, "anthropic-ratelimit-unified-5h-utilization");
395 let reset_5h = hdr_u64(headers, "anthropic-ratelimit-unified-5h-reset");
396 let status_5h = hdr_str(headers, "anthropic-ratelimit-unified-5h-status");
397 let utilization_7d = hdr_f64(headers, "anthropic-ratelimit-unified-7d-utilization");
398 let reset_7d = hdr_u64(headers, "anthropic-ratelimit-unified-7d-reset");
399 let status_7d = hdr_str(headers, "anthropic-ratelimit-unified-7d-status");
400 let overage_status = hdr_str(headers, "anthropic-ratelimit-unified-overage-status");
401 let overage_disabled_reason = hdr_str(headers, "anthropic-ratelimit-unified-overage-disabled-reason");
402 let representative_claim = hdr_str(headers, "anthropic-ratelimit-unified-representative-claim");
403
404 if utilization_5h.is_some() || utilization_7d.is_some() {
405 state.update_rate_limits(account, RateLimitInfo {
406 utilization_5h,
407 reset_5h,
408 status_5h,
409 utilization_7d,
410 reset_7d,
411 status_7d,
412 overage_status,
413 overage_disabled_reason,
414 representative_claim,
415 updated_ms: now_ms(),
416 });
417 }
418}
419
420pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore) {
428 let upstream = &config.server.upstream_url;
429 let url = format!("{upstream}/v1/messages");
430 let client = reqwest::Client::builder()
431 .timeout(std::time::Duration::from_secs(20))
432 .build()
433 .unwrap_or_default();
434
435 let body = json!({
437 "model": "claude-haiku-4-5-20251001",
438 "max_tokens": 1,
439 "messages": [{"role": "user", "content": "hi"}]
440 });
441
442 for account in &config.accounts {
443 let rl = state.rate_limit_snapshot();
445 if let Some(r) = rl.get(&account.name) {
446 if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
447 continue;
448 }
449 }
450
451 let creds = match account.credential.clone() {
452 Some(c) => c,
453 None => continue, };
455 let resp = client
456 .post(&url)
457 .header("authorization", format!("Bearer {}", creds.access_token))
458 .header("anthropic-version", "2023-06-01")
459 .header("anthropic-dangerous-direct-browser-access", "true")
460 .json(&body)
461 .send()
462 .await;
463
464 let r = match resp {
465 Ok(r) => r,
466 Err(e) => { tracing::warn!(account = %account.name, "prefetch request failed: {e}"); continue; }
467 };
468
469 if r.status() == reqwest::StatusCode::UNAUTHORIZED {
470 tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
472 let fresh = match crate::oauth::refresh_token(&creds).await {
473 Ok(f) => f,
474 Err(e) => {
475 tracing::warn!(account = %account.name, "token refresh failed: {e}");
476 state.set_auth_failed(&account.name);
477 continue;
478 }
479 };
480 let mut store = crate::config::CredentialsStore::load();
482 store.accounts.insert(account.name.clone(), fresh.clone());
483 store.save().ok();
484
485 let retry = client
486 .post(&url)
487 .header("authorization", format!("Bearer {}", fresh.access_token))
488 .header("anthropic-version", "2023-06-01")
489 .header("anthropic-dangerous-direct-browser-access", "true")
490 .json(&body)
491 .send()
492 .await;
493 match retry {
494 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
495 tracing::error!(account = %account.name, "401 after refresh — credentials need re-authorization");
496 state.set_auth_failed(&account.name);
497 }
498 Ok(r2) => {
499 capture_rate_limit_headers(r2.headers(), &state, &account.name);
500 }
501 Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
502 }
503 } else {
504 tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
505 capture_rate_limit_headers(r.headers(), &state, &account.name);
506 }
507 }
508}
509
510enum ProxyError {
515 BodyRead,
516 Upstream,
517 AllAccountsUnavailable,
518 Unauthorized,
519}
520
521impl IntoResponse for ProxyError {
522 fn into_response(self) -> Response {
523 let (status, msg) = match self {
524 ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
525 ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
526 ProxyError::AllAccountsUnavailable => {
527 (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
528 }
529 ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
530 };
531
532 (status, axum::Json(json!({
533 "type": "error",
534 "error": {"type": "api_error", "message": msg}
535 }))).into_response()
536 }
537}