1use async_trait::async_trait;
11use tokio::sync::mpsc;
12
13use super::message::Message;
14use super::stream::StreamEvent;
15use crate::tools::ToolSchema;
16
17#[async_trait]
20pub trait Provider: Send + Sync {
21 fn name(&self) -> &str;
23
24 async fn stream(
26 &self,
27 request: &ProviderRequest,
28 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
29}
30
31#[derive(Debug, Clone, Default)]
33pub enum ToolChoice {
34 #[default]
36 Auto,
37 Any,
39 None,
41 Specific(String),
43}
44
45pub struct ProviderRequest {
47 pub messages: Vec<Message>,
48 pub system_prompt: String,
49 pub tools: Vec<ToolSchema>,
50 pub model: String,
51 pub max_tokens: u32,
52 pub temperature: Option<f64>,
53 pub enable_caching: bool,
54 pub tool_choice: ToolChoice,
56 pub metadata: Option<serde_json::Value>,
58}
59
60#[derive(Debug)]
62pub enum ProviderError {
63 Auth(String),
64 RateLimited { retry_after_ms: u64 },
65 Overloaded,
66 RequestTooLarge(String),
67 Network(String),
68 InvalidResponse(String),
69}
70
71impl std::fmt::Display for ProviderError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Self::Auth(msg) => write!(f, "auth: {msg}"),
75 Self::RateLimited { retry_after_ms } => {
76 write!(f, "rate limited (retry in {retry_after_ms}ms)")
77 }
78 Self::Overloaded => write!(f, "server overloaded"),
79 Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
80 Self::Network(msg) => write!(f, "network: {msg}"),
81 Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
82 }
83 }
84}
85
86pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
88 let model_lower = model.to_lowercase();
89 let url_lower = base_url.to_lowercase();
90
91 if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
93 return ProviderKind::Bedrock;
94 }
95 if url_lower.contains("aiplatform.googleapis.com") {
97 return ProviderKind::Vertex;
98 }
99 if url_lower.contains("anthropic.com") {
100 return ProviderKind::Anthropic;
101 }
102 if url_lower.contains("openai.com") {
103 return ProviderKind::OpenAi;
104 }
105 if url_lower.contains("x.ai") || url_lower.contains("xai.") {
106 return ProviderKind::Xai;
107 }
108 if url_lower.contains("googleapis.com") || url_lower.contains("google") {
109 return ProviderKind::Google;
110 }
111 if url_lower.contains("deepseek.com") {
112 return ProviderKind::DeepSeek;
113 }
114 if url_lower.contains("groq.com") {
115 return ProviderKind::Groq;
116 }
117 if url_lower.contains("mistral.ai") {
118 return ProviderKind::Mistral;
119 }
120 if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
121 return ProviderKind::Together;
122 }
123 if url_lower.contains("bigmodel.cn")
124 || url_lower.contains("z.ai")
125 || url_lower.contains("zhipu")
126 {
127 return ProviderKind::Zhipu;
128 }
129 if url_lower.contains("openrouter.ai") {
130 return ProviderKind::OpenRouter;
131 }
132 if url_lower.contains("cohere.com") || url_lower.contains("cohere.ai") {
133 return ProviderKind::Cohere;
134 }
135 if url_lower.contains("perplexity.ai") {
136 return ProviderKind::Perplexity;
137 }
138 if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
139 return ProviderKind::OpenAiCompatible;
140 }
141
142 if model_lower.starts_with("claude")
144 || model_lower.contains("opus")
145 || model_lower.contains("sonnet")
146 || model_lower.contains("haiku")
147 {
148 return ProviderKind::Anthropic;
149 }
150 if model_lower.starts_with("gpt")
151 || model_lower.starts_with("o1")
152 || model_lower.starts_with("o3")
153 {
154 return ProviderKind::OpenAi;
155 }
156 if model_lower.starts_with("grok") {
157 return ProviderKind::Xai;
158 }
159 if model_lower.starts_with("gemini") {
160 return ProviderKind::Google;
161 }
162 if model_lower.starts_with("deepseek") {
163 return ProviderKind::DeepSeek;
164 }
165 if model_lower.starts_with("llama") && url_lower.contains("groq") {
166 return ProviderKind::Groq;
167 }
168 if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
169 return ProviderKind::Mistral;
170 }
171 if model_lower.starts_with("glm") {
172 return ProviderKind::Zhipu;
173 }
174 if model_lower.starts_with("command") {
175 return ProviderKind::Cohere;
176 }
177 if model_lower.starts_with("pplx") || model_lower.starts_with("sonar") {
178 return ProviderKind::Perplexity;
179 }
180
181 ProviderKind::OpenAiCompatible
182}
183
184#[derive(Debug, Clone, Copy, PartialEq, Eq)]
186pub enum WireFormat {
187 Anthropic,
189 OpenAiCompatible,
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq)]
195pub enum ProviderKind {
196 Anthropic,
197 Bedrock,
198 Vertex,
199 OpenAi,
200 Xai,
201 Google,
202 DeepSeek,
203 Groq,
204 Mistral,
205 Together,
206 Zhipu,
207 OpenRouter,
208 Cohere,
209 Perplexity,
210 OpenAiCompatible,
211}
212
213impl ProviderKind {
214 pub fn wire_format(&self) -> WireFormat {
216 match self {
217 Self::Anthropic | Self::Bedrock | Self::Vertex => WireFormat::Anthropic,
218 Self::OpenAi
219 | Self::Xai
220 | Self::Google
221 | Self::DeepSeek
222 | Self::Groq
223 | Self::Mistral
224 | Self::Together
225 | Self::Zhipu
226 | Self::OpenRouter
227 | Self::Cohere
228 | Self::Perplexity
229 | Self::OpenAiCompatible => WireFormat::OpenAiCompatible,
230 }
231 }
232
233 pub fn default_base_url(&self) -> Option<&str> {
237 match self {
238 Self::Anthropic => Some("https://api.anthropic.com/v1"),
239 Self::OpenAi => Some("https://api.openai.com/v1"),
240 Self::Xai => Some("https://api.x.ai/v1"),
241 Self::Google => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
242 Self::DeepSeek => Some("https://api.deepseek.com/v1"),
243 Self::Groq => Some("https://api.groq.com/openai/v1"),
244 Self::Mistral => Some("https://api.mistral.ai/v1"),
245 Self::Together => Some("https://api.together.xyz/v1"),
246 Self::Zhipu => Some("https://open.bigmodel.cn/api/paas/v4"),
247 Self::OpenRouter => Some("https://openrouter.ai/api/v1"),
248 Self::Cohere => Some("https://api.cohere.com/v2"),
249 Self::Perplexity => Some("https://api.perplexity.ai"),
250 Self::Bedrock | Self::Vertex | Self::OpenAiCompatible => None,
252 }
253 }
254
255 pub fn env_var_name(&self) -> &str {
257 match self {
258 Self::Anthropic | Self::Bedrock | Self::Vertex => "ANTHROPIC_API_KEY",
259 Self::OpenAi => "OPENAI_API_KEY",
260 Self::Xai => "XAI_API_KEY",
261 Self::Google => "GOOGLE_API_KEY",
262 Self::DeepSeek => "DEEPSEEK_API_KEY",
263 Self::Groq => "GROQ_API_KEY",
264 Self::Mistral => "MISTRAL_API_KEY",
265 Self::Together => "TOGETHER_API_KEY",
266 Self::Zhipu => "ZHIPU_API_KEY",
267 Self::OpenRouter => "OPENROUTER_API_KEY",
268 Self::Cohere => "COHERE_API_KEY",
269 Self::Perplexity => "PERPLEXITY_API_KEY",
270 Self::OpenAiCompatible => "OPENAI_API_KEY",
271 }
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_detect_from_url_anthropic() {
281 assert!(matches!(
282 detect_provider("any", "https://api.anthropic.com/v1"),
283 ProviderKind::Anthropic
284 ));
285 }
286
287 #[test]
288 fn test_detect_from_url_openai() {
289 assert!(matches!(
290 detect_provider("any", "https://api.openai.com/v1"),
291 ProviderKind::OpenAi
292 ));
293 }
294
295 #[test]
296 fn test_detect_from_url_bedrock() {
297 assert!(matches!(
298 detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
299 ProviderKind::Bedrock
300 ));
301 }
302
303 #[test]
304 fn test_detect_from_url_vertex() {
305 assert!(matches!(
306 detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
307 ProviderKind::Vertex
308 ));
309 }
310
311 #[test]
312 fn test_detect_from_url_xai() {
313 assert!(matches!(
314 detect_provider("any", "https://api.x.ai/v1"),
315 ProviderKind::Xai
316 ));
317 }
318
319 #[test]
320 fn test_detect_from_url_deepseek() {
321 assert!(matches!(
322 detect_provider("any", "https://api.deepseek.com/v1"),
323 ProviderKind::DeepSeek
324 ));
325 }
326
327 #[test]
328 fn test_detect_from_url_groq() {
329 assert!(matches!(
330 detect_provider("any", "https://api.groq.com/openai/v1"),
331 ProviderKind::Groq
332 ));
333 }
334
335 #[test]
336 fn test_detect_from_url_mistral() {
337 assert!(matches!(
338 detect_provider("any", "https://api.mistral.ai/v1"),
339 ProviderKind::Mistral
340 ));
341 }
342
343 #[test]
344 fn test_detect_from_url_together() {
345 assert!(matches!(
346 detect_provider("any", "https://api.together.xyz/v1"),
347 ProviderKind::Together
348 ));
349 }
350
351 #[test]
352 fn test_detect_from_url_cohere() {
353 assert!(matches!(
354 detect_provider("any", "https://api.cohere.com/v2"),
355 ProviderKind::Cohere
356 ));
357 }
358
359 #[test]
360 fn test_detect_from_url_perplexity() {
361 assert!(matches!(
362 detect_provider("any", "https://api.perplexity.ai"),
363 ProviderKind::Perplexity
364 ));
365 }
366
367 #[test]
368 fn test_detect_from_model_command_r() {
369 assert!(matches!(
370 detect_provider("command-r-plus", ""),
371 ProviderKind::Cohere
372 ));
373 }
374
375 #[test]
376 fn test_detect_from_model_sonar() {
377 assert!(matches!(
378 detect_provider("sonar-pro", ""),
379 ProviderKind::Perplexity
380 ));
381 }
382
383 #[test]
384 fn test_detect_from_url_openrouter() {
385 assert!(matches!(
386 detect_provider("any", "https://openrouter.ai/api/v1"),
387 ProviderKind::OpenRouter
388 ));
389 }
390
391 #[test]
392 fn test_detect_from_url_localhost() {
393 assert!(matches!(
394 detect_provider("any", "http://localhost:11434/v1"),
395 ProviderKind::OpenAiCompatible
396 ));
397 }
398
399 #[test]
400 fn test_detect_from_model_claude() {
401 assert!(matches!(
402 detect_provider("claude-sonnet-4", ""),
403 ProviderKind::Anthropic
404 ));
405 assert!(matches!(
406 detect_provider("claude-opus-4", ""),
407 ProviderKind::Anthropic
408 ));
409 }
410
411 #[test]
412 fn test_detect_from_model_gpt() {
413 assert!(matches!(
414 detect_provider("gpt-4.1-mini", ""),
415 ProviderKind::OpenAi
416 ));
417 assert!(matches!(
418 detect_provider("o3-mini", ""),
419 ProviderKind::OpenAi
420 ));
421 }
422
423 #[test]
424 fn test_detect_from_model_grok() {
425 assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
426 }
427
428 #[test]
429 fn test_detect_from_model_gemini() {
430 assert!(matches!(
431 detect_provider("gemini-2.5-flash", ""),
432 ProviderKind::Google
433 ));
434 }
435
436 #[test]
437 fn test_detect_unknown_defaults_openai_compat() {
438 assert!(matches!(
439 detect_provider("some-random-model", "https://my-server.com"),
440 ProviderKind::OpenAiCompatible
441 ));
442 }
443
444 #[test]
445 fn test_url_takes_priority_over_model() {
446 assert!(matches!(
448 detect_provider("claude-sonnet", "https://api.openai.com/v1"),
449 ProviderKind::OpenAi
450 ));
451 }
452}