Skip to main content

agent_code_lib/llm/
provider.rs

1//! LLM provider abstraction.
2//!
3//! Two wire formats cover the entire ecosystem:
4//! - Anthropic Messages API (Claude models)
5//! - OpenAI Chat Completions (GPT, plus Groq, Together, Ollama, DeepSeek, etc.)
6//!
7//! Each provider translates between our unified message types and
8//! the provider-specific JSON format for requests and SSE streams.
9
10use async_trait::async_trait;
11use tokio::sync::mpsc;
12
13use super::message::Message;
14use super::stream::StreamEvent;
15use crate::tools::ToolSchema;
16
17/// Unified provider trait. Both Anthropic and OpenAI-compatible
18/// endpoints implement this.
19#[async_trait]
20pub trait Provider: Send + Sync {
21    /// Human-readable provider name.
22    fn name(&self) -> &str;
23
24    /// Send a streaming request. Returns a channel of events.
25    async fn stream(
26        &self,
27        request: &ProviderRequest,
28    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
29}
30
31/// Tool choice mode for controlling tool usage.
32#[derive(Debug, Clone, Default)]
33pub enum ToolChoice {
34    /// Model decides whether to use tools.
35    #[default]
36    Auto,
37    /// Model must use a tool.
38    Any,
39    /// Model must not use tools.
40    None,
41    /// Model must use a specific tool.
42    Specific(String),
43}
44
45/// A provider-agnostic request.
46pub struct ProviderRequest {
47    pub messages: Vec<Message>,
48    pub system_prompt: String,
49    pub tools: Vec<ToolSchema>,
50    pub model: String,
51    pub max_tokens: u32,
52    pub temperature: Option<f64>,
53    pub enable_caching: bool,
54    /// Controls whether/how the model should use tools.
55    pub tool_choice: ToolChoice,
56    /// Metadata to send with the request (e.g., user_id for Anthropic).
57    pub metadata: Option<serde_json::Value>,
58}
59
60/// Provider-level errors.
61#[derive(Debug)]
62pub enum ProviderError {
63    Auth(String),
64    RateLimited { retry_after_ms: u64 },
65    Overloaded,
66    RequestTooLarge(String),
67    Network(String),
68    InvalidResponse(String),
69}
70
71impl std::fmt::Display for ProviderError {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            Self::Auth(msg) => write!(f, "auth: {msg}"),
75            Self::RateLimited { retry_after_ms } => {
76                write!(f, "rate limited (retry in {retry_after_ms}ms)")
77            }
78            Self::Overloaded => write!(f, "server overloaded"),
79            Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
80            Self::Network(msg) => write!(f, "network: {msg}"),
81            Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
82        }
83    }
84}
85
86/// Detect the right provider from a model name or base URL.
87pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
88    let model_lower = model.to_lowercase();
89    let url_lower = base_url.to_lowercase();
90
91    // AWS Bedrock (Claude via AWS).
92    if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
93        return ProviderKind::Bedrock;
94    }
95    // Google Vertex AI (Claude via GCP).
96    if url_lower.contains("aiplatform.googleapis.com") {
97        return ProviderKind::Vertex;
98    }
99    if url_lower.contains("anthropic.com") {
100        return ProviderKind::Anthropic;
101    }
102    if url_lower.contains("openai.com") {
103        return ProviderKind::OpenAi;
104    }
105    if url_lower.contains("x.ai") || url_lower.contains("xai.") {
106        return ProviderKind::Xai;
107    }
108    if url_lower.contains("googleapis.com") || url_lower.contains("google") {
109        return ProviderKind::Google;
110    }
111    if url_lower.contains("deepseek.com") {
112        return ProviderKind::DeepSeek;
113    }
114    if url_lower.contains("groq.com") {
115        return ProviderKind::Groq;
116    }
117    if url_lower.contains("mistral.ai") {
118        return ProviderKind::Mistral;
119    }
120    if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
121        return ProviderKind::Together;
122    }
123    if url_lower.contains("bigmodel.cn")
124        || url_lower.contains("z.ai")
125        || url_lower.contains("zhipu")
126    {
127        return ProviderKind::Zhipu;
128    }
129    if url_lower.contains("openrouter.ai") {
130        return ProviderKind::OpenRouter;
131    }
132    if url_lower.contains("cohere.com") || url_lower.contains("cohere.ai") {
133        return ProviderKind::Cohere;
134    }
135    if url_lower.contains("perplexity.ai") {
136        return ProviderKind::Perplexity;
137    }
138    if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
139        return ProviderKind::OpenAiCompatible;
140    }
141
142    // Detect from model name.
143    if model_lower.starts_with("claude")
144        || model_lower.contains("opus")
145        || model_lower.contains("sonnet")
146        || model_lower.contains("haiku")
147    {
148        return ProviderKind::Anthropic;
149    }
150    if model_lower.starts_with("gpt")
151        || model_lower.starts_with("o1")
152        || model_lower.starts_with("o3")
153    {
154        return ProviderKind::OpenAi;
155    }
156    if model_lower.starts_with("grok") {
157        return ProviderKind::Xai;
158    }
159    if model_lower.starts_with("gemini") {
160        return ProviderKind::Google;
161    }
162    if model_lower.starts_with("deepseek") {
163        return ProviderKind::DeepSeek;
164    }
165    if model_lower.starts_with("llama") && url_lower.contains("groq") {
166        return ProviderKind::Groq;
167    }
168    if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
169        return ProviderKind::Mistral;
170    }
171    if model_lower.starts_with("glm") {
172        return ProviderKind::Zhipu;
173    }
174    if model_lower.starts_with("command") {
175        return ProviderKind::Cohere;
176    }
177    if model_lower.starts_with("pplx") || model_lower.starts_with("sonar") {
178        return ProviderKind::Perplexity;
179    }
180
181    ProviderKind::OpenAiCompatible
182}
183
184/// The two wire formats that cover the entire LLM ecosystem.
185#[derive(Debug, Clone, Copy, PartialEq, Eq)]
186pub enum WireFormat {
187    /// Anthropic Messages API (Claude models, Bedrock, Vertex).
188    Anthropic,
189    /// OpenAI Chat Completions (GPT, Groq, Together, Ollama, DeepSeek, etc.).
190    OpenAiCompatible,
191}
192
193/// Provider kinds.
194#[derive(Debug, Clone, Copy, PartialEq, Eq)]
195pub enum ProviderKind {
196    Anthropic,
197    Bedrock,
198    Vertex,
199    OpenAi,
200    Xai,
201    Google,
202    DeepSeek,
203    Groq,
204    Mistral,
205    Together,
206    Zhipu,
207    OpenRouter,
208    Cohere,
209    Perplexity,
210    OpenAiCompatible,
211}
212
213impl ProviderKind {
214    /// Which wire format this provider uses.
215    pub fn wire_format(&self) -> WireFormat {
216        match self {
217            Self::Anthropic | Self::Bedrock | Self::Vertex => WireFormat::Anthropic,
218            Self::OpenAi
219            | Self::Xai
220            | Self::Google
221            | Self::DeepSeek
222            | Self::Groq
223            | Self::Mistral
224            | Self::Together
225            | Self::Zhipu
226            | Self::OpenRouter
227            | Self::Cohere
228            | Self::Perplexity
229            | Self::OpenAiCompatible => WireFormat::OpenAiCompatible,
230        }
231    }
232
233    /// The default base URL for this provider, or `None` for providers
234    /// whose URL must come from user configuration (Bedrock, Vertex,
235    /// and generic OpenAI-compatible endpoints).
236    pub fn default_base_url(&self) -> Option<&str> {
237        match self {
238            Self::Anthropic => Some("https://api.anthropic.com/v1"),
239            Self::OpenAi => Some("https://api.openai.com/v1"),
240            Self::Xai => Some("https://api.x.ai/v1"),
241            Self::Google => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
242            Self::DeepSeek => Some("https://api.deepseek.com/v1"),
243            Self::Groq => Some("https://api.groq.com/openai/v1"),
244            Self::Mistral => Some("https://api.mistral.ai/v1"),
245            Self::Together => Some("https://api.together.xyz/v1"),
246            Self::Zhipu => Some("https://open.bigmodel.cn/api/paas/v4"),
247            Self::OpenRouter => Some("https://openrouter.ai/api/v1"),
248            Self::Cohere => Some("https://api.cohere.com/v2"),
249            Self::Perplexity => Some("https://api.perplexity.ai"),
250            // These require user-supplied URLs.
251            Self::Bedrock | Self::Vertex | Self::OpenAiCompatible => None,
252        }
253    }
254
255    /// The environment variable name conventionally used for this provider's API key.
256    pub fn env_var_name(&self) -> &str {
257        match self {
258            Self::Anthropic | Self::Bedrock | Self::Vertex => "ANTHROPIC_API_KEY",
259            Self::OpenAi => "OPENAI_API_KEY",
260            Self::Xai => "XAI_API_KEY",
261            Self::Google => "GOOGLE_API_KEY",
262            Self::DeepSeek => "DEEPSEEK_API_KEY",
263            Self::Groq => "GROQ_API_KEY",
264            Self::Mistral => "MISTRAL_API_KEY",
265            Self::Together => "TOGETHER_API_KEY",
266            Self::Zhipu => "ZHIPU_API_KEY",
267            Self::OpenRouter => "OPENROUTER_API_KEY",
268            Self::Cohere => "COHERE_API_KEY",
269            Self::Perplexity => "PERPLEXITY_API_KEY",
270            Self::OpenAiCompatible => "OPENAI_API_KEY",
271        }
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_detect_from_url_anthropic() {
281        assert!(matches!(
282            detect_provider("any", "https://api.anthropic.com/v1"),
283            ProviderKind::Anthropic
284        ));
285    }
286
287    #[test]
288    fn test_detect_from_url_openai() {
289        assert!(matches!(
290            detect_provider("any", "https://api.openai.com/v1"),
291            ProviderKind::OpenAi
292        ));
293    }
294
295    #[test]
296    fn test_detect_from_url_bedrock() {
297        assert!(matches!(
298            detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
299            ProviderKind::Bedrock
300        ));
301    }
302
303    #[test]
304    fn test_detect_from_url_vertex() {
305        assert!(matches!(
306            detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
307            ProviderKind::Vertex
308        ));
309    }
310
311    #[test]
312    fn test_detect_from_url_xai() {
313        assert!(matches!(
314            detect_provider("any", "https://api.x.ai/v1"),
315            ProviderKind::Xai
316        ));
317    }
318
319    #[test]
320    fn test_detect_from_url_deepseek() {
321        assert!(matches!(
322            detect_provider("any", "https://api.deepseek.com/v1"),
323            ProviderKind::DeepSeek
324        ));
325    }
326
327    #[test]
328    fn test_detect_from_url_groq() {
329        assert!(matches!(
330            detect_provider("any", "https://api.groq.com/openai/v1"),
331            ProviderKind::Groq
332        ));
333    }
334
335    #[test]
336    fn test_detect_from_url_mistral() {
337        assert!(matches!(
338            detect_provider("any", "https://api.mistral.ai/v1"),
339            ProviderKind::Mistral
340        ));
341    }
342
343    #[test]
344    fn test_detect_from_url_together() {
345        assert!(matches!(
346            detect_provider("any", "https://api.together.xyz/v1"),
347            ProviderKind::Together
348        ));
349    }
350
351    #[test]
352    fn test_detect_from_url_cohere() {
353        assert!(matches!(
354            detect_provider("any", "https://api.cohere.com/v2"),
355            ProviderKind::Cohere
356        ));
357    }
358
359    #[test]
360    fn test_detect_from_url_perplexity() {
361        assert!(matches!(
362            detect_provider("any", "https://api.perplexity.ai"),
363            ProviderKind::Perplexity
364        ));
365    }
366
367    #[test]
368    fn test_detect_from_model_command_r() {
369        assert!(matches!(
370            detect_provider("command-r-plus", ""),
371            ProviderKind::Cohere
372        ));
373    }
374
375    #[test]
376    fn test_detect_from_model_sonar() {
377        assert!(matches!(
378            detect_provider("sonar-pro", ""),
379            ProviderKind::Perplexity
380        ));
381    }
382
383    #[test]
384    fn test_detect_from_url_openrouter() {
385        assert!(matches!(
386            detect_provider("any", "https://openrouter.ai/api/v1"),
387            ProviderKind::OpenRouter
388        ));
389    }
390
391    #[test]
392    fn test_detect_from_url_localhost() {
393        assert!(matches!(
394            detect_provider("any", "http://localhost:11434/v1"),
395            ProviderKind::OpenAiCompatible
396        ));
397    }
398
399    #[test]
400    fn test_detect_from_model_claude() {
401        assert!(matches!(
402            detect_provider("claude-sonnet-4", ""),
403            ProviderKind::Anthropic
404        ));
405        assert!(matches!(
406            detect_provider("claude-opus-4", ""),
407            ProviderKind::Anthropic
408        ));
409    }
410
411    #[test]
412    fn test_detect_from_model_gpt() {
413        assert!(matches!(
414            detect_provider("gpt-4.1-mini", ""),
415            ProviderKind::OpenAi
416        ));
417        assert!(matches!(
418            detect_provider("o3-mini", ""),
419            ProviderKind::OpenAi
420        ));
421    }
422
423    #[test]
424    fn test_detect_from_model_grok() {
425        assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
426    }
427
428    #[test]
429    fn test_detect_from_model_gemini() {
430        assert!(matches!(
431            detect_provider("gemini-2.5-flash", ""),
432            ProviderKind::Google
433        ));
434    }
435
436    #[test]
437    fn test_detect_unknown_defaults_openai_compat() {
438        assert!(matches!(
439            detect_provider("some-random-model", "https://my-server.com"),
440            ProviderKind::OpenAiCompatible
441        ));
442    }
443
444    #[test]
445    fn test_url_takes_priority_over_model() {
446        // URL says OpenAI but model says Claude — URL wins.
447        assert!(matches!(
448            detect_provider("claude-sonnet", "https://api.openai.com/v1"),
449            ProviderKind::OpenAi
450        ));
451    }
452}