Skip to main content

agent_code_lib/llm/
provider.rs

1//! LLM provider abstraction.
2//!
3//! Two wire formats cover the entire ecosystem:
4//! - Anthropic Messages API (Claude models)
5//! - OpenAI Chat Completions (GPT, plus Groq, Together, Ollama, DeepSeek, etc.)
6//!
7//! Each provider translates between our unified message types and
8//! the provider-specific JSON format for requests and SSE streams.
9
10use async_trait::async_trait;
11use tokio::sync::mpsc;
12
13use super::message::Message;
14use super::stream::StreamEvent;
15use crate::tools::ToolSchema;
16
17/// Unified provider trait. Both Anthropic and OpenAI-compatible
18/// endpoints implement this.
19#[async_trait]
20pub trait Provider: Send + Sync {
21    /// Human-readable provider name.
22    fn name(&self) -> &str;
23
24    /// Send a streaming request. Returns a channel of events.
25    async fn stream(
26        &self,
27        request: &ProviderRequest,
28    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
29}
30
31/// Tool choice mode for controlling tool usage.
32#[derive(Debug, Clone, Default)]
33pub enum ToolChoice {
34    /// Model decides whether to use tools.
35    #[default]
36    Auto,
37    /// Model must use a tool.
38    Any,
39    /// Model must not use tools.
40    None,
41    /// Model must use a specific tool.
42    Specific(String),
43}
44
45/// A provider-agnostic request.
46pub struct ProviderRequest {
47    pub messages: Vec<Message>,
48    pub system_prompt: String,
49    pub tools: Vec<ToolSchema>,
50    pub model: String,
51    pub max_tokens: u32,
52    pub temperature: Option<f64>,
53    pub enable_caching: bool,
54    /// Controls whether/how the model should use tools.
55    pub tool_choice: ToolChoice,
56    /// Metadata to send with the request (e.g., user_id for Anthropic).
57    pub metadata: Option<serde_json::Value>,
58}
59
60/// Provider-level errors.
61#[derive(Debug)]
62pub enum ProviderError {
63    Auth(String),
64    RateLimited { retry_after_ms: u64 },
65    Overloaded,
66    RequestTooLarge(String),
67    Network(String),
68    InvalidResponse(String),
69}
70
71impl std::fmt::Display for ProviderError {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            Self::Auth(msg) => write!(f, "auth: {msg}"),
75            Self::RateLimited { retry_after_ms } => {
76                write!(f, "rate limited (retry in {retry_after_ms}ms)")
77            }
78            Self::Overloaded => write!(f, "server overloaded"),
79            Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
80            Self::Network(msg) => write!(f, "network: {msg}"),
81            Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
82        }
83    }
84}
85
86/// Detect the right provider from a model name or base URL.
87pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
88    let model_lower = model.to_lowercase();
89    let url_lower = base_url.to_lowercase();
90
91    // AWS Bedrock (Claude via AWS).
92    if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
93        return ProviderKind::Bedrock;
94    }
95    // Google Vertex AI (Claude via GCP).
96    if url_lower.contains("aiplatform.googleapis.com") {
97        return ProviderKind::Vertex;
98    }
99    if url_lower.contains("anthropic.com") {
100        return ProviderKind::Anthropic;
101    }
102    if url_lower.contains("openai.com") {
103        return ProviderKind::OpenAi;
104    }
105    if url_lower.contains("x.ai") || url_lower.contains("xai.") {
106        return ProviderKind::Xai;
107    }
108    if url_lower.contains("googleapis.com") || url_lower.contains("google") {
109        return ProviderKind::Google;
110    }
111    if url_lower.contains("deepseek.com") {
112        return ProviderKind::DeepSeek;
113    }
114    if url_lower.contains("groq.com") {
115        return ProviderKind::Groq;
116    }
117    if url_lower.contains("mistral.ai") {
118        return ProviderKind::Mistral;
119    }
120    if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
121        return ProviderKind::Together;
122    }
123    if url_lower.contains("bigmodel.cn")
124        || url_lower.contains("z.ai")
125        || url_lower.contains("zhipu")
126    {
127        return ProviderKind::Zhipu;
128    }
129    if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
130        return ProviderKind::OpenAiCompatible;
131    }
132
133    // Detect from model name.
134    if model_lower.starts_with("claude")
135        || model_lower.contains("opus")
136        || model_lower.contains("sonnet")
137        || model_lower.contains("haiku")
138    {
139        return ProviderKind::Anthropic;
140    }
141    if model_lower.starts_with("gpt")
142        || model_lower.starts_with("o1")
143        || model_lower.starts_with("o3")
144    {
145        return ProviderKind::OpenAi;
146    }
147    if model_lower.starts_with("grok") {
148        return ProviderKind::Xai;
149    }
150    if model_lower.starts_with("gemini") {
151        return ProviderKind::Google;
152    }
153    if model_lower.starts_with("deepseek") {
154        return ProviderKind::DeepSeek;
155    }
156    if model_lower.starts_with("llama") && url_lower.contains("groq") {
157        return ProviderKind::Groq;
158    }
159    if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
160        return ProviderKind::Mistral;
161    }
162    if model_lower.starts_with("glm") {
163        return ProviderKind::Zhipu;
164    }
165
166    ProviderKind::OpenAiCompatible
167}
168
169/// The two wire formats that cover the entire LLM ecosystem.
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum WireFormat {
172    /// Anthropic Messages API (Claude models, Bedrock, Vertex).
173    Anthropic,
174    /// OpenAI Chat Completions (GPT, Groq, Together, Ollama, DeepSeek, etc.).
175    OpenAiCompatible,
176}
177
178/// Provider kinds.
179#[derive(Debug, Clone, Copy, PartialEq, Eq)]
180pub enum ProviderKind {
181    Anthropic,
182    Bedrock,
183    Vertex,
184    OpenAi,
185    Xai,
186    Google,
187    DeepSeek,
188    Groq,
189    Mistral,
190    Together,
191    Zhipu,
192    OpenAiCompatible,
193}
194
195impl ProviderKind {
196    /// Which wire format this provider uses.
197    pub fn wire_format(&self) -> WireFormat {
198        match self {
199            Self::Anthropic | Self::Bedrock | Self::Vertex => WireFormat::Anthropic,
200            Self::OpenAi
201            | Self::Xai
202            | Self::Google
203            | Self::DeepSeek
204            | Self::Groq
205            | Self::Mistral
206            | Self::Together
207            | Self::Zhipu
208            | Self::OpenAiCompatible => WireFormat::OpenAiCompatible,
209        }
210    }
211
212    /// The default base URL for this provider, or `None` for providers
213    /// whose URL must come from user configuration (Bedrock, Vertex,
214    /// and generic OpenAI-compatible endpoints).
215    pub fn default_base_url(&self) -> Option<&str> {
216        match self {
217            Self::Anthropic => Some("https://api.anthropic.com/v1"),
218            Self::OpenAi => Some("https://api.openai.com/v1"),
219            Self::Xai => Some("https://api.x.ai/v1"),
220            Self::Google => Some("https://generativelanguage.googleapis.com/v1beta/openai"),
221            Self::DeepSeek => Some("https://api.deepseek.com/v1"),
222            Self::Groq => Some("https://api.groq.com/openai/v1"),
223            Self::Mistral => Some("https://api.mistral.ai/v1"),
224            Self::Together => Some("https://api.together.xyz/v1"),
225            Self::Zhipu => Some("https://open.bigmodel.cn/api/paas/v4"),
226            // These require user-supplied URLs.
227            Self::Bedrock | Self::Vertex | Self::OpenAiCompatible => None,
228        }
229    }
230
231    /// The environment variable name conventionally used for this provider's API key.
232    pub fn env_var_name(&self) -> &str {
233        match self {
234            Self::Anthropic | Self::Bedrock | Self::Vertex => "ANTHROPIC_API_KEY",
235            Self::OpenAi => "OPENAI_API_KEY",
236            Self::Xai => "XAI_API_KEY",
237            Self::Google => "GOOGLE_API_KEY",
238            Self::DeepSeek => "DEEPSEEK_API_KEY",
239            Self::Groq => "GROQ_API_KEY",
240            Self::Mistral => "MISTRAL_API_KEY",
241            Self::Together => "TOGETHER_API_KEY",
242            Self::Zhipu => "ZHIPU_API_KEY",
243            Self::OpenAiCompatible => "OPENAI_API_KEY",
244        }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_detect_from_url_anthropic() {
254        assert!(matches!(
255            detect_provider("any", "https://api.anthropic.com/v1"),
256            ProviderKind::Anthropic
257        ));
258    }
259
260    #[test]
261    fn test_detect_from_url_openai() {
262        assert!(matches!(
263            detect_provider("any", "https://api.openai.com/v1"),
264            ProviderKind::OpenAi
265        ));
266    }
267
268    #[test]
269    fn test_detect_from_url_bedrock() {
270        assert!(matches!(
271            detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
272            ProviderKind::Bedrock
273        ));
274    }
275
276    #[test]
277    fn test_detect_from_url_vertex() {
278        assert!(matches!(
279            detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
280            ProviderKind::Vertex
281        ));
282    }
283
284    #[test]
285    fn test_detect_from_url_xai() {
286        assert!(matches!(
287            detect_provider("any", "https://api.x.ai/v1"),
288            ProviderKind::Xai
289        ));
290    }
291
292    #[test]
293    fn test_detect_from_url_deepseek() {
294        assert!(matches!(
295            detect_provider("any", "https://api.deepseek.com/v1"),
296            ProviderKind::DeepSeek
297        ));
298    }
299
300    #[test]
301    fn test_detect_from_url_groq() {
302        assert!(matches!(
303            detect_provider("any", "https://api.groq.com/openai/v1"),
304            ProviderKind::Groq
305        ));
306    }
307
308    #[test]
309    fn test_detect_from_url_mistral() {
310        assert!(matches!(
311            detect_provider("any", "https://api.mistral.ai/v1"),
312            ProviderKind::Mistral
313        ));
314    }
315
316    #[test]
317    fn test_detect_from_url_together() {
318        assert!(matches!(
319            detect_provider("any", "https://api.together.xyz/v1"),
320            ProviderKind::Together
321        ));
322    }
323
324    #[test]
325    fn test_detect_from_url_localhost() {
326        assert!(matches!(
327            detect_provider("any", "http://localhost:11434/v1"),
328            ProviderKind::OpenAiCompatible
329        ));
330    }
331
332    #[test]
333    fn test_detect_from_model_claude() {
334        assert!(matches!(
335            detect_provider("claude-sonnet-4", ""),
336            ProviderKind::Anthropic
337        ));
338        assert!(matches!(
339            detect_provider("claude-opus-4", ""),
340            ProviderKind::Anthropic
341        ));
342    }
343
344    #[test]
345    fn test_detect_from_model_gpt() {
346        assert!(matches!(
347            detect_provider("gpt-4.1-mini", ""),
348            ProviderKind::OpenAi
349        ));
350        assert!(matches!(
351            detect_provider("o3-mini", ""),
352            ProviderKind::OpenAi
353        ));
354    }
355
356    #[test]
357    fn test_detect_from_model_grok() {
358        assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
359    }
360
361    #[test]
362    fn test_detect_from_model_gemini() {
363        assert!(matches!(
364            detect_provider("gemini-2.5-flash", ""),
365            ProviderKind::Google
366        ));
367    }
368
369    #[test]
370    fn test_detect_unknown_defaults_openai_compat() {
371        assert!(matches!(
372            detect_provider("some-random-model", "https://my-server.com"),
373            ProviderKind::OpenAiCompatible
374        ));
375    }
376
377    #[test]
378    fn test_url_takes_priority_over_model() {
379        // URL says OpenAI but model says Claude — URL wins.
380        assert!(matches!(
381            detect_provider("claude-sonnet", "https://api.openai.com/v1"),
382            ProviderKind::OpenAi
383        ));
384    }
385}