1use async_trait::async_trait;
11use tokio::sync::mpsc;
12use tokio_util::sync::CancellationToken;
13
14use super::message::Message;
15use super::stream::StreamEvent;
16use crate::tools::ToolSchema;
17
18#[async_trait]
21pub trait Provider: Send + Sync {
22 fn name(&self) -> &str;
24
25 async fn stream(
27 &self,
28 request: &ProviderRequest,
29 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
30}
31
32#[derive(Debug, Clone, Default)]
34pub enum ToolChoice {
35 #[default]
37 Auto,
38 Any,
40 None,
42 Specific(String),
44}
45
46pub struct ProviderRequest {
48 pub messages: Vec<Message>,
49 pub system_prompt: String,
50 pub tools: Vec<ToolSchema>,
51 pub model: String,
52 pub max_tokens: u32,
53 pub temperature: Option<f64>,
54 pub enable_caching: bool,
55 pub tool_choice: ToolChoice,
57 pub metadata: Option<serde_json::Value>,
59 pub cancel: CancellationToken,
66}
67
68#[derive(Debug)]
70pub enum ProviderError {
71 Auth(String),
72 RateLimited { retry_after_ms: u64 },
73 Overloaded,
74 RequestTooLarge(String),
75 Network(String),
76 InvalidResponse(String),
77}
78
79impl std::fmt::Display for ProviderError {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 match self {
82 Self::Auth(msg) => write!(f, "auth: {msg}"),
83 Self::RateLimited { retry_after_ms } => {
84 write!(f, "rate limited (retry in {retry_after_ms}ms)")
85 }
86 Self::Overloaded => write!(f, "server overloaded"),
87 Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
88 Self::Network(msg) => write!(f, "network: {msg}"),
89 Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
90 }
91 }
92}
93
94pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
96 let model_lower = model.to_lowercase();
97 let url_lower = base_url.to_lowercase();
98
99 if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
101 return ProviderKind::Bedrock;
102 }
103 if url_lower.contains("aiplatform.googleapis.com") {
105 return ProviderKind::Vertex;
106 }
107 if url_lower.contains("anthropic.com") {
108 return ProviderKind::Anthropic;
109 }
110 if url_lower.contains("openai.azure.com")
112 || url_lower.contains("azure.com") && url_lower.contains("openai")
113 {
114 return ProviderKind::AzureOpenAi;
115 }
116 if url_lower.contains("openai.com") {
117 return ProviderKind::OpenAi;
118 }
119 if url_lower.contains("x.ai") || url_lower.contains("xai.") {
120 return ProviderKind::Xai;
121 }
122 if url_lower.contains("googleapis.com") || url_lower.contains("google") {
123 return ProviderKind::Google;
124 }
125 if url_lower.contains("deepseek.com") {
126 return ProviderKind::DeepSeek;
127 }
128 if url_lower.contains("groq.com") {
129 return ProviderKind::Groq;
130 }
131 if url_lower.contains("mistral.ai") {
132 return ProviderKind::Mistral;
133 }
134 if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
135 return ProviderKind::Together;
136 }
137 if url_lower.contains("bigmodel.cn")
138 || url_lower.contains("z.ai")
139 || url_lower.contains("zhipu")
140 {
141 return ProviderKind::Zhipu;
142 }
143 if url_lower.contains("openrouter.ai") {
144 return ProviderKind::OpenRouter;
145 }
146 if url_lower.contains("cohere.com") || url_lower.contains("cohere.ai") {
147 return ProviderKind::Cohere;
148 }
149 if url_lower.contains("perplexity.ai") {
150 return ProviderKind::Perplexity;
151 }
152 if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
153 return ProviderKind::OpenAiCompatible;
154 }
155
156 if model_lower.starts_with("claude")
158 || model_lower.contains("opus")
159 || model_lower.contains("sonnet")
160 || model_lower.contains("haiku")
161 {
162 return ProviderKind::Anthropic;
163 }
164 if model_lower.starts_with("gpt")
165 || model_lower.starts_with("o1")
166 || model_lower.starts_with("o3")
167 {
168 return ProviderKind::OpenAi;
169 }
170 if model_lower.starts_with("grok") {
171 return ProviderKind::Xai;
172 }
173 if model_lower.starts_with("gemini") {
174 return ProviderKind::Google;
175 }
176 if model_lower.starts_with("deepseek") {
177 return ProviderKind::DeepSeek;
178 }
179 if model_lower.starts_with("llama") && url_lower.contains("groq") {
180 return ProviderKind::Groq;
181 }
182 if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
183 return ProviderKind::Mistral;
184 }
185 if model_lower.starts_with("glm") {
186 return ProviderKind::Zhipu;
187 }
188 if model_lower.starts_with("command") {
189 return ProviderKind::Cohere;
190 }
191 if model_lower.starts_with("pplx") || model_lower.starts_with("sonar") {
192 return ProviderKind::Perplexity;
193 }
194
195 ProviderKind::OpenAiCompatible
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq)]
200pub enum WireFormat {
201 Anthropic,
203 OpenAiCompatible,
205}
206
207#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub enum ProviderKind {
210 Anthropic,
211 Bedrock,
212 Vertex,
213 OpenAi,
214 AzureOpenAi,
215 Xai,
216 Google,
217 DeepSeek,
218 Groq,
219 Mistral,
220 Together,
221 Zhipu,
222 OpenRouter,
223 Cohere,
224 Perplexity,
225 OpenAiCompatible,
226}
227
228impl ProviderKind {
229 pub fn wire_format(&self) -> WireFormat {
231 match self {
232 Self::Anthropic | Self::Bedrock | Self::Vertex => WireFormat::Anthropic,
233 Self::OpenAi
234 | Self::AzureOpenAi
235 | Self::Xai
236 | Self::Google
237 | Self::DeepSeek
238 | Self::Groq
239 | Self::Mistral
240 | Self::Together
241 | Self::Zhipu
242 | Self::OpenRouter
243 | Self::Cohere
244 | Self::Perplexity
245 | Self::OpenAiCompatible => WireFormat::OpenAiCompatible,
246 }
247 }
248
249 pub fn default_base_url(&self) -> Option<&str> {
253 match self {
254 Self::Anthropic => Some("https://api.anthropic.com/v1"),
255 Self::OpenAi => Some("https://api.openai.com/v1"),
256 Self::Xai => Some("https://api.x.ai/v1"),
257 Self::Google => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
258 Self::DeepSeek => Some("https://api.deepseek.com/v1"),
259 Self::Groq => Some("https://api.groq.com/openai/v1"),
260 Self::Mistral => Some("https://api.mistral.ai/v1"),
261 Self::Together => Some("https://api.together.xyz/v1"),
262 Self::Zhipu => Some("https://open.bigmodel.cn/api/paas/v4"),
263 Self::OpenRouter => Some("https://openrouter.ai/api/v1"),
264 Self::Cohere => Some("https://api.cohere.com/v2"),
265 Self::Perplexity => Some("https://api.perplexity.ai"),
266 Self::Bedrock | Self::Vertex | Self::AzureOpenAi | Self::OpenAiCompatible => None,
268 }
269 }
270
271 pub fn env_var_name(&self) -> &str {
273 match self {
274 Self::Anthropic | Self::Bedrock | Self::Vertex => "ANTHROPIC_API_KEY",
275 Self::OpenAi => "OPENAI_API_KEY",
276 Self::AzureOpenAi => "AZURE_OPENAI_API_KEY",
277 Self::Xai => "XAI_API_KEY",
278 Self::Google => "GOOGLE_API_KEY",
279 Self::DeepSeek => "DEEPSEEK_API_KEY",
280 Self::Groq => "GROQ_API_KEY",
281 Self::Mistral => "MISTRAL_API_KEY",
282 Self::Together => "TOGETHER_API_KEY",
283 Self::Zhipu => "ZHIPU_API_KEY",
284 Self::OpenRouter => "OPENROUTER_API_KEY",
285 Self::Cohere => "COHERE_API_KEY",
286 Self::Perplexity => "PERPLEXITY_API_KEY",
287 Self::OpenAiCompatible => "OPENAI_API_KEY",
288 }
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_detect_from_url_anthropic() {
298 assert!(matches!(
299 detect_provider("any", "https://api.anthropic.com/v1"),
300 ProviderKind::Anthropic
301 ));
302 }
303
304 #[test]
305 fn test_detect_from_url_openai() {
306 assert!(matches!(
307 detect_provider("any", "https://api.openai.com/v1"),
308 ProviderKind::OpenAi
309 ));
310 }
311
312 #[test]
313 fn test_detect_from_url_bedrock() {
314 assert!(matches!(
315 detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
316 ProviderKind::Bedrock
317 ));
318 }
319
320 #[test]
321 fn test_detect_from_url_vertex() {
322 assert!(matches!(
323 detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
324 ProviderKind::Vertex
325 ));
326 }
327
328 #[test]
329 fn test_detect_from_url_azure_openai() {
330 assert!(matches!(
331 detect_provider(
332 "any",
333 "https://myresource.openai.azure.com/openai/deployments/gpt-4"
334 ),
335 ProviderKind::AzureOpenAi
336 ));
337 }
338
339 #[test]
340 fn test_detect_azure_before_generic_openai() {
341 assert!(matches!(
343 detect_provider(
344 "gpt-4",
345 "https://myresource.openai.azure.com/openai/deployments/gpt-4"
346 ),
347 ProviderKind::AzureOpenAi
348 ));
349 }
350
351 #[test]
352 fn test_detect_from_url_xai() {
353 assert!(matches!(
354 detect_provider("any", "https://api.x.ai/v1"),
355 ProviderKind::Xai
356 ));
357 }
358
359 #[test]
360 fn test_detect_from_url_deepseek() {
361 assert!(matches!(
362 detect_provider("any", "https://api.deepseek.com/v1"),
363 ProviderKind::DeepSeek
364 ));
365 }
366
367 #[test]
368 fn test_detect_from_url_groq() {
369 assert!(matches!(
370 detect_provider("any", "https://api.groq.com/openai/v1"),
371 ProviderKind::Groq
372 ));
373 }
374
375 #[test]
376 fn test_detect_from_url_mistral() {
377 assert!(matches!(
378 detect_provider("any", "https://api.mistral.ai/v1"),
379 ProviderKind::Mistral
380 ));
381 }
382
383 #[test]
384 fn test_detect_from_url_together() {
385 assert!(matches!(
386 detect_provider("any", "https://api.together.xyz/v1"),
387 ProviderKind::Together
388 ));
389 }
390
391 #[test]
392 fn test_detect_from_url_cohere() {
393 assert!(matches!(
394 detect_provider("any", "https://api.cohere.com/v2"),
395 ProviderKind::Cohere
396 ));
397 }
398
399 #[test]
400 fn test_detect_from_url_perplexity() {
401 assert!(matches!(
402 detect_provider("any", "https://api.perplexity.ai"),
403 ProviderKind::Perplexity
404 ));
405 }
406
407 #[test]
408 fn test_detect_from_model_command_r() {
409 assert!(matches!(
410 detect_provider("command-r-plus", ""),
411 ProviderKind::Cohere
412 ));
413 }
414
415 #[test]
416 fn test_detect_from_model_sonar() {
417 assert!(matches!(
418 detect_provider("sonar-pro", ""),
419 ProviderKind::Perplexity
420 ));
421 }
422
423 #[test]
424 fn test_detect_from_url_openrouter() {
425 assert!(matches!(
426 detect_provider("any", "https://openrouter.ai/api/v1"),
427 ProviderKind::OpenRouter
428 ));
429 }
430
431 #[test]
432 fn test_detect_from_url_localhost() {
433 assert!(matches!(
434 detect_provider("any", "http://localhost:11434/v1"),
435 ProviderKind::OpenAiCompatible
436 ));
437 }
438
439 #[test]
440 fn test_detect_from_model_claude() {
441 assert!(matches!(
442 detect_provider("claude-sonnet-4", ""),
443 ProviderKind::Anthropic
444 ));
445 assert!(matches!(
446 detect_provider("claude-opus-4", ""),
447 ProviderKind::Anthropic
448 ));
449 }
450
451 #[test]
452 fn test_detect_from_model_gpt() {
453 assert!(matches!(
454 detect_provider("gpt-4.1-mini", ""),
455 ProviderKind::OpenAi
456 ));
457 assert!(matches!(
458 detect_provider("o3-mini", ""),
459 ProviderKind::OpenAi
460 ));
461 }
462
463 #[test]
464 fn test_detect_from_model_grok() {
465 assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
466 }
467
468 #[test]
469 fn test_detect_from_model_gemini() {
470 assert!(matches!(
471 detect_provider("gemini-2.5-flash", ""),
472 ProviderKind::Google
473 ));
474 }
475
476 #[test]
477 fn test_detect_unknown_defaults_openai_compat() {
478 assert!(matches!(
479 detect_provider("some-random-model", "https://my-server.com"),
480 ProviderKind::OpenAiCompatible
481 ));
482 }
483
484 #[test]
485 fn test_url_takes_priority_over_model() {
486 assert!(matches!(
488 detect_provider("claude-sonnet", "https://api.openai.com/v1"),
489 ProviderKind::OpenAi
490 ));
491 }
492
493 #[test]
494 fn test_wire_format_anthropic_family() {
495 assert_eq!(ProviderKind::Anthropic.wire_format(), WireFormat::Anthropic);
496 assert_eq!(ProviderKind::Bedrock.wire_format(), WireFormat::Anthropic);
497 assert_eq!(ProviderKind::Vertex.wire_format(), WireFormat::Anthropic);
498 }
499
500 #[test]
501 fn test_wire_format_openai_compatible_family() {
502 let openai_compat_providers = [
503 ProviderKind::OpenAi,
504 ProviderKind::Xai,
505 ProviderKind::Google,
506 ProviderKind::DeepSeek,
507 ProviderKind::Groq,
508 ProviderKind::Mistral,
509 ProviderKind::Together,
510 ProviderKind::Zhipu,
511 ProviderKind::OpenRouter,
512 ProviderKind::Cohere,
513 ProviderKind::Perplexity,
514 ProviderKind::OpenAiCompatible,
515 ];
516 for p in openai_compat_providers {
517 assert_eq!(
518 p.wire_format(),
519 WireFormat::OpenAiCompatible,
520 "{p:?} should use OpenAiCompatible wire format"
521 );
522 }
523 }
524
525 #[test]
526 fn test_default_base_url_returns_some_for_known_providers() {
527 let providers_with_urls = [
528 ProviderKind::Anthropic,
529 ProviderKind::OpenAi,
530 ProviderKind::Xai,
531 ProviderKind::Google,
532 ProviderKind::DeepSeek,
533 ProviderKind::Groq,
534 ProviderKind::Mistral,
535 ProviderKind::Together,
536 ProviderKind::Zhipu,
537 ProviderKind::OpenRouter,
538 ProviderKind::Cohere,
539 ProviderKind::Perplexity,
540 ];
541 for p in providers_with_urls {
542 assert!(
543 p.default_base_url().is_some(),
544 "{p:?} should have a default base URL"
545 );
546 }
547 }
548
549 #[test]
550 fn test_default_base_url_returns_none_for_user_configured() {
551 assert!(ProviderKind::Bedrock.default_base_url().is_none());
552 assert!(ProviderKind::Vertex.default_base_url().is_none());
553 assert!(ProviderKind::AzureOpenAi.default_base_url().is_none());
554 assert!(ProviderKind::OpenAiCompatible.default_base_url().is_none());
555 }
556
557 #[test]
558 fn test_env_var_name_all_variants() {
559 assert_eq!(ProviderKind::Anthropic.env_var_name(), "ANTHROPIC_API_KEY");
560 assert_eq!(ProviderKind::Bedrock.env_var_name(), "ANTHROPIC_API_KEY");
561 assert_eq!(ProviderKind::Vertex.env_var_name(), "ANTHROPIC_API_KEY");
562 assert_eq!(ProviderKind::OpenAi.env_var_name(), "OPENAI_API_KEY");
563 assert_eq!(
564 ProviderKind::AzureOpenAi.env_var_name(),
565 "AZURE_OPENAI_API_KEY"
566 );
567 assert_eq!(ProviderKind::Xai.env_var_name(), "XAI_API_KEY");
568 assert_eq!(ProviderKind::Google.env_var_name(), "GOOGLE_API_KEY");
569 assert_eq!(ProviderKind::DeepSeek.env_var_name(), "DEEPSEEK_API_KEY");
570 assert_eq!(ProviderKind::Groq.env_var_name(), "GROQ_API_KEY");
571 assert_eq!(ProviderKind::Mistral.env_var_name(), "MISTRAL_API_KEY");
572 assert_eq!(ProviderKind::Together.env_var_name(), "TOGETHER_API_KEY");
573 assert_eq!(ProviderKind::Zhipu.env_var_name(), "ZHIPU_API_KEY");
574 assert_eq!(
575 ProviderKind::OpenRouter.env_var_name(),
576 "OPENROUTER_API_KEY"
577 );
578 assert_eq!(ProviderKind::Cohere.env_var_name(), "COHERE_API_KEY");
579 assert_eq!(
580 ProviderKind::Perplexity.env_var_name(),
581 "PERPLEXITY_API_KEY"
582 );
583 assert_eq!(
584 ProviderKind::OpenAiCompatible.env_var_name(),
585 "OPENAI_API_KEY"
586 );
587 }
588
589 #[test]
590 fn test_detect_from_url_zhipu_bigmodel() {
591 assert!(matches!(
592 detect_provider("any", "https://open.bigmodel.cn/api/paas/v4"),
593 ProviderKind::Zhipu
594 ));
595 }
596
597 #[test]
598 fn test_detect_from_model_deepseek_chat() {
599 assert!(matches!(
600 detect_provider("deepseek-chat", ""),
601 ProviderKind::DeepSeek
602 ));
603 }
604
605 #[test]
606 fn test_detect_from_model_mistral_large() {
607 assert!(matches!(
608 detect_provider("mistral-large", ""),
609 ProviderKind::Mistral
610 ));
611 }
612
613 #[test]
614 fn test_detect_from_model_glm4() {
615 assert!(matches!(detect_provider("glm-4", ""), ProviderKind::Zhipu));
616 }
617
618 #[test]
619 fn test_detect_from_model_llama3_with_groq_url() {
620 assert!(matches!(
621 detect_provider("llama-3", "https://api.groq.com/openai/v1"),
622 ProviderKind::Groq
623 ));
624 }
625
626 #[test]
627 fn test_detect_from_model_codestral() {
628 assert!(matches!(
629 detect_provider("codestral-latest", ""),
630 ProviderKind::Mistral
631 ));
632 }
633
634 #[test]
635 fn test_detect_from_model_pplx() {
636 assert!(matches!(
637 detect_provider("pplx-70b-online", ""),
638 ProviderKind::Perplexity
639 ));
640 }
641
642 #[test]
643 fn test_provider_error_display() {
644 let err = ProviderError::Auth("bad token".into());
645 assert_eq!(format!("{err}"), "auth: bad token");
646
647 let err = ProviderError::RateLimited {
648 retry_after_ms: 1000,
649 };
650 assert_eq!(format!("{err}"), "rate limited (retry in 1000ms)");
651
652 let err = ProviderError::Overloaded;
653 assert_eq!(format!("{err}"), "server overloaded");
654
655 let err = ProviderError::RequestTooLarge("4MB limit".into());
656 assert_eq!(format!("{err}"), "request too large: 4MB limit");
657
658 let err = ProviderError::Network("timeout".into());
659 assert_eq!(format!("{err}"), "network: timeout");
660
661 let err = ProviderError::InvalidResponse("missing field".into());
662 assert_eq!(format!("{err}"), "invalid response: missing field");
663 }
664
665 #[test]
666 fn test_tool_choice_default_is_auto() {
667 let tc = ToolChoice::default();
668 assert!(matches!(tc, ToolChoice::Auto));
669 }
670}