Skip to main content

ares/llm/
client.rs

1use crate::types::{AppError, Result, ToolCall, ToolDefinition};
2use crate::utils::toml_config::{ModelConfig, ProviderConfig};
3use async_trait::async_trait;
4
5/// Generic LLM client trait for provider abstraction
6#[async_trait]
7pub trait LLMClient: Send + Sync {
8    /// Generate a completion from a prompt
9    async fn generate(&self, prompt: &str) -> Result<String>;
10
11    /// Generate with system prompt
12    async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String>;
13
14    /// Generate with conversation history
15    async fn generate_with_history(
16        &self,
17        messages: &[(String, String)], // (role, content) pairs
18    ) -> Result<String>;
19
20    /// Generate with tool calling support
21    async fn generate_with_tools(
22        &self,
23        prompt: &str,
24        tools: &[ToolDefinition],
25    ) -> Result<LLMResponse>;
26
27    /// Stream a completion
28    async fn stream(
29        &self,
30        prompt: &str,
31    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
32
33    /// Stream a completion with system prompt
34    async fn stream_with_system(
35        &self,
36        system: &str,
37        prompt: &str,
38    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
39
40    /// Stream a completion with conversation history
41    async fn stream_with_history(
42        &self,
43        messages: &[(String, String)], // (role, content) pairs
44    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
45
46    /// Get the model name/identifier
47    fn model_name(&self) -> &str;
48}
49
50/// Response from an LLM generation call
51#[derive(Debug, Clone)]
52pub struct LLMResponse {
53    /// The generated text content
54    pub content: String,
55    /// Any tool calls the model wants to make
56    pub tool_calls: Vec<ToolCall>,
57    /// Reason the generation finished (e.g., "stop", "tool_calls", "length")
58    pub finish_reason: String,
59}
60
61/// Model inference parameters
62#[derive(Debug, Clone, Default)]
63pub struct ModelParams {
64    /// Sampling temperature (0.0 = deterministic, 1.0+ = creative)
65    pub temperature: Option<f32>,
66    /// Maximum tokens to generate
67    pub max_tokens: Option<u32>,
68    /// Nucleus sampling parameter
69    pub top_p: Option<f32>,
70    /// Frequency penalty (-2.0 to 2.0)
71    pub frequency_penalty: Option<f32>,
72    /// Presence penalty (-2.0 to 2.0)
73    pub presence_penalty: Option<f32>,
74}
75
76impl ModelParams {
77    /// Create params from a ModelConfig
78    pub fn from_model_config(config: &ModelConfig) -> Self {
79        Self {
80            temperature: Some(config.temperature),
81            max_tokens: Some(config.max_tokens),
82            top_p: config.top_p,
83            frequency_penalty: config.frequency_penalty,
84            presence_penalty: config.presence_penalty,
85        }
86    }
87}
88
89/// LLM Provider configuration
90///
91/// Each variant is feature-gated to ensure only enabled providers are available.
92/// Use `Provider::from_env()` to automatically select based on environment variables.
93#[derive(Debug, Clone)]
94#[non_exhaustive]
95pub enum Provider {
96    /// OpenAI API and compatible endpoints (e.g., Azure OpenAI, local vLLM)
97    #[cfg(feature = "openai")]
98    OpenAI {
99        /// API key for authentication
100        api_key: String,
101        /// Base URL for the API (default: <https://api.openai.com/v1>)
102        api_base: String,
103        /// Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
104        model: String,
105        /// Model inference parameters
106        params: ModelParams,
107    },
108
109    /// Ollama local inference server
110    #[cfg(feature = "ollama")]
111    Ollama {
112        /// Base URL for Ollama API (default: http://localhost:11434)
113        base_url: String,
114        /// Model name (e.g., "ministral-3:3b", "mistral", "qwen3-vl:2b")
115        model: String,
116        /// Model inference parameters
117        params: ModelParams,
118    },
119
120    /// LlamaCpp for direct GGUF model loading
121    #[cfg(feature = "llamacpp")]
122    LlamaCpp {
123        /// Path to the GGUF model file
124        model_path: String,
125        /// Model inference parameters
126        params: ModelParams,
127    },
128}
129
130impl Provider {
131    /// Create an LLM client from this provider configuration
132    ///
133    /// # Errors
134    ///
135    /// Returns an error if:
136    /// - The provider cannot be initialized
137    /// - Required configuration is missing
138    /// - Network connectivity issues (for remote providers)
139    #[allow(unreachable_patterns)]
140    pub async fn create_client(&self) -> Result<Box<dyn LLMClient>> {
141        match self {
142            #[cfg(feature = "openai")]
143            Provider::OpenAI {
144                api_key,
145                api_base,
146                model,
147                params,
148            } => Ok(Box::new(super::openai::OpenAIClient::with_params(
149                api_key.clone(),
150                api_base.clone(),
151                model.clone(),
152                params.clone(),
153            ))),
154
155            #[cfg(feature = "ollama")]
156            Provider::Ollama {
157                base_url,
158                model,
159                params,
160            } => Ok(Box::new(
161                super::ollama::OllamaClient::with_params(
162                    base_url.clone(),
163                    model.clone(),
164                    params.clone(),
165                )
166                .await?,
167            )),
168
169            #[cfg(feature = "llamacpp")]
170            Provider::LlamaCpp { model_path, params } => Ok(Box::new(
171                super::llamacpp::LlamaCppClient::with_params(model_path.clone(), params.clone())?,
172            )),
173            _ => unreachable!("Provider variant not enabled"),
174        }
175    }
176
177    /// Create a provider from environment variables
178    ///
179    /// Provider priority (first match wins):
180    /// 1. **LlamaCpp** - if `LLAMACPP_MODEL_PATH` is set
181    /// 2. **OpenAI** - if `OPENAI_API_KEY` is set
182    /// 3. **Ollama** - default fallback for local inference
183    ///
184    /// # Environment Variables
185    ///
186    /// ## LlamaCpp
187    /// - `LLAMACPP_MODEL_PATH` - Path to GGUF model file (required)
188    ///
189    /// ## OpenAI
190    /// - `OPENAI_API_KEY` - API key (required)
191    /// - `OPENAI_API_BASE` - Base URL (default: <https://api.openai.com/v1>)
192    /// - `OPENAI_MODEL` - Model name (default: gpt-4)
193    ///
194    /// ## Ollama
195    /// - `OLLAMA_BASE_URL` - Server URL (default: http://localhost:11434)
196    /// - `OLLAMA_MODEL` - Model name (default: ministral-3:3b)
197    ///
198    /// # Errors
199    ///
200    /// Returns an error if no LLM provider features are enabled or configured.
201    ///
202    /// # Example
203    ///
204    /// ```rust,ignore
205    /// // Set environment variables
206    /// std::env::set_var("OLLAMA_MODEL", "ministral-3:3b");
207    ///
208    /// let provider = Provider::from_env()?;
209    /// let client = provider.create_client().await?;
210    /// ```
211    #[allow(unreachable_code)]
212    pub fn from_env() -> Result<Self> {
213        // Check for LlamaCpp first (direct GGUF loading - highest priority when configured)
214        #[cfg(feature = "llamacpp")]
215        if let Ok(model_path) = std::env::var("LLAMACPP_MODEL_PATH") {
216            if !model_path.is_empty() {
217                return Ok(Provider::LlamaCpp {
218                    model_path,
219                    params: ModelParams::default(),
220                });
221            }
222        }
223
224        // Check for OpenAI (requires explicit API key configuration)
225        #[cfg(feature = "openai")]
226        if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
227            if !api_key.is_empty() {
228                let api_base = std::env::var("OPENAI_API_BASE")
229                    .unwrap_or_else(|_| "https://api.openai.com/v1".into());
230                let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4".into());
231                return Ok(Provider::OpenAI {
232                    api_key,
233                    api_base,
234                    model,
235                    params: ModelParams::default(),
236                });
237            }
238        }
239
240        // Ollama as default local inference (no API key required)
241        #[cfg(feature = "ollama")]
242        {
243            // Accept both OLLAMA_URL (preferred) and legacy OLLAMA_BASE_URL
244            let base_url = std::env::var("OLLAMA_URL")
245                .or_else(|_| std::env::var("OLLAMA_BASE_URL"))
246                .unwrap_or_else(|_| "http://localhost:11434".into());
247            let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "ministral-3:3b".into());
248            return Ok(Provider::Ollama {
249                base_url,
250                model,
251                params: ModelParams::default(),
252            });
253        }
254
255        // No provider available
256        #[allow(unreachable_code)]
257        Err(AppError::Configuration(
258            "No LLM provider configured. Enable a feature (ollama, openai, llamacpp) and set the appropriate environment variables.".into(),
259        ))
260    }
261
262    /// Get the provider name as a string
263    #[allow(unreachable_patterns)]
264    pub fn name(&self) -> &'static str {
265        match self {
266            #[cfg(feature = "openai")]
267            Provider::OpenAI { .. } => "openai",
268
269            #[cfg(feature = "ollama")]
270            Provider::Ollama { .. } => "ollama",
271
272            #[cfg(feature = "llamacpp")]
273            Provider::LlamaCpp { .. } => "llamacpp",
274            _ => unreachable!("Provider variant not enabled"),
275        }
276    }
277
278    /// Check if this provider requires an API key
279    #[allow(unreachable_patterns)]
280    pub fn requires_api_key(&self) -> bool {
281        match self {
282            #[cfg(feature = "openai")]
283            Provider::OpenAI { .. } => true,
284
285            #[cfg(feature = "ollama")]
286            Provider::Ollama { .. } => false,
287
288            #[cfg(feature = "llamacpp")]
289            Provider::LlamaCpp { .. } => false,
290            _ => unreachable!("Provider variant not enabled"),
291        }
292    }
293
294    /// Check if this provider is local (no network required)
295    #[allow(unreachable_patterns)]
296    pub fn is_local(&self) -> bool {
297        match self {
298            #[cfg(feature = "openai")]
299            Provider::OpenAI { api_base, .. } => {
300                api_base.contains("localhost") || api_base.contains("127.0.0.1")
301            }
302
303            #[cfg(feature = "ollama")]
304            Provider::Ollama { base_url, .. } => {
305                base_url.contains("localhost") || base_url.contains("127.0.0.1")
306            }
307
308            #[cfg(feature = "llamacpp")]
309            Provider::LlamaCpp { .. } => true,
310            _ => unreachable!("Provider variant not enabled"),
311        }
312    }
313
314    /// Create a provider from TOML configuration
315    ///
316    /// # Arguments
317    ///
318    /// * `provider_config` - The provider configuration from ares.toml
319    /// * `model_override` - Optional model name to override the provider default
320    ///
321    /// # Errors
322    ///
323    /// Returns an error if the provider type doesn't match an enabled feature
324    /// or if required environment variables are not set.
325    #[allow(unused_variables)]
326    pub fn from_config(
327        provider_config: &ProviderConfig,
328        model_override: Option<&str>,
329    ) -> Result<Self> {
330        Self::from_config_with_params(provider_config, model_override, ModelParams::default())
331    }
332
333    /// Create a provider from TOML configuration with model parameters
334    #[allow(unused_variables)]
335    pub fn from_config_with_params(
336        provider_config: &ProviderConfig,
337        model_override: Option<&str>,
338        params: ModelParams,
339    ) -> Result<Self> {
340        match provider_config {
341            #[cfg(feature = "ollama")]
342            ProviderConfig::Ollama {
343                base_url,
344                default_model,
345            } => Ok(Provider::Ollama {
346                base_url: base_url.clone(),
347                model: model_override
348                    .map(String::from)
349                    .unwrap_or_else(|| default_model.clone()),
350                params,
351            }),
352
353            #[cfg(not(feature = "ollama"))]
354            ProviderConfig::Ollama { .. } => Err(AppError::Configuration(
355                "Ollama provider configured but 'ollama' feature is not enabled".into(),
356            )),
357
358            #[cfg(feature = "openai")]
359            ProviderConfig::OpenAI {
360                api_key_env,
361                api_base,
362                default_model,
363            } => {
364                let api_key = std::env::var(api_key_env).map_err(|_| {
365                    AppError::Configuration(format!(
366                        "OpenAI API key environment variable '{}' is not set",
367                        api_key_env
368                    ))
369                })?;
370                Ok(Provider::OpenAI {
371                    api_key,
372                    api_base: api_base.clone(),
373                    model: model_override
374                        .map(String::from)
375                        .unwrap_or_else(|| default_model.clone()),
376                    params,
377                })
378            }
379
380            #[cfg(not(feature = "openai"))]
381            ProviderConfig::OpenAI { .. } => Err(AppError::Configuration(
382                "OpenAI provider configured but 'openai' feature is not enabled".into(),
383            )),
384
385            #[cfg(feature = "llamacpp")]
386            ProviderConfig::LlamaCpp { model_path, .. } => Ok(Provider::LlamaCpp {
387                model_path: model_path.clone(),
388                params,
389            }),
390
391            #[cfg(not(feature = "llamacpp"))]
392            ProviderConfig::LlamaCpp { .. } => Err(AppError::Configuration(
393                "LlamaCpp provider configured but 'llamacpp' feature is not enabled".into(),
394            )),
395        }
396    }
397
398    /// Create a provider from a model configuration and its associated provider config
399    ///
400    /// This is the primary way to create providers from TOML config, as it resolves
401    /// the model -> provider reference chain.
402    pub fn from_model_config(
403        model_config: &ModelConfig,
404        provider_config: &ProviderConfig,
405    ) -> Result<Self> {
406        let params = ModelParams::from_model_config(model_config);
407        Self::from_config_with_params(provider_config, Some(&model_config.model), params)
408    }
409}
410
411/// Trait abstraction for LLM client factories (useful for mocking in tests)
412#[async_trait]
413pub trait LLMClientFactoryTrait: Send + Sync {
414    /// Get the default provider configuration
415    fn default_provider(&self) -> &Provider;
416
417    /// Create an LLM client using the default provider
418    async fn create_default(&self) -> Result<Box<dyn LLMClient>>;
419
420    /// Create an LLM client using a specific provider
421    async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>>;
422}
423
424/// Configuration-based LLM client factory
425///
426/// Provides a convenient way to create LLM clients with a default provider
427/// while allowing runtime provider switching.
428pub struct LLMClientFactory {
429    default_provider: Provider,
430}
431
432impl LLMClientFactory {
433    /// Create a new factory with a specific default provider
434    pub fn new(default_provider: Provider) -> Self {
435        Self { default_provider }
436    }
437
438    /// Create a factory from environment variables
439    ///
440    /// Uses `Provider::from_env()` to determine the default provider.
441    pub fn from_env() -> Result<Self> {
442        Ok(Self {
443            default_provider: Provider::from_env()?,
444        })
445    }
446
447    /// Get the default provider configuration
448    pub fn default_provider(&self) -> &Provider {
449        &self.default_provider
450    }
451
452    /// Create an LLM client using the default provider
453    pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
454        self.default_provider.create_client().await
455    }
456
457    /// Create an LLM client using a specific provider
458    pub async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
459        provider.create_client().await
460    }
461}
462
463#[async_trait]
464impl LLMClientFactoryTrait for LLMClientFactory {
465    fn default_provider(&self) -> &Provider {
466        &self.default_provider
467    }
468
469    async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
470        self.default_provider.create_client().await
471    }
472
473    async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
474        provider.create_client().await
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[test]
483    fn test_llm_response_creation() {
484        let response = LLMResponse {
485            content: "Hello".to_string(),
486            tool_calls: vec![],
487            finish_reason: "stop".to_string(),
488        };
489
490        assert_eq!(response.content, "Hello");
491        assert!(response.tool_calls.is_empty());
492        assert_eq!(response.finish_reason, "stop");
493    }
494
495    #[test]
496    fn test_llm_response_with_tool_calls() {
497        let tool_calls = vec![
498            ToolCall {
499                id: "1".to_string(),
500                name: "calculator".to_string(),
501                arguments: serde_json::json!({"a": 1, "b": 2}),
502            },
503            ToolCall {
504                id: "2".to_string(),
505                name: "search".to_string(),
506                arguments: serde_json::json!({"query": "test"}),
507            },
508        ];
509
510        let response = LLMResponse {
511            content: "".to_string(),
512            tool_calls,
513            finish_reason: "tool_calls".to_string(),
514        };
515
516        assert_eq!(response.tool_calls.len(), 2);
517        assert_eq!(response.tool_calls[0].name, "calculator");
518        assert_eq!(response.finish_reason, "tool_calls");
519    }
520
521    #[test]
522    fn test_factory_creation() {
523        // This test just verifies the factory can be created
524        // Actual provider tests require feature flags
525        #[cfg(feature = "ollama")]
526        {
527            let factory = LLMClientFactory::new(Provider::Ollama {
528                base_url: "http://localhost:11434".to_string(),
529                model: "test".to_string(),
530                params: ModelParams::default(),
531            });
532            assert_eq!(factory.default_provider().name(), "ollama");
533        }
534    }
535
536    #[cfg(feature = "ollama")]
537    #[test]
538    fn test_ollama_provider_properties() {
539        let provider = Provider::Ollama {
540            base_url: "http://localhost:11434".to_string(),
541            model: "ministral-3:3b".to_string(),
542            params: ModelParams::default(),
543        };
544
545        assert_eq!(provider.name(), "ollama");
546        assert!(!provider.requires_api_key());
547        assert!(provider.is_local());
548    }
549
550    #[cfg(feature = "openai")]
551    #[test]
552    fn test_openai_provider_properties() {
553        let provider = Provider::OpenAI {
554            api_key: "sk-test".to_string(),
555            api_base: "https://api.openai.com/v1".to_string(),
556            model: "gpt-4".to_string(),
557            params: ModelParams::default(),
558        };
559
560        assert_eq!(provider.name(), "openai");
561        assert!(provider.requires_api_key());
562        assert!(!provider.is_local());
563    }
564
565    #[cfg(feature = "openai")]
566    #[test]
567    fn test_openai_local_provider() {
568        let provider = Provider::OpenAI {
569            api_key: "test".to_string(),
570            api_base: "http://localhost:8000/v1".to_string(),
571            model: "local-model".to_string(),
572            params: ModelParams::default(),
573        };
574
575        assert!(provider.is_local());
576    }
577
578    #[cfg(feature = "llamacpp")]
579    #[test]
580    fn test_llamacpp_provider_properties() {
581        let provider = Provider::LlamaCpp {
582            model_path: "/path/to/model.gguf".to_string(),
583            params: ModelParams::default(),
584        };
585
586        assert_eq!(provider.name(), "llamacpp");
587        assert!(!provider.requires_api_key());
588        assert!(provider.is_local());
589    }
590}