1use async_trait::async_trait;
11use tokio::sync::mpsc;
12
13use super::message::Message;
14use super::stream::StreamEvent;
15use crate::tools::ToolSchema;
16
17#[async_trait]
20pub trait Provider: Send + Sync {
21 fn name(&self) -> &str;
23
24 async fn stream(
26 &self,
27 request: &ProviderRequest,
28 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
29}
30
31#[derive(Debug, Clone, Default)]
33pub enum ToolChoice {
34 #[default]
36 Auto,
37 Any,
39 None,
41 Specific(String),
43}
44
45pub struct ProviderRequest {
47 pub messages: Vec<Message>,
48 pub system_prompt: String,
49 pub tools: Vec<ToolSchema>,
50 pub model: String,
51 pub max_tokens: u32,
52 pub temperature: Option<f64>,
53 pub enable_caching: bool,
54 pub tool_choice: ToolChoice,
56 pub metadata: Option<serde_json::Value>,
58}
59
60#[derive(Debug)]
62pub enum ProviderError {
63 Auth(String),
64 RateLimited { retry_after_ms: u64 },
65 Overloaded,
66 RequestTooLarge(String),
67 Network(String),
68 InvalidResponse(String),
69}
70
71impl std::fmt::Display for ProviderError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Self::Auth(msg) => write!(f, "auth: {msg}"),
75 Self::RateLimited { retry_after_ms } => {
76 write!(f, "rate limited (retry in {retry_after_ms}ms)")
77 }
78 Self::Overloaded => write!(f, "server overloaded"),
79 Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
80 Self::Network(msg) => write!(f, "network: {msg}"),
81 Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
82 }
83 }
84}
85
86pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
88 let model_lower = model.to_lowercase();
89 let url_lower = base_url.to_lowercase();
90
91 if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
93 return ProviderKind::Bedrock;
94 }
95 if url_lower.contains("aiplatform.googleapis.com") {
97 return ProviderKind::Vertex;
98 }
99 if url_lower.contains("anthropic.com") {
100 return ProviderKind::Anthropic;
101 }
102 if url_lower.contains("openai.com") {
103 return ProviderKind::OpenAi;
104 }
105 if url_lower.contains("x.ai") || url_lower.contains("xai.") {
106 return ProviderKind::Xai;
107 }
108 if url_lower.contains("googleapis.com") || url_lower.contains("google") {
109 return ProviderKind::Google;
110 }
111 if url_lower.contains("deepseek.com") {
112 return ProviderKind::DeepSeek;
113 }
114 if url_lower.contains("groq.com") {
115 return ProviderKind::Groq;
116 }
117 if url_lower.contains("mistral.ai") {
118 return ProviderKind::Mistral;
119 }
120 if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
121 return ProviderKind::Together;
122 }
123 if url_lower.contains("bigmodel.cn")
124 || url_lower.contains("z.ai")
125 || url_lower.contains("zhipu")
126 {
127 return ProviderKind::Zhipu;
128 }
129 if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
130 return ProviderKind::OpenAiCompatible;
131 }
132
133 if model_lower.starts_with("claude")
135 || model_lower.contains("opus")
136 || model_lower.contains("sonnet")
137 || model_lower.contains("haiku")
138 {
139 return ProviderKind::Anthropic;
140 }
141 if model_lower.starts_with("gpt")
142 || model_lower.starts_with("o1")
143 || model_lower.starts_with("o3")
144 {
145 return ProviderKind::OpenAi;
146 }
147 if model_lower.starts_with("grok") {
148 return ProviderKind::Xai;
149 }
150 if model_lower.starts_with("gemini") {
151 return ProviderKind::Google;
152 }
153 if model_lower.starts_with("deepseek") {
154 return ProviderKind::DeepSeek;
155 }
156 if model_lower.starts_with("llama") && url_lower.contains("groq") {
157 return ProviderKind::Groq;
158 }
159 if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
160 return ProviderKind::Mistral;
161 }
162 if model_lower.starts_with("glm") {
163 return ProviderKind::Zhipu;
164 }
165
166 ProviderKind::OpenAiCompatible
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum WireFormat {
172 Anthropic,
174 OpenAiCompatible,
176}
177
178#[derive(Debug, Clone, Copy, PartialEq, Eq)]
180pub enum ProviderKind {
181 Anthropic,
182 Bedrock,
183 Vertex,
184 OpenAi,
185 Xai,
186 Google,
187 DeepSeek,
188 Groq,
189 Mistral,
190 Together,
191 Zhipu,
192 OpenAiCompatible,
193}
194
195impl ProviderKind {
196 pub fn wire_format(&self) -> WireFormat {
198 match self {
199 Self::Anthropic | Self::Bedrock | Self::Vertex => WireFormat::Anthropic,
200 Self::OpenAi
201 | Self::Xai
202 | Self::Google
203 | Self::DeepSeek
204 | Self::Groq
205 | Self::Mistral
206 | Self::Together
207 | Self::Zhipu
208 | Self::OpenAiCompatible => WireFormat::OpenAiCompatible,
209 }
210 }
211
212 pub fn default_base_url(&self) -> Option<&str> {
216 match self {
217 Self::Anthropic => Some("https://api.anthropic.com/v1"),
218 Self::OpenAi => Some("https://api.openai.com/v1"),
219 Self::Xai => Some("https://api.x.ai/v1"),
220 Self::Google => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
221 Self::DeepSeek => Some("https://api.deepseek.com/v1"),
222 Self::Groq => Some("https://api.groq.com/openai/v1"),
223 Self::Mistral => Some("https://api.mistral.ai/v1"),
224 Self::Together => Some("https://api.together.xyz/v1"),
225 Self::Zhipu => Some("https://open.bigmodel.cn/api/paas/v4"),
226 Self::Bedrock | Self::Vertex | Self::OpenAiCompatible => None,
228 }
229 }
230
231 pub fn env_var_name(&self) -> &str {
233 match self {
234 Self::Anthropic | Self::Bedrock | Self::Vertex => "ANTHROPIC_API_KEY",
235 Self::OpenAi => "OPENAI_API_KEY",
236 Self::Xai => "XAI_API_KEY",
237 Self::Google => "GOOGLE_API_KEY",
238 Self::DeepSeek => "DEEPSEEK_API_KEY",
239 Self::Groq => "GROQ_API_KEY",
240 Self::Mistral => "MISTRAL_API_KEY",
241 Self::Together => "TOGETHER_API_KEY",
242 Self::Zhipu => "ZHIPU_API_KEY",
243 Self::OpenAiCompatible => "OPENAI_API_KEY",
244 }
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_detect_from_url_anthropic() {
254 assert!(matches!(
255 detect_provider("any", "https://api.anthropic.com/v1"),
256 ProviderKind::Anthropic
257 ));
258 }
259
260 #[test]
261 fn test_detect_from_url_openai() {
262 assert!(matches!(
263 detect_provider("any", "https://api.openai.com/v1"),
264 ProviderKind::OpenAi
265 ));
266 }
267
268 #[test]
269 fn test_detect_from_url_bedrock() {
270 assert!(matches!(
271 detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
272 ProviderKind::Bedrock
273 ));
274 }
275
276 #[test]
277 fn test_detect_from_url_vertex() {
278 assert!(matches!(
279 detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
280 ProviderKind::Vertex
281 ));
282 }
283
284 #[test]
285 fn test_detect_from_url_xai() {
286 assert!(matches!(
287 detect_provider("any", "https://api.x.ai/v1"),
288 ProviderKind::Xai
289 ));
290 }
291
292 #[test]
293 fn test_detect_from_url_deepseek() {
294 assert!(matches!(
295 detect_provider("any", "https://api.deepseek.com/v1"),
296 ProviderKind::DeepSeek
297 ));
298 }
299
300 #[test]
301 fn test_detect_from_url_groq() {
302 assert!(matches!(
303 detect_provider("any", "https://api.groq.com/openai/v1"),
304 ProviderKind::Groq
305 ));
306 }
307
308 #[test]
309 fn test_detect_from_url_mistral() {
310 assert!(matches!(
311 detect_provider("any", "https://api.mistral.ai/v1"),
312 ProviderKind::Mistral
313 ));
314 }
315
316 #[test]
317 fn test_detect_from_url_together() {
318 assert!(matches!(
319 detect_provider("any", "https://api.together.xyz/v1"),
320 ProviderKind::Together
321 ));
322 }
323
324 #[test]
325 fn test_detect_from_url_localhost() {
326 assert!(matches!(
327 detect_provider("any", "http://localhost:11434/v1"),
328 ProviderKind::OpenAiCompatible
329 ));
330 }
331
332 #[test]
333 fn test_detect_from_model_claude() {
334 assert!(matches!(
335 detect_provider("claude-sonnet-4", ""),
336 ProviderKind::Anthropic
337 ));
338 assert!(matches!(
339 detect_provider("claude-opus-4", ""),
340 ProviderKind::Anthropic
341 ));
342 }
343
344 #[test]
345 fn test_detect_from_model_gpt() {
346 assert!(matches!(
347 detect_provider("gpt-4.1-mini", ""),
348 ProviderKind::OpenAi
349 ));
350 assert!(matches!(
351 detect_provider("o3-mini", ""),
352 ProviderKind::OpenAi
353 ));
354 }
355
356 #[test]
357 fn test_detect_from_model_grok() {
358 assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
359 }
360
361 #[test]
362 fn test_detect_from_model_gemini() {
363 assert!(matches!(
364 detect_provider("gemini-2.5-flash", ""),
365 ProviderKind::Google
366 ));
367 }
368
369 #[test]
370 fn test_detect_unknown_defaults_openai_compat() {
371 assert!(matches!(
372 detect_provider("some-random-model", "https://my-server.com"),
373 ProviderKind::OpenAiCompatible
374 ));
375 }
376
377 #[test]
378 fn test_url_takes_priority_over_model() {
379 assert!(matches!(
381 detect_provider("claude-sonnet", "https://api.openai.com/v1"),
382 ProviderKind::OpenAi
383 ));
384 }
385}