Skip to main content

agent_code_lib/llm/
provider.rs

1//! LLM provider abstraction.
2//!
3//! Two wire formats cover the entire ecosystem:
4//! - Anthropic Messages API (Claude models)
5//! - OpenAI Chat Completions (GPT, plus Groq, Together, Ollama, DeepSeek, etc.)
6//!
7//! Each provider translates between our unified message types and
8//! the provider-specific JSON format for requests and SSE streams.
9
10use async_trait::async_trait;
11use tokio::sync::mpsc;
12use tokio_util::sync::CancellationToken;
13
14use super::message::Message;
15use super::stream::StreamEvent;
16use crate::tools::ToolSchema;
17
18/// Unified provider trait. Both Anthropic and OpenAI-compatible
19/// endpoints implement this.
20#[async_trait]
21pub trait Provider: Send + Sync {
22    /// Human-readable provider name.
23    fn name(&self) -> &str;
24
25    /// Send a streaming request. Returns a channel of events.
26    async fn stream(
27        &self,
28        request: &ProviderRequest,
29    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
30}
31
32/// Tool choice mode for controlling tool usage.
33#[derive(Debug, Clone, Default)]
34pub enum ToolChoice {
35    /// Model decides whether to use tools.
36    #[default]
37    Auto,
38    /// Model must use a tool.
39    Any,
40    /// Model must not use tools.
41    None,
42    /// Model must use a specific tool.
43    Specific(String),
44}
45
46/// A provider-agnostic request.
47pub struct ProviderRequest {
48    pub messages: Vec<Message>,
49    pub system_prompt: String,
50    pub tools: Vec<ToolSchema>,
51    pub model: String,
52    pub max_tokens: u32,
53    pub temperature: Option<f64>,
54    pub enable_caching: bool,
55    /// Controls whether/how the model should use tools.
56    pub tool_choice: ToolChoice,
57    /// Metadata to send with the request (e.g., user_id for Anthropic).
58    pub metadata: Option<serde_json::Value>,
59    /// Cancellation token for interrupting the in-flight streaming HTTP read.
60    /// Providers must race `byte_stream.next().await` against
61    /// `cancel.cancelled()` so that the spawned streaming task exits
62    /// promptly when the user presses Escape or Ctrl+C. Background callers
63    /// (memory extraction, consolidation) can pass `CancellationToken::new()`
64    /// for an uncancellable request.
65    pub cancel: CancellationToken,
66}
67
68/// Provider-level errors.
69#[derive(Debug)]
70pub enum ProviderError {
71    Auth(String),
72    RateLimited { retry_after_ms: u64 },
73    Overloaded,
74    RequestTooLarge(String),
75    Network(String),
76    InvalidResponse(String),
77}
78
79impl std::fmt::Display for ProviderError {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        match self {
82            Self::Auth(msg) => write!(f, "auth: {msg}"),
83            Self::RateLimited { retry_after_ms } => {
84                write!(f, "rate limited (retry in {retry_after_ms}ms)")
85            }
86            Self::Overloaded => write!(f, "server overloaded"),
87            Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
88            Self::Network(msg) => write!(f, "network: {msg}"),
89            Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
90        }
91    }
92}
93
94/// Detect the right provider from a model name or base URL.
95pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
96    let model_lower = model.to_lowercase();
97    let url_lower = base_url.to_lowercase();
98
99    // AWS Bedrock (Claude via AWS).
100    if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
101        return ProviderKind::Bedrock;
102    }
103    // Google Vertex AI (Claude via GCP).
104    if url_lower.contains("aiplatform.googleapis.com") {
105        return ProviderKind::Vertex;
106    }
107    if url_lower.contains("anthropic.com") {
108        return ProviderKind::Anthropic;
109    }
110    // Azure OpenAI — must be checked before generic openai.com.
111    if url_lower.contains("openai.azure.com")
112        || url_lower.contains("azure.com") && url_lower.contains("openai")
113    {
114        return ProviderKind::AzureOpenAi;
115    }
116    if url_lower.contains("openai.com") {
117        return ProviderKind::OpenAi;
118    }
119    if url_lower.contains("x.ai") || url_lower.contains("xai.") {
120        return ProviderKind::Xai;
121    }
122    if url_lower.contains("googleapis.com") || url_lower.contains("google") {
123        return ProviderKind::Google;
124    }
125    if url_lower.contains("deepseek.com") {
126        return ProviderKind::DeepSeek;
127    }
128    if url_lower.contains("groq.com") {
129        return ProviderKind::Groq;
130    }
131    if url_lower.contains("mistral.ai") {
132        return ProviderKind::Mistral;
133    }
134    if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
135        return ProviderKind::Together;
136    }
137    if url_lower.contains("bigmodel.cn")
138        || url_lower.contains("z.ai")
139        || url_lower.contains("zhipu")
140    {
141        return ProviderKind::Zhipu;
142    }
143    if url_lower.contains("openrouter.ai") {
144        return ProviderKind::OpenRouter;
145    }
146    if url_lower.contains("cohere.com") || url_lower.contains("cohere.ai") {
147        return ProviderKind::Cohere;
148    }
149    if url_lower.contains("perplexity.ai") {
150        return ProviderKind::Perplexity;
151    }
152    if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
153        return ProviderKind::OpenAiCompatible;
154    }
155
156    // Detect from model name.
157    if model_lower.starts_with("claude")
158        || model_lower.contains("opus")
159        || model_lower.contains("sonnet")
160        || model_lower.contains("haiku")
161    {
162        return ProviderKind::Anthropic;
163    }
164    if model_lower.starts_with("gpt")
165        || model_lower.starts_with("o1")
166        || model_lower.starts_with("o3")
167    {
168        return ProviderKind::OpenAi;
169    }
170    if model_lower.starts_with("grok") {
171        return ProviderKind::Xai;
172    }
173    if model_lower.starts_with("gemini") {
174        return ProviderKind::Google;
175    }
176    if model_lower.starts_with("deepseek") {
177        return ProviderKind::DeepSeek;
178    }
179    if model_lower.starts_with("llama") && url_lower.contains("groq") {
180        return ProviderKind::Groq;
181    }
182    if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
183        return ProviderKind::Mistral;
184    }
185    if model_lower.starts_with("glm") {
186        return ProviderKind::Zhipu;
187    }
188    if model_lower.starts_with("command") {
189        return ProviderKind::Cohere;
190    }
191    if model_lower.starts_with("pplx") || model_lower.starts_with("sonar") {
192        return ProviderKind::Perplexity;
193    }
194
195    ProviderKind::OpenAiCompatible
196}
197
198/// The two wire formats that cover the entire LLM ecosystem.
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
200pub enum WireFormat {
201    /// Anthropic Messages API (Claude models, Bedrock, Vertex).
202    Anthropic,
203    /// OpenAI Chat Completions (GPT, Groq, Together, Ollama, DeepSeek, etc.).
204    OpenAiCompatible,
205}
206
207/// Provider kinds.
208#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub enum ProviderKind {
210    Anthropic,
211    Bedrock,
212    Vertex,
213    OpenAi,
214    AzureOpenAi,
215    Xai,
216    Google,
217    DeepSeek,
218    Groq,
219    Mistral,
220    Together,
221    Zhipu,
222    OpenRouter,
223    Cohere,
224    Perplexity,
225    OpenAiCompatible,
226}
227
228impl ProviderKind {
229    /// Which wire format this provider uses.
230    pub fn wire_format(&self) -> WireFormat {
231        match self {
232            Self::Anthropic | Self::Bedrock | Self::Vertex => WireFormat::Anthropic,
233            Self::OpenAi
234            | Self::AzureOpenAi
235            | Self::Xai
236            | Self::Google
237            | Self::DeepSeek
238            | Self::Groq
239            | Self::Mistral
240            | Self::Together
241            | Self::Zhipu
242            | Self::OpenRouter
243            | Self::Cohere
244            | Self::Perplexity
245            | Self::OpenAiCompatible => WireFormat::OpenAiCompatible,
246        }
247    }
248
249    /// The default base URL for this provider, or `None` for providers
250    /// whose URL must come from user configuration (Bedrock, Vertex,
251    /// and generic OpenAI-compatible endpoints).
252    pub fn default_base_url(&self) -> Option<&str> {
253        match self {
254            Self::Anthropic => Some("https://api.anthropic.com/v1"),
255            Self::OpenAi => Some("https://api.openai.com/v1"),
256            Self::Xai => Some("https://api.x.ai/v1"),
257            Self::Google => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
258            Self::DeepSeek => Some("https://api.deepseek.com/v1"),
259            Self::Groq => Some("https://api.groq.com/openai/v1"),
260            Self::Mistral => Some("https://api.mistral.ai/v1"),
261            Self::Together => Some("https://api.together.xyz/v1"),
262            Self::Zhipu => Some("https://open.bigmodel.cn/api/paas/v4"),
263            Self::OpenRouter => Some("https://openrouter.ai/api/v1"),
264            Self::Cohere => Some("https://api.cohere.com/v2"),
265            Self::Perplexity => Some("https://api.perplexity.ai"),
266            // These require user-supplied URLs.
267            Self::Bedrock | Self::Vertex | Self::AzureOpenAi | Self::OpenAiCompatible => None,
268        }
269    }
270
271    /// The environment variable name conventionally used for this provider's API key.
272    pub fn env_var_name(&self) -> &str {
273        match self {
274            Self::Anthropic | Self::Bedrock | Self::Vertex => "ANTHROPIC_API_KEY",
275            Self::OpenAi => "OPENAI_API_KEY",
276            Self::AzureOpenAi => "AZURE_OPENAI_API_KEY",
277            Self::Xai => "XAI_API_KEY",
278            Self::Google => "GOOGLE_API_KEY",
279            Self::DeepSeek => "DEEPSEEK_API_KEY",
280            Self::Groq => "GROQ_API_KEY",
281            Self::Mistral => "MISTRAL_API_KEY",
282            Self::Together => "TOGETHER_API_KEY",
283            Self::Zhipu => "ZHIPU_API_KEY",
284            Self::OpenRouter => "OPENROUTER_API_KEY",
285            Self::Cohere => "COHERE_API_KEY",
286            Self::Perplexity => "PERPLEXITY_API_KEY",
287            Self::OpenAiCompatible => "OPENAI_API_KEY",
288        }
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_detect_from_url_anthropic() {
298        assert!(matches!(
299            detect_provider("any", "https://api.anthropic.com/v1"),
300            ProviderKind::Anthropic
301        ));
302    }
303
304    #[test]
305    fn test_detect_from_url_openai() {
306        assert!(matches!(
307            detect_provider("any", "https://api.openai.com/v1"),
308            ProviderKind::OpenAi
309        ));
310    }
311
312    #[test]
313    fn test_detect_from_url_bedrock() {
314        assert!(matches!(
315            detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
316            ProviderKind::Bedrock
317        ));
318    }
319
320    #[test]
321    fn test_detect_from_url_vertex() {
322        assert!(matches!(
323            detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
324            ProviderKind::Vertex
325        ));
326    }
327
328    #[test]
329    fn test_detect_from_url_azure_openai() {
330        assert!(matches!(
331            detect_provider(
332                "any",
333                "https://myresource.openai.azure.com/openai/deployments/gpt-4"
334            ),
335            ProviderKind::AzureOpenAi
336        ));
337    }
338
339    #[test]
340    fn test_detect_azure_before_generic_openai() {
341        // Azure URL contains "openai" but should match Azure, not generic OpenAI.
342        assert!(matches!(
343            detect_provider(
344                "gpt-4",
345                "https://myresource.openai.azure.com/openai/deployments/gpt-4"
346            ),
347            ProviderKind::AzureOpenAi
348        ));
349    }
350
351    #[test]
352    fn test_detect_from_url_xai() {
353        assert!(matches!(
354            detect_provider("any", "https://api.x.ai/v1"),
355            ProviderKind::Xai
356        ));
357    }
358
359    #[test]
360    fn test_detect_from_url_deepseek() {
361        assert!(matches!(
362            detect_provider("any", "https://api.deepseek.com/v1"),
363            ProviderKind::DeepSeek
364        ));
365    }
366
367    #[test]
368    fn test_detect_from_url_groq() {
369        assert!(matches!(
370            detect_provider("any", "https://api.groq.com/openai/v1"),
371            ProviderKind::Groq
372        ));
373    }
374
375    #[test]
376    fn test_detect_from_url_mistral() {
377        assert!(matches!(
378            detect_provider("any", "https://api.mistral.ai/v1"),
379            ProviderKind::Mistral
380        ));
381    }
382
383    #[test]
384    fn test_detect_from_url_together() {
385        assert!(matches!(
386            detect_provider("any", "https://api.together.xyz/v1"),
387            ProviderKind::Together
388        ));
389    }
390
391    #[test]
392    fn test_detect_from_url_cohere() {
393        assert!(matches!(
394            detect_provider("any", "https://api.cohere.com/v2"),
395            ProviderKind::Cohere
396        ));
397    }
398
399    #[test]
400    fn test_detect_from_url_perplexity() {
401        assert!(matches!(
402            detect_provider("any", "https://api.perplexity.ai"),
403            ProviderKind::Perplexity
404        ));
405    }
406
407    #[test]
408    fn test_detect_from_model_command_r() {
409        assert!(matches!(
410            detect_provider("command-r-plus", ""),
411            ProviderKind::Cohere
412        ));
413    }
414
415    #[test]
416    fn test_detect_from_model_sonar() {
417        assert!(matches!(
418            detect_provider("sonar-pro", ""),
419            ProviderKind::Perplexity
420        ));
421    }
422
423    #[test]
424    fn test_detect_from_url_openrouter() {
425        assert!(matches!(
426            detect_provider("any", "https://openrouter.ai/api/v1"),
427            ProviderKind::OpenRouter
428        ));
429    }
430
431    #[test]
432    fn test_detect_from_url_localhost() {
433        assert!(matches!(
434            detect_provider("any", "http://localhost:11434/v1"),
435            ProviderKind::OpenAiCompatible
436        ));
437    }
438
439    #[test]
440    fn test_detect_from_model_claude() {
441        assert!(matches!(
442            detect_provider("claude-sonnet-4", ""),
443            ProviderKind::Anthropic
444        ));
445        assert!(matches!(
446            detect_provider("claude-opus-4", ""),
447            ProviderKind::Anthropic
448        ));
449    }
450
451    #[test]
452    fn test_detect_from_model_gpt() {
453        assert!(matches!(
454            detect_provider("gpt-4.1-mini", ""),
455            ProviderKind::OpenAi
456        ));
457        assert!(matches!(
458            detect_provider("o3-mini", ""),
459            ProviderKind::OpenAi
460        ));
461    }
462
463    #[test]
464    fn test_detect_from_model_grok() {
465        assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
466    }
467
468    #[test]
469    fn test_detect_from_model_gemini() {
470        assert!(matches!(
471            detect_provider("gemini-2.5-flash", ""),
472            ProviderKind::Google
473        ));
474    }
475
476    #[test]
477    fn test_detect_unknown_defaults_openai_compat() {
478        assert!(matches!(
479            detect_provider("some-random-model", "https://my-server.com"),
480            ProviderKind::OpenAiCompatible
481        ));
482    }
483
484    #[test]
485    fn test_url_takes_priority_over_model() {
486        // URL says OpenAI but model says Claude — URL wins.
487        assert!(matches!(
488            detect_provider("claude-sonnet", "https://api.openai.com/v1"),
489            ProviderKind::OpenAi
490        ));
491    }
492
493    #[test]
494    fn test_wire_format_anthropic_family() {
495        assert_eq!(ProviderKind::Anthropic.wire_format(), WireFormat::Anthropic);
496        assert_eq!(ProviderKind::Bedrock.wire_format(), WireFormat::Anthropic);
497        assert_eq!(ProviderKind::Vertex.wire_format(), WireFormat::Anthropic);
498    }
499
500    #[test]
501    fn test_wire_format_openai_compatible_family() {
502        let openai_compat_providers = [
503            ProviderKind::OpenAi,
504            ProviderKind::Xai,
505            ProviderKind::Google,
506            ProviderKind::DeepSeek,
507            ProviderKind::Groq,
508            ProviderKind::Mistral,
509            ProviderKind::Together,
510            ProviderKind::Zhipu,
511            ProviderKind::OpenRouter,
512            ProviderKind::Cohere,
513            ProviderKind::Perplexity,
514            ProviderKind::OpenAiCompatible,
515        ];
516        for p in openai_compat_providers {
517            assert_eq!(
518                p.wire_format(),
519                WireFormat::OpenAiCompatible,
520                "{p:?} should use OpenAiCompatible wire format"
521            );
522        }
523    }
524
525    #[test]
526    fn test_default_base_url_returns_some_for_known_providers() {
527        let providers_with_urls = [
528            ProviderKind::Anthropic,
529            ProviderKind::OpenAi,
530            ProviderKind::Xai,
531            ProviderKind::Google,
532            ProviderKind::DeepSeek,
533            ProviderKind::Groq,
534            ProviderKind::Mistral,
535            ProviderKind::Together,
536            ProviderKind::Zhipu,
537            ProviderKind::OpenRouter,
538            ProviderKind::Cohere,
539            ProviderKind::Perplexity,
540        ];
541        for p in providers_with_urls {
542            assert!(
543                p.default_base_url().is_some(),
544                "{p:?} should have a default base URL"
545            );
546        }
547    }
548
549    #[test]
550    fn test_default_base_url_returns_none_for_user_configured() {
551        assert!(ProviderKind::Bedrock.default_base_url().is_none());
552        assert!(ProviderKind::Vertex.default_base_url().is_none());
553        assert!(ProviderKind::AzureOpenAi.default_base_url().is_none());
554        assert!(ProviderKind::OpenAiCompatible.default_base_url().is_none());
555    }
556
557    #[test]
558    fn test_env_var_name_all_variants() {
559        assert_eq!(ProviderKind::Anthropic.env_var_name(), "ANTHROPIC_API_KEY");
560        assert_eq!(ProviderKind::Bedrock.env_var_name(), "ANTHROPIC_API_KEY");
561        assert_eq!(ProviderKind::Vertex.env_var_name(), "ANTHROPIC_API_KEY");
562        assert_eq!(ProviderKind::OpenAi.env_var_name(), "OPENAI_API_KEY");
563        assert_eq!(
564            ProviderKind::AzureOpenAi.env_var_name(),
565            "AZURE_OPENAI_API_KEY"
566        );
567        assert_eq!(ProviderKind::Xai.env_var_name(), "XAI_API_KEY");
568        assert_eq!(ProviderKind::Google.env_var_name(), "GOOGLE_API_KEY");
569        assert_eq!(ProviderKind::DeepSeek.env_var_name(), "DEEPSEEK_API_KEY");
570        assert_eq!(ProviderKind::Groq.env_var_name(), "GROQ_API_KEY");
571        assert_eq!(ProviderKind::Mistral.env_var_name(), "MISTRAL_API_KEY");
572        assert_eq!(ProviderKind::Together.env_var_name(), "TOGETHER_API_KEY");
573        assert_eq!(ProviderKind::Zhipu.env_var_name(), "ZHIPU_API_KEY");
574        assert_eq!(
575            ProviderKind::OpenRouter.env_var_name(),
576            "OPENROUTER_API_KEY"
577        );
578        assert_eq!(ProviderKind::Cohere.env_var_name(), "COHERE_API_KEY");
579        assert_eq!(
580            ProviderKind::Perplexity.env_var_name(),
581            "PERPLEXITY_API_KEY"
582        );
583        assert_eq!(
584            ProviderKind::OpenAiCompatible.env_var_name(),
585            "OPENAI_API_KEY"
586        );
587    }
588
589    #[test]
590    fn test_detect_from_url_zhipu_bigmodel() {
591        assert!(matches!(
592            detect_provider("any", "https://open.bigmodel.cn/api/paas/v4"),
593            ProviderKind::Zhipu
594        ));
595    }
596
597    #[test]
598    fn test_detect_from_model_deepseek_chat() {
599        assert!(matches!(
600            detect_provider("deepseek-chat", ""),
601            ProviderKind::DeepSeek
602        ));
603    }
604
605    #[test]
606    fn test_detect_from_model_mistral_large() {
607        assert!(matches!(
608            detect_provider("mistral-large", ""),
609            ProviderKind::Mistral
610        ));
611    }
612
613    #[test]
614    fn test_detect_from_model_glm4() {
615        assert!(matches!(detect_provider("glm-4", ""), ProviderKind::Zhipu));
616    }
617
618    #[test]
619    fn test_detect_from_model_llama3_with_groq_url() {
620        assert!(matches!(
621            detect_provider("llama-3", "https://api.groq.com/openai/v1"),
622            ProviderKind::Groq
623        ));
624    }
625
626    #[test]
627    fn test_detect_from_model_codestral() {
628        assert!(matches!(
629            detect_provider("codestral-latest", ""),
630            ProviderKind::Mistral
631        ));
632    }
633
634    #[test]
635    fn test_detect_from_model_pplx() {
636        assert!(matches!(
637            detect_provider("pplx-70b-online", ""),
638            ProviderKind::Perplexity
639        ));
640    }
641
642    #[test]
643    fn test_provider_error_display() {
644        let err = ProviderError::Auth("bad token".into());
645        assert_eq!(format!("{err}"), "auth: bad token");
646
647        let err = ProviderError::RateLimited {
648            retry_after_ms: 1000,
649        };
650        assert_eq!(format!("{err}"), "rate limited (retry in 1000ms)");
651
652        let err = ProviderError::Overloaded;
653        assert_eq!(format!("{err}"), "server overloaded");
654
655        let err = ProviderError::RequestTooLarge("4MB limit".into());
656        assert_eq!(format!("{err}"), "request too large: 4MB limit");
657
658        let err = ProviderError::Network("timeout".into());
659        assert_eq!(format!("{err}"), "network: timeout");
660
661        let err = ProviderError::InvalidResponse("missing field".into());
662        assert_eq!(format!("{err}"), "invalid response: missing field");
663    }
664
665    #[test]
666    fn test_tool_choice_default_is_auto() {
667        let tc = ToolChoice::default();
668        assert!(matches!(tc, ToolChoice::Auto));
669    }
670}