1use axum::http::HeaderMap;
7use serde::{Deserialize, Serialize};
8
9use crate::credential::Credential;
10use crate::oauth::OAuthCredential;
11use crate::state::RateLimitInfo;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum AuthKind {
19 OAuth,
21 ApiKey,
23 None,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum WireProtocol {
33 Anthropic,
35 OpenAICompat,
37}
38
39#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
44#[serde(rename_all = "lowercase")]
45pub enum Provider {
46 #[default]
48 Anthropic,
49 OpenAI,
51 #[serde(rename = "openai-api")]
53 OpenAIApi,
54 #[serde(rename = "ollama")]
56 OllamaCloud,
57 Groq,
59 Mistral,
61 Together,
63 OpenRouter,
65 DeepSeek,
67 Fireworks,
69 Gemini,
71 Local,
73}
74
75impl std::fmt::Display for Provider {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 Provider::Anthropic => write!(f, "anthropic"),
79 Provider::OpenAI => write!(f, "openai"),
80 Provider::OpenAIApi => write!(f, "openai-api"),
81 Provider::OllamaCloud => write!(f, "ollama"),
82 Provider::Groq => write!(f, "groq"),
83 Provider::Mistral => write!(f, "mistral"),
84 Provider::Together => write!(f, "together"),
85 Provider::OpenRouter => write!(f, "openrouter"),
86 Provider::DeepSeek => write!(f, "deepseek"),
87 Provider::Fireworks => write!(f, "fireworks"),
88 Provider::Gemini => write!(f, "gemini"),
89 Provider::Local => write!(f, "local"),
90 }
91 }
92}
93
94impl Provider {
95 pub fn from_str(s: &str) -> Self {
96 match s.to_ascii_lowercase().as_str() {
97 "openai" | "codex" => Provider::OpenAI,
98 "openai-api" | "openai_api" => Provider::OpenAIApi,
99 "ollama" | "ollama-cloud" | "ollamacloud" => Provider::OllamaCloud,
100 "groq" => Provider::Groq,
101 "mistral" => Provider::Mistral,
102 "together" | "together-ai" => Provider::Together,
103 "openrouter" | "open-router" => Provider::OpenRouter,
104 "deepseek" | "deep-seek" => Provider::DeepSeek,
105 "fireworks" | "fireworks-ai" => Provider::Fireworks,
106 "gemini" | "google" => Provider::Gemini,
107 "local" => Provider::Local,
108 _ => Provider::Anthropic,
109 }
110 }
111
112 pub fn auth_kind(&self) -> AuthKind {
114 match self {
115 Provider::Anthropic | Provider::OpenAI => AuthKind::OAuth,
116 Provider::Local => AuthKind::None,
117 _ => AuthKind::ApiKey,
118 }
119 }
120
121 pub fn wire_protocol(&self) -> WireProtocol {
123 match self {
124 Provider::Anthropic => WireProtocol::Anthropic,
125 _ => WireProtocol::OpenAICompat,
126 }
127 }
128
129 pub fn default_model(&self) -> &'static str {
135 match self {
136 Provider::Anthropic => "claude-sonnet-4-6",
137 Provider::OpenAI => "gpt-4o",
138 Provider::OpenAIApi => "gpt-4o",
139 Provider::OllamaCloud => "llama3.3",
140 Provider::Groq => "llama-3.3-70b-versatile",
141 Provider::Mistral => "mistral-large-latest",
142 Provider::Together => "meta-llama/Llama-3.3-70B-Instruct-Turbo",
143 Provider::OpenRouter => "meta-llama/llama-3.3-70b-instruct",
144 Provider::DeepSeek => "deepseek-chat",
145 Provider::Fireworks => "accounts/fireworks/models/llama-v3p3-70b-instruct",
146 Provider::Gemini => "gemini-2.0-flash",
147 Provider::Local => "", }
149 }
150
151 pub fn accepts_claude_models(&self) -> bool {
154 matches!(self, Provider::Anthropic)
155 }
156
157 pub fn api_key_env_var(&self) -> Option<&'static str> {
160 match self {
161 Provider::OpenAIApi => Some("OPENAI_API_KEY"),
162 Provider::OllamaCloud => Some("OLLAMA_API_KEY"),
163 Provider::Groq => Some("GROQ_API_KEY"),
164 Provider::Mistral => Some("MISTRAL_API_KEY"),
165 Provider::Together => Some("TOGETHER_API_KEY"),
166 Provider::OpenRouter => Some("OPENROUTER_API_KEY"),
167 Provider::DeepSeek => Some("DEEPSEEK_API_KEY"),
168 Provider::Fireworks => Some("FIREWORKS_API_KEY"),
169 Provider::Gemini => Some("GEMINI_API_KEY"),
170 _ => None,
171 }
172 }
173
174 pub fn default_upstream_url(&self) -> &'static str {
176 match self {
177 Provider::Anthropic => "https://api.anthropic.com",
178 Provider::OpenAI => "https://chatgpt.com",
179 Provider::OpenAIApi => "https://api.openai.com",
180 Provider::OllamaCloud => "https://api.ollama.com",
181 Provider::Groq => "https://api.groq.com",
182 Provider::Mistral => "https://api.mistral.ai",
183 Provider::Together => "https://api.together.xyz",
184 Provider::OpenRouter => "https://openrouter.ai",
185 Provider::DeepSeek => "https://api.deepseek.com",
186 Provider::Fireworks => "https://api.fireworks.ai",
187 Provider::Gemini => "https://generativelanguage.googleapis.com",
188 Provider::Local => "http://localhost:11434",
189 }
190 }
191
192 pub fn default_port(&self) -> u16 {
194 match self {
195 Provider::Anthropic => 8082,
196 Provider::OpenAI => 8083,
197 Provider::OpenAIApi => 8084,
198 Provider::OllamaCloud => 8085,
199 Provider::Groq => 8086,
200 Provider::Mistral => 8087,
201 Provider::Together => 8088,
202 Provider::OpenRouter => 8089,
203 Provider::DeepSeek => 8090,
204 Provider::Fireworks => 8091,
205 Provider::Gemini => 8092,
206 Provider::Local => 8093,
207 }
208 }
209
210 pub fn inject_auth_headers(
215 &self,
216 headers: &mut reqwest::header::HeaderMap,
217 token: &str,
218 ) -> anyhow::Result<()> {
219 use reqwest::header::{HeaderName, HeaderValue};
220
221 if self.auth_kind() == AuthKind::None {
223 return Ok(());
224 }
225
226 headers.insert(
228 HeaderName::from_static("authorization"),
229 HeaderValue::from_str(&format!("Bearer {token}"))
230 .map_err(|_| anyhow::anyhow!("invalid access token"))?,
231 );
232
233 match self {
234 Provider::Anthropic => {
235 headers.insert(
237 HeaderName::from_static("anthropic-dangerous-direct-browser-access"),
238 HeaderValue::from_static("true"),
239 );
240
241 let beta_key = HeaderName::from_static("anthropic-beta");
244 let existing = headers
245 .get(&beta_key)
246 .and_then(|v| v.to_str().ok())
247 .unwrap_or("")
248 .to_owned();
249 let merged = if existing.split(',').any(|s| s.trim() == "oauth-2025-04-20") {
250 existing
251 } else if existing.is_empty() {
252 "oauth-2025-04-20".to_owned()
253 } else {
254 format!("{existing},oauth-2025-04-20")
255 };
256 headers.insert(beta_key, HeaderValue::from_str(&merged).unwrap());
257 }
258 Provider::OpenRouter => {
259 headers.insert(
261 HeaderName::from_static("http-referer"),
262 HeaderValue::from_static("https://github.com/shunt-proxy/shunt"),
263 );
264 }
265 _ => {}
267 }
268
269 Ok(())
270 }
271
272 pub fn prefetch_extra_headers(&self) -> &'static [(&'static str, &'static str)] {
276 match self {
277 Provider::Anthropic => &[("anthropic-version", "2023-06-01")],
278 _ => &[],
279 }
280 }
281
282 pub fn prefetch_request(&self) -> Option<(&'static str, serde_json::Value)> {
286 match self {
287 Provider::Anthropic => Some((
288 "/v1/messages",
289 serde_json::json!({
290 "model": "claude-haiku-4-5-20251001",
291 "max_tokens": 1,
292 "messages": [{"role": "user", "content": "hi"}]
293 }),
294 )),
295 _ => None,
298 }
299 }
300
301 pub fn auth_probe_get_path(&self) -> Option<&'static str> {
304 match self {
305 Provider::Anthropic => None, Provider::OpenAI => Some("/backend-api/me"),
307 Provider::OpenAIApi => Some("/v1/models"),
308 Provider::OllamaCloud => Some("/v1/models"),
309 Provider::Groq => Some("/openai/v1/models"),
310 Provider::Mistral => Some("/v1/models"),
311 Provider::Together => Some("/v1/models"),
312 Provider::OpenRouter => Some("/api/v1/models"),
313 Provider::DeepSeek => Some("/v1/models"),
314 Provider::Fireworks => Some("/v1/models"),
315 Provider::Gemini => Some("/v1beta/models"),
316 Provider::Local => None, }
318 }
319
320 pub fn parse_rate_limits(&self, headers: &HeaderMap) -> Option<RateLimitInfo> {
324 let now_ms = std::time::SystemTime::now()
325 .duration_since(std::time::UNIX_EPOCH)
326 .unwrap_or_default()
327 .as_millis() as u64;
328
329 match self {
330 Provider::Anthropic => parse_anthropic_rate_limits(headers, now_ms),
331 Provider::OpenAI
333 | Provider::OpenAIApi
334 | Provider::OllamaCloud
335 | Provider::Groq
336 | Provider::Mistral
337 | Provider::Together
338 | Provider::OpenRouter
339 | Provider::DeepSeek
340 | Provider::Fireworks => parse_openai_rate_limits(headers, now_ms),
341 Provider::Gemini | Provider::Local => None,
343 }
344 }
345
346 pub fn read_local_credentials(&self) -> Option<Credential> {
352 match self.auth_kind() {
353 AuthKind::OAuth => match self {
354 Provider::Anthropic => {
355 crate::oauth::read_claude_credentials().map(Credential::Oauth)
356 }
357 Provider::OpenAI => {
358 crate::oauth::read_codex_credentials().map(Credential::Oauth)
359 }
360 _ => None,
361 },
362 AuthKind::ApiKey => {
363 self.api_key_env_var()
365 .and_then(|var| std::env::var(var).ok())
366 .map(|key| Credential::Apikey { key })
367 }
368 AuthKind::None => None,
369 }
370 }
371
372 pub async fn refresh_token(&self, cred: &OAuthCredential) -> anyhow::Result<OAuthCredential> {
376 match self {
377 Provider::Anthropic => crate::oauth::refresh_token(cred).await,
378 Provider::OpenAI => crate::oauth::refresh_openai_token(cred).await,
379 _ => anyhow::bail!("provider {} does not support token refresh", self),
380 }
381 }
382}
383
384fn parse_anthropic_rate_limits(headers: &HeaderMap, now_ms: u64) -> Option<RateLimitInfo> {
389 fn hdr_u64(h: &HeaderMap, name: &str) -> Option<u64> {
390 h.get(name)?.to_str().ok()?.parse().ok()
391 }
392 fn hdr_f64(h: &HeaderMap, name: &str) -> Option<f64> {
393 h.get(name)?.to_str().ok()?.parse().ok()
394 }
395 fn hdr_str(h: &HeaderMap, name: &str) -> Option<String> {
396 Some(h.get(name)?.to_str().ok()?.to_owned())
397 }
398
399 let utilization_5h = hdr_f64(headers, "anthropic-ratelimit-unified-5h-utilization");
400 let utilization_7d = hdr_f64(headers, "anthropic-ratelimit-unified-7d-utilization");
401
402 if utilization_5h.is_none() && utilization_7d.is_none() {
403 return None;
404 }
405
406 Some(RateLimitInfo {
407 utilization_5h,
408 reset_5h: hdr_u64(headers, "anthropic-ratelimit-unified-5h-reset"),
409 status_5h: hdr_str(headers, "anthropic-ratelimit-unified-5h-status"),
410 utilization_7d,
411 reset_7d: hdr_u64(headers, "anthropic-ratelimit-unified-7d-reset"),
412 status_7d: hdr_str(headers, "anthropic-ratelimit-unified-7d-status"),
413 overage_status: hdr_str(headers, "anthropic-ratelimit-unified-overage-status"),
414 overage_disabled_reason: hdr_str(headers, "anthropic-ratelimit-unified-overage-disabled-reason"),
415 representative_claim: hdr_str(headers, "anthropic-ratelimit-unified-representative-claim"),
416 updated_ms: now_ms,
417 })
418}
419
420fn parse_openai_rate_limits(headers: &HeaderMap, now_ms: u64) -> Option<RateLimitInfo> {
425 fn hdr_u64(h: &HeaderMap, name: &str) -> Option<u64> {
426 h.get(name)?.to_str().ok()?.parse().ok()
427 }
428 fn hdr_str(h: &HeaderMap, name: &str) -> Option<String> {
429 Some(h.get(name)?.to_str().ok()?.to_owned())
430 }
431
432 let limit_tok = hdr_u64(headers, "x-ratelimit-limit-tokens");
434 let remaining_tok = hdr_u64(headers, "x-ratelimit-remaining-tokens");
435 let reset_tok_str = hdr_str(headers, "x-ratelimit-reset-tokens");
436
437 let utilization = match (limit_tok, remaining_tok) {
438 (Some(limit), Some(remaining)) if limit > 0 => {
439 Some(1.0_f64 - (remaining as f64 / limit as f64))
440 }
441 _ => None,
442 };
443
444 let reset_secs = reset_tok_str.as_deref().and_then(parse_openai_reset_duration);
446
447 if utilization.is_none() && reset_secs.is_none() {
448 return None;
449 }
450
451 Some(RateLimitInfo {
452 utilization_5h: utilization,
453 reset_5h: reset_secs,
454 status_5h: utilization.map(|u| if u >= 1.0 { "exhausted".into() } else { "allowed".into() }),
455 utilization_7d: None,
457 reset_7d: None,
458 status_7d: None,
459 overage_status: None,
460 overage_disabled_reason: None,
461 representative_claim: None,
462 updated_ms: now_ms,
463 })
464}
465
466fn parse_openai_reset_duration(s: &str) -> Option<u64> {
469 if s.is_empty() { return None; }
470
471 let mut total_secs: u64 = 0;
472 let mut parsed = false;
473 let mut rest = s;
474
475 if let Some(idx) = rest.find('m') {
476 let mins: u64 = rest[..idx].parse().ok()?;
477 total_secs += mins * 60;
478 rest = &rest[idx + 1..];
479 parsed = true;
480 }
481
482 if let Some(stripped) = rest.strip_suffix('s') {
483 if !stripped.is_empty() {
484 let secs: u64 = stripped.parse().ok()?;
485 total_secs += secs;
486 }
487 parsed = true;
488 } else if !rest.is_empty() {
489 return None; }
491
492 if !parsed { return None; }
493
494 let now_secs = std::time::SystemTime::now()
495 .duration_since(std::time::UNIX_EPOCH)
496 .unwrap_or_default()
497 .as_secs();
498
499 Some(now_secs + total_secs)
500}
501
502#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_provider_from_str() {
512 assert_eq!(Provider::from_str("anthropic"), Provider::Anthropic);
513 assert_eq!(Provider::from_str("ANTHROPIC"), Provider::Anthropic);
514 assert_eq!(Provider::from_str("openai"), Provider::OpenAI);
515 assert_eq!(Provider::from_str("codex"), Provider::OpenAI);
516 assert_eq!(Provider::from_str("openai-api"), Provider::OpenAIApi);
517 assert_eq!(Provider::from_str("ollama"), Provider::OllamaCloud);
518 assert_eq!(Provider::from_str("ollama-cloud"), Provider::OllamaCloud);
519 assert_eq!(Provider::from_str("groq"), Provider::Groq);
520 assert_eq!(Provider::from_str("mistral"), Provider::Mistral);
521 assert_eq!(Provider::from_str("together"), Provider::Together);
522 assert_eq!(Provider::from_str("openrouter"), Provider::OpenRouter);
523 assert_eq!(Provider::from_str("deepseek"), Provider::DeepSeek);
524 assert_eq!(Provider::from_str("fireworks"), Provider::Fireworks);
525 assert_eq!(Provider::from_str("gemini"), Provider::Gemini);
526 assert_eq!(Provider::from_str("local"), Provider::Local);
527 assert_eq!(Provider::from_str("unknown"), Provider::Anthropic);
528 }
529
530 #[test]
531 fn test_provider_display() {
532 assert_eq!(Provider::Anthropic.to_string(), "anthropic");
533 assert_eq!(Provider::OpenAI.to_string(), "openai");
534 assert_eq!(Provider::OpenAIApi.to_string(), "openai-api");
535 assert_eq!(Provider::OllamaCloud.to_string(), "ollama");
536 assert_eq!(Provider::Groq.to_string(), "groq");
537 assert_eq!(Provider::Mistral.to_string(), "mistral");
538 assert_eq!(Provider::Together.to_string(), "together");
539 assert_eq!(Provider::OpenRouter.to_string(), "openrouter");
540 assert_eq!(Provider::DeepSeek.to_string(), "deepseek");
541 assert_eq!(Provider::Fireworks.to_string(), "fireworks");
542 assert_eq!(Provider::Gemini.to_string(), "gemini");
543 assert_eq!(Provider::Local.to_string(), "local");
544 }
545
546 #[test]
547 fn test_auth_kind() {
548 assert_eq!(Provider::Anthropic.auth_kind(), AuthKind::OAuth);
549 assert_eq!(Provider::OpenAI.auth_kind(), AuthKind::OAuth);
550 assert_eq!(Provider::Local.auth_kind(), AuthKind::None);
551 assert_eq!(Provider::Groq.auth_kind(), AuthKind::ApiKey);
552 assert_eq!(Provider::OpenAIApi.auth_kind(), AuthKind::ApiKey);
553 assert_eq!(Provider::OllamaCloud.auth_kind(), AuthKind::ApiKey);
554 }
555
556 #[test]
557 fn test_wire_protocol() {
558 assert_eq!(Provider::Anthropic.wire_protocol(), WireProtocol::Anthropic);
559 assert_eq!(Provider::OpenAI.wire_protocol(), WireProtocol::OpenAICompat);
560 assert_eq!(Provider::Groq.wire_protocol(), WireProtocol::OpenAICompat);
561 assert_eq!(Provider::Local.wire_protocol(), WireProtocol::OpenAICompat);
562 }
563
564 #[test]
565 fn test_api_key_env_var() {
566 assert_eq!(Provider::Groq.api_key_env_var(), Some("GROQ_API_KEY"));
567 assert_eq!(Provider::OpenAIApi.api_key_env_var(), Some("OPENAI_API_KEY"));
568 assert_eq!(Provider::Gemini.api_key_env_var(), Some("GEMINI_API_KEY"));
569 assert_eq!(Provider::Anthropic.api_key_env_var(), None);
570 assert_eq!(Provider::Local.api_key_env_var(), None);
571 }
572
573 #[test]
574 fn test_parse_openai_reset_duration_formats() {
575 let now = std::time::SystemTime::now()
576 .duration_since(std::time::UNIX_EPOCH)
577 .unwrap()
578 .as_secs();
579
580 let r = parse_openai_reset_duration("1m30s").unwrap();
581 assert!(r >= now + 89 && r <= now + 91, "1m30s should be ~90s from now");
582
583 let r = parse_openai_reset_duration("45s").unwrap();
584 assert!(r >= now + 44 && r <= now + 46, "45s should be ~45s from now");
585
586 let r = parse_openai_reset_duration("2m").unwrap();
587 assert!(r >= now + 119 && r <= now + 121, "2m should be ~120s from now");
588
589 let r = parse_openai_reset_duration("0s").unwrap();
590 assert!(r >= now && r <= now + 1, "0s should be now");
591 }
592
593 #[test]
594 fn test_parse_openai_reset_duration_invalid() {
595 assert!(parse_openai_reset_duration("bad").is_none());
596 assert!(parse_openai_reset_duration("").is_none());
597 }
598
599 #[test]
600 fn test_openai_utilization_computation() {
601 use axum::http::HeaderMap;
602 let mut headers = HeaderMap::new();
603 headers.insert("x-ratelimit-limit-tokens", "100000".parse().unwrap());
604 headers.insert("x-ratelimit-remaining-tokens", "75000".parse().unwrap());
605 headers.insert("x-ratelimit-reset-tokens", "45s".parse().unwrap());
606
607 let info = Provider::OpenAI.parse_rate_limits(&headers).unwrap();
608 let util = info.utilization_5h.unwrap();
609 assert!((util - 0.25).abs() < 0.001, "utilization should be 0.25 (75k/100k remaining)");
610 assert_eq!(info.status_5h.as_deref(), Some("allowed"));
611 assert!(info.reset_5h.is_some());
612 }
613
614 #[test]
615 fn test_anthropic_rate_limits_absent() {
616 let headers = axum::http::HeaderMap::new();
617 assert!(Provider::Anthropic.parse_rate_limits(&headers).is_none());
618 }
619
620 #[test]
621 fn test_openai_rate_limits_absent() {
622 let headers = axum::http::HeaderMap::new();
623 assert!(Provider::OpenAI.parse_rate_limits(&headers).is_none());
624 }
625}