Skip to main content

agent_code_lib/llm/
provider.rs

1//! LLM provider abstraction.
2//!
3//! Two wire formats cover the entire ecosystem:
4//! - Anthropic Messages API (Claude models)
5//! - OpenAI Chat Completions (GPT, plus Groq, Together, Ollama, DeepSeek, etc.)
6//!
7//! Each provider translates between our unified message types and
8//! the provider-specific JSON format for requests and SSE streams.
9
10use async_trait::async_trait;
11use tokio::sync::mpsc;
12
13use super::message::Message;
14use super::stream::StreamEvent;
15use crate::tools::ToolSchema;
16
17/// Unified provider trait. Both Anthropic and OpenAI-compatible
18/// endpoints implement this.
19#[async_trait]
20pub trait Provider: Send + Sync {
21    /// Human-readable provider name.
22    fn name(&self) -> &str;
23
24    /// Send a streaming request. Returns a channel of events.
25    async fn stream(
26        &self,
27        request: &ProviderRequest,
28    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError>;
29}
30
31/// Tool choice mode for controlling tool usage.
32#[derive(Debug, Clone, Default)]
33pub enum ToolChoice {
34    /// Model decides whether to use tools.
35    #[default]
36    Auto,
37    /// Model must use a tool.
38    Any,
39    /// Model must not use tools.
40    None,
41    /// Model must use a specific tool.
42    Specific(String),
43}
44
45/// A provider-agnostic request.
46pub struct ProviderRequest {
47    pub messages: Vec<Message>,
48    pub system_prompt: String,
49    pub tools: Vec<ToolSchema>,
50    pub model: String,
51    pub max_tokens: u32,
52    pub temperature: Option<f64>,
53    pub enable_caching: bool,
54    /// Controls whether/how the model should use tools.
55    pub tool_choice: ToolChoice,
56    /// Metadata to send with the request (e.g., user_id for Anthropic).
57    pub metadata: Option<serde_json::Value>,
58}
59
60/// Provider-level errors.
61#[derive(Debug)]
62pub enum ProviderError {
63    Auth(String),
64    RateLimited { retry_after_ms: u64 },
65    Overloaded,
66    RequestTooLarge(String),
67    Network(String),
68    InvalidResponse(String),
69}
70
71impl std::fmt::Display for ProviderError {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            Self::Auth(msg) => write!(f, "auth: {msg}"),
75            Self::RateLimited { retry_after_ms } => {
76                write!(f, "rate limited (retry in {retry_after_ms}ms)")
77            }
78            Self::Overloaded => write!(f, "server overloaded"),
79            Self::RequestTooLarge(msg) => write!(f, "request too large: {msg}"),
80            Self::Network(msg) => write!(f, "network: {msg}"),
81            Self::InvalidResponse(msg) => write!(f, "invalid response: {msg}"),
82        }
83    }
84}
85
86/// Detect the right provider from a model name or base URL.
87pub fn detect_provider(model: &str, base_url: &str) -> ProviderKind {
88    let model_lower = model.to_lowercase();
89    let url_lower = base_url.to_lowercase();
90
91    // AWS Bedrock (Claude via AWS).
92    if url_lower.contains("bedrock") || url_lower.contains("amazonaws.com") {
93        return ProviderKind::Bedrock;
94    }
95    // Google Vertex AI (Claude via GCP).
96    if url_lower.contains("aiplatform.googleapis.com") {
97        return ProviderKind::Vertex;
98    }
99    if url_lower.contains("anthropic.com") {
100        return ProviderKind::Anthropic;
101    }
102    if url_lower.contains("openai.com") {
103        return ProviderKind::OpenAi;
104    }
105    if url_lower.contains("x.ai") || url_lower.contains("xai.") {
106        return ProviderKind::Xai;
107    }
108    if url_lower.contains("googleapis.com") || url_lower.contains("google") {
109        return ProviderKind::Google;
110    }
111    if url_lower.contains("deepseek.com") {
112        return ProviderKind::DeepSeek;
113    }
114    if url_lower.contains("groq.com") {
115        return ProviderKind::Groq;
116    }
117    if url_lower.contains("mistral.ai") {
118        return ProviderKind::Mistral;
119    }
120    if url_lower.contains("together.xyz") || url_lower.contains("together.ai") {
121        return ProviderKind::Together;
122    }
123    if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
124        return ProviderKind::OpenAiCompatible;
125    }
126
127    // Detect from model name.
128    if model_lower.starts_with("claude")
129        || model_lower.contains("opus")
130        || model_lower.contains("sonnet")
131        || model_lower.contains("haiku")
132    {
133        return ProviderKind::Anthropic;
134    }
135    if model_lower.starts_with("gpt")
136        || model_lower.starts_with("o1")
137        || model_lower.starts_with("o3")
138    {
139        return ProviderKind::OpenAi;
140    }
141    if model_lower.starts_with("grok") {
142        return ProviderKind::Xai;
143    }
144    if model_lower.starts_with("gemini") {
145        return ProviderKind::Google;
146    }
147    if model_lower.starts_with("deepseek") {
148        return ProviderKind::DeepSeek;
149    }
150    if model_lower.starts_with("llama") && url_lower.contains("groq") {
151        return ProviderKind::Groq;
152    }
153    if model_lower.starts_with("mistral") || model_lower.starts_with("codestral") {
154        return ProviderKind::Mistral;
155    }
156
157    ProviderKind::OpenAiCompatible
158}
159
160/// Provider kinds.
161#[derive(Debug, Clone, Copy, PartialEq, Eq)]
162pub enum ProviderKind {
163    Anthropic,
164    Bedrock,
165    Vertex,
166    OpenAi,
167    Xai,
168    Google,
169    DeepSeek,
170    Groq,
171    Mistral,
172    Together,
173    OpenAiCompatible,
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_detect_from_url_anthropic() {
182        assert!(matches!(
183            detect_provider("any", "https://api.anthropic.com/v1"),
184            ProviderKind::Anthropic
185        ));
186    }
187
188    #[test]
189    fn test_detect_from_url_openai() {
190        assert!(matches!(
191            detect_provider("any", "https://api.openai.com/v1"),
192            ProviderKind::OpenAi
193        ));
194    }
195
196    #[test]
197    fn test_detect_from_url_bedrock() {
198        assert!(matches!(
199            detect_provider("any", "https://bedrock-runtime.us-east-1.amazonaws.com"),
200            ProviderKind::Bedrock
201        ));
202    }
203
204    #[test]
205    fn test_detect_from_url_vertex() {
206        assert!(matches!(
207            detect_provider("any", "https://us-central1-aiplatform.googleapis.com/v1"),
208            ProviderKind::Vertex
209        ));
210    }
211
212    #[test]
213    fn test_detect_from_url_xai() {
214        assert!(matches!(
215            detect_provider("any", "https://api.x.ai/v1"),
216            ProviderKind::Xai
217        ));
218    }
219
220    #[test]
221    fn test_detect_from_url_deepseek() {
222        assert!(matches!(
223            detect_provider("any", "https://api.deepseek.com/v1"),
224            ProviderKind::DeepSeek
225        ));
226    }
227
228    #[test]
229    fn test_detect_from_url_groq() {
230        assert!(matches!(
231            detect_provider("any", "https://api.groq.com/openai/v1"),
232            ProviderKind::Groq
233        ));
234    }
235
236    #[test]
237    fn test_detect_from_url_mistral() {
238        assert!(matches!(
239            detect_provider("any", "https://api.mistral.ai/v1"),
240            ProviderKind::Mistral
241        ));
242    }
243
244    #[test]
245    fn test_detect_from_url_together() {
246        assert!(matches!(
247            detect_provider("any", "https://api.together.xyz/v1"),
248            ProviderKind::Together
249        ));
250    }
251
252    #[test]
253    fn test_detect_from_url_localhost() {
254        assert!(matches!(
255            detect_provider("any", "http://localhost:11434/v1"),
256            ProviderKind::OpenAiCompatible
257        ));
258    }
259
260    #[test]
261    fn test_detect_from_model_claude() {
262        assert!(matches!(
263            detect_provider("claude-sonnet-4", ""),
264            ProviderKind::Anthropic
265        ));
266        assert!(matches!(
267            detect_provider("claude-opus-4", ""),
268            ProviderKind::Anthropic
269        ));
270    }
271
272    #[test]
273    fn test_detect_from_model_gpt() {
274        assert!(matches!(
275            detect_provider("gpt-4.1-mini", ""),
276            ProviderKind::OpenAi
277        ));
278        assert!(matches!(
279            detect_provider("o3-mini", ""),
280            ProviderKind::OpenAi
281        ));
282    }
283
284    #[test]
285    fn test_detect_from_model_grok() {
286        assert!(matches!(detect_provider("grok-3", ""), ProviderKind::Xai));
287    }
288
289    #[test]
290    fn test_detect_from_model_gemini() {
291        assert!(matches!(
292            detect_provider("gemini-2.5-flash", ""),
293            ProviderKind::Google
294        ));
295    }
296
297    #[test]
298    fn test_detect_unknown_defaults_openai_compat() {
299        assert!(matches!(
300            detect_provider("some-random-model", "https://my-server.com"),
301            ProviderKind::OpenAiCompatible
302        ));
303    }
304
305    #[test]
306    fn test_url_takes_priority_over_model() {
307        // URL says OpenAI but model says Claude — URL wins.
308        assert!(matches!(
309            detect_provider("claude-sonnet", "https://api.openai.com/v1"),
310            ProviderKind::OpenAi
311        ));
312    }
313}