1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::credential::Credential;
16use crate::forwarder::Forwarder;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24 config: Arc<Config>,
25 forwarder: Arc<Forwarder>,
26 state: StateStore,
27 credentials: Arc<RwLock<HashMap<String, Credential>>>,
29 refresh_locks: Arc<std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
37 started_ms: u64,
39 anthropic_base_url: Option<String>,
42}
43
44pub fn create_app(config: Config) -> anyhow::Result<Router> {
45 let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
46 Ok(app)
47}
48
49pub type LiveCredentials = Arc<RwLock<HashMap<String, Credential>>>;
51
52fn build_app_state(
56 config: Config,
57 state: StateStore,
58 anthropic_base_url: Option<String>,
59) -> anyhow::Result<(AppState, LiveCredentials)> {
60 let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
61
62 for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
63 state.set_auth_failed(&a.name);
64 }
65
66 let credentials: LiveCredentials = Arc::new(RwLock::new(
67 config.accounts.iter()
68 .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
69 .collect::<HashMap<_, _>>(),
70 ));
71
72 let app_state = AppState {
73 config: Arc::new(config),
74 forwarder: Arc::new(forwarder),
75 state,
76 credentials: Arc::clone(&credentials),
77 refresh_locks: Arc::new(std::sync::Mutex::new(HashMap::new())),
78 started_ms: now_ms(),
79 anthropic_base_url,
80 };
81
82 Ok((app_state, credentials))
83}
84
85pub fn create_proxy_app(
86 config: Config,
87 state: StateStore,
88 anthropic_base_url: Option<String>,
89) -> anyhow::Result<(Router, LiveCredentials)> {
90 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
91
92 let app = Router::new()
93 .route("/v1/messages", post(proxy_handler))
94 .route("/v1/messages/count_tokens", post(proxy_handler))
95 .route("/v1/chat/completions", post(openai_compat_handler))
96 .route("/v1/models", get(openai_models_handler))
97 .fallback(proxy_handler)
98 .with_state(app_state);
99
100 Ok((app, credentials))
101}
102
103pub fn create_control_app(
106 config: Config,
107 state: StateStore,
108) -> anyhow::Result<Router> {
109 let (app_state, _) = build_app_state(config, state, None)?;
110
111 let app = Router::new()
112 .route("/health", get(health))
113 .route("/status", get(status_handler))
114 .route("/use", post(use_handler))
115 .with_state(app_state);
116
117 Ok(app)
118}
119
120pub fn create_app_with_state(
124 config: Config,
125 state: StateStore,
126 anthropic_base_url: Option<String>,
127) -> anyhow::Result<(Router, LiveCredentials)> {
128 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
129
130 let app = Router::new()
131 .route("/health", get(health))
133 .route("/status", get(status_handler))
134 .route("/use", post(use_handler))
135 .route("/v1/messages", post(proxy_handler))
137 .route("/v1/messages/count_tokens", post(proxy_handler))
138 .route("/v1/chat/completions", post(openai_compat_handler))
139 .route("/v1/models", get(openai_models_handler))
140 .fallback(proxy_handler)
141 .with_state(app_state);
142
143 Ok((app, credentials))
144}
145
146async fn health() -> impl IntoResponse {
147 axum::Json(json!({"status": "ok"}))
148}
149
150async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
151 let account_states = s.state.account_states();
152 let quotas = s.state.quota_snapshot();
153 let rate_limits = s.state.rate_limit_snapshot();
154
155 let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
156 let st = account_states.get(&a.name);
157 let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
158 "reauth_required"
159 } else if st.map(|s| s.disabled).unwrap_or(false) {
160 "disabled"
161 } else if s.state.is_available(&a.name) {
162 "available"
163 } else {
164 "cooling"
165 };
166
167 let quota = quotas.get(&a.name);
168 let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
169 let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
170 let tokens_used = quota.map(|q| json!({
171 "input": q.input_tokens,
172 "output": q.output_tokens,
173 "total": q.total_tokens(),
174 }));
175
176 let rl = rate_limits.get(&a.name);
177 let rate_limit = rl.map(|r| json!({
178 "utilization_5h": r.utilization_5h,
179 "reset_5h": r.reset_5h,
180 "status_5h": r.status_5h,
181 "utilization_7d": r.utilization_7d,
182 "reset_7d": r.reset_7d,
183 "status_7d": r.status_7d,
184 "representative_claim": r.representative_claim,
185 "updated_ms": r.updated_ms,
186 }));
187
188 let acc_state = account_states.get(&a.name);
189 let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
190 let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
191 let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
192 let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
193 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
194 let reset_5h = rl.and_then(|r| r.reset_5h);
195 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
196 let reset_7d = rl.and_then(|r| r.reset_7d);
197 let available = s.state.is_available(&a.name);
198
199 json!({
200 "name": a.name,
201 "email": email,
202 "plan_type": a.plan_type,
203 "provider": a.provider.to_string(),
204 "status": avail_status,
205 "available": available,
206 "disabled": disabled,
207 "auth_failed": auth_failed,
208 "cooldown_until_ms": cooldown_until_ms,
209 "utilization_5h": utilization_5h,
210 "reset_5h": reset_5h,
211 "utilization_7d": utilization_7d,
212 "reset_7d": reset_7d,
213 "window_expires_ms": window_expires_ms,
214 "tokens_used": tokens_used,
215 "rate_limit": rate_limit,
216 })
217 }).collect();
218
219 let recent_requests = s.state.recent_requests_snapshot();
220 let savings = s.state.savings_snapshot();
221
222 axum::Json(json!({
223 "version": env!("CARGO_PKG_VERSION"),
224 "started_ms": s.started_ms,
225 "accounts": accounts,
226 "pinned_account": s.state.get_pinned(),
227 "last_used_account": s.state.get_last_used(),
228 "recent_requests": recent_requests,
229 "savings": savings,
230 }))
231}
232
233async fn use_handler(
234 State(s): State<AppState>,
235 axum::Json(body): axum::Json<serde_json::Value>,
236) -> Response {
237 let account = body["account"].as_str().map(|s| s.to_owned());
238 if let Some(ref name) = account {
240 if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
241 return (StatusCode::BAD_REQUEST, axum::Json(json!({
242 "error": format!("unknown account '{name}'")
243 }))).into_response();
244 }
245 let pinned = if name == "auto" { None } else { Some(name.clone()) };
246 s.state.set_pinned(pinned);
247 axum::Json(json!({ "pinned": name })).into_response()
248 } else {
249 s.state.set_pinned(None);
250 axum::Json(json!({ "pinned": null })).into_response()
251 }
252}
253
254fn now_ms() -> u64 {
255 use std::time::{SystemTime, UNIX_EPOCH};
256 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
257}
258
259async fn proxy_handler(
260 State(s): State<AppState>,
261 req: Request,
262) -> Result<Response, ProxyError> {
263 if let Some(ref expected) = s.config.server.remote_key {
265 let provided = req.headers()
266 .get("x-api-key")
267 .and_then(|v| v.to_str().ok())
268 .unwrap_or("");
269 if provided != expected {
270 return Err(ProxyError::Unauthorized);
271 }
272 }
273
274 let method = req.method().as_str().to_owned();
275 let path = req.uri().path().to_owned();
276 let headers = req.headers().clone();
277
278 let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
279 .await
280 .map_err(|_| ProxyError::BodyRead)?;
281
282 let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
283 .ok()
284 .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
285 .unwrap_or_default();
286 let req_start_ms = now_ms();
287
288 let fp = router::fingerprint(&body_bytes);
289 let fp_ref = fp.as_deref();
290
291 let mut tried: HashSet<String> = HashSet::new();
292 let mut refreshed: HashSet<String> = HashSet::new();
294 let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
296
297 loop {
298 let account = match router::pick_account(
299 &s.config.accounts, &s.state, fp_ref, &tried,
300 s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
301 ) {
302 Some(a) => a,
303 None => {
304 let account_states = s.state.account_states();
308 let now = now_ms();
309 let soonest_ms = s.config.accounts.iter()
310 .filter_map(|a| {
311 let st = account_states.get(&a.name)?;
312 if st.disabled { return None; } if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
314 })
315 .min();
316
317 match soonest_ms {
318 Some(wake_ms) if wake_ms <= wait_deadline_ms => {
319 let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; warn!(wait_ms, "all accounts cooling — waiting for next available account");
321 tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
322 tried.clear(); }
324 _ => return Err(ProxyError::AllAccountsUnavailable),
325 }
326 continue;
327 }
328 };
329
330 let account_name = account.name.clone();
331
332 let token = {
337 let creds = s.credentials.read().await;
338 let cred = creds.get(&account_name)
339 .cloned()
340 .or_else(|| account.credential.clone());
341 match cred {
342 Some(c) => c.bearer_token().to_owned(),
343 None => String::new(),
344 }
345 };
346
347 let req_is_anthropic = path.starts_with("/v1/messages");
351 let acct_is_anthropic = account.provider.wire_protocol()
352 == crate::provider::WireProtocol::Anthropic;
353 let acct_is_chatgpt = matches!(account.provider, Provider::OpenAI);
356
357 let (fwd_path, fwd_body, mut fwd_headers) = if req_is_anthropic == acct_is_anthropic {
358 (path.clone(), body_bytes.clone(), headers.clone())
360 } else if req_is_anthropic && acct_is_chatgpt {
361 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
363 let translated = translate_anthropic_req_to_chatgpt(&val);
364 let mut h = headers.clone();
365 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
366 h.remove(*name);
367 }
368 (
369 "/backend-api/conversation".to_owned(),
370 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
371 h,
372 )
373 } else if req_is_anthropic {
374 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
376 let translated = translate_anthropic_req_to_openai(val);
377 let mut h = headers.clone();
378 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
379 h.remove(*name);
380 }
381 (
382 "/v1/chat/completions".to_owned(),
383 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
384 h,
385 )
386 } else {
387 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
389 let translated = translate_to_anthropic(val);
390 (
391 "/v1/messages".to_owned(),
392 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
393 headers.clone(),
394 )
395 };
396
397 let upstream = account.upstream_url.as_deref()
400 .unwrap_or(&s.config.server.upstream_url);
401
402 if req_is_anthropic && acct_is_chatgpt {
405 tracing::info!(account = %account_name, upstream = %upstream, "routing to chatgpt.com — fetching sentinel");
406 let sentinel_client = reqwest::Client::builder()
407 .timeout(std::time::Duration::from_secs(3))
408 .build()
409 .unwrap_or_default();
410 let sentinel_opt = tokio::time::timeout(
411 std::time::Duration::from_secs(3),
412 fetch_sentinel_token(&sentinel_client, upstream, &token),
413 ).await.ok().flatten();
414 if let Some(sentinel) = sentinel_opt {
415 if let Ok(name) = axum::http::header::HeaderName::from_bytes(
416 b"openai-sentinel-chat-requirements-token",
417 ) {
418 if let Ok(val) = axum::http::HeaderValue::from_str(&sentinel) {
419 fwd_headers.insert(name, val);
420 }
421 }
422 }
423 }
424
425 let response = if acct_is_chatgpt {
428 tracing::info!(account = %account_name, path = %fwd_path, "forwarding to chatgpt.com (15s cap)");
429 match tokio::time::timeout(
430 std::time::Duration::from_secs(15),
431 s.forwarder.forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token),
432 ).await {
433 Ok(Ok(r)) => r,
434 Ok(Err(e)) => {
435 error!(account = %account_name, "chatgpt.com forward error: {:#}", e);
436 s.state.set_cooldown(&account_name, 5 * 60_000);
437 tried.insert(account_name);
438 continue;
439 }
440 Err(_) => {
441 warn!(account = %account_name, "chatgpt.com request timed out (Cloudflare) — cooling 5min");
442 s.state.set_cooldown(&account_name, 5 * 60_000);
443 tried.insert(account_name);
444 continue;
445 }
446 }
447 } else {
448 s.forwarder
449 .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
450 .await
451 .map_err(|e| {
452 error!("Forward error: {:#}", e);
453 ProxyError::Upstream
454 })?
455 };
456
457 match response.status().as_u16() {
458 200..=299 => {
459 s.state.set_last_used(&account_name);
460 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
461 s.state.update_rate_limits(&account_name, info);
462 }
463 let response = if req_is_anthropic == acct_is_anthropic {
465 response
466 } else if req_is_anthropic && acct_is_chatgpt {
467 translate_response_chatgpt_to_anthropic(response, &model).await
469 } else if req_is_anthropic {
470 translate_response_openai_to_anthropic(response, &model).await
472 } else {
473 translate_response_anthropic_to_openai(response).await
475 };
476 return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
477 }
478 429 => {
479 let info = account.provider.parse_rate_limits(response.headers());
480 let cooldown_ms = info.as_ref()
483 .and_then(|i| i.reset_5h.or(i.reset_7d))
484 .map(|reset_secs| {
485 let reset_ms = reset_secs.saturating_mul(1_000);
486 reset_ms.saturating_sub(now_ms()).saturating_add(500) })
488 .unwrap_or(60_000);
489 warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
490 if let Some(info) = info {
491 s.state.update_rate_limits(&account_name, info);
492 }
493 s.state.set_cooldown(&account_name, cooldown_ms);
494 if cooldown_ms >= 5 * 60_000 {
495 let mins = cooldown_ms / 60_000;
496 notify(
497 "shunt: Rate Limited",
498 &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
499 "Ping",
500 );
501 }
502 tried.insert(account_name);
503 }
504 529 => {
505 warn!(account = %account_name, "529 overloaded — cooling 30s");
506 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
507 s.state.update_rate_limits(&account_name, info);
508 }
509 s.state.set_cooldown(&account_name, 30_000);
510 tried.insert(account_name);
511 }
512 401 => {
513 if !refreshed.contains(&account_name) {
514 let account_lock = {
522 let mut locks = s.refresh_locks.lock().unwrap();
523 locks.entry(account_name.clone())
524 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
525 .clone()
526 };
527 let _guard = account_lock.lock().await;
528
529 let cred_before = {
532 let creds = s.credentials.read().await;
533 creds.get(&account_name).cloned()
534 .or_else(|| account.credential.clone())
535 };
536 let Some(cred) = cred_before else {
537 tried.insert(account_name);
538 continue;
539 };
540
541 let token_before = cred.access_token().to_owned();
543 let already_refreshed = {
544 let creds = s.credentials.read().await;
545 creds.get(&account_name)
546 .map(|c| c.access_token() != token_before)
547 .unwrap_or(false)
548 };
549
550 if already_refreshed {
551 warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
553 refreshed.insert(account_name);
554 } else if let Some(oauth_cred) = cred.as_oauth() {
555 match tokio::time::timeout(
557 std::time::Duration::from_secs(10),
558 account.provider.refresh_token(oauth_cred),
559 ).await {
560 Ok(Ok(fresh)) => {
561 warn!(account = %account_name, "401 — token refreshed, retrying");
562 {
563 let mut creds = s.credentials.write().await;
564 creds.insert(account_name.clone(), Credential::Oauth(fresh.clone()));
565 }
566 let name = account_name.clone();
568 let fresh = fresh.clone();
569 tokio::task::spawn_blocking(move || {
570 let mut store = CredentialsStore::load();
571 store.accounts.insert(name, Credential::Oauth(fresh.clone()));
572 store.save().ok();
573 if fresh.id_token.is_some() {
574 crate::oauth::write_codex_auth_file(&fresh);
575 }
576 });
577 refreshed.insert(account_name);
579 }
580 _ => {
581 error!(account = %account_name, "401 — token refresh failed, cooling 5min");
583 s.state.set_cooldown(&account_name, 5 * 60_000);
584 tried.insert(account_name);
585 }
586 }
587 } else {
588 error!(account = %account_name, "401 — API key rejected, cooling 5min");
590 s.state.set_cooldown(&account_name, 5 * 60_000);
591 tried.insert(account_name);
592 }
593 } else {
594 error!(account = %account_name, "401 after refresh — cooling 5min");
596 s.state.set_cooldown(&account_name, 5 * 60_000);
597 tried.insert(account_name);
598 }
599 }
600 403 => {
601 if acct_is_anthropic {
605 error!(account = %account_name, "403 forbidden — cooling 30min");
606 s.state.set_cooldown(&account_name, 30 * 60_000);
607 notify(
608 "shunt: Account Forbidden",
609 &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
610 "Basso",
611 );
612 } else {
613 warn!(account = %account_name, "403 from chatgpt.com (Cloudflare) — cooling 5min");
614 s.state.set_cooldown(&account_name, 5 * 60_000);
615 }
616 tried.insert(account_name);
617 }
618 _ => {
619 return Ok(response);
621 }
622 }
623 }
624}
625
626async fn tap_usage(
635 resp: Response,
636 state: &StateStore,
637 account: &str,
638 model: &str,
639 req_start_ms: u64,
640) -> Response {
641 use axum::body::Body;
642 use crate::state::RequestLog;
643
644 if quota::is_streaming_response(&resp) {
645 let state = state.clone();
646 let account = account.to_owned();
647 let model = model.to_owned();
648 let on_complete = Arc::new(move |input: u64, output: u64| {
649 state.record_usage(&account, input, output);
650 state.record_global(&model, input, output);
651 state.record_request(RequestLog {
652 ts_ms: req_start_ms,
653 account: account.clone(),
654 model: model.clone(),
655 status: 200,
656 input_tokens: input,
657 output_tokens: output,
658 duration_ms: now_ms().saturating_sub(req_start_ms),
659 });
660 });
661 let (parts, body) = resp.into_parts();
662 let wrapped = quota::wrap_streaming_body(body, on_complete);
663 return Response::from_parts(parts, wrapped);
664 }
665
666 let (parts, body) = resp.into_parts();
668 let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
669 Ok(b) => b,
670 Err(_) => return Response::from_parts(parts, Body::empty()),
671 };
672 let (input, output) = quota::extract_usage_from_json(&bytes);
673 state.record_usage(account, input, output);
674 state.record_global(model, input, output);
675 state.record_request(RequestLog {
676 ts_ms: req_start_ms,
677 account: account.to_owned(),
678 model: model.to_owned(),
679 status: 200,
680 input_tokens: input,
681 output_tokens: output,
682 duration_ms: now_ms().saturating_sub(req_start_ms),
683 });
684 Response::from_parts(parts, Body::from(bytes))
685}
686
687
688pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
696 let client = reqwest::Client::builder()
697 .timeout(std::time::Duration::from_secs(20))
698 .build()
699 .unwrap_or_default();
700
701 for account in &config.accounts {
702 let rl = state.rate_limit_snapshot();
704 if let Some(r) = rl.get(&account.name) {
705 if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
706 continue;
707 }
708 }
709
710 let cred = match account.credential.clone() {
712 Some(c) => c,
713 None => continue,
714 };
715
716 let Some((path, body)) = account.provider.prefetch_request() else {
717 if let Some(probe_path) = account.provider.auth_probe_get_path() {
719 auth_probe_get(&client, probe_path, account, &state).await;
720 }
721 continue;
722 };
723 let url = format!("{}{}", config.server.upstream_url, path);
724
725 let resp = prefetch_send(&client, &url, &account.provider, cred.bearer_token(), &body).await;
726
727 let r = match resp {
728 Ok(r) => r,
729 Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
730 };
731
732 if r.status() == reqwest::StatusCode::UNAUTHORIZED {
733 tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
734 let Some(oauth_cred) = cred.as_oauth() else {
735 tracing::error!(account = %account.name, "prefetch 401 — API key rejected");
737 state.set_auth_failed(&account.name);
738 continue;
739 };
740 let fresh = match account.provider.refresh_token(oauth_cred).await {
741 Ok(f) => f,
742 Err(e) => {
743 tracing::warn!(account = %account.name, "token refresh failed: {e}");
744 state.set_auth_failed(&account.name);
745 continue;
746 }
747 };
748 let mut store = crate::config::CredentialsStore::load();
749 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
750 store.save().ok();
751 if fresh.id_token.is_some() {
752 crate::oauth::write_codex_auth_file(&fresh);
753 }
754 live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
756
757 match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
758 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
759 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
760 state.set_auth_failed(&account.name);
761 }
762 Ok(r2) => {
763 if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
764 state.update_rate_limits(&account.name, info);
765 }
766 }
767 Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
768 }
769 } else {
770 tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
771 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
772 state.update_rate_limits(&account.name, info);
773 }
774 }
775 }
776}
777
778async fn prefetch_send(
780 client: &reqwest::Client,
781 url: &str,
782 provider: &crate::provider::Provider,
783 token: &str,
784 body: &serde_json::Value,
785) -> anyhow::Result<reqwest::Response> {
786 let mut headers = reqwest::header::HeaderMap::new();
787 provider.inject_auth_headers(&mut headers, token)?;
788 for (name, value) in provider.prefetch_extra_headers() {
789 headers.insert(
790 reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
791 reqwest::header::HeaderValue::from_static(value),
792 );
793 }
794 Ok(client.post(url).headers(headers).json(body).send().await?)
795}
796
797async fn auth_probe_get(
801 client: &reqwest::Client,
802 path: &str,
803 account: &crate::config::AccountConfig,
804 state: &StateStore,
805) {
806 let cred = match account.credential.clone() {
807 Some(c) => c,
808 None => return,
809 };
810 let upstream = account.upstream_url.as_deref()
811 .unwrap_or_else(|| account.provider.default_upstream_url());
812 let url = format!("{}{}", upstream, path);
813
814 let do_get = |token: &str| -> reqwest::RequestBuilder {
815 let mut headers = reqwest::header::HeaderMap::new();
816 let _ = account.provider.inject_auth_headers(&mut headers, token);
817 client.get(&url).headers(headers)
818 };
819
820 let resp = match do_get(cred.bearer_token()).send().await {
821 Ok(r) => r,
822 Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
823 };
824
825 if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
826 tracing::info!(account = %account.name, "auth probe: token rejected, refreshing");
827 let Some(oauth_cred) = cred.as_oauth() else {
828 tracing::error!(account = %account.name, "auth probe 401 — API key rejected");
830 state.set_auth_failed(&account.name);
831 return;
832 };
833 let fresh = match account.provider.refresh_token(oauth_cred).await {
834 Ok(f) => f,
835 Err(e) => {
836 tracing::warn!(account = %account.name, "token refresh failed: {e}");
837 state.set_auth_failed(&account.name);
838 return;
839 }
840 };
841 let mut store = crate::config::CredentialsStore::load();
842 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
843 store.save().ok();
844 if fresh.id_token.is_some() {
845 crate::oauth::write_codex_auth_file(&fresh);
846 }
847
848 let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
849 match do_get(fresh_token).send().await {
850 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
851 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
852 state.set_auth_failed(&account.name);
853 }
854 Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
855 Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
856 }
857 } else {
858 tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
859 }
863}
864
865fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
872 let now_ms = std::time::SystemTime::now()
873 .duration_since(std::time::UNIX_EPOCH)
874 .unwrap_or_default()
875 .as_millis() as u64;
876 let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
877 .unwrap_or(cred.expires_at);
878 exp_ms < now_ms + threshold_mins * 60 * 1_000
879}
880
881async fn sync_live_creds_from_auth_json(
886 account_name: &str,
887 live_creds: &LiveCredentials,
888) {
889 let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
890 let current_exp = live_creds.read().await
891 .get(account_name)
892 .and_then(|c| c.as_oauth())
893 .map(|c| c.expires_at)
894 .unwrap_or(0);
895 if from_file.expires_at > current_exp {
896 tracing::info!(account = %account_name, "synced fresher token from auth.json");
897 live_creds.write().await.insert(account_name.to_owned(), Credential::Oauth(from_file));
898 }
899}
900
901async fn do_proactive_refresh(
903 account: &crate::config::AccountConfig,
904 creds: &crate::oauth::OAuthCredential,
905 live_creds: &LiveCredentials,
906 state: &StateStore,
907) {
908 tracing::info!(account = %account.name, "proactive OpenAI token refresh");
909 match account.provider.refresh_token(creds).await {
910 Ok(fresh) => {
911 tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
912 {
913 let mut map = live_creds.write().await;
914 map.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
915 }
916 let mut store = crate::config::CredentialsStore::load();
917 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
918 store.save().ok();
919 if fresh.id_token.is_some() {
920 crate::oauth::write_codex_auth_file(&fresh);
921 }
922 state.clear_auth_failed(&account.name);
923 }
924 Err(e) => {
925 tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
926 state.set_auth_failed(&account.name);
927 }
928 }
929}
930
931
932pub async fn openai_token_refresh_loop(
940 config: Arc<Config>,
941 state: StateStore,
942 live_creds: LiveCredentials,
943) {
944 for account in config.accounts.iter()
946 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
947 {
948 if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
949 continue;
950 }
951 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
952
953 let creds = {
954 let map = live_creds.read().await;
955 map.get(&account.name).cloned().or_else(|| account.credential.clone())
956 };
957 if let Some(creds) = creds {
958 if let Some(oauth) = creds.as_oauth() {
959 if access_token_expires_soon(oauth, 30) {
960 do_proactive_refresh(account, oauth, &live_creds, &state).await;
962 } else {
963 tracing::info!(account = %account.name, "access_token fresh at startup");
964 }
965 }
966 }
967 }
968
969 loop {
972 tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
973 for account in config.accounts.iter()
974 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
975 {
976 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
977 }
978 }
979}
980
981enum ProxyError {
986 BodyRead,
987 Upstream,
988 AllAccountsUnavailable,
989 Unauthorized,
990}
991
992impl IntoResponse for ProxyError {
993 fn into_response(self) -> Response {
994 let (status, msg) = match self {
995 ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
996 ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
997 ProxyError::AllAccountsUnavailable => {
998 (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
999 }
1000 ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
1001 };
1002
1003 (status, axum::Json(json!({
1004 "type": "error",
1005 "error": {"type": "api_error", "message": msg}
1006 }))).into_response()
1007 }
1008}
1009
1010pub async fn recovery_watcher(
1019 config: Arc<Config>,
1020 state: StateStore,
1021 credentials: LiveCredentials,
1022) {
1023 use std::time::{Duration, Instant};
1024 const CHECK_INTERVAL: Duration = Duration::from_secs(120);
1025 const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
1026
1027 let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
1028 let mut last_notified: Option<Instant> = None;
1029
1030 loop {
1031 tokio::time::sleep(CHECK_INTERVAL).await;
1032
1033 let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
1034 let failed = state.auth_failed_accounts(&name_refs);
1035 if failed.is_empty() {
1036 last_notified = None;
1037 continue;
1038 }
1039
1040 tracing::warn!(
1041 accounts = ?failed,
1042 "recovery: {} account(s) auth_failed, attempting token refresh",
1043 failed.len()
1044 );
1045
1046 let mut any_recovered = false;
1047
1048 for name in &failed {
1049 let cred = {
1050 let map = credentials.read().await;
1051 map.get(*name).cloned()
1052 };
1053 let Some(cred) = cred else { continue };
1054 if !cred.has_refresh_token() { continue; }
1055 let Some(oauth_cred) = cred.as_oauth().cloned() else { continue };
1056
1057 let provider = config.accounts.iter()
1058 .find(|a| a.name == *name)
1059 .map(|a| a.provider.clone())
1060 .unwrap_or_default();
1061
1062 let result = tokio::time::timeout(
1063 Duration::from_secs(20),
1064 provider.refresh_token(&oauth_cred),
1065 ).await;
1066
1067 match result {
1068 Ok(Ok(fresh)) => {
1069 tracing::info!(account = %name, "recovery: token refreshed — account back online");
1070 {
1071 let mut map = credentials.write().await;
1072 map.insert(name.to_string(), Credential::Oauth(fresh.clone()));
1073 }
1074 let name_owned = name.to_string();
1075 let fresh_owned = fresh.clone();
1076 tokio::task::spawn_blocking(move || {
1077 let mut store = crate::config::CredentialsStore::load();
1078 store.accounts.insert(name_owned, Credential::Oauth(fresh_owned.clone()));
1079 store.save().ok();
1080 if fresh_owned.id_token.is_some() {
1081 crate::oauth::write_codex_auth_file(&fresh_owned);
1082 }
1083 });
1084 state.clear_auth_failed(name);
1085 any_recovered = true;
1086 }
1087 Ok(Err(e)) => {
1088 tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
1089 notify(
1090 "shunt: Reauth Required",
1091 &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
1092 "Basso",
1093 );
1094 }
1095 Err(_) => {
1096 tracing::error!(account = %name, "recovery: token refresh timed out");
1097 notify(
1098 "shunt: Reauth Required",
1099 &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
1100 "Basso",
1101 );
1102 }
1103 }
1104 }
1105
1106 if any_recovered {
1107 tracing::info!("recovery: at least one account is back online");
1108 continue;
1109 }
1110
1111 let still_failed = state.auth_failed_accounts(&name_refs);
1113 if still_failed.len() == account_names.len() {
1114 let should_notify = last_notified
1115 .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
1116 .unwrap_or(true);
1117 if should_notify {
1118 error!(
1119 "ALL accounts are offline (auth failed). \
1120 Run `shunt add-account` to re-authorize."
1121 );
1122 notify(
1123 "shunt: All Accounts Offline",
1124 "All accounts need re-authorization. Run `shunt add-account`.",
1125 "Basso",
1126 );
1127 last_notified = Some(Instant::now());
1128 }
1129 }
1130 }
1131}
1132
1133async fn post_cooldown_prefetch(
1137 client: &reqwest::Client,
1138 account: &crate::config::AccountConfig,
1139 token: &str,
1140 state: &StateStore,
1141 upstream_url: &str,
1142) {
1143 let Some((path, body)) = account.provider.prefetch_request() else {
1144 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1145 auth_probe_get(client, probe_path, account, state).await;
1146 }
1147 return;
1148 };
1149 let url = format!("{upstream_url}{path}");
1150 match prefetch_send(client, &url, &account.provider, token, &body).await {
1151 Ok(r) => {
1152 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1153 state.update_rate_limits(&account.name, info);
1154 tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1155 }
1156 }
1157 Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1158 }
1159}
1160
1161pub async fn cooldown_watcher(
1172 config: Arc<Config>,
1173 state: StateStore,
1174 credentials: LiveCredentials,
1175) {
1176 const STALE_RL_MS: u64 = 60 * 60_000;
1178
1179 let client = reqwest::Client::builder()
1180 .timeout(std::time::Duration::from_secs(20))
1181 .build()
1182 .unwrap_or_default();
1183
1184 let mut last_resumed: HashMap<String, u64> = HashMap::new();
1187 let mut notify_on_resume: HashSet<String> = HashSet::new();
1189 let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1191
1192 loop {
1193 let states = state.account_states();
1194 let rl_snapshot = state.rate_limit_snapshot();
1195 let now = now_ms();
1196 let mut next_wake_ms: Option<u64> = None;
1197
1198 for account in &config.accounts {
1199 let Some(st) = states.get(&account.name) else { continue };
1200 if st.disabled { continue; } let cdl = st.cooldown_until_ms;
1202
1203 if cdl > 0 && cdl <= now {
1204 let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1206 if !handled {
1207 tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1208 let token = {
1209 let creds = credentials.read().await;
1210 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1211 };
1212 if let Some(token) = token {
1213 post_cooldown_prefetch(
1214 &client, account, &token, &state,
1215 &config.server.upstream_url,
1216 ).await;
1217 }
1218 if notify_on_resume.remove(&account.name) {
1219 notify(
1220 "shunt: Account Resumed",
1221 &format!("Account '{}' is back online.", account.name),
1222 "Glass",
1223 );
1224 }
1225 last_resumed.insert(account.name.clone(), cdl);
1226 last_stale_prefetch.insert(account.name.clone(), now);
1227 }
1228 } else if cdl > now {
1229 let remaining = cdl - now;
1231 if remaining >= 5 * 60_000 {
1232 notify_on_resume.insert(account.name.clone());
1233 }
1234 next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1235 } else {
1236 let rl_age = rl_snapshot
1238 .get(&account.name)
1239 .map(|r| now.saturating_sub(r.updated_ms))
1240 .unwrap_or(u64::MAX); let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1242 let fetched_ago = now.saturating_sub(last_fetched);
1243
1244 if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1245 tracing::debug!(
1246 account = %account.name,
1247 age_min = rl_age / 60_000,
1248 "rate-limit data stale — refreshing"
1249 );
1250 let token = {
1251 let creds = credentials.read().await;
1252 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1253 };
1254 if let Some(token) = token {
1255 post_cooldown_prefetch(
1256 &client, account, &token, &state,
1257 &config.server.upstream_url,
1258 ).await;
1259 }
1260 last_stale_prefetch.insert(account.name.clone(), now);
1261 }
1262 }
1263 }
1264
1265 let sleep_ms = next_wake_ms
1267 .map(|wake| wake.saturating_sub(now_ms()).max(50))
1268 .unwrap_or(30_000);
1269 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1270 }
1271}
1272
1273use crate::notify::notify;
1274
1275fn map_model(openai_model: &str) -> String {
1288 if openai_model.starts_with("claude-") {
1289 return openai_model.to_owned();
1290 }
1291 match openai_model {
1292 "gpt-4o" | "gpt-4.5" | "o1" | "o1-pro" | "o3" | "o3-pro" | "gpt-5" | "gpt-5.5" => {
1293 "claude-opus-4-6"
1294 }
1295 "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" | "o1-mini" | "o3-mini" => {
1296 "claude-haiku-4-5-20251001"
1297 }
1298 _ => "claude-sonnet-4-6",
1299 }.to_owned()
1300}
1301
1302fn translate_to_anthropic(body: serde_json::Value) -> serde_json::Value {
1304 let model = body["model"].as_str().unwrap_or("gpt-4o");
1305 let claude_model = map_model(model);
1306
1307 let mut system: Option<String> = None;
1309 let mut messages = Vec::new();
1310 if let Some(arr) = body["messages"].as_array() {
1311 for msg in arr {
1312 let role = msg["role"].as_str().unwrap_or("");
1313 if role == "system" {
1314 let content = msg["content"].as_str()
1316 .map(|s| s.to_owned())
1317 .unwrap_or_else(|| serde_json::to_string(&msg["content"]).unwrap_or_default());
1318 system = Some(content);
1319 } else if role == "tool" {
1320 let tool_use_id = msg["tool_call_id"].as_str().unwrap_or("").to_owned();
1322 let content = msg["content"].as_str().unwrap_or("").to_owned();
1323 messages.push(json!({
1324 "role": "user",
1325 "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}]
1326 }));
1327 } else {
1328 if let Some(tool_calls) = msg["tool_calls"].as_array() {
1330 let mut content_blocks: Vec<serde_json::Value> = Vec::new();
1331 if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
1332 content_blocks.push(json!({"type": "text", "text": text}));
1333 }
1334 for tc in tool_calls {
1335 content_blocks.push(json!({
1336 "type": "tool_use",
1337 "id": tc["id"].as_str().unwrap_or(""),
1338 "name": tc["function"]["name"].as_str().unwrap_or(""),
1339 "input": serde_json::from_str::<serde_json::Value>(
1340 tc["function"]["arguments"].as_str().unwrap_or("{}")
1341 ).unwrap_or(json!({})),
1342 }));
1343 }
1344 messages.push(json!({"role": "assistant", "content": content_blocks}));
1345 } else {
1346 let content = msg["content"].as_str().unwrap_or("").to_owned();
1347 messages.push(json!({ "role": role, "content": content }));
1348 }
1349 }
1350 }
1351 }
1352
1353 let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1354 let stream = body["stream"].as_bool().unwrap_or(false);
1355
1356 let mut req = json!({
1357 "model": claude_model,
1358 "messages": messages,
1359 "max_tokens": max_tokens,
1360 "stream": stream,
1361 });
1362
1363 if let Some(sys) = system {
1364 req["system"] = json!(sys);
1365 }
1366 if let Some(temp) = body.get("temperature") {
1367 req["temperature"] = temp.clone();
1368 }
1369 if let Some(sp) = body.get("stop") {
1370 req["stop_sequences"] = sp.clone();
1371 }
1372
1373 if let Some(tools) = body["tools"].as_array() {
1375 let claude_tools: Vec<serde_json::Value> = tools.iter().filter_map(|t| {
1376 let func = &t["function"];
1377 Some(json!({
1378 "name": func["name"].as_str()?,
1379 "description": func["description"].as_str().unwrap_or(""),
1380 "input_schema": func.get("parameters").cloned().unwrap_or(json!({"type": "object", "properties": {}})),
1381 }))
1382 }).collect();
1383 if !claude_tools.is_empty() {
1384 req["tools"] = json!(claude_tools);
1385 }
1386 }
1387
1388 req
1389}
1390
1391fn translate_from_anthropic(body: serde_json::Value) -> serde_json::Value {
1393 let id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1394 let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6").to_owned();
1395
1396 let mut text_content = String::new();
1398 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1399 if let Some(blocks) = body["content"].as_array() {
1400 for (idx, block) in blocks.iter().enumerate() {
1401 match block["type"].as_str() {
1402 Some("text") => {
1403 text_content.push_str(block["text"].as_str().unwrap_or(""));
1404 }
1405 Some("tool_use") => {
1406 let args = match &block["input"] {
1407 serde_json::Value::String(s) => s.clone(),
1408 v => serde_json::to_string(v).unwrap_or_default(),
1409 };
1410 tool_calls.push(json!({
1411 "id": block["id"].as_str().unwrap_or(""),
1412 "type": "function",
1413 "index": idx,
1414 "function": {
1415 "name": block["name"].as_str().unwrap_or(""),
1416 "arguments": args,
1417 }
1418 }));
1419 }
1420 _ => {}
1421 }
1422 }
1423 }
1424
1425 let stop_reason = body["stop_reason"].as_str().unwrap_or("end_turn");
1426 let finish_reason = match stop_reason {
1427 "end_turn" => "stop",
1428 "tool_use" => "tool_calls",
1429 "max_tokens" => "length",
1430 other => other,
1431 };
1432
1433 let input_tokens = body["usage"]["input_tokens"].as_u64().unwrap_or(0);
1434 let output_tokens = body["usage"]["output_tokens"].as_u64().unwrap_or(0);
1435
1436 let mut message = json!({"role": "assistant", "content": text_content});
1437 if !tool_calls.is_empty() {
1438 message["tool_calls"] = json!(tool_calls);
1439 }
1440
1441 json!({
1442 "id": id,
1443 "object": "chat.completion",
1444 "model": model,
1445 "choices": [{
1446 "index": 0,
1447 "message": message,
1448 "finish_reason": finish_reason,
1449 }],
1450 "usage": {
1451 "prompt_tokens": input_tokens,
1452 "completion_tokens": output_tokens,
1453 "total_tokens": input_tokens + output_tokens,
1454 }
1455 })
1456}
1457
1458fn uuid_v4() -> String {
1459 use crate::oauth::rand_bytes;
1460 let b: [u8; 16] = rand_bytes();
1461 format!("{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
1462 u32::from_be_bytes(b[0..4].try_into().unwrap()),
1463 u16::from_be_bytes(b[4..6].try_into().unwrap()),
1464 u16::from_be_bytes(b[6..8].try_into().unwrap()),
1465 u16::from_be_bytes(b[8..10].try_into().unwrap()),
1466 {
1467 let mut v = 0u64;
1468 for &x in &b[10..16] { v = (v << 8) | x as u64; }
1469 v
1470 }
1471 )
1472}
1473
1474async fn openai_models_handler() -> impl IntoResponse {
1476 axum::Json(json!({
1477 "object": "list",
1478 "data": [
1479 { "id": "claude-opus-4-6", "object": "model", "owned_by": "anthropic" },
1480 { "id": "claude-sonnet-4-6", "object": "model", "owned_by": "anthropic" },
1481 { "id": "claude-haiku-4-5-20251001", "object": "model", "owned_by": "anthropic" },
1482 ]
1483 }))
1484}
1485
1486async fn openai_compat_handler(
1488 State(s): State<AppState>,
1489 req: Request,
1490) -> Result<Response, ProxyError> {
1491 let Some(ref anthropic_url) = s.anthropic_base_url else {
1492 return proxy_handler(State(s), req).await;
1494 };
1495
1496 let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1497 .await
1498 .map_err(|_| ProxyError::BodyRead)?;
1499
1500 let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1501 .unwrap_or(json!({}));
1502
1503 let stream = openai_body["stream"].as_bool().unwrap_or(false);
1504 let anthropic_body = translate_to_anthropic(openai_body);
1505
1506 let client = reqwest::Client::builder()
1507 .timeout(std::time::Duration::from_secs(300))
1508 .build()
1509 .map_err(|_| ProxyError::Upstream)?;
1510
1511 let resp = client
1512 .post(format!("{anthropic_url}/v1/messages"))
1513 .header("content-type", "application/json")
1514 .header("anthropic-version", "2023-06-01")
1515 .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1516 .header("x-shunt-compat", "openai")
1517 .json(&anthropic_body)
1518 .send()
1519 .await
1520 .map_err(|_| ProxyError::Upstream)?;
1521
1522 if !resp.status().is_success() {
1523 let status = resp.status();
1524 let body = resp.text().await.unwrap_or_default();
1525 let code = status.as_u16();
1526 return Ok(axum::response::Response::builder()
1527 .status(code)
1528 .header("content-type", "application/json")
1529 .body(axum::body::Body::from(body))
1530 .unwrap());
1531 }
1532
1533 if stream {
1534 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1536 let stream = translate_anthropic_stream(resp, chat_id);
1537 Ok(axum::response::Response::builder()
1538 .status(200)
1539 .header("content-type", "text/event-stream")
1540 .header("cache-control", "no-cache")
1541 .body(axum::body::Body::from_stream(stream))
1542 .unwrap())
1543 } else {
1544 let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1545 let openai_resp = translate_from_anthropic(anthropic_resp);
1546 Ok(axum::Json(openai_resp).into_response())
1547 }
1548}
1549
1550fn translate_anthropic_stream(
1553 resp: reqwest::Response,
1554 chat_id: String,
1555) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1556 use futures_util::StreamExt;
1557
1558 let id = chat_id;
1559 let byte_stream = resp.bytes_stream();
1560
1561 async_stream::stream! {
1562 let mut buf = String::new();
1563 let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1565 let mut tool_call_count: usize = 0;
1566 futures_util::pin_mut!(byte_stream);
1567
1568 let init = format!(
1570 "data: {}\n\n",
1571 serde_json::to_string(&json!({
1572 "id": id,
1573 "object": "chat.completion.chunk",
1574 "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1575 })).unwrap()
1576 );
1577 yield Ok(bytes::Bytes::from(init));
1578
1579 while let Some(chunk) = byte_stream.next().await {
1580 let chunk = match chunk {
1581 Ok(c) => c,
1582 Err(_) => break,
1583 };
1584 buf.push_str(&String::from_utf8_lossy(&chunk));
1585
1586 while let Some(nl) = buf.find('\n') {
1588 let line = buf[..nl].trim_end_matches('\r').to_owned();
1589 buf = buf[nl + 1..].to_owned();
1590
1591 if !line.starts_with("data: ") { continue; }
1592 let data = &line["data: ".len()..];
1593 if data == "[DONE]" { continue; }
1594
1595 let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1596 let event_type = event["type"].as_str().unwrap_or("");
1597
1598 let maybe_chunk = match event_type {
1599 "content_block_start" => {
1600 let block_idx = event["index"].as_u64().unwrap_or(0);
1601 let cb = &event["content_block"];
1602 if cb["type"].as_str() == Some("tool_use") {
1603 let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
1604 let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
1605 let oai_idx = tool_call_count;
1606 tool_call_count += 1;
1607 tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
1608 Some(json!({
1609 "id": id,
1610 "object": "chat.completion.chunk",
1611 "choices": [{"index": 0, "delta": {
1612 "tool_calls": [{
1613 "index": oai_idx,
1614 "id": tool_id,
1615 "type": "function",
1616 "function": {"name": tool_name, "arguments": ""}
1617 }]
1618 }, "finish_reason": null}]
1619 }))
1620 } else {
1621 None
1622 }
1623 }
1624 "content_block_delta" => {
1625 let block_idx = event["index"].as_u64().unwrap_or(0);
1626 let delta = &event["delta"];
1627 match delta["type"].as_str() {
1628 Some("text_delta") => {
1629 let text = delta["text"].as_str().unwrap_or("");
1630 if text.is_empty() { continue; }
1631 Some(json!({
1632 "id": id,
1633 "object": "chat.completion.chunk",
1634 "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1635 }))
1636 }
1637 Some("input_json_delta") => {
1638 let args = delta["partial_json"].as_str().unwrap_or("");
1639 if let Some((oai_idx, _, _)) = tool_blocks.get(&block_idx) {
1640 Some(json!({
1641 "id": id,
1642 "object": "chat.completion.chunk",
1643 "choices": [{"index": 0, "delta": {
1644 "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
1645 }, "finish_reason": null}]
1646 }))
1647 } else {
1648 None
1649 }
1650 }
1651 _ => None,
1652 }
1653 }
1654 "message_delta" => {
1655 let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1656 let finish = match stop_reason {
1657 "end_turn" => "stop",
1658 "tool_use" => "tool_calls",
1659 "max_tokens" => "length",
1660 other => other,
1661 };
1662 Some(json!({
1663 "id": id,
1664 "object": "chat.completion.chunk",
1665 "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1666 }))
1667 }
1668 _ => None,
1669 };
1670
1671 if let Some(c) = maybe_chunk {
1672 let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
1673 yield Ok(bytes::Bytes::from(out));
1674 }
1675 }
1676 }
1677
1678 yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
1679 }
1680}
1681
1682fn map_model_to_openai(claude_model: &str) -> &str {
1688 match claude_model {
1689 m if m.contains("opus") => "gpt-4o",
1690 m if m.contains("haiku") => "gpt-4o-mini",
1691 _ => "gpt-4o", }
1693}
1694
1695async fn fetch_sentinel_token(client: &reqwest::Client, upstream: &str, token: &str) -> Option<String> {
1702 let url = format!("{}/backend-api/sentinel/chat-requirements", upstream);
1703 let resp = client
1704 .get(&url)
1705 .header("Authorization", format!("Bearer {}", token))
1706 .send()
1707 .await
1708 .ok()?;
1709 if !resp.status().is_success() {
1710 return None;
1711 }
1712 let json: serde_json::Value = resp.json().await.ok()?;
1713 if json["proofofwork"]["required"].as_bool() == Some(true) {
1714 return None;
1715 }
1716 json["token"].as_str().map(ToOwned::to_owned)
1717}
1718
1719fn map_model_to_chatgpt(model: &str) -> &str {
1721 if model.contains("opus") {
1722 "gpt-4o"
1723 } else if model.contains("haiku") {
1724 "gpt-4o-mini"
1725 } else {
1726 "gpt-4o"
1727 }
1728}
1729
1730fn extract_text_from_anthropic_content(content: &serde_json::Value) -> String {
1733 if let Some(s) = content.as_str() {
1734 return s.to_owned();
1735 }
1736 if let Some(arr) = content.as_array() {
1737 let mut text = String::new();
1738 for block in arr {
1739 match block["type"].as_str() {
1740 Some("text") => text.push_str(block["text"].as_str().unwrap_or("")),
1741 Some("tool_use") => {
1742 let name = block["name"].as_str().unwrap_or("tool");
1743 let args = serde_json::to_string(&block["input"]).unwrap_or_default();
1744 text.push_str(&format!("[Tool: {}({})]", name, args));
1745 }
1746 Some("tool_result") => {
1747 let result = block["content"].as_str()
1748 .map(|s| s.to_owned())
1749 .unwrap_or_else(|| serde_json::to_string(&block["content"]).unwrap_or_default());
1750 text.push_str(&result);
1751 }
1752 _ => {}
1753 }
1754 }
1755 return text;
1756 }
1757 String::new()
1758}
1759
1760fn translate_anthropic_req_to_chatgpt(body: &serde_json::Value) -> serde_json::Value {
1763 let claude_model = body["model"].as_str().unwrap_or("claude-sonnet-4-6");
1764 let model = map_model_to_chatgpt(claude_model);
1765 let system_prompt = body["system"].as_str().unwrap_or("").to_owned();
1766
1767 let mut messages: Vec<serde_json::Value> = Vec::new();
1768 if let Some(arr) = body["messages"].as_array() {
1769 for msg in arr {
1770 let role = msg["role"].as_str().unwrap_or("user");
1771 let text = extract_text_from_anthropic_content(&msg["content"]);
1772 messages.push(json!({
1773 "id": uuid_v4(),
1774 "author": {"role": role},
1775 "content": {"content_type": "text", "parts": [text]},
1776 "metadata": {}
1777 }));
1778 }
1779 }
1780
1781 json!({
1782 "action": "next",
1783 "messages": messages,
1784 "model": model,
1785 "parent_message_id": uuid_v4(),
1786 "system_prompt": system_prompt,
1787 "history_and_training_disabled": true,
1788 "supports_modapi": false,
1789 })
1790}
1791
1792fn translate_chatgpt_resp_to_anthropic(body: serde_json::Value, model: &str) -> serde_json::Value {
1794 let id = format!("msg_{}", &uuid_v4()[..8]);
1795 let text = body["message"]["content"]["parts"][0]
1796 .as_str()
1797 .unwrap_or("")
1798 .to_owned();
1799 json!({
1800 "id": id,
1801 "type": "message",
1802 "role": "assistant",
1803 "model": model,
1804 "content": [{"type": "text", "text": text}],
1805 "stop_reason": "end_turn",
1806 "stop_sequence": null,
1807 "usage": {"input_tokens": 0, "output_tokens": 0}
1808 })
1809}
1810
1811async fn translate_response_chatgpt_to_anthropic(resp: Response, model: &str) -> Response {
1814 use axum::body::Body;
1815 let msg_id = format!("msg_{}", &uuid_v4()[..8]);
1816 let model = model.to_owned();
1817
1818 if quota::is_streaming_response(&resp) {
1819 let (mut parts, body) = resp.into_parts();
1820 parts.headers.insert(
1821 axum::http::header::CONTENT_TYPE,
1822 axum::http::HeaderValue::from_static("text/event-stream"),
1823 );
1824 let stream = translate_chatgpt_stream_to_anthropic(body, model, msg_id);
1825 Response::from_parts(parts, Body::from_stream(stream))
1826 } else {
1827 let (mut parts, body) = resp.into_parts();
1828 let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
1829 let chatgpt_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
1830 let anthropic_val = translate_chatgpt_resp_to_anthropic(chatgpt_val, &model);
1831 let out = serde_json::to_vec(&anthropic_val).unwrap_or_default();
1832 parts.headers.insert(
1833 axum::http::header::CONTENT_TYPE,
1834 axum::http::HeaderValue::from_static("application/json"),
1835 );
1836 Response::from_parts(parts, Body::from(out))
1837 }
1838}
1839
1840fn translate_chatgpt_stream_to_anthropic(
1845 body: axum::body::Body,
1846 model: String,
1847 msg_id: String,
1848) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1849 use futures_util::StreamExt;
1850
1851 async_stream::stream! {
1852 let start_evt = format!(
1853 "event: message_start\ndata: {}\n\nevent: ping\ndata: {{\"type\":\"ping\"}}\n\n",
1854 serde_json::to_string(&json!({
1855 "type": "message_start",
1856 "message": {
1857 "id": msg_id, "type": "message", "role": "assistant",
1858 "content": [], "model": model, "stop_reason": null,
1859 "usage": {"input_tokens": 0, "output_tokens": 0}
1860 }
1861 })).unwrap()
1862 );
1863 yield Ok(bytes::Bytes::from(start_evt));
1864
1865 let mut buf = String::new();
1866 let mut content_block_open = false;
1867 let mut prev_len: usize = 0;
1868 let byte_stream = body.into_data_stream();
1869 futures_util::pin_mut!(byte_stream);
1870
1871 'outer: while let Some(chunk) = byte_stream.next().await {
1872 let chunk = match chunk { Ok(c) => c, Err(_) => break };
1873 buf.push_str(&String::from_utf8_lossy(&chunk));
1874
1875 while let Some(nl) = buf.find('\n') {
1876 let line = buf[..nl].trim_end_matches('\r').to_owned();
1877 buf = buf[nl + 1..].to_owned();
1878 if !line.starts_with("data: ") { continue; }
1879 let data = &line["data: ".len()..];
1880 if data == "[DONE]" { break 'outer; }
1881 let Ok(val) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1882
1883 let text = match val["message"]["content"]["parts"][0].as_str() {
1884 Some(t) => t.to_owned(),
1885 None => continue,
1886 };
1887
1888 let delta = text[prev_len..].to_owned();
1889 if !delta.is_empty() {
1890 if !content_block_open {
1891 content_block_open = true;
1892 yield Ok(bytes::Bytes::from(format!(
1893 "event: content_block_start\ndata: {}\n\n",
1894 serde_json::to_string(&json!({
1895 "type": "content_block_start", "index": 0,
1896 "content_block": {"type": "text", "text": ""}
1897 })).unwrap()
1898 )));
1899 }
1900 yield Ok(bytes::Bytes::from(format!(
1901 "event: content_block_delta\ndata: {}\n\n",
1902 serde_json::to_string(&json!({
1903 "type": "content_block_delta", "index": 0,
1904 "delta": {"type": "text_delta", "text": delta}
1905 })).unwrap()
1906 )));
1907 prev_len = text.len();
1908 }
1909
1910 if val["message"]["end_turn"].as_bool() == Some(true) {
1911 break 'outer;
1912 }
1913 }
1914 }
1915
1916 if content_block_open {
1917 yield Ok(bytes::Bytes::from(format!(
1918 "event: content_block_stop\ndata: {}\n\n",
1919 serde_json::to_string(&json!({"type": "content_block_stop", "index": 0})).unwrap()
1920 )));
1921 }
1922 yield Ok(bytes::Bytes::from(format!(
1923 "event: message_delta\ndata: {}\n\n",
1924 serde_json::to_string(&json!({
1925 "type": "message_delta",
1926 "delta": {"stop_reason": "end_turn", "stop_sequence": null},
1927 "usage": {"output_tokens": 0}
1928 })).unwrap()
1929 )));
1930 yield Ok(bytes::Bytes::from(
1931 "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
1932 ));
1933 }
1934}
1935
1936fn translate_anthropic_req_to_openai(body: serde_json::Value) -> serde_json::Value {
1939 let claude_model = body["model"].as_str().unwrap_or("claude-sonnet-4-6");
1940 let model = map_model_to_openai(claude_model);
1941 let stream = body["stream"].as_bool().unwrap_or(false);
1942 let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1943
1944 let mut messages: Vec<serde_json::Value> = Vec::new();
1945
1946 if let Some(sys) = body["system"].as_str().filter(|s| !s.is_empty()) {
1948 messages.push(json!({"role": "system", "content": sys}));
1949 }
1950
1951 if let Some(arr) = body["messages"].as_array() {
1952 for msg in arr {
1953 let role = msg["role"].as_str().unwrap_or("user");
1954
1955 if let Some(blocks) = msg["content"].as_array() {
1956 let has_tool_result = blocks.iter().any(|b| b["type"] == "tool_result");
1958 if has_tool_result {
1959 for b in blocks {
1960 if b["type"] == "tool_result" {
1961 let content = b["content"].as_str()
1962 .map(|s| s.to_owned())
1963 .unwrap_or_else(|| serde_json::to_string(&b["content"]).unwrap_or_default());
1964 messages.push(json!({
1965 "role": "tool",
1966 "tool_call_id": b["tool_use_id"].as_str().unwrap_or(""),
1967 "content": content,
1968 }));
1969 }
1970 }
1971 continue;
1972 }
1973
1974 let mut text = String::new();
1976 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1977 for b in blocks {
1978 match b["type"].as_str() {
1979 Some("text") => text.push_str(b["text"].as_str().unwrap_or("")),
1980 Some("tool_use") => {
1981 let args = match &b["input"] {
1982 serde_json::Value::String(s) => s.clone(),
1983 v => serde_json::to_string(v).unwrap_or_default(),
1984 };
1985 tool_calls.push(json!({
1986 "id": b["id"].as_str().unwrap_or(""),
1987 "type": "function",
1988 "function": {"name": b["name"].as_str().unwrap_or(""), "arguments": args},
1989 }));
1990 }
1991 _ => {}
1992 }
1993 }
1994 let mut m = json!({"role": role, "content": text});
1995 if !tool_calls.is_empty() {
1996 m["tool_calls"] = json!(tool_calls);
1997 }
1998 messages.push(m);
1999 } else if let Some(s) = msg["content"].as_str() {
2000 messages.push(json!({"role": role, "content": s}));
2001 }
2002 }
2003 }
2004
2005 let mut req = json!({
2006 "model": model,
2007 "messages": messages,
2008 "max_tokens": max_tokens,
2009 "stream": stream,
2010 });
2011
2012 if stream {
2014 req["stream_options"] = json!({"include_usage": true});
2015 }
2016 if let Some(t) = body.get("temperature") { req["temperature"] = t.clone(); }
2017 if let Some(sp) = body.get("stop_sequences") { req["stop"] = sp.clone(); }
2018
2019 if let Some(tools) = body["tools"].as_array() {
2021 let oai: Vec<serde_json::Value> = tools.iter().map(|t| json!({
2022 "type": "function",
2023 "function": {
2024 "name": t["name"].as_str().unwrap_or(""),
2025 "description": t["description"].as_str().unwrap_or(""),
2026 "parameters": t.get("input_schema").cloned()
2027 .unwrap_or(json!({"type": "object", "properties": {}})),
2028 }
2029 })).collect();
2030 if !oai.is_empty() { req["tools"] = json!(oai); }
2031 }
2032
2033 if let Some(tc) = body.get("tool_choice") {
2034 req["tool_choice"] = match tc["type"].as_str() {
2035 Some("any") => json!({"type": "required"}),
2036 Some("tool") => json!({"type": "function", "function": {"name": tc["name"]}}),
2037 _ => json!("auto"),
2038 };
2039 }
2040
2041 req
2042}
2043
2044fn translate_openai_resp_to_anthropic(body: serde_json::Value, model: &str) -> serde_json::Value {
2046 let id = format!("msg_{}", &uuid_v4()[..8]);
2047 let choice = &body["choices"][0];
2048 let msg = &choice["message"];
2049
2050 let mut content: Vec<serde_json::Value> = Vec::new();
2051 if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
2052 content.push(json!({"type": "text", "text": text}));
2053 }
2054 if let Some(tcs) = msg["tool_calls"].as_array() {
2055 for tc in tcs {
2056 content.push(json!({
2057 "type": "tool_use",
2058 "id": tc["id"].as_str().unwrap_or(""),
2059 "name": tc["function"]["name"].as_str().unwrap_or(""),
2060 "input": serde_json::from_str::<serde_json::Value>(
2061 tc["function"]["arguments"].as_str().unwrap_or("{}")
2062 ).unwrap_or(json!({})),
2063 }));
2064 }
2065 }
2066
2067 let stop_reason = match choice["finish_reason"].as_str().unwrap_or("stop") {
2068 "stop" => "end_turn",
2069 "tool_calls" => "tool_use",
2070 "length" => "max_tokens",
2071 other => other,
2072 };
2073
2074 json!({
2075 "id": id,
2076 "type": "message",
2077 "role": "assistant",
2078 "model": model,
2079 "content": content,
2080 "stop_reason": stop_reason,
2081 "stop_sequence": null,
2082 "usage": {
2083 "input_tokens": body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
2084 "output_tokens": body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
2085 }
2086 })
2087}
2088
2089async fn translate_response_openai_to_anthropic(resp: Response, model: &str) -> Response {
2092 use axum::body::Body;
2093 let msg_id = format!("msg_{}", &uuid_v4()[..8]);
2094 let model = model.to_owned();
2095
2096 if quota::is_streaming_response(&resp) {
2097 let (mut parts, body) = resp.into_parts();
2098 parts.headers.insert(
2099 axum::http::header::CONTENT_TYPE,
2100 axum::http::HeaderValue::from_static("text/event-stream"),
2101 );
2102 let stream = translate_openai_stream_to_anthropic(body, model, msg_id);
2103 Response::from_parts(parts, Body::from_stream(stream))
2104 } else {
2105 let (mut parts, body) = resp.into_parts();
2106 let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
2107 let openai_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
2108 let anthropic_val = translate_openai_resp_to_anthropic(openai_val, &model);
2109 let out = serde_json::to_vec(&anthropic_val).unwrap_or_default();
2110 parts.headers.insert(
2111 axum::http::header::CONTENT_TYPE,
2112 axum::http::HeaderValue::from_static("application/json"),
2113 );
2114 Response::from_parts(parts, Body::from(out))
2115 }
2116}
2117
2118async fn translate_response_anthropic_to_openai(resp: Response) -> Response {
2120 use axum::body::Body;
2121 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
2122
2123 if quota::is_streaming_response(&resp) {
2124 let (parts, body) = resp.into_parts();
2125 let stream = translate_body_anthropic_to_openai(body, chat_id);
2126 Response::from_parts(parts, Body::from_stream(stream))
2127 } else {
2128 let (mut parts, body) = resp.into_parts();
2129 let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
2130 let anthropic_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
2131 let openai_val = translate_from_anthropic(anthropic_val);
2132 let out = serde_json::to_vec(&openai_val).unwrap_or_default();
2133 parts.headers.insert(
2134 axum::http::header::CONTENT_TYPE,
2135 axum::http::HeaderValue::from_static("application/json"),
2136 );
2137 Response::from_parts(parts, Body::from(out))
2138 }
2139}
2140
2141fn translate_openai_stream_to_anthropic(
2146 body: axum::body::Body,
2147 model: String,
2148 msg_id: String,
2149) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
2150 use futures_util::StreamExt;
2151
2152 async_stream::stream! {
2153 let start_evt = format!(
2155 "event: message_start\ndata: {}\n\nevent: ping\ndata: {{\"type\":\"ping\"}}\n\n",
2156 serde_json::to_string(&json!({
2157 "type": "message_start",
2158 "message": {
2159 "id": msg_id, "type": "message", "role": "assistant",
2160 "content": [], "model": model, "stop_reason": null,
2161 "usage": {"input_tokens": 0, "output_tokens": 0}
2162 }
2163 })).unwrap()
2164 );
2165 yield Ok(bytes::Bytes::from(start_evt));
2166
2167 let mut buf = String::new();
2168 let mut content_block_open = false;
2169 let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
2170 let mut tool_call_count: usize = 0;
2171 let mut output_tokens: u64 = 0;
2172 let mut input_tokens: u64 = 0;
2173 let byte_stream = body.into_data_stream();
2174 futures_util::pin_mut!(byte_stream);
2175
2176 while let Some(chunk) = byte_stream.next().await {
2177 let chunk = match chunk { Ok(c) => c, Err(_) => break };
2178 buf.push_str(&String::from_utf8_lossy(&chunk));
2179
2180 while let Some(nl) = buf.find('\n') {
2181 let line = buf[..nl].trim_end_matches('\r').to_owned();
2182 buf = buf[nl + 1..].to_owned();
2183 if !line.starts_with("data: ") { continue; }
2184 let data = &line["data: ".len()..];
2185 if data == "[DONE]" { continue; }
2186 let Ok(ev) = serde_json::from_str::<serde_json::Value>(data) else { continue };
2187
2188 if let Some(u) = ev.get("usage") {
2190 input_tokens = u["prompt_tokens"].as_u64().unwrap_or(input_tokens);
2191 output_tokens = u["completion_tokens"].as_u64().unwrap_or(output_tokens);
2192 }
2193
2194 let choice = &ev["choices"][0];
2195 let delta = &choice["delta"];
2196 let finish = choice["finish_reason"].as_str();
2197
2198 if let Some(text) = delta["content"].as_str().filter(|s| !s.is_empty()) {
2200 if !content_block_open {
2201 content_block_open = true;
2202 let cb = format!(
2203 "event: content_block_start\ndata: {}\n\n",
2204 serde_json::to_string(&json!({
2205 "type": "content_block_start", "index": 0,
2206 "content_block": {"type": "text", "text": ""}
2207 })).unwrap()
2208 );
2209 yield Ok(bytes::Bytes::from(cb));
2210 }
2211 let d = format!(
2212 "event: content_block_delta\ndata: {}\n\n",
2213 serde_json::to_string(&json!({
2214 "type": "content_block_delta", "index": 0,
2215 "delta": {"type": "text_delta", "text": text}
2216 })).unwrap()
2217 );
2218 yield Ok(bytes::Bytes::from(d));
2219 }
2220
2221 if let Some(tcs) = delta["tool_calls"].as_array() {
2223 for tc in tcs {
2224 let oai_idx = tc["index"].as_u64().unwrap_or(0);
2225 if let Some(id) = tc["id"].as_str() {
2227 let name = tc["function"]["name"].as_str().unwrap_or("").to_owned();
2228 let my_idx = tool_call_count;
2229 tool_call_count += 1;
2230 tool_blocks.insert(oai_idx, (my_idx, id.to_owned(), name.clone()));
2231 let cb = format!(
2232 "event: content_block_start\ndata: {}\n\n",
2233 serde_json::to_string(&json!({
2234 "type": "content_block_start",
2235 "index": my_idx + 1, "content_block": {"type": "tool_use", "id": id, "name": name, "input": {}}
2237 })).unwrap()
2238 );
2239 yield Ok(bytes::Bytes::from(cb));
2240 }
2241 if let Some(args_chunk) = tc["function"]["arguments"].as_str() {
2243 if let Some(&(my_idx, _, _)) = tool_blocks.get(&oai_idx) {
2244 let d = format!(
2245 "event: content_block_delta\ndata: {}\n\n",
2246 serde_json::to_string(&json!({
2247 "type": "content_block_delta",
2248 "index": my_idx + 1,
2249 "delta": {"type": "input_json_delta", "partial_json": args_chunk}
2250 })).unwrap()
2251 );
2252 yield Ok(bytes::Bytes::from(d));
2253 }
2254 }
2255 }
2256 }
2257
2258 if let Some(fr) = finish {
2260 let stop_reason = match fr {
2261 "stop" => "end_turn",
2262 "tool_calls" => "tool_use",
2263 "length" => "max_tokens",
2264 other => other,
2265 };
2266
2267 if content_block_open {
2269 yield Ok(bytes::Bytes::from(format!(
2270 "event: content_block_stop\ndata: {}\n\n",
2271 serde_json::to_string(&json!({"type":"content_block_stop","index":0})).unwrap()
2272 )));
2273 }
2274 for (_, (my_idx, _, _)) in &tool_blocks {
2275 yield Ok(bytes::Bytes::from(format!(
2276 "event: content_block_stop\ndata: {}\n\n",
2277 serde_json::to_string(&json!({"type":"content_block_stop","index": my_idx + 1})).unwrap()
2278 )));
2279 }
2280
2281 yield Ok(bytes::Bytes::from(format!(
2282 "event: message_delta\ndata: {}\n\n",
2283 serde_json::to_string(&json!({
2284 "type": "message_delta",
2285 "delta": {"stop_reason": stop_reason, "stop_sequence": null},
2286 "usage": {"output_tokens": output_tokens}
2287 })).unwrap()
2288 )));
2289 yield Ok(bytes::Bytes::from(
2290 "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
2291 ));
2292 }
2293 }
2294 }
2295 }
2296}
2297
2298fn translate_body_anthropic_to_openai(
2302 body: axum::body::Body,
2303 chat_id: String,
2304) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
2305 use futures_util::StreamExt;
2306
2307 async_stream::stream! {
2308 let id = chat_id;
2309
2310 let init = format!(
2312 "data: {}\n\n",
2313 serde_json::to_string(&json!({
2314 "id": id, "object": "chat.completion.chunk",
2315 "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
2316 })).unwrap()
2317 );
2318 yield Ok(bytes::Bytes::from(init));
2319
2320 let mut buf = String::new();
2321 let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
2322 let mut tool_call_count: usize = 0;
2323 let byte_stream = body.into_data_stream();
2324 futures_util::pin_mut!(byte_stream);
2325
2326 while let Some(chunk) = byte_stream.next().await {
2327 let chunk = match chunk { Ok(c) => c, Err(_) => break };
2328 buf.push_str(&String::from_utf8_lossy(&chunk));
2329
2330 while let Some(nl) = buf.find('\n') {
2331 let line = buf[..nl].trim_end_matches('\r').to_owned();
2332 buf = buf[nl + 1..].to_owned();
2333 if !line.starts_with("data: ") { continue; }
2334 let data = &line["data: ".len()..];
2335 if data == "[DONE]" { continue; }
2336 let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
2337 let event_type = event["type"].as_str().unwrap_or("");
2338
2339 let maybe_chunk = match event_type {
2340 "content_block_start" => {
2341 let block_idx = event["index"].as_u64().unwrap_or(0);
2342 let cb = &event["content_block"];
2343 if cb["type"].as_str() == Some("tool_use") {
2344 let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
2345 let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
2346 let oai_idx = tool_call_count;
2347 tool_call_count += 1;
2348 tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
2349 Some(json!({
2350 "id": id, "object": "chat.completion.chunk",
2351 "choices": [{"index": 0, "delta": {
2352 "tool_calls": [{"index": oai_idx, "id": tool_id, "type": "function",
2353 "function": {"name": tool_name, "arguments": ""}}]
2354 }, "finish_reason": null}]
2355 }))
2356 } else { None }
2357 }
2358 "content_block_delta" => {
2359 let block_idx = event["index"].as_u64().unwrap_or(0);
2360 let delta = &event["delta"];
2361 match delta["type"].as_str() {
2362 Some("text_delta") => {
2363 let text = delta["text"].as_str().unwrap_or("");
2364 if text.is_empty() { continue; }
2365 Some(json!({
2366 "id": id, "object": "chat.completion.chunk",
2367 "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
2368 }))
2369 }
2370 Some("input_json_delta") => {
2371 let args = delta["partial_json"].as_str().unwrap_or("");
2372 tool_blocks.get(&block_idx).map(|(oai_idx, _, _)| json!({
2373 "id": id, "object": "chat.completion.chunk",
2374 "choices": [{"index": 0, "delta": {
2375 "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
2376 }, "finish_reason": null}]
2377 }))
2378 }
2379 _ => None,
2380 }
2381 }
2382 "message_delta" => {
2383 let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
2384 let finish = match stop_reason {
2385 "end_turn" => "stop",
2386 "tool_use" => "tool_calls",
2387 "max_tokens" => "length",
2388 other => other,
2389 };
2390 Some(json!({
2391 "id": id, "object": "chat.completion.chunk",
2392 "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
2393 }))
2394 }
2395 _ => None,
2396 };
2397
2398 if let Some(c) = maybe_chunk {
2399 let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
2400 yield Ok(bytes::Bytes::from(out));
2401 }
2402 }
2403 }
2404 yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
2405 }
2406}