Skip to main content

agent_code_lib/llm/
provider.rs

1//! LLM provider abstraction.
2//!
3//! Two wire formats cover the entire ecosystem:
4//! - Anthropic Messages API (Claude models)
5//! - OpenAI Chat Completions (GPT, plus Groq, Together, Ollama, DeepSeek, etc.)
6//!
7//! Each provider translates between our unified message types and
8//! the provider-specific JSON format for requests and SSE streams.
9
10use async_trait::async_trait;
11use tokio::sync::mpsc;
12
13use super::message::Message;
14use super::stream::StreamEvent;
15use crate::tools::ToolSchema;
16
17/// Unified provider trait. Both Anthropic and OpenAI-compatible
18/// endpoints implement this.
19#[async_trait]
20pub trait Provider: Send + Sync {
21    /// Human-readable provider name.
22    fn name(&self) -> &str;
23
24    /// Send a streaming request. Returns a channel of events.
25    async fn stream(
26        &self,
27        request: &ProviderRequest,
28    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
29}
30
31/// Tool choice mode for controlling tool usage.
32#[derive(Debug, Clone, Default)]
33pub enum ToolChoice {
34    /// Model decides whether to use tools.
35    #[default]
36    Auto,
37    /// Model must use a tool.
38    Any,
39    /// Model must not use tools.
40    None,
41    /// Model must use a specific tool.
42    Specific(String),
43}
44
45/// A provider-agnostic request.
46pub struct ProviderRequest {
47    pub messages: Vec<Message>,
48    pub system_prompt: String,
49    pub tools: Vec<ToolSchema>,
50    pub model: String,
51    pub max_tokens: u32,
52    pub temperature: Option<f64>,
53    pub enable_caching: bool,
54    /// Controls whether/how the model should use tools.
55    pub tool_choice: ToolChoice,
56    /// Metadata to send with the request (e.g., user_id for Anthropic).
57    pub metadata: Option<serde_json::Value>,
58}
59
60/// Provider-level errors.
61#[derive(Debug)]
62pub enum ProviderError {
63    Auth(String),
64    RateLimited { retry_after_ms: u64 },
65    Overloaded,
66    RequestTooLarge(String),
67    Network(String),
68    InvalidResponse(String),
69}
70
71impl std::fmt::Display for ProviderError {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            Self::Auth(msg) => write!(f, "auth: {msg}"),
75            Self::RateLimited { retry_after_ms } => {
76                write!(f, "rate limited (retry in {retry_after_ms}ms)")
77            }
78            Self::Overloaded => write!(f, "server overloaded"),
79            Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
80            Self::Network(msg) => write!(f, "network: {msg}"),
81            Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
82        }
83    }
84}
85
86/// Detect the right provider from a model name or base URL.
87pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
88    let model_lower = model.to_lowercase();
89    let url_lower = base_url.to_lowercase();
90
91    // AWS Bedrock (Claude via AWS).
92    if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
93        return ProviderKind::Bedrock;
94    }
95    // Google Vertex AI (Claude via GCP).
96    if url_lower.contains("aiplatform.googleapis.com") {
97        return ProviderKind::Vertex;
98    }
99    if url_lower.contains("anthropic.com") {
100        return ProviderKind::Anthropic;
101    }
102    // Azure OpenAI — must be checked before generic openai.com.
103    if url_lower.contains("openai.azure.com")
104        || url_lower.contains("azure.com") && url_lower.contains("openai")
105    {
106        return ProviderKind::AzureOpenAi;
107    }
108    if url_lower.contains("openai.com") {
109        return ProviderKind::OpenAi;
110    }
111    if url_lower.contains("x.ai") || url_lower.contains("xai.") {
112        return ProviderKind::Xai;
113    }
114    if url_lower.contains("googleapis.com") || url_lower.contains("google") {
115        return ProviderKind::Google;
116    }
117    if url_lower.contains("deepseek.com") {
118        return ProviderKind::DeepSeek;
119    }
120    if url_lower.contains("groq.com") {
121        return ProviderKind::Groq;
122    }
123    if url_lower.contains("mistral.ai") {
124        return ProviderKind::Mistral;
125    }
126    if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
127        return ProviderKind::Together;
128    }
129    if url_lower.contains("bigmodel.cn")
130        || url_lower.contains("z.ai")
131        || url_lower.contains("zhipu")
132    {
133        return ProviderKind::Zhipu;
134    }
135    if url_lower.contains("openrouter.ai") {
136        return ProviderKind::OpenRouter;
137    }
138    if url_lower.contains("cohere.com") || url_lower.contains("cohere.ai") {
139        return ProviderKind::Cohere;
140    }
141    if url_lower.contains("perplexity.ai") {
142        return ProviderKind::Perplexity;
143    }
144    if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
145        return ProviderKind::OpenAiCompatible;
146    }
147
148    // Detect from model name.
149    if model_lower.starts_with("claude")
150        || model_lower.contains("opus")
151        || model_lower.contains("sonnet")
152        || model_lower.contains("haiku")
153    {
154        return ProviderKind::Anthropic;
155    }
156    if model_lower.starts_with("gpt")
157        || model_lower.starts_with("o1")
158        || model_lower.starts_with("o3")
159    {
160        return ProviderKind::OpenAi;
161    }
162    if model_lower.starts_with("grok") {
163        return ProviderKind::Xai;
164    }
165    if model_lower.starts_with("gemini") {
166        return ProviderKind::Google;
167    }
168    if model_lower.starts_with("deepseek") {
169        return ProviderKind::DeepSeek;
170    }
171    if model_lower.starts_with("llama") && url_lower.contains("groq") {
172        return ProviderKind::Groq;
173    }
174    if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
175        return ProviderKind::Mistral;
176    }
177    if model_lower.starts_with("glm") {
178        return ProviderKind::Zhipu;
179    }
180    if model_lower.starts_with("command") {
181        return ProviderKind::Cohere;
182    }
183    if model_lower.starts_with("pplx") || model_lower.starts_with("sonar") {
184        return ProviderKind::Perplexity;
185    }
186
187    ProviderKind::OpenAiCompatible
188}
189
190/// The two wire formats that cover the entire LLM ecosystem.
191#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum WireFormat {
193    /// Anthropic Messages API (Claude models, Bedrock, Vertex).
194    Anthropic,
195    /// OpenAI Chat Completions (GPT, Groq, Together, Ollama, DeepSeek, etc.).
196    OpenAiCompatible,
197}
198
199/// Provider kinds.
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum ProviderKind {
202    Anthropic,
203    Bedrock,
204    Vertex,
205    OpenAi,
206    AzureOpenAi,
207    Xai,
208    Google,
209    DeepSeek,
210    Groq,
211    Mistral,
212    Together,
213    Zhipu,
214    OpenRouter,
215    Cohere,
216    Perplexity,
217    OpenAiCompatible,
218}
219
220impl ProviderKind {
221    /// Which wire format this provider uses.
222    pub fn wire_format(&self) -> WireFormat {
223        match self {
224            Self::Anthropic | Self::Bedrock | Self::Vertex => WireFormat::Anthropic,
225            Self::OpenAi
226            | Self::AzureOpenAi
227            | Self::Xai
228            | Self::Google
229            | Self::DeepSeek
230            | Self::Groq
231            | Self::Mistral
232            | Self::Together
233            | Self::Zhipu
234            | Self::OpenRouter
235            | Self::Cohere
236            | Self::Perplexity
237            | Self::OpenAiCompatible => WireFormat::OpenAiCompatible,
238        }
239    }
240
241    /// The default base URL for this provider, or `None` for providers
242    /// whose URL must come from user configuration (Bedrock, Vertex,
243    /// and generic OpenAI-compatible endpoints).
244    pub fn default_base_url(&self) -> Option<&str> {
245        match self {
246            Self::Anthropic => Some("https://api.anthropic.com/v1"),
247            Self::OpenAi => Some("https://api.openai.com/v1"),
248            Self::Xai => Some("https://api.x.ai/v1"),
249            Self::Google => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
250            Self::DeepSeek => Some("https://api.deepseek.com/v1"),
251            Self::Groq => Some("https://api.groq.com/openai/v1"),
252            Self::Mistral => Some("https://api.mistral.ai/v1"),
253            Self::Together => Some("https://api.together.xyz/v1"),
254            Self::Zhipu => Some("https://open.bigmodel.cn/api/paas/v4"),
255            Self::OpenRouter => Some("https://openrouter.ai/api/v1"),
256            Self::Cohere => Some("https://api.cohere.com/v2"),
257            Self::Perplexity => Some("https://api.perplexity.ai"),
258            // These require user-supplied URLs.
259            Self::Bedrock | Self::Vertex | Self::AzureOpenAi | Self::OpenAiCompatible => None,
260        }
261    }
262
263    /// The environment variable name conventionally used for this provider's API key.
264    pub fn env_var_name(&self) -> &str {
265        match self {
266            Self::Anthropic | Self::Bedrock | Self::Vertex => "ANTHROPIC_API_KEY",
267            Self::OpenAi => "OPENAI_API_KEY",
268            Self::AzureOpenAi => "AZURE_OPENAI_API_KEY",
269            Self::Xai => "XAI_API_KEY",
270            Self::Google => "GOOGLE_API_KEY",
271            Self::DeepSeek => "DEEPSEEK_API_KEY",
272            Self::Groq => "GROQ_API_KEY",
273            Self::Mistral => "MISTRAL_API_KEY",
274            Self::Together => "TOGETHER_API_KEY",
275            Self::Zhipu => "ZHIPU_API_KEY",
276            Self::OpenRouter => "OPENROUTER_API_KEY",
277            Self::Cohere => "COHERE_API_KEY",
278            Self::Perplexity => "PERPLEXITY_API_KEY",
279            Self::OpenAiCompatible => "OPENAI_API_KEY",
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_detect_from_url_anthropic() {
290        assert!(matches!(
291            detect_provider("any", "https://api.anthropic.com/v1"),
292            ProviderKind::Anthropic
293        ));
294    }
295
296    #[test]
297    fn test_detect_from_url_openai() {
298        assert!(matches!(
299            detect_provider("any", "https://api.openai.com/v1"),
300            ProviderKind::OpenAi
301        ));
302    }
303
304    #[test]
305    fn test_detect_from_url_bedrock() {
306        assert!(matches!(
307            detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
308            ProviderKind::Bedrock
309        ));
310    }
311
312    #[test]
313    fn test_detect_from_url_vertex() {
314        assert!(matches!(
315            detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
316            ProviderKind::Vertex
317        ));
318    }
319
320    #[test]
321    fn test_detect_from_url_azure_openai() {
322        assert!(matches!(
323            detect_provider(
324                "any",
325                "https://myresource.openai.azure.com/openai/deployments/gpt-4"
326            ),
327            ProviderKind::AzureOpenAi
328        ));
329    }
330
331    #[test]
332    fn test_detect_azure_before_generic_openai() {
333        // Azure URL contains "openai" but should match Azure, not generic OpenAI.
334        assert!(matches!(
335            detect_provider(
336                "gpt-4",
337                "https://myresource.openai.azure.com/openai/deployments/gpt-4"
338            ),
339            ProviderKind::AzureOpenAi
340        ));
341    }
342
343    #[test]
344    fn test_detect_from_url_xai() {
345        assert!(matches!(
346            detect_provider("any", "https://api.x.ai/v1"),
347            ProviderKind::Xai
348        ));
349    }
350
351    #[test]
352    fn test_detect_from_url_deepseek() {
353        assert!(matches!(
354            detect_provider("any", "https://api.deepseek.com/v1"),
355            ProviderKind::DeepSeek
356        ));
357    }
358
359    #[test]
360    fn test_detect_from_url_groq() {
361        assert!(matches!(
362            detect_provider("any", "https://api.groq.com/openai/v1"),
363            ProviderKind::Groq
364        ));
365    }
366
367    #[test]
368    fn test_detect_from_url_mistral() {
369        assert!(matches!(
370            detect_provider("any", "https://api.mistral.ai/v1"),
371            ProviderKind::Mistral
372        ));
373    }
374
375    #[test]
376    fn test_detect_from_url_together() {
377        assert!(matches!(
378            detect_provider("any", "https://api.together.xyz/v1"),
379            ProviderKind::Together
380        ));
381    }
382
383    #[test]
384    fn test_detect_from_url_cohere() {
385        assert!(matches!(
386            detect_provider("any", "https://api.cohere.com/v2"),
387            ProviderKind::Cohere
388        ));
389    }
390
391    #[test]
392    fn test_detect_from_url_perplexity() {
393        assert!(matches!(
394            detect_provider("any", "https://api.perplexity.ai"),
395            ProviderKind::Perplexity
396        ));
397    }
398
399    #[test]
400    fn test_detect_from_model_command_r() {
401        assert!(matches!(
402            detect_provider("command-r-plus", ""),
403            ProviderKind::Cohere
404        ));
405    }
406
407    #[test]
408    fn test_detect_from_model_sonar() {
409        assert!(matches!(
410            detect_provider("sonar-pro", ""),
411            ProviderKind::Perplexity
412        ));
413    }
414
415    #[test]
416    fn test_detect_from_url_openrouter() {
417        assert!(matches!(
418            detect_provider("any", "https://openrouter.ai/api/v1"),
419            ProviderKind::OpenRouter
420        ));
421    }
422
423    #[test]
424    fn test_detect_from_url_localhost() {
425        assert!(matches!(
426            detect_provider("any", "http://localhost:11434/v1"),
427            ProviderKind::OpenAiCompatible
428        ));
429    }
430
431    #[test]
432    fn test_detect_from_model_claude() {
433        assert!(matches!(
434            detect_provider("claude-sonnet-4", ""),
435            ProviderKind::Anthropic
436        ));
437        assert!(matches!(
438            detect_provider("claude-opus-4", ""),
439            ProviderKind::Anthropic
440        ));
441    }
442
443    #[test]
444    fn test_detect_from_model_gpt() {
445        assert!(matches!(
446            detect_provider("gpt-4.1-mini", ""),
447            ProviderKind::OpenAi
448        ));
449        assert!(matches!(
450            detect_provider("o3-mini", ""),
451            ProviderKind::OpenAi
452        ));
453    }
454
455    #[test]
456    fn test_detect_from_model_grok() {
457        assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
458    }
459
460    #[test]
461    fn test_detect_from_model_gemini() {
462        assert!(matches!(
463            detect_provider("gemini-2.5-flash", ""),
464            ProviderKind::Google
465        ));
466    }
467
468    #[test]
469    fn test_detect_unknown_defaults_openai_compat() {
470        assert!(matches!(
471            detect_provider("some-random-model", "https://my-server.com"),
472            ProviderKind::OpenAiCompatible
473        ));
474    }
475
476    #[test]
477    fn test_url_takes_priority_over_model() {
478        // URL says OpenAI but model says Claude — URL wins.
479        assert!(matches!(
480            detect_provider("claude-sonnet", "https://api.openai.com/v1"),
481            ProviderKind::OpenAi
482        ));
483    }
484}