1use async_trait::async_trait;
11use tokio::sync::mpsc;
12
13use super::message::Message;
14use super::stream::StreamEvent;
15use crate::tools::ToolSchema;
16
17#[async_trait]
20pub trait Provider: Send + Sync {
21 fn name(&self) -> &str;
23
24 async fn stream(
26 &self,
27 request: &ProviderRequest,
28 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
29}
30
31#[derive(Debug, Clone, Default)]
33pub enum ToolChoice {
34 #[default]
36 Auto,
37 Any,
39 None,
41 Specific(String),
43}
44
45pub struct ProviderRequest {
47 pub messages: Vec<Message>,
48 pub system_prompt: String,
49 pub tools: Vec<ToolSchema>,
50 pub model: String,
51 pub max_tokens: u32,
52 pub temperature: Option<f64>,
53 pub enable_caching: bool,
54 pub tool_choice: ToolChoice,
56 pub metadata: Option<serde_json::Value>,
58}
59
60#[derive(Debug)]
62pub enum ProviderError {
63 Auth(String),
64 RateLimited { retry_after_ms: u64 },
65 Overloaded,
66 RequestTooLarge(String),
67 Network(String),
68 InvalidResponse(String),
69}
70
71impl std::fmt::Display for ProviderError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Self::Auth(msg) => write!(f, "auth: {msg}"),
75 Self::RateLimited { retry_after_ms } => {
76 write!(f, "rate limited (retry in {retry_after_ms}ms)")
77 }
78 Self::Overloaded => write!(f, "server overloaded"),
79 Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
80 Self::Network(msg) => write!(f, "network: {msg}"),
81 Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
82 }
83 }
84}
85
86pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
88 let model_lower = model.to_lowercase();
89 let url_lower = base_url.to_lowercase();
90
91 if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
93 return ProviderKind::Bedrock;
94 }
95 if url_lower.contains("aiplatform.googleapis.com") {
97 return ProviderKind::Vertex;
98 }
99 if url_lower.contains("anthropic.com") {
100 return ProviderKind::Anthropic;
101 }
102 if url_lower.contains("openai.azure.com")
104 || url_lower.contains("azure.com") && url_lower.contains("openai")
105 {
106 return ProviderKind::AzureOpenAi;
107 }
108 if url_lower.contains("openai.com") {
109 return ProviderKind::OpenAi;
110 }
111 if url_lower.contains("x.ai") || url_lower.contains("xai.") {
112 return ProviderKind::Xai;
113 }
114 if url_lower.contains("googleapis.com") || url_lower.contains("google") {
115 return ProviderKind::Google;
116 }
117 if url_lower.contains("deepseek.com") {
118 return ProviderKind::DeepSeek;
119 }
120 if url_lower.contains("groq.com") {
121 return ProviderKind::Groq;
122 }
123 if url_lower.contains("mistral.ai") {
124 return ProviderKind::Mistral;
125 }
126 if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
127 return ProviderKind::Together;
128 }
129 if url_lower.contains("bigmodel.cn")
130 || url_lower.contains("z.ai")
131 || url_lower.contains("zhipu")
132 {
133 return ProviderKind::Zhipu;
134 }
135 if url_lower.contains("openrouter.ai") {
136 return ProviderKind::OpenRouter;
137 }
138 if url_lower.contains("cohere.com") || url_lower.contains("cohere.ai") {
139 return ProviderKind::Cohere;
140 }
141 if url_lower.contains("perplexity.ai") {
142 return ProviderKind::Perplexity;
143 }
144 if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
145 return ProviderKind::OpenAiCompatible;
146 }
147
148 if model_lower.starts_with("claude")
150 || model_lower.contains("opus")
151 || model_lower.contains("sonnet")
152 || model_lower.contains("haiku")
153 {
154 return ProviderKind::Anthropic;
155 }
156 if model_lower.starts_with("gpt")
157 || model_lower.starts_with("o1")
158 || model_lower.starts_with("o3")
159 {
160 return ProviderKind::OpenAi;
161 }
162 if model_lower.starts_with("grok") {
163 return ProviderKind::Xai;
164 }
165 if model_lower.starts_with("gemini") {
166 return ProviderKind::Google;
167 }
168 if model_lower.starts_with("deepseek") {
169 return ProviderKind::DeepSeek;
170 }
171 if model_lower.starts_with("llama") && url_lower.contains("groq") {
172 return ProviderKind::Groq;
173 }
174 if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
175 return ProviderKind::Mistral;
176 }
177 if model_lower.starts_with("glm") {
178 return ProviderKind::Zhipu;
179 }
180 if model_lower.starts_with("command") {
181 return ProviderKind::Cohere;
182 }
183 if model_lower.starts_with("pplx") || model_lower.starts_with("sonar") {
184 return ProviderKind::Perplexity;
185 }
186
187 ProviderKind::OpenAiCompatible
188}
189
190#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum WireFormat {
193 Anthropic,
195 OpenAiCompatible,
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum ProviderKind {
202 Anthropic,
203 Bedrock,
204 Vertex,
205 OpenAi,
206 AzureOpenAi,
207 Xai,
208 Google,
209 DeepSeek,
210 Groq,
211 Mistral,
212 Together,
213 Zhipu,
214 OpenRouter,
215 Cohere,
216 Perplexity,
217 OpenAiCompatible,
218}
219
220impl ProviderKind {
221 pub fn wire_format(&self) -> WireFormat {
223 match self {
224 Self::Anthropic | Self::Bedrock | Self::Vertex => WireFormat::Anthropic,
225 Self::OpenAi
226 | Self::AzureOpenAi
227 | Self::Xai
228 | Self::Google
229 | Self::DeepSeek
230 | Self::Groq
231 | Self::Mistral
232 | Self::Together
233 | Self::Zhipu
234 | Self::OpenRouter
235 | Self::Cohere
236 | Self::Perplexity
237 | Self::OpenAiCompatible => WireFormat::OpenAiCompatible,
238 }
239 }
240
241 pub fn default_base_url(&self) -> Option<&str> {
245 match self {
246 Self::Anthropic => Some("https://api.anthropic.com/v1"),
247 Self::OpenAi => Some("https://api.openai.com/v1"),
248 Self::Xai => Some("https://api.x.ai/v1"),
249 Self::Google => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
250 Self::DeepSeek => Some("https://api.deepseek.com/v1"),
251 Self::Groq => Some("https://api.groq.com/openai/v1"),
252 Self::Mistral => Some("https://api.mistral.ai/v1"),
253 Self::Together => Some("https://api.together.xyz/v1"),
254 Self::Zhipu => Some("https://open.bigmodel.cn/api/paas/v4"),
255 Self::OpenRouter => Some("https://openrouter.ai/api/v1"),
256 Self::Cohere => Some("https://api.cohere.com/v2"),
257 Self::Perplexity => Some("https://api.perplexity.ai"),
258 Self::Bedrock | Self::Vertex | Self::AzureOpenAi | Self::OpenAiCompatible => None,
260 }
261 }
262
263 pub fn env_var_name(&self) -> &str {
265 match self {
266 Self::Anthropic | Self::Bedrock | Self::Vertex => "ANTHROPIC_API_KEY",
267 Self::OpenAi => "OPENAI_API_KEY",
268 Self::AzureOpenAi => "AZURE_OPENAI_API_KEY",
269 Self::Xai => "XAI_API_KEY",
270 Self::Google => "GOOGLE_API_KEY",
271 Self::DeepSeek => "DEEPSEEK_API_KEY",
272 Self::Groq => "GROQ_API_KEY",
273 Self::Mistral => "MISTRAL_API_KEY",
274 Self::Together => "TOGETHER_API_KEY",
275 Self::Zhipu => "ZHIPU_API_KEY",
276 Self::OpenRouter => "OPENROUTER_API_KEY",
277 Self::Cohere => "COHERE_API_KEY",
278 Self::Perplexity => "PERPLEXITY_API_KEY",
279 Self::OpenAiCompatible => "OPENAI_API_KEY",
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_detect_from_url_anthropic() {
290 assert!(matches!(
291 detect_provider("any", "https://api.anthropic.com/v1"),
292 ProviderKind::Anthropic
293 ));
294 }
295
296 #[test]
297 fn test_detect_from_url_openai() {
298 assert!(matches!(
299 detect_provider("any", "https://api.openai.com/v1"),
300 ProviderKind::OpenAi
301 ));
302 }
303
304 #[test]
305 fn test_detect_from_url_bedrock() {
306 assert!(matches!(
307 detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
308 ProviderKind::Bedrock
309 ));
310 }
311
312 #[test]
313 fn test_detect_from_url_vertex() {
314 assert!(matches!(
315 detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
316 ProviderKind::Vertex
317 ));
318 }
319
320 #[test]
321 fn test_detect_from_url_azure_openai() {
322 assert!(matches!(
323 detect_provider(
324 "any",
325 "https://myresource.openai.azure.com/openai/deployments/gpt-4"
326 ),
327 ProviderKind::AzureOpenAi
328 ));
329 }
330
331 #[test]
332 fn test_detect_azure_before_generic_openai() {
333 assert!(matches!(
335 detect_provider(
336 "gpt-4",
337 "https://myresource.openai.azure.com/openai/deployments/gpt-4"
338 ),
339 ProviderKind::AzureOpenAi
340 ));
341 }
342
343 #[test]
344 fn test_detect_from_url_xai() {
345 assert!(matches!(
346 detect_provider("any", "https://api.x.ai/v1"),
347 ProviderKind::Xai
348 ));
349 }
350
351 #[test]
352 fn test_detect_from_url_deepseek() {
353 assert!(matches!(
354 detect_provider("any", "https://api.deepseek.com/v1"),
355 ProviderKind::DeepSeek
356 ));
357 }
358
359 #[test]
360 fn test_detect_from_url_groq() {
361 assert!(matches!(
362 detect_provider("any", "https://api.groq.com/openai/v1"),
363 ProviderKind::Groq
364 ));
365 }
366
367 #[test]
368 fn test_detect_from_url_mistral() {
369 assert!(matches!(
370 detect_provider("any", "https://api.mistral.ai/v1"),
371 ProviderKind::Mistral
372 ));
373 }
374
375 #[test]
376 fn test_detect_from_url_together() {
377 assert!(matches!(
378 detect_provider("any", "https://api.together.xyz/v1"),
379 ProviderKind::Together
380 ));
381 }
382
383 #[test]
384 fn test_detect_from_url_cohere() {
385 assert!(matches!(
386 detect_provider("any", "https://api.cohere.com/v2"),
387 ProviderKind::Cohere
388 ));
389 }
390
391 #[test]
392 fn test_detect_from_url_perplexity() {
393 assert!(matches!(
394 detect_provider("any", "https://api.perplexity.ai"),
395 ProviderKind::Perplexity
396 ));
397 }
398
399 #[test]
400 fn test_detect_from_model_command_r() {
401 assert!(matches!(
402 detect_provider("command-r-plus", ""),
403 ProviderKind::Cohere
404 ));
405 }
406
407 #[test]
408 fn test_detect_from_model_sonar() {
409 assert!(matches!(
410 detect_provider("sonar-pro", ""),
411 ProviderKind::Perplexity
412 ));
413 }
414
415 #[test]
416 fn test_detect_from_url_openrouter() {
417 assert!(matches!(
418 detect_provider("any", "https://openrouter.ai/api/v1"),
419 ProviderKind::OpenRouter
420 ));
421 }
422
423 #[test]
424 fn test_detect_from_url_localhost() {
425 assert!(matches!(
426 detect_provider("any", "http://localhost:11434/v1"),
427 ProviderKind::OpenAiCompatible
428 ));
429 }
430
431 #[test]
432 fn test_detect_from_model_claude() {
433 assert!(matches!(
434 detect_provider("claude-sonnet-4", ""),
435 ProviderKind::Anthropic
436 ));
437 assert!(matches!(
438 detect_provider("claude-opus-4", ""),
439 ProviderKind::Anthropic
440 ));
441 }
442
443 #[test]
444 fn test_detect_from_model_gpt() {
445 assert!(matches!(
446 detect_provider("gpt-4.1-mini", ""),
447 ProviderKind::OpenAi
448 ));
449 assert!(matches!(
450 detect_provider("o3-mini", ""),
451 ProviderKind::OpenAi
452 ));
453 }
454
455 #[test]
456 fn test_detect_from_model_grok() {
457 assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
458 }
459
460 #[test]
461 fn test_detect_from_model_gemini() {
462 assert!(matches!(
463 detect_provider("gemini-2.5-flash", ""),
464 ProviderKind::Google
465 ));
466 }
467
468 #[test]
469 fn test_detect_unknown_defaults_openai_compat() {
470 assert!(matches!(
471 detect_provider("some-random-model", "https://my-server.com"),
472 ProviderKind::OpenAiCompatible
473 ));
474 }
475
476 #[test]
477 fn test_url_takes_priority_over_model() {
478 assert!(matches!(
480 detect_provider("claude-sonnet", "https://api.openai.com/v1"),
481 ProviderKind::OpenAi
482 ));
483 }
484
485 #[test]
486 fn test_wire_format_anthropic_family() {
487 assert_eq!(ProviderKind::Anthropic.wire_format(), WireFormat::Anthropic);
488 assert_eq!(ProviderKind::Bedrock.wire_format(), WireFormat::Anthropic);
489 assert_eq!(ProviderKind::Vertex.wire_format(), WireFormat::Anthropic);
490 }
491
492 #[test]
493 fn test_wire_format_openai_compatible_family() {
494 let openai_compat_providers = [
495 ProviderKind::OpenAi,
496 ProviderKind::Xai,
497 ProviderKind::Google,
498 ProviderKind::DeepSeek,
499 ProviderKind::Groq,
500 ProviderKind::Mistral,
501 ProviderKind::Together,
502 ProviderKind::Zhipu,
503 ProviderKind::OpenRouter,
504 ProviderKind::Cohere,
505 ProviderKind::Perplexity,
506 ProviderKind::OpenAiCompatible,
507 ];
508 for p in openai_compat_providers {
509 assert_eq!(
510 p.wire_format(),
511 WireFormat::OpenAiCompatible,
512 "{p:?} should use OpenAiCompatible wire format"
513 );
514 }
515 }
516
517 #[test]
518 fn test_default_base_url_returns_some_for_known_providers() {
519 let providers_with_urls = [
520 ProviderKind::Anthropic,
521 ProviderKind::OpenAi,
522 ProviderKind::Xai,
523 ProviderKind::Google,
524 ProviderKind::DeepSeek,
525 ProviderKind::Groq,
526 ProviderKind::Mistral,
527 ProviderKind::Together,
528 ProviderKind::Zhipu,
529 ProviderKind::OpenRouter,
530 ProviderKind::Cohere,
531 ProviderKind::Perplexity,
532 ];
533 for p in providers_with_urls {
534 assert!(
535 p.default_base_url().is_some(),
536 "{p:?} should have a default base URL"
537 );
538 }
539 }
540
541 #[test]
542 fn test_default_base_url_returns_none_for_user_configured() {
543 assert!(ProviderKind::Bedrock.default_base_url().is_none());
544 assert!(ProviderKind::Vertex.default_base_url().is_none());
545 assert!(ProviderKind::AzureOpenAi.default_base_url().is_none());
546 assert!(ProviderKind::OpenAiCompatible.default_base_url().is_none());
547 }
548
549 #[test]
550 fn test_env_var_name_all_variants() {
551 assert_eq!(ProviderKind::Anthropic.env_var_name(), "ANTHROPIC_API_KEY");
552 assert_eq!(ProviderKind::Bedrock.env_var_name(), "ANTHROPIC_API_KEY");
553 assert_eq!(ProviderKind::Vertex.env_var_name(), "ANTHROPIC_API_KEY");
554 assert_eq!(ProviderKind::OpenAi.env_var_name(), "OPENAI_API_KEY");
555 assert_eq!(
556 ProviderKind::AzureOpenAi.env_var_name(),
557 "AZURE_OPENAI_API_KEY"
558 );
559 assert_eq!(ProviderKind::Xai.env_var_name(), "XAI_API_KEY");
560 assert_eq!(ProviderKind::Google.env_var_name(), "GOOGLE_API_KEY");
561 assert_eq!(ProviderKind::DeepSeek.env_var_name(), "DEEPSEEK_API_KEY");
562 assert_eq!(ProviderKind::Groq.env_var_name(), "GROQ_API_KEY");
563 assert_eq!(ProviderKind::Mistral.env_var_name(), "MISTRAL_API_KEY");
564 assert_eq!(ProviderKind::Together.env_var_name(), "TOGETHER_API_KEY");
565 assert_eq!(ProviderKind::Zhipu.env_var_name(), "ZHIPU_API_KEY");
566 assert_eq!(
567 ProviderKind::OpenRouter.env_var_name(),
568 "OPENROUTER_API_KEY"
569 );
570 assert_eq!(ProviderKind::Cohere.env_var_name(), "COHERE_API_KEY");
571 assert_eq!(
572 ProviderKind::Perplexity.env_var_name(),
573 "PERPLEXITY_API_KEY"
574 );
575 assert_eq!(
576 ProviderKind::OpenAiCompatible.env_var_name(),
577 "OPENAI_API_KEY"
578 );
579 }
580
581 #[test]
582 fn test_detect_from_url_zhipu_bigmodel() {
583 assert!(matches!(
584 detect_provider("any", "https://open.bigmodel.cn/api/paas/v4"),
585 ProviderKind::Zhipu
586 ));
587 }
588
589 #[test]
590 fn test_detect_from_model_deepseek_chat() {
591 assert!(matches!(
592 detect_provider("deepseek-chat", ""),
593 ProviderKind::DeepSeek
594 ));
595 }
596
597 #[test]
598 fn test_detect_from_model_mistral_large() {
599 assert!(matches!(
600 detect_provider("mistral-large", ""),
601 ProviderKind::Mistral
602 ));
603 }
604
605 #[test]
606 fn test_detect_from_model_glm4() {
607 assert!(matches!(detect_provider("glm-4", ""), ProviderKind::Zhipu));
608 }
609
610 #[test]
611 fn test_detect_from_model_llama3_with_groq_url() {
612 assert!(matches!(
613 detect_provider("llama-3", "https://api.groq.com/openai/v1"),
614 ProviderKind::Groq
615 ));
616 }
617
618 #[test]
619 fn test_detect_from_model_codestral() {
620 assert!(matches!(
621 detect_provider("codestral-latest", ""),
622 ProviderKind::Mistral
623 ));
624 }
625
626 #[test]
627 fn test_detect_from_model_pplx() {
628 assert!(matches!(
629 detect_provider("pplx-70b-online", ""),
630 ProviderKind::Perplexity
631 ));
632 }
633
634 #[test]
635 fn test_provider_error_display() {
636 let err = ProviderError::Auth("bad token".into());
637 assert_eq!(format!("{err}"), "auth: bad token");
638
639 let err = ProviderError::RateLimited {
640 retry_after_ms: 1000,
641 };
642 assert_eq!(format!("{err}"), "rate limited (retry in 1000ms)");
643
644 let err = ProviderError::Overloaded;
645 assert_eq!(format!("{err}"), "server overloaded");
646
647 let err = ProviderError::RequestTooLarge("4MB limit".into());
648 assert_eq!(format!("{err}"), "request too large: 4MB limit");
649
650 let err = ProviderError::Network("timeout".into());
651 assert_eq!(format!("{err}"), "network: timeout");
652
653 let err = ProviderError::InvalidResponse("missing field".into());
654 assert_eq!(format!("{err}"), "invalid response: missing field");
655 }
656
657 #[test]
658 fn test_tool_choice_default_is_auto() {
659 let tc = ToolChoice::default();
660 assert!(matches!(tc, ToolChoice::Auto));
661 }
662}