ares/llm/
client.rs

1use crate::types::{AppError, Result, ToolCall, ToolDefinition};
2use crate::utils::toml_config::{ModelConfig, ProviderConfig};
3use async_trait::async_trait;
4
5/// Generic LLM client trait for provider abstraction
6#[async_trait]
7pub trait LLMClient: Send + Sync {
8    /// Generate a completion from a prompt
9    async fn generate(&self, prompt: &str) -> Result<String>;
10
11    /// Generate with system prompt
12    async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String>;
13
14    /// Generate with conversation history
15    async fn generate_with_history(
16        &self,
17        messages: &[(String, String)], // (role, content) pairs
18    ) -> Result<String>;
19
20    /// Generate with tool calling support
21    async fn generate_with_tools(
22        &self,
23        prompt: &str,
24        tools: &[ToolDefinition],
25    ) -> Result<LLMResponse>;
26
27    /// Stream a completion
28    async fn stream(
29        &self,
30        prompt: &str,
31    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
32
33    /// Get the model name/identifier
34    fn model_name(&self) -> &str;
35}
36
37/// Response from an LLM generation call
38#[derive(Debug, Clone)]
39pub struct LLMResponse {
40    /// The generated text content
41    pub content: String,
42    /// Any tool calls the model wants to make
43    pub tool_calls: Vec<ToolCall>,
44    /// Reason the generation finished (e.g., "stop", "tool_calls", "length")
45    pub finish_reason: String,
46}
47
48/// LLM Provider configuration
49///
50/// Each variant is feature-gated to ensure only enabled providers are available.
51/// Use `Provider::from_env()` to automatically select based on environment variables.
52#[derive(Debug, Clone)]
53#[non_exhaustive]
54pub enum Provider {
55    /// OpenAI API and compatible endpoints (e.g., Azure OpenAI, local vLLM)
56    #[cfg(feature = "openai")]
57    OpenAI {
58        /// API key for authentication
59        api_key: String,
60        /// Base URL for the API (default: https://api.openai.com/v1)
61        api_base: String,
62        /// Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
63        model: String,
64    },
65
66    /// Ollama local inference server
67    #[cfg(feature = "ollama")]
68    Ollama {
69        /// Base URL for Ollama API (default: http://localhost:11434)
70        base_url: String,
71        /// Model name (e.g., "ministral-3:3b", "mistral", "qwen3-vl:2b")
72        model: String,
73    },
74
75    /// LlamaCpp for direct GGUF model loading
76    #[cfg(feature = "llamacpp")]
77    LlamaCpp {
78        /// Path to the GGUF model file
79        model_path: String,
80    },
81}
82
83impl Provider {
84    /// Create an LLM client from this provider configuration
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if:
89    /// - The provider cannot be initialized
90    /// - Required configuration is missing
91    /// - Network connectivity issues (for remote providers)
92    pub async fn create_client(&self) -> Result<Box<dyn LLMClient>> {
93        match self {
94            #[cfg(feature = "openai")]
95            Provider::OpenAI {
96                api_key,
97                api_base,
98                model,
99            } => Ok(Box::new(super::openai::OpenAIClient::new(
100                api_key.clone(),
101                api_base.clone(),
102                model.clone(),
103            ))),
104
105            #[cfg(feature = "ollama")]
106            Provider::Ollama { base_url, model } => Ok(Box::new(
107                super::ollama::OllamaClient::new(base_url.clone(), model.clone()).await?,
108            )),
109
110            #[cfg(feature = "llamacpp")]
111            Provider::LlamaCpp { model_path } => Ok(Box::new(
112                super::llamacpp::LlamaCppClient::new(model_path.clone())?,
113            )),
114        }
115    }
116
117    /// Create a provider from environment variables
118    ///
119    /// Provider priority (first match wins):
120    /// 1. **LlamaCpp** - if `LLAMACPP_MODEL_PATH` is set
121    /// 2. **OpenAI** - if `OPENAI_API_KEY` is set
122    /// 3. **Ollama** - default fallback for local inference
123    ///
124    /// # Environment Variables
125    ///
126    /// ## LlamaCpp
127    /// - `LLAMACPP_MODEL_PATH` - Path to GGUF model file (required)
128    ///
129    /// ## OpenAI
130    /// - `OPENAI_API_KEY` - API key (required)
131    /// - `OPENAI_API_BASE` - Base URL (default: https://api.openai.com/v1)
132    /// - `OPENAI_MODEL` - Model name (default: gpt-4)
133    ///
134    /// ## Ollama
135    /// - `OLLAMA_BASE_URL` - Server URL (default: http://localhost:11434)
136    /// - `OLLAMA_MODEL` - Model name (default: ministral-3:3b)
137    ///
138    /// # Errors
139    ///
140    /// Returns an error if no LLM provider features are enabled or configured.
141    ///
142    /// # Example
143    ///
144    /// ```rust,ignore
145    /// // Set environment variables
146    /// std::env::set_var("OLLAMA_MODEL", "ministral-3:3b");
147    ///
148    /// let provider = Provider::from_env()?;
149    /// let client = provider.create_client().await?;
150    /// ```
151    #[allow(unreachable_code)]
152    pub fn from_env() -> Result<Self> {
153        // Check for LlamaCpp first (direct GGUF loading - highest priority when configured)
154        #[cfg(feature = "llamacpp")]
155        if let Ok(model_path) = std::env::var("LLAMACPP_MODEL_PATH") {
156            if !model_path.is_empty() {
157                return Ok(Provider::LlamaCpp { model_path });
158            }
159        }
160
161        // Check for OpenAI (requires explicit API key configuration)
162        #[cfg(feature = "openai")]
163        if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
164            if !api_key.is_empty() {
165                let api_base = std::env::var("OPENAI_API_BASE")
166                    .unwrap_or_else(|_| "https://api.openai.com/v1".into());
167                let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4".into());
168                return Ok(Provider::OpenAI {
169                    api_key,
170                    api_base,
171                    model,
172                });
173            }
174        }
175
176        // Ollama as default local inference (no API key required)
177        #[cfg(feature = "ollama")]
178        {
179            // Accept both OLLAMA_URL (preferred) and legacy OLLAMA_BASE_URL
180            let base_url = std::env::var("OLLAMA_URL")
181                .or_else(|_| std::env::var("OLLAMA_BASE_URL"))
182                .unwrap_or_else(|_| "http://localhost:11434".into());
183            let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "ministral-3:3b".into());
184            return Ok(Provider::Ollama { base_url, model });
185        }
186
187        // No provider available
188        #[allow(unreachable_code)]
189        Err(AppError::Configuration(
190            "No LLM provider configured. Enable a feature (ollama, openai, llamacpp) and set the appropriate environment variables.".into(),
191        ))
192    }
193
194    /// Get the provider name as a string
195    pub fn name(&self) -> &'static str {
196        match self {
197            #[cfg(feature = "openai")]
198            Provider::OpenAI { .. } => "openai",
199
200            #[cfg(feature = "ollama")]
201            Provider::Ollama { .. } => "ollama",
202
203            #[cfg(feature = "llamacpp")]
204            Provider::LlamaCpp { .. } => "llamacpp",
205        }
206    }
207
208    /// Check if this provider requires an API key
209    pub fn requires_api_key(&self) -> bool {
210        match self {
211            #[cfg(feature = "openai")]
212            Provider::OpenAI { .. } => true,
213
214            #[cfg(feature = "ollama")]
215            Provider::Ollama { .. } => false,
216
217            #[cfg(feature = "llamacpp")]
218            Provider::LlamaCpp { .. } => false,
219        }
220    }
221
222    /// Check if this provider is local (no network required)
223    pub fn is_local(&self) -> bool {
224        match self {
225            #[cfg(feature = "openai")]
226            Provider::OpenAI { api_base, .. } => {
227                api_base.contains("localhost") || api_base.contains("127.0.0.1")
228            }
229
230            #[cfg(feature = "ollama")]
231            Provider::Ollama { base_url, .. } => {
232                base_url.contains("localhost") || base_url.contains("127.0.0.1")
233            }
234
235            #[cfg(feature = "llamacpp")]
236            Provider::LlamaCpp { .. } => true,
237        }
238    }
239
240    /// Create a provider from TOML configuration
241    ///
242    /// # Arguments
243    ///
244    /// * `provider_config` - The provider configuration from ares.toml
245    /// * `model_override` - Optional model name to override the provider default
246    ///
247    /// # Errors
248    ///
249    /// Returns an error if the provider type doesn't match an enabled feature
250    /// or if required environment variables are not set.
251    #[allow(unused_variables)]
252    pub fn from_config(
253        provider_config: &ProviderConfig,
254        model_override: Option<&str>,
255    ) -> Result<Self> {
256        match provider_config {
257            #[cfg(feature = "ollama")]
258            ProviderConfig::Ollama {
259                base_url,
260                default_model,
261            } => Ok(Provider::Ollama {
262                base_url: base_url.clone(),
263                model: model_override
264                    .map(String::from)
265                    .unwrap_or_else(|| default_model.clone()),
266            }),
267
268            #[cfg(not(feature = "ollama"))]
269            ProviderConfig::Ollama { .. } => Err(AppError::Configuration(
270                "Ollama provider configured but 'ollama' feature is not enabled".into(),
271            )),
272
273            #[cfg(feature = "openai")]
274            ProviderConfig::OpenAI {
275                api_key_env,
276                api_base,
277                default_model,
278            } => {
279                let api_key = std::env::var(api_key_env).map_err(|_| {
280                    AppError::Configuration(format!(
281                        "OpenAI API key environment variable '{}' is not set",
282                        api_key_env
283                    ))
284                })?;
285                Ok(Provider::OpenAI {
286                    api_key,
287                    api_base: api_base.clone(),
288                    model: model_override
289                        .map(String::from)
290                        .unwrap_or_else(|| default_model.clone()),
291                })
292            }
293
294            #[cfg(not(feature = "openai"))]
295            ProviderConfig::OpenAI { .. } => Err(AppError::Configuration(
296                "OpenAI provider configured but 'openai' feature is not enabled".into(),
297            )),
298
299            #[cfg(feature = "llamacpp")]
300            ProviderConfig::LlamaCpp { model_path, .. } => Ok(Provider::LlamaCpp {
301                model_path: model_path.clone(),
302            }),
303
304            #[cfg(not(feature = "llamacpp"))]
305            ProviderConfig::LlamaCpp { .. } => Err(AppError::Configuration(
306                "LlamaCpp provider configured but 'llamacpp' feature is not enabled".into(),
307            )),
308        }
309    }
310
311    /// Create a provider from a model configuration and its associated provider config
312    ///
313    /// This is the primary way to create providers from TOML config, as it resolves
314    /// the model -> provider reference chain.
315    pub fn from_model_config(
316        model_config: &ModelConfig,
317        provider_config: &ProviderConfig,
318    ) -> Result<Self> {
319        Self::from_config(provider_config, Some(&model_config.model))
320    }
321}
322
323/// Trait abstraction for LLM client factories (useful for mocking in tests)
324#[async_trait]
325pub trait LLMClientFactoryTrait: Send + Sync {
326    /// Get the default provider configuration
327    fn default_provider(&self) -> &Provider;
328
329    /// Create an LLM client using the default provider
330    async fn create_default(&self) -> Result<Box<dyn LLMClient>>;
331
332    /// Create an LLM client using a specific provider
333    async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>>;
334}
335
336/// Configuration-based LLM client factory
337///
338/// Provides a convenient way to create LLM clients with a default provider
339/// while allowing runtime provider switching.
340pub struct LLMClientFactory {
341    default_provider: Provider,
342}
343
344impl LLMClientFactory {
345    /// Create a new factory with a specific default provider
346    pub fn new(default_provider: Provider) -> Self {
347        Self { default_provider }
348    }
349
350    /// Create a factory from environment variables
351    ///
352    /// Uses `Provider::from_env()` to determine the default provider.
353    pub fn from_env() -> Result<Self> {
354        Ok(Self {
355            default_provider: Provider::from_env()?,
356        })
357    }
358
359    /// Get the default provider configuration
360    pub fn default_provider(&self) -> &Provider {
361        &self.default_provider
362    }
363
364    /// Create an LLM client using the default provider
365    pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
366        self.default_provider.create_client().await
367    }
368
369    /// Create an LLM client using a specific provider
370    pub async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
371        provider.create_client().await
372    }
373}
374
375#[async_trait]
376impl LLMClientFactoryTrait for LLMClientFactory {
377    fn default_provider(&self) -> &Provider {
378        &self.default_provider
379    }
380
381    async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
382        self.default_provider.create_client().await
383    }
384
385    async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
386        provider.create_client().await
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn test_llm_response_creation() {
396        let response = LLMResponse {
397            content: "Hello".to_string(),
398            tool_calls: vec![],
399            finish_reason: "stop".to_string(),
400        };
401
402        assert_eq!(response.content, "Hello");
403        assert!(response.tool_calls.is_empty());
404        assert_eq!(response.finish_reason, "stop");
405    }
406
407    #[test]
408    fn test_llm_response_with_tool_calls() {
409        let tool_calls = vec![
410            ToolCall {
411                id: "1".to_string(),
412                name: "calculator".to_string(),
413                arguments: serde_json::json!({"a": 1, "b": 2}),
414            },
415            ToolCall {
416                id: "2".to_string(),
417                name: "search".to_string(),
418                arguments: serde_json::json!({"query": "test"}),
419            },
420        ];
421
422        let response = LLMResponse {
423            content: "".to_string(),
424            tool_calls,
425            finish_reason: "tool_calls".to_string(),
426        };
427
428        assert_eq!(response.tool_calls.len(), 2);
429        assert_eq!(response.tool_calls[0].name, "calculator");
430        assert_eq!(response.finish_reason, "tool_calls");
431    }
432
433    #[test]
434    fn test_factory_creation() {
435        // This test just verifies the factory can be created
436        // Actual provider tests require feature flags
437        #[cfg(feature = "ollama")]
438        {
439            let factory = LLMClientFactory::new(Provider::Ollama {
440                base_url: "http://localhost:11434".to_string(),
441                model: "test".to_string(),
442            });
443            assert_eq!(factory.default_provider().name(), "ollama");
444        }
445    }
446
447    #[cfg(feature = "ollama")]
448    #[test]
449    fn test_ollama_provider_properties() {
450        let provider = Provider::Ollama {
451            base_url: "http://localhost:11434".to_string(),
452            model: "ministral-3:3b".to_string(),
453        };
454
455        assert_eq!(provider.name(), "ollama");
456        assert!(!provider.requires_api_key());
457        assert!(provider.is_local());
458    }
459
460    #[cfg(feature = "openai")]
461    #[test]
462    fn test_openai_provider_properties() {
463        let provider = Provider::OpenAI {
464            api_key: "sk-test".to_string(),
465            api_base: "https://api.openai.com/v1".to_string(),
466            model: "gpt-4".to_string(),
467        };
468
469        assert_eq!(provider.name(), "openai");
470        assert!(provider.requires_api_key());
471        assert!(!provider.is_local());
472    }
473
474    #[cfg(feature = "openai")]
475    #[test]
476    fn test_openai_local_provider() {
477        let provider = Provider::OpenAI {
478            api_key: "test".to_string(),
479            api_base: "http://localhost:8000/v1".to_string(),
480            model: "local-model".to_string(),
481        };
482
483        assert!(provider.is_local());
484    }
485
486    #[cfg(feature = "llamacpp")]
487    #[test]
488    fn test_llamacpp_provider_properties() {
489        let provider = Provider::LlamaCpp {
490            model_path: "/path/to/model.gguf".to_string(),
491        };
492
493        assert_eq!(provider.name(), "llamacpp");
494        assert!(!provider.requires_api_key());
495        assert!(provider.is_local());
496    }
497}