1use async_trait::async_trait;
11use tokio::sync::mpsc;
12
13use super::message::Message;
14use super::stream::StreamEvent;
15use crate::tools::ToolSchema;
16
17#[async_trait]
20pub trait Provider: Send + Sync {
21 fn name(&self) -> &str;
23
24 async fn stream(
26 &self,
27 request: &ProviderRequest,
28 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
29}
30
31#[derive(Debug, Clone, Default)]
33pub enum ToolChoice {
34 #[default]
36 Auto,
37 Any,
39 None,
41 Specific(String),
43}
44
45pub struct ProviderRequest {
47 pub messages: Vec<Message>,
48 pub system_prompt: String,
49 pub tools: Vec<ToolSchema>,
50 pub model: String,
51 pub max_tokens: u32,
52 pub temperature: Option<f64>,
53 pub enable_caching: bool,
54 pub tool_choice: ToolChoice,
56 pub metadata: Option<serde_json::Value>,
58}
59
60#[derive(Debug)]
62pub enum ProviderError {
63 Auth(String),
64 RateLimited { retry_after_ms: u64 },
65 Overloaded,
66 RequestTooLarge(String),
67 Network(String),
68 InvalidResponse(String),
69}
70
71impl std::fmt::Display for ProviderError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Self::Auth(msg) => write!(f, "auth: {msg}"),
75 Self::RateLimited { retry_after_ms } => {
76 write!(f, "rate limited (retry in {retry_after_ms}ms)")
77 }
78 Self::Overloaded => write!(f, "server overloaded"),
79 Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
80 Self::Network(msg) => write!(f, "network: {msg}"),
81 Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
82 }
83 }
84}
85
86pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
88 let model_lower = model.to_lowercase();
89 let url_lower = base_url.to_lowercase();
90
91 if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
93 return ProviderKind::Bedrock;
94 }
95 if url_lower.contains("aiplatform.googleapis.com") {
97 return ProviderKind::Vertex;
98 }
99 if url_lower.contains("anthropic.com") {
100 return ProviderKind::Anthropic;
101 }
102 if url_lower.contains("openai.com") {
103 return ProviderKind::OpenAi;
104 }
105 if url_lower.contains("x.ai") || url_lower.contains("xai.") {
106 return ProviderKind::Xai;
107 }
108 if url_lower.contains("googleapis.com") || url_lower.contains("google") {
109 return ProviderKind::Google;
110 }
111 if url_lower.contains("deepseek.com") {
112 return ProviderKind::DeepSeek;
113 }
114 if url_lower.contains("groq.com") {
115 return ProviderKind::Groq;
116 }
117 if url_lower.contains("mistral.ai") {
118 return ProviderKind::Mistral;
119 }
120 if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
121 return ProviderKind::Together;
122 }
123 if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
124 return ProviderKind::OpenAiCompatible;
125 }
126
127 if model_lower.starts_with("claude")
129 || model_lower.contains("opus")
130 || model_lower.contains("sonnet")
131 || model_lower.contains("haiku")
132 {
133 return ProviderKind::Anthropic;
134 }
135 if model_lower.starts_with("gpt")
136 || model_lower.starts_with("o1")
137 || model_lower.starts_with("o3")
138 {
139 return ProviderKind::OpenAi;
140 }
141 if model_lower.starts_with("grok") {
142 return ProviderKind::Xai;
143 }
144 if model_lower.starts_with("gemini") {
145 return ProviderKind::Google;
146 }
147 if model_lower.starts_with("deepseek") {
148 return ProviderKind::DeepSeek;
149 }
150 if model_lower.starts_with("llama") && url_lower.contains("groq") {
151 return ProviderKind::Groq;
152 }
153 if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
154 return ProviderKind::Mistral;
155 }
156
157 ProviderKind::OpenAiCompatible
158}
159
160#[derive(Debug, Clone, Copy, PartialEq, Eq)]
162pub enum ProviderKind {
163 Anthropic,
164 Bedrock,
165 Vertex,
166 OpenAi,
167 Xai,
168 Google,
169 DeepSeek,
170 Groq,
171 Mistral,
172 Together,
173 OpenAiCompatible,
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_detect_from_url_anthropic() {
182 assert!(matches!(
183 detect_provider("any", "https://api.anthropic.com/v1"),
184 ProviderKind::Anthropic
185 ));
186 }
187
188 #[test]
189 fn test_detect_from_url_openai() {
190 assert!(matches!(
191 detect_provider("any", "https://api.openai.com/v1"),
192 ProviderKind::OpenAi
193 ));
194 }
195
196 #[test]
197 fn test_detect_from_url_bedrock() {
198 assert!(matches!(
199 detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
200 ProviderKind::Bedrock
201 ));
202 }
203
204 #[test]
205 fn test_detect_from_url_vertex() {
206 assert!(matches!(
207 detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
208 ProviderKind::Vertex
209 ));
210 }
211
212 #[test]
213 fn test_detect_from_url_xai() {
214 assert!(matches!(
215 detect_provider("any", "https://api.x.ai/v1"),
216 ProviderKind::Xai
217 ));
218 }
219
220 #[test]
221 fn test_detect_from_url_deepseek() {
222 assert!(matches!(
223 detect_provider("any", "https://api.deepseek.com/v1"),
224 ProviderKind::DeepSeek
225 ));
226 }
227
228 #[test]
229 fn test_detect_from_url_groq() {
230 assert!(matches!(
231 detect_provider("any", "https://api.groq.com/openai/v1"),
232 ProviderKind::Groq
233 ));
234 }
235
236 #[test]
237 fn test_detect_from_url_mistral() {
238 assert!(matches!(
239 detect_provider("any", "https://api.mistral.ai/v1"),
240 ProviderKind::Mistral
241 ));
242 }
243
244 #[test]
245 fn test_detect_from_url_together() {
246 assert!(matches!(
247 detect_provider("any", "https://api.together.xyz/v1"),
248 ProviderKind::Together
249 ));
250 }
251
252 #[test]
253 fn test_detect_from_url_localhost() {
254 assert!(matches!(
255 detect_provider("any", "http://localhost:11434/v1"),
256 ProviderKind::OpenAiCompatible
257 ));
258 }
259
260 #[test]
261 fn test_detect_from_model_claude() {
262 assert!(matches!(
263 detect_provider("claude-sonnet-4", ""),
264 ProviderKind::Anthropic
265 ));
266 assert!(matches!(
267 detect_provider("claude-opus-4", ""),
268 ProviderKind::Anthropic
269 ));
270 }
271
272 #[test]
273 fn test_detect_from_model_gpt() {
274 assert!(matches!(
275 detect_provider("gpt-4.1-mini", ""),
276 ProviderKind::OpenAi
277 ));
278 assert!(matches!(
279 detect_provider("o3-mini", ""),
280 ProviderKind::OpenAi
281 ));
282 }
283
284 #[test]
285 fn test_detect_from_model_grok() {
286 assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
287 }
288
289 #[test]
290 fn test_detect_from_model_gemini() {
291 assert!(matches!(
292 detect_provider("gemini-2.5-flash", ""),
293 ProviderKind::Google
294 ));
295 }
296
297 #[test]
298 fn test_detect_unknown_defaults_openai_compat() {
299 assert!(matches!(
300 detect_provider("some-random-model", "https://my-server.com"),
301 ProviderKind::OpenAiCompatible
302 ));
303 }
304
305 #[test]
306 fn test_url_takes_priority_over_model() {
307 assert!(matches!(
309 detect_provider("claude-sonnet", "https://api.openai.com/v1"),
310 ProviderKind::OpenAi
311 ));
312 }
313}