1use async_trait::async_trait;
11use tokio::sync::mpsc;
12
13use super::message::Message;
14use super::stream::StreamEvent;
15use crate::tools::ToolSchema;
16
17#[async_trait]
20pub trait Provider: Send + Sync {
21 fn name(&self) -> &str;
23
24 async fn stream(
26 &self,
27 request: &ProviderRequest,
28 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
29}
30
31#[derive(Debug, Clone, Default)]
33pub enum ToolChoice {
34 #[default]
36 Auto,
37 Any,
39 None,
41 Specific(String),
43}
44
45pub struct ProviderRequest {
47 pub messages: Vec<Message>,
48 pub system_prompt: String,
49 pub tools: Vec<ToolSchema>,
50 pub model: String,
51 pub max_tokens: u32,
52 pub temperature: Option<f64>,
53 pub enable_caching: bool,
54 pub tool_choice: ToolChoice,
56 pub metadata: Option<serde_json::Value>,
58}
59
60#[derive(Debug)]
62pub enum ProviderError {
63 Auth(String),
64 RateLimited { retry_after_ms: u64 },
65 Overloaded,
66 RequestTooLarge(String),
67 Network(String),
68 InvalidResponse(String),
69}
70
71impl std::fmt::Display for ProviderError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Self::Auth(msg) => write!(f, "auth: {msg}"),
75 Self::RateLimited { retry_after_ms } => {
76 write!(f, "rate limited (retry in {retry_after_ms}ms)")
77 }
78 Self::Overloaded => write!(f, "server overloaded"),
79 Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
80 Self::Network(msg) => write!(f, "network: {msg}"),
81 Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
82 }
83 }
84}
85
86pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
88 let model_lower = model.to_lowercase();
89 let url_lower = base_url.to_lowercase();
90
91 if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
93 return ProviderKind::Bedrock;
94 }
95 if url_lower.contains("aiplatform.googleapis.com") {
97 return ProviderKind::Vertex;
98 }
99 if url_lower.contains("anthropic.com") {
100 return ProviderKind::Anthropic;
101 }
102 if url_lower.contains("openai.azure.com")
104 || url_lower.contains("azure.com") && url_lower.contains("openai")
105 {
106 return ProviderKind::AzureOpenAi;
107 }
108 if url_lower.contains("openai.com") {
109 return ProviderKind::OpenAi;
110 }
111 if url_lower.contains("x.ai") || url_lower.contains("xai.") {
112 return ProviderKind::Xai;
113 }
114 if url_lower.contains("googleapis.com") || url_lower.contains("google") {
115 return ProviderKind::Google;
116 }
117 if url_lower.contains("deepseek.com") {
118 return ProviderKind::DeepSeek;
119 }
120 if url_lower.contains("groq.com") {
121 return ProviderKind::Groq;
122 }
123 if url_lower.contains("mistral.ai") {
124 return ProviderKind::Mistral;
125 }
126 if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
127 return ProviderKind::Together;
128 }
129 if url_lower.contains("bigmodel.cn")
130 || url_lower.contains("z.ai")
131 || url_lower.contains("zhipu")
132 {
133 return ProviderKind::Zhipu;
134 }
135 if url_lower.contains("openrouter.ai") {
136 return ProviderKind::OpenRouter;
137 }
138 if url_lower.contains("cohere.com") || url_lower.contains("cohere.ai") {
139 return ProviderKind::Cohere;
140 }
141 if url_lower.contains("perplexity.ai") {
142 return ProviderKind::Perplexity;
143 }
144 if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
145 return ProviderKind::OpenAiCompatible;
146 }
147
148 if model_lower.starts_with("claude")
150 || model_lower.contains("opus")
151 || model_lower.contains("sonnet")
152 || model_lower.contains("haiku")
153 {
154 return ProviderKind::Anthropic;
155 }
156 if model_lower.starts_with("gpt")
157 || model_lower.starts_with("o1")
158 || model_lower.starts_with("o3")
159 {
160 return ProviderKind::OpenAi;
161 }
162 if model_lower.starts_with("grok") {
163 return ProviderKind::Xai;
164 }
165 if model_lower.starts_with("gemini") {
166 return ProviderKind::Google;
167 }
168 if model_lower.starts_with("deepseek") {
169 return ProviderKind::DeepSeek;
170 }
171 if model_lower.starts_with("llama") && url_lower.contains("groq") {
172 return ProviderKind::Groq;
173 }
174 if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
175 return ProviderKind::Mistral;
176 }
177 if model_lower.starts_with("glm") {
178 return ProviderKind::Zhipu;
179 }
180 if model_lower.starts_with("command") {
181 return ProviderKind::Cohere;
182 }
183 if model_lower.starts_with("pplx") || model_lower.starts_with("sonar") {
184 return ProviderKind::Perplexity;
185 }
186
187 ProviderKind::OpenAiCompatible
188}
189
190#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum WireFormat {
193 Anthropic,
195 OpenAiCompatible,
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum ProviderKind {
202 Anthropic,
203 Bedrock,
204 Vertex,
205 OpenAi,
206 AzureOpenAi,
207 Xai,
208 Google,
209 DeepSeek,
210 Groq,
211 Mistral,
212 Together,
213 Zhipu,
214 OpenRouter,
215 Cohere,
216 Perplexity,
217 OpenAiCompatible,
218}
219
220impl ProviderKind {
221 pub fn wire_format(&self) -> WireFormat {
223 match self {
224 Self::Anthropic | Self::Bedrock | Self::Vertex => WireFormat::Anthropic,
225 Self::OpenAi
226 | Self::AzureOpenAi
227 | Self::Xai
228 | Self::Google
229 | Self::DeepSeek
230 | Self::Groq
231 | Self::Mistral
232 | Self::Together
233 | Self::Zhipu
234 | Self::OpenRouter
235 | Self::Cohere
236 | Self::Perplexity
237 | Self::OpenAiCompatible => WireFormat::OpenAiCompatible,
238 }
239 }
240
241 pub fn default_base_url(&self) -> Option<&str> {
245 match self {
246 Self::Anthropic => Some("https://api.anthropic.com/v1"),
247 Self::OpenAi => Some("https://api.openai.com/v1"),
248 Self::Xai => Some("https://api.x.ai/v1"),
249 Self::Google => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
250 Self::DeepSeek => Some("https://api.deepseek.com/v1"),
251 Self::Groq => Some("https://api.groq.com/openai/v1"),
252 Self::Mistral => Some("https://api.mistral.ai/v1"),
253 Self::Together => Some("https://api.together.xyz/v1"),
254 Self::Zhipu => Some("https://open.bigmodel.cn/api/paas/v4"),
255 Self::OpenRouter => Some("https://openrouter.ai/api/v1"),
256 Self::Cohere => Some("https://api.cohere.com/v2"),
257 Self::Perplexity => Some("https://api.perplexity.ai"),
258 Self::Bedrock | Self::Vertex | Self::AzureOpenAi | Self::OpenAiCompatible => None,
260 }
261 }
262
263 pub fn env_var_name(&self) -> &str {
265 match self {
266 Self::Anthropic | Self::Bedrock | Self::Vertex => "ANTHROPIC_API_KEY",
267 Self::OpenAi => "OPENAI_API_KEY",
268 Self::AzureOpenAi => "AZURE_OPENAI_API_KEY",
269 Self::Xai => "XAI_API_KEY",
270 Self::Google => "GOOGLE_API_KEY",
271 Self::DeepSeek => "DEEPSEEK_API_KEY",
272 Self::Groq => "GROQ_API_KEY",
273 Self::Mistral => "MISTRAL_API_KEY",
274 Self::Together => "TOGETHER_API_KEY",
275 Self::Zhipu => "ZHIPU_API_KEY",
276 Self::OpenRouter => "OPENROUTER_API_KEY",
277 Self::Cohere => "COHERE_API_KEY",
278 Self::Perplexity => "PERPLEXITY_API_KEY",
279 Self::OpenAiCompatible => "OPENAI_API_KEY",
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_detect_from_url_anthropic() {
290 assert!(matches!(
291 detect_provider("any", "https://api.anthropic.com/v1"),
292 ProviderKind::Anthropic
293 ));
294 }
295
296 #[test]
297 fn test_detect_from_url_openai() {
298 assert!(matches!(
299 detect_provider("any", "https://api.openai.com/v1"),
300 ProviderKind::OpenAi
301 ));
302 }
303
304 #[test]
305 fn test_detect_from_url_bedrock() {
306 assert!(matches!(
307 detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
308 ProviderKind::Bedrock
309 ));
310 }
311
312 #[test]
313 fn test_detect_from_url_vertex() {
314 assert!(matches!(
315 detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
316 ProviderKind::Vertex
317 ));
318 }
319
320 #[test]
321 fn test_detect_from_url_azure_openai() {
322 assert!(matches!(
323 detect_provider(
324 "any",
325 "https://myresource.openai.azure.com/openai/deployments/gpt-4"
326 ),
327 ProviderKind::AzureOpenAi
328 ));
329 }
330
331 #[test]
332 fn test_detect_azure_before_generic_openai() {
333 assert!(matches!(
335 detect_provider(
336 "gpt-4",
337 "https://myresource.openai.azure.com/openai/deployments/gpt-4"
338 ),
339 ProviderKind::AzureOpenAi
340 ));
341 }
342
343 #[test]
344 fn test_detect_from_url_xai() {
345 assert!(matches!(
346 detect_provider("any", "https://api.x.ai/v1"),
347 ProviderKind::Xai
348 ));
349 }
350
351 #[test]
352 fn test_detect_from_url_deepseek() {
353 assert!(matches!(
354 detect_provider("any", "https://api.deepseek.com/v1"),
355 ProviderKind::DeepSeek
356 ));
357 }
358
359 #[test]
360 fn test_detect_from_url_groq() {
361 assert!(matches!(
362 detect_provider("any", "https://api.groq.com/openai/v1"),
363 ProviderKind::Groq
364 ));
365 }
366
367 #[test]
368 fn test_detect_from_url_mistral() {
369 assert!(matches!(
370 detect_provider("any", "https://api.mistral.ai/v1"),
371 ProviderKind::Mistral
372 ));
373 }
374
375 #[test]
376 fn test_detect_from_url_together() {
377 assert!(matches!(
378 detect_provider("any", "https://api.together.xyz/v1"),
379 ProviderKind::Together
380 ));
381 }
382
383 #[test]
384 fn test_detect_from_url_cohere() {
385 assert!(matches!(
386 detect_provider("any", "https://api.cohere.com/v2"),
387 ProviderKind::Cohere
388 ));
389 }
390
391 #[test]
392 fn test_detect_from_url_perplexity() {
393 assert!(matches!(
394 detect_provider("any", "https://api.perplexity.ai"),
395 ProviderKind::Perplexity
396 ));
397 }
398
399 #[test]
400 fn test_detect_from_model_command_r() {
401 assert!(matches!(
402 detect_provider("command-r-plus", ""),
403 ProviderKind::Cohere
404 ));
405 }
406
407 #[test]
408 fn test_detect_from_model_sonar() {
409 assert!(matches!(
410 detect_provider("sonar-pro", ""),
411 ProviderKind::Perplexity
412 ));
413 }
414
415 #[test]
416 fn test_detect_from_url_openrouter() {
417 assert!(matches!(
418 detect_provider("any", "https://openrouter.ai/api/v1"),
419 ProviderKind::OpenRouter
420 ));
421 }
422
423 #[test]
424 fn test_detect_from_url_localhost() {
425 assert!(matches!(
426 detect_provider("any", "http://localhost:11434/v1"),
427 ProviderKind::OpenAiCompatible
428 ));
429 }
430
431 #[test]
432 fn test_detect_from_model_claude() {
433 assert!(matches!(
434 detect_provider("claude-sonnet-4", ""),
435 ProviderKind::Anthropic
436 ));
437 assert!(matches!(
438 detect_provider("claude-opus-4", ""),
439 ProviderKind::Anthropic
440 ));
441 }
442
443 #[test]
444 fn test_detect_from_model_gpt() {
445 assert!(matches!(
446 detect_provider("gpt-4.1-mini", ""),
447 ProviderKind::OpenAi
448 ));
449 assert!(matches!(
450 detect_provider("o3-mini", ""),
451 ProviderKind::OpenAi
452 ));
453 }
454
455 #[test]
456 fn test_detect_from_model_grok() {
457 assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
458 }
459
460 #[test]
461 fn test_detect_from_model_gemini() {
462 assert!(matches!(
463 detect_provider("gemini-2.5-flash", ""),
464 ProviderKind::Google
465 ));
466 }
467
468 #[test]
469 fn test_detect_unknown_defaults_openai_compat() {
470 assert!(matches!(
471 detect_provider("some-random-model", "https://my-server.com"),
472 ProviderKind::OpenAiCompatible
473 ));
474 }
475
476 #[test]
477 fn test_url_takes_priority_over_model() {
478 assert!(matches!(
480 detect_provider("claude-sonnet", "https://api.openai.com/v1"),
481 ProviderKind::OpenAi
482 ));
483 }
484}