Skip to main content

agent_code_lib/llm/
provider.rs

1//! LLM provider abstraction.
2//!
3//! Two wire formats cover the entire ecosystem:
4//! - Anthropic Messages API (Claude models)
5//! - OpenAI Chat Completions (GPT, plus Groq, Together, Ollama, DeepSeek, etc.)
6//!
7//! Each provider translates between our unified message types and
8//! the provider-specific JSON format for requests and SSE streams.
9
10use async_trait::async_trait;
11use tokio::sync::mpsc;
12
13use super::message::Message;
14use super::stream::StreamEvent;
15use crate::tools::ToolSchema;
16
17/// Unified provider trait. Both Anthropic and OpenAI-compatible
18/// endpoints implement this.
19#[async_trait]
20pub trait Provider: Send + Sync {
21    /// Human-readable provider name.
22    fn name(&self) -> &str;
23
24    /// Send a streaming request. Returns a channel of events.
25    async fn stream(
26        &self,
27        request: &ProviderRequest,
28    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
29}
30
31/// Tool choice mode for controlling tool usage.
32#[derive(Debug, Clone, Default)]
33pub enum ToolChoice {
34    /// Model decides whether to use tools.
35    #[default]
36    Auto,
37    /// Model must use a tool.
38    Any,
39    /// Model must not use tools.
40    None,
41    /// Model must use a specific tool.
42    Specific(String),
43}
44
45/// A provider-agnostic request.
46pub struct ProviderRequest {
47    pub messages: Vec<Message>,
48    pub system_prompt: String,
49    pub tools: Vec<ToolSchema>,
50    pub model: String,
51    pub max_tokens: u32,
52    pub temperature: Option<f64>,
53    pub enable_caching: bool,
54    /// Controls whether/how the model should use tools.
55    pub tool_choice: ToolChoice,
56    /// Metadata to send with the request (e.g., user_id for Anthropic).
57    pub metadata: Option<serde_json::Value>,
58}
59
60/// Provider-level errors.
61#[derive(Debug)]
62pub enum ProviderError {
63    Auth(String),
64    RateLimited { retry_after_ms: u64 },
65    Overloaded,
66    RequestTooLarge(String),
67    Network(String),
68    InvalidResponse(String),
69}
70
71impl std::fmt::Display for ProviderError {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            Self::Auth(msg) => write!(f, "auth: {msg}"),
75            Self::RateLimited { retry_after_ms } => {
76                write!(f, "rate limited (retry in {retry_after_ms}ms)")
77            }
78            Self::Overloaded => write!(f, "server overloaded"),
79            Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
80            Self::Network(msg) => write!(f, "network: {msg}"),
81            Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
82        }
83    }
84}
85
86/// Detect the right provider from a model name or base URL.
87pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
88    let model_lower = model.to_lowercase();
89    let url_lower = base_url.to_lowercase();
90
91    // AWS Bedrock (Claude via AWS).
92    if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
93        return ProviderKind::Bedrock;
94    }
95    // Google Vertex AI (Claude via GCP).
96    if url_lower.contains("aiplatform.googleapis.com") {
97        return ProviderKind::Vertex;
98    }
99    if url_lower.contains("anthropic.com") {
100        return ProviderKind::Anthropic;
101    }
102    // Azure OpenAI — must be checked before generic openai.com.
103    if url_lower.contains("openai.azure.com")
104        || url_lower.contains("azure.com") && url_lower.contains("openai")
105    {
106        return ProviderKind::AzureOpenAi;
107    }
108    if url_lower.contains("openai.com") {
109        return ProviderKind::OpenAi;
110    }
111    if url_lower.contains("x.ai") || url_lower.contains("xai.") {
112        return ProviderKind::Xai;
113    }
114    if url_lower.contains("googleapis.com") || url_lower.contains("google") {
115        return ProviderKind::Google;
116    }
117    if url_lower.contains("deepseek.com") {
118        return ProviderKind::DeepSeek;
119    }
120    if url_lower.contains("groq.com") {
121        return ProviderKind::Groq;
122    }
123    if url_lower.contains("mistral.ai") {
124        return ProviderKind::Mistral;
125    }
126    if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
127        return ProviderKind::Together;
128    }
129    if url_lower.contains("bigmodel.cn")
130        || url_lower.contains("z.ai")
131        || url_lower.contains("zhipu")
132    {
133        return ProviderKind::Zhipu;
134    }
135    if url_lower.contains("openrouter.ai") {
136        return ProviderKind::OpenRouter;
137    }
138    if url_lower.contains("cohere.com") || url_lower.contains("cohere.ai") {
139        return ProviderKind::Cohere;
140    }
141    if url_lower.contains("perplexity.ai") {
142        return ProviderKind::Perplexity;
143    }
144    if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
145        return ProviderKind::OpenAiCompatible;
146    }
147
148    // Detect from model name.
149    if model_lower.starts_with("claude")
150        || model_lower.contains("opus")
151        || model_lower.contains("sonnet")
152        || model_lower.contains("haiku")
153    {
154        return ProviderKind::Anthropic;
155    }
156    if model_lower.starts_with("gpt")
157        || model_lower.starts_with("o1")
158        || model_lower.starts_with("o3")
159    {
160        return ProviderKind::OpenAi;
161    }
162    if model_lower.starts_with("grok") {
163        return ProviderKind::Xai;
164    }
165    if model_lower.starts_with("gemini") {
166        return ProviderKind::Google;
167    }
168    if model_lower.starts_with("deepseek") {
169        return ProviderKind::DeepSeek;
170    }
171    if model_lower.starts_with("llama") && url_lower.contains("groq") {
172        return ProviderKind::Groq;
173    }
174    if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
175        return ProviderKind::Mistral;
176    }
177    if model_lower.starts_with("glm") {
178        return ProviderKind::Zhipu;
179    }
180    if model_lower.starts_with("command") {
181        return ProviderKind::Cohere;
182    }
183    if model_lower.starts_with("pplx") || model_lower.starts_with("sonar") {
184        return ProviderKind::Perplexity;
185    }
186
187    ProviderKind::OpenAiCompatible
188}
189
190/// The two wire formats that cover the entire LLM ecosystem.
191#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum WireFormat {
193    /// Anthropic Messages API (Claude models, Bedrock, Vertex).
194    Anthropic,
195    /// OpenAI Chat Completions (GPT, Groq, Together, Ollama, DeepSeek, etc.).
196    OpenAiCompatible,
197}
198
199/// Provider kinds.
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum ProviderKind {
202    Anthropic,
203    Bedrock,
204    Vertex,
205    OpenAi,
206    AzureOpenAi,
207    Xai,
208    Google,
209    DeepSeek,
210    Groq,
211    Mistral,
212    Together,
213    Zhipu,
214    OpenRouter,
215    Cohere,
216    Perplexity,
217    OpenAiCompatible,
218}
219
220impl ProviderKind {
221    /// Which wire format this provider uses.
222    pub fn wire_format(&self) -> WireFormat {
223        match self {
224            Self::Anthropic | Self::Bedrock | Self::Vertex => WireFormat::Anthropic,
225            Self::OpenAi
226            | Self::AzureOpenAi
227            | Self::Xai
228            | Self::Google
229            | Self::DeepSeek
230            | Self::Groq
231            | Self::Mistral
232            | Self::Together
233            | Self::Zhipu
234            | Self::OpenRouter
235            | Self::Cohere
236            | Self::Perplexity
237            | Self::OpenAiCompatible => WireFormat::OpenAiCompatible,
238        }
239    }
240
241    /// The default base URL for this provider, or `None` for providers
242    /// whose URL must come from user configuration (Bedrock, Vertex,
243    /// and generic OpenAI-compatible endpoints).
244    pub fn default_base_url(&self) -> Option<&str> {
245        match self {
246            Self::Anthropic => Some("https://api.anthropic.com/v1"),
247            Self::OpenAi => Some("https://api.openai.com/v1"),
248            Self::Xai => Some("https://api.x.ai/v1"),
249            Self::Google => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
250            Self::DeepSeek => Some("https://api.deepseek.com/v1"),
251            Self::Groq => Some("https://api.groq.com/openai/v1"),
252            Self::Mistral => Some("https://api.mistral.ai/v1"),
253            Self::Together => Some("https://api.together.xyz/v1"),
254            Self::Zhipu => Some("https://open.bigmodel.cn/api/paas/v4"),
255            Self::OpenRouter => Some("https://openrouter.ai/api/v1"),
256            Self::Cohere => Some("https://api.cohere.com/v2"),
257            Self::Perplexity => Some("https://api.perplexity.ai"),
258            // These require user-supplied URLs.
259            Self::Bedrock | Self::Vertex | Self::AzureOpenAi | Self::OpenAiCompatible => None,
260        }
261    }
262
263    /// The environment variable name conventionally used for this provider's API key.
264    pub fn env_var_name(&self) -> &str {
265        match self {
266            Self::Anthropic | Self::Bedrock | Self::Vertex => "ANTHROPIC_API_KEY",
267            Self::OpenAi => "OPENAI_API_KEY",
268            Self::AzureOpenAi => "AZURE_OPENAI_API_KEY",
269            Self::Xai => "XAI_API_KEY",
270            Self::Google => "GOOGLE_API_KEY",
271            Self::DeepSeek => "DEEPSEEK_API_KEY",
272            Self::Groq => "GROQ_API_KEY",
273            Self::Mistral => "MISTRAL_API_KEY",
274            Self::Together => "TOGETHER_API_KEY",
275            Self::Zhipu => "ZHIPU_API_KEY",
276            Self::OpenRouter => "OPENROUTER_API_KEY",
277            Self::Cohere => "COHERE_API_KEY",
278            Self::Perplexity => "PERPLEXITY_API_KEY",
279            Self::OpenAiCompatible => "OPENAI_API_KEY",
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_detect_from_url_anthropic() {
290        assert!(matches!(
291            detect_provider("any", "https://api.anthropic.com/v1"),
292            ProviderKind::Anthropic
293        ));
294    }
295
296    #[test]
297    fn test_detect_from_url_openai() {
298        assert!(matches!(
299            detect_provider("any", "https://api.openai.com/v1"),
300            ProviderKind::OpenAi
301        ));
302    }
303
304    #[test]
305    fn test_detect_from_url_bedrock() {
306        assert!(matches!(
307            detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
308            ProviderKind::Bedrock
309        ));
310    }
311
312    #[test]
313    fn test_detect_from_url_vertex() {
314        assert!(matches!(
315            detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
316            ProviderKind::Vertex
317        ));
318    }
319
320    #[test]
321    fn test_detect_from_url_azure_openai() {
322        assert!(matches!(
323            detect_provider(
324                "any",
325                "https://myresource.openai.azure.com/openai/deployments/gpt-4"
326            ),
327            ProviderKind::AzureOpenAi
328        ));
329    }
330
331    #[test]
332    fn test_detect_azure_before_generic_openai() {
333        // Azure URL contains "openai" but should match Azure, not generic OpenAI.
334        assert!(matches!(
335            detect_provider(
336                "gpt-4",
337                "https://myresource.openai.azure.com/openai/deployments/gpt-4"
338            ),
339            ProviderKind::AzureOpenAi
340        ));
341    }
342
343    #[test]
344    fn test_detect_from_url_xai() {
345        assert!(matches!(
346            detect_provider("any", "https://api.x.ai/v1"),
347            ProviderKind::Xai
348        ));
349    }
350
351    #[test]
352    fn test_detect_from_url_deepseek() {
353        assert!(matches!(
354            detect_provider("any", "https://api.deepseek.com/v1"),
355            ProviderKind::DeepSeek
356        ));
357    }
358
359    #[test]
360    fn test_detect_from_url_groq() {
361        assert!(matches!(
362            detect_provider("any", "https://api.groq.com/openai/v1"),
363            ProviderKind::Groq
364        ));
365    }
366
367    #[test]
368    fn test_detect_from_url_mistral() {
369        assert!(matches!(
370            detect_provider("any", "https://api.mistral.ai/v1"),
371            ProviderKind::Mistral
372        ));
373    }
374
375    #[test]
376    fn test_detect_from_url_together() {
377        assert!(matches!(
378            detect_provider("any", "https://api.together.xyz/v1"),
379            ProviderKind::Together
380        ));
381    }
382
383    #[test]
384    fn test_detect_from_url_cohere() {
385        assert!(matches!(
386            detect_provider("any", "https://api.cohere.com/v2"),
387            ProviderKind::Cohere
388        ));
389    }
390
391    #[test]
392    fn test_detect_from_url_perplexity() {
393        assert!(matches!(
394            detect_provider("any", "https://api.perplexity.ai"),
395            ProviderKind::Perplexity
396        ));
397    }
398
399    #[test]
400    fn test_detect_from_model_command_r() {
401        assert!(matches!(
402            detect_provider("command-r-plus", ""),
403            ProviderKind::Cohere
404        ));
405    }
406
407    #[test]
408    fn test_detect_from_model_sonar() {
409        assert!(matches!(
410            detect_provider("sonar-pro", ""),
411            ProviderKind::Perplexity
412        ));
413    }
414
415    #[test]
416    fn test_detect_from_url_openrouter() {
417        assert!(matches!(
418            detect_provider("any", "https://openrouter.ai/api/v1"),
419            ProviderKind::OpenRouter
420        ));
421    }
422
423    #[test]
424    fn test_detect_from_url_localhost() {
425        assert!(matches!(
426            detect_provider("any", "http://localhost:11434/v1"),
427            ProviderKind::OpenAiCompatible
428        ));
429    }
430
431    #[test]
432    fn test_detect_from_model_claude() {
433        assert!(matches!(
434            detect_provider("claude-sonnet-4", ""),
435            ProviderKind::Anthropic
436        ));
437        assert!(matches!(
438            detect_provider("claude-opus-4", ""),
439            ProviderKind::Anthropic
440        ));
441    }
442
443    #[test]
444    fn test_detect_from_model_gpt() {
445        assert!(matches!(
446            detect_provider("gpt-4.1-mini", ""),
447            ProviderKind::OpenAi
448        ));
449        assert!(matches!(
450            detect_provider("o3-mini", ""),
451            ProviderKind::OpenAi
452        ));
453    }
454
455    #[test]
456    fn test_detect_from_model_grok() {
457        assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
458    }
459
460    #[test]
461    fn test_detect_from_model_gemini() {
462        assert!(matches!(
463            detect_provider("gemini-2.5-flash", ""),
464            ProviderKind::Google
465        ));
466    }
467
468    #[test]
469    fn test_detect_unknown_defaults_openai_compat() {
470        assert!(matches!(
471            detect_provider("some-random-model", "https://my-server.com"),
472            ProviderKind::OpenAiCompatible
473        ));
474    }
475
476    #[test]
477    fn test_url_takes_priority_over_model() {
478        // URL says OpenAI but model says Claude — URL wins.
479        assert!(matches!(
480            detect_provider("claude-sonnet", "https://api.openai.com/v1"),
481            ProviderKind::OpenAi
482        ));
483    }
484
485    #[test]
486    fn test_wire_format_anthropic_family() {
487        assert_eq!(ProviderKind::Anthropic.wire_format(), WireFormat::Anthropic);
488        assert_eq!(ProviderKind::Bedrock.wire_format(), WireFormat::Anthropic);
489        assert_eq!(ProviderKind::Vertex.wire_format(), WireFormat::Anthropic);
490    }
491
492    #[test]
493    fn test_wire_format_openai_compatible_family() {
494        let openai_compat_providers = [
495            ProviderKind::OpenAi,
496            ProviderKind::Xai,
497            ProviderKind::Google,
498            ProviderKind::DeepSeek,
499            ProviderKind::Groq,
500            ProviderKind::Mistral,
501            ProviderKind::Together,
502            ProviderKind::Zhipu,
503            ProviderKind::OpenRouter,
504            ProviderKind::Cohere,
505            ProviderKind::Perplexity,
506            ProviderKind::OpenAiCompatible,
507        ];
508        for p in openai_compat_providers {
509            assert_eq!(
510                p.wire_format(),
511                WireFormat::OpenAiCompatible,
512                "{p:?} should use OpenAiCompatible wire format"
513            );
514        }
515    }
516
517    #[test]
518    fn test_default_base_url_returns_some_for_known_providers() {
519        let providers_with_urls = [
520            ProviderKind::Anthropic,
521            ProviderKind::OpenAi,
522            ProviderKind::Xai,
523            ProviderKind::Google,
524            ProviderKind::DeepSeek,
525            ProviderKind::Groq,
526            ProviderKind::Mistral,
527            ProviderKind::Together,
528            ProviderKind::Zhipu,
529            ProviderKind::OpenRouter,
530            ProviderKind::Cohere,
531            ProviderKind::Perplexity,
532        ];
533        for p in providers_with_urls {
534            assert!(
535                p.default_base_url().is_some(),
536                "{p:?} should have a default base URL"
537            );
538        }
539    }
540
541    #[test]
542    fn test_default_base_url_returns_none_for_user_configured() {
543        assert!(ProviderKind::Bedrock.default_base_url().is_none());
544        assert!(ProviderKind::Vertex.default_base_url().is_none());
545        assert!(ProviderKind::AzureOpenAi.default_base_url().is_none());
546        assert!(ProviderKind::OpenAiCompatible.default_base_url().is_none());
547    }
548
549    #[test]
550    fn test_env_var_name_all_variants() {
551        assert_eq!(ProviderKind::Anthropic.env_var_name(), "ANTHROPIC_API_KEY");
552        assert_eq!(ProviderKind::Bedrock.env_var_name(), "ANTHROPIC_API_KEY");
553        assert_eq!(ProviderKind::Vertex.env_var_name(), "ANTHROPIC_API_KEY");
554        assert_eq!(ProviderKind::OpenAi.env_var_name(), "OPENAI_API_KEY");
555        assert_eq!(
556            ProviderKind::AzureOpenAi.env_var_name(),
557            "AZURE_OPENAI_API_KEY"
558        );
559        assert_eq!(ProviderKind::Xai.env_var_name(), "XAI_API_KEY");
560        assert_eq!(ProviderKind::Google.env_var_name(), "GOOGLE_API_KEY");
561        assert_eq!(ProviderKind::DeepSeek.env_var_name(), "DEEPSEEK_API_KEY");
562        assert_eq!(ProviderKind::Groq.env_var_name(), "GROQ_API_KEY");
563        assert_eq!(ProviderKind::Mistral.env_var_name(), "MISTRAL_API_KEY");
564        assert_eq!(ProviderKind::Together.env_var_name(), "TOGETHER_API_KEY");
565        assert_eq!(ProviderKind::Zhipu.env_var_name(), "ZHIPU_API_KEY");
566        assert_eq!(
567            ProviderKind::OpenRouter.env_var_name(),
568            "OPENROUTER_API_KEY"
569        );
570        assert_eq!(ProviderKind::Cohere.env_var_name(), "COHERE_API_KEY");
571        assert_eq!(
572            ProviderKind::Perplexity.env_var_name(),
573            "PERPLEXITY_API_KEY"
574        );
575        assert_eq!(
576            ProviderKind::OpenAiCompatible.env_var_name(),
577            "OPENAI_API_KEY"
578        );
579    }
580
581    #[test]
582    fn test_detect_from_url_zhipu_bigmodel() {
583        assert!(matches!(
584            detect_provider("any", "https://open.bigmodel.cn/api/paas/v4"),
585            ProviderKind::Zhipu
586        ));
587    }
588
589    #[test]
590    fn test_detect_from_model_deepseek_chat() {
591        assert!(matches!(
592            detect_provider("deepseek-chat", ""),
593            ProviderKind::DeepSeek
594        ));
595    }
596
597    #[test]
598    fn test_detect_from_model_mistral_large() {
599        assert!(matches!(
600            detect_provider("mistral-large", ""),
601            ProviderKind::Mistral
602        ));
603    }
604
605    #[test]
606    fn test_detect_from_model_glm4() {
607        assert!(matches!(detect_provider("glm-4", ""), ProviderKind::Zhipu));
608    }
609
610    #[test]
611    fn test_detect_from_model_llama3_with_groq_url() {
612        assert!(matches!(
613            detect_provider("llama-3", "https://api.groq.com/openai/v1"),
614            ProviderKind::Groq
615        ));
616    }
617
618    #[test]
619    fn test_detect_from_model_codestral() {
620        assert!(matches!(
621            detect_provider("codestral-latest", ""),
622            ProviderKind::Mistral
623        ));
624    }
625
626    #[test]
627    fn test_detect_from_model_pplx() {
628        assert!(matches!(
629            detect_provider("pplx-70b-online", ""),
630            ProviderKind::Perplexity
631        ));
632    }
633
634    #[test]
635    fn test_provider_error_display() {
636        let err = ProviderError::Auth("bad token".into());
637        assert_eq!(format!("{err}"), "auth: bad token");
638
639        let err = ProviderError::RateLimited {
640            retry_after_ms: 1000,
641        };
642        assert_eq!(format!("{err}"), "rate limited (retry in 1000ms)");
643
644        let err = ProviderError::Overloaded;
645        assert_eq!(format!("{err}"), "server overloaded");
646
647        let err = ProviderError::RequestTooLarge("4MB limit".into());
648        assert_eq!(format!("{err}"), "request too large: 4MB limit");
649
650        let err = ProviderError::Network("timeout".into());
651        assert_eq!(format!("{err}"), "network: timeout");
652
653        let err = ProviderError::InvalidResponse("missing field".into());
654        assert_eq!(format!("{err}"), "invalid response: missing field");
655    }
656
657    #[test]
658    fn test_tool_choice_default_is_auto() {
659        let tc = ToolChoice::default();
660        assert!(matches!(tc, ToolChoice::Auto));
661    }
662}