Skip to main content

ares/llm/
client.rs

1use crate::types::{AppError, Result, ToolCall, ToolDefinition};
2use crate::utils::toml_config::{ModelConfig, ProviderConfig};
3use async_trait::async_trait;
4
5/// Generic LLM client trait for provider abstraction
6#[async_trait]
7pub trait LLMClient: Send + Sync {
8    /// Generate a completion from a prompt
9    async fn generate(&self, prompt: &str) -> Result<String>;
10
11    /// Generate with system prompt
12    async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String>;
13
14    /// Generate with conversation history
15    async fn generate_with_history(
16        &self,
17        messages: &[(String, String)], // (role, content) pairs
18    ) -> Result<String>;
19
20    /// Generate with tool calling support
21    async fn generate_with_tools(
22        &self,
23        prompt: &str,
24        tools: &[ToolDefinition],
25    ) -> Result<LLMResponse>;
26
27    /// Stream a completion
28    async fn stream(
29        &self,
30        prompt: &str,
31    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
32
33    /// Stream a completion with system prompt
34    async fn stream_with_system(
35        &self,
36        system: &str,
37        prompt: &str,
38    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
39
40    /// Stream a completion with conversation history
41    async fn stream_with_history(
42        &self,
43        messages: &[(String, String)], // (role, content) pairs
44    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
45
46    /// Get the model name/identifier
47    fn model_name(&self) -> &str;
48}
49
50/// Token usage statistics from an LLM generation call
51#[derive(Debug, Clone, Default, PartialEq, Eq)]
52pub struct TokenUsage {
53    /// Number of tokens in the prompt/input
54    pub prompt_tokens: u32,
55    /// Number of tokens in the completion/output
56    pub completion_tokens: u32,
57    /// Total tokens used (prompt + completion)
58    pub total_tokens: u32,
59}
60
61impl TokenUsage {
62    /// Create a new TokenUsage with the given values
63    pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
64        Self {
65            prompt_tokens,
66            completion_tokens,
67            total_tokens: prompt_tokens + completion_tokens,
68        }
69    }
70}
71
72/// Response from an LLM generation call
73#[derive(Debug, Clone)]
74pub struct LLMResponse {
75    /// The generated text content
76    pub content: String,
77    /// Any tool calls the model wants to make
78    pub tool_calls: Vec<ToolCall>,
79    /// Reason the generation finished (e.g., "stop", "tool_calls", "length")
80    pub finish_reason: String,
81    /// Token usage statistics (if provided by the model)
82    pub usage: Option<TokenUsage>,
83}
84
85/// Model inference parameters
86#[derive(Debug, Clone, Default)]
87pub struct ModelParams {
88    /// Sampling temperature (0.0 = deterministic, 1.0+ = creative)
89    pub temperature: Option<f32>,
90    /// Maximum tokens to generate
91    pub max_tokens: Option<u32>,
92    /// Nucleus sampling parameter
93    pub top_p: Option<f32>,
94    /// Frequency penalty (-2.0 to 2.0)
95    pub frequency_penalty: Option<f32>,
96    /// Presence penalty (-2.0 to 2.0)
97    pub presence_penalty: Option<f32>,
98}
99
100impl ModelParams {
101    /// Create params from a ModelConfig
102    pub fn from_model_config(config: &ModelConfig) -> Self {
103        Self {
104            temperature: Some(config.temperature),
105            max_tokens: Some(config.max_tokens),
106            top_p: config.top_p,
107            frequency_penalty: config.frequency_penalty,
108            presence_penalty: config.presence_penalty,
109        }
110    }
111}
112
113/// LLM Provider configuration
114///
115/// Each variant is feature-gated to ensure only enabled providers are available.
116/// Use `Provider::from_env()` to automatically select based on environment variables.
117#[derive(Debug, Clone)]
118#[non_exhaustive]
119pub enum Provider {
120    /// OpenAI API and compatible endpoints (e.g., Azure OpenAI, local vLLM)
121    #[cfg(feature = "openai")]
122    OpenAI {
123        /// API key for authentication
124        api_key: String,
125        /// Base URL for the API (default: <https://api.openai.com/v1>)
126        api_base: String,
127        /// Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
128        model: String,
129        /// Model inference parameters
130        params: ModelParams,
131    },
132
133    /// Ollama local inference server
134    #[cfg(feature = "ollama")]
135    Ollama {
136        /// Base URL for Ollama API (default: http://localhost:11434)
137        base_url: String,
138        /// Model name (e.g., "ministral-3:3b", "mistral", "qwen3-vl:2b")
139        model: String,
140        /// Model inference parameters
141        params: ModelParams,
142    },
143
144    /// LlamaCpp for direct GGUF model loading
145    #[cfg(feature = "llamacpp")]
146    LlamaCpp {
147        /// Path to the GGUF model file
148        model_path: String,
149        /// Model inference parameters
150        params: ModelParams,
151    },
152
153    /// Anthropic Claude API
154    #[cfg(feature = "anthropic")]
155    Anthropic {
156        /// API key for authentication
157        api_key: String,
158        /// Model identifier (e.g., "claude-3-5-sonnet-20241022")
159        model: String,
160        /// Model inference parameters
161        params: ModelParams,
162    },
163}
164
165impl Provider {
166    /// Create an LLM client from this provider configuration
167    ///
168    /// # Errors
169    ///
170    /// Returns an error if:
171    /// - The provider cannot be initialized
172    /// - Required configuration is missing
173    /// - Network connectivity issues (for remote providers)
174    #[allow(unreachable_patterns)]
175    pub async fn create_client(&self) -> Result<Box<dyn LLMClient>> {
176        match self {
177            #[cfg(feature = "openai")]
178            Provider::OpenAI {
179                api_key,
180                api_base,
181                model,
182                params,
183            } => Ok(Box::new(super::openai::OpenAIClient::with_params(
184                api_key.clone(),
185                api_base.clone(),
186                model.clone(),
187                params.clone(),
188            ))),
189
190            #[cfg(feature = "ollama")]
191            Provider::Ollama {
192                base_url,
193                model,
194                params,
195            } => Ok(Box::new(
196                super::ollama::OllamaClient::with_params(
197                    base_url.clone(),
198                    model.clone(),
199                    params.clone(),
200                )
201                .await?,
202            )),
203
204            #[cfg(feature = "llamacpp")]
205            Provider::LlamaCpp { model_path, params } => Ok(Box::new(
206                super::llamacpp::LlamaCppClient::with_params(model_path.clone(), params.clone())?,
207            )),
208
209            #[cfg(feature = "anthropic")]
210            Provider::Anthropic {
211                api_key,
212                model,
213                params,
214            } => Ok(Box::new(super::anthropic::AnthropicClient::with_params(
215                api_key.clone(),
216                model.clone(),
217                params.clone(),
218            ))),
219            _ => unreachable!("Provider variant not enabled"),
220        }
221    }
222
223    /// Create a provider from environment variables
224    ///
225    /// Provider priority (first match wins):
226    /// 1. **LlamaCpp** - if `LLAMACPP_MODEL_PATH` is set
227    /// 2. **OpenAI** - if `OPENAI_API_KEY` is set
228    /// 3. **Ollama** - default fallback for local inference
229    ///
230    /// # Environment Variables
231    ///
232    /// ## LlamaCpp
233    /// - `LLAMACPP_MODEL_PATH` - Path to GGUF model file (required)
234    ///
235    /// ## OpenAI
236    /// - `OPENAI_API_KEY` - API key (required)
237    /// - `OPENAI_API_BASE` - Base URL (default: <https://api.openai.com/v1>)
238    /// - `OPENAI_MODEL` - Model name (default: gpt-4)
239    ///
240    /// ## Ollama
241    /// - `OLLAMA_BASE_URL` - Server URL (default: http://localhost:11434)
242    /// - `OLLAMA_MODEL` - Model name (default: ministral-3:3b)
243    ///
244    /// # Errors
245    ///
246    /// Returns an error if no LLM provider features are enabled or configured.
247    ///
248    /// # Example
249    ///
250    /// ```rust,ignore
251    /// // Set environment variables
252    /// std::env::set_var("OLLAMA_MODEL", "ministral-3:3b");
253    ///
254    /// let provider = Provider::from_env()?;
255    /// let client = provider.create_client().await?;
256    /// ```
257    #[allow(unreachable_code)]
258    pub fn from_env() -> Result<Self> {
259        // Check for LlamaCpp first (direct GGUF loading - highest priority when configured)
260        #[cfg(feature = "llamacpp")]
261        if let Ok(model_path) = std::env::var("LLAMACPP_MODEL_PATH") {
262            if !model_path.is_empty() {
263                return Ok(Provider::LlamaCpp {
264                    model_path,
265                    params: ModelParams::default(),
266                });
267            }
268        }
269
270        // Check for OpenAI (requires explicit API key configuration)
271        #[cfg(feature = "openai")]
272        if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
273            if !api_key.is_empty() {
274                let api_base = std::env::var("OPENAI_API_BASE")
275                    .unwrap_or_else(|_| "https://api.openai.com/v1".into());
276                let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4".into());
277                return Ok(Provider::OpenAI {
278                    api_key,
279                    api_base,
280                    model,
281                    params: ModelParams::default(),
282                });
283            }
284        }
285
286        // Check for Anthropic (requires explicit API key configuration)
287        #[cfg(feature = "anthropic")]
288        if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
289            if !api_key.is_empty() {
290                let model = std::env::var("ANTHROPIC_MODEL")
291                    .unwrap_or_else(|_| "claude-3-5-sonnet-20241022".into());
292                return Ok(Provider::Anthropic {
293                    api_key,
294                    model,
295                    params: ModelParams::default(),
296                });
297            }
298        }
299
300        // Ollama as default local inference (no API key required)
301        #[cfg(feature = "ollama")]
302        {
303            // Accept both OLLAMA_URL (preferred) and legacy OLLAMA_BASE_URL
304            let base_url = std::env::var("OLLAMA_URL")
305                .or_else(|_| std::env::var("OLLAMA_BASE_URL"))
306                .unwrap_or_else(|_| "http://localhost:11434".into());
307            let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "ministral-3:3b".into());
308            return Ok(Provider::Ollama {
309                base_url,
310                model,
311                params: ModelParams::default(),
312            });
313        }
314
315        // No provider available
316        #[allow(unreachable_code)]
317        Err(AppError::Configuration(
318            "No LLM provider configured. Enable a feature (ollama, openai, llamacpp) and set the appropriate environment variables.".into(),
319        ))
320    }
321
322    /// Get the provider name as a string
323    #[allow(unreachable_patterns)]
324    pub fn name(&self) -> &'static str {
325        match self {
326            #[cfg(feature = "openai")]
327            Provider::OpenAI { .. } => "openai",
328
329            #[cfg(feature = "ollama")]
330            Provider::Ollama { .. } => "ollama",
331
332            #[cfg(feature = "llamacpp")]
333            Provider::LlamaCpp { .. } => "llamacpp",
334
335            #[cfg(feature = "anthropic")]
336            Provider::Anthropic { .. } => "anthropic",
337            _ => unreachable!("Provider variant not enabled"),
338        }
339    }
340
341    /// Check if this provider requires an API key
342    #[allow(unreachable_patterns)]
343    pub fn requires_api_key(&self) -> bool {
344        match self {
345            #[cfg(feature = "openai")]
346            Provider::OpenAI { .. } => true,
347
348            #[cfg(feature = "ollama")]
349            Provider::Ollama { .. } => false,
350
351            #[cfg(feature = "llamacpp")]
352            Provider::LlamaCpp { .. } => false,
353
354            #[cfg(feature = "anthropic")]
355            Provider::Anthropic { .. } => true,
356            _ => unreachable!("Provider variant not enabled"),
357        }
358    }
359
360    /// Check if this provider is local (no network required)
361    #[allow(unreachable_patterns)]
362    pub fn is_local(&self) -> bool {
363        match self {
364            #[cfg(feature = "openai")]
365            Provider::OpenAI { api_base, .. } => {
366                api_base.contains("localhost") || api_base.contains("127.0.0.1")
367            }
368
369            #[cfg(feature = "ollama")]
370            Provider::Ollama { base_url, .. } => {
371                base_url.contains("localhost") || base_url.contains("127.0.0.1")
372            }
373
374            #[cfg(feature = "llamacpp")]
375            Provider::LlamaCpp { .. } => true,
376
377            #[cfg(feature = "anthropic")]
378            Provider::Anthropic { .. } => false,
379            _ => unreachable!("Provider variant not enabled"),
380        }
381    }
382
383    /// Create a provider from TOML configuration
384    ///
385    /// # Arguments
386    ///
387    /// * `provider_config` - The provider configuration from ares.toml
388    /// * `model_override` - Optional model name to override the provider default
389    ///
390    /// # Errors
391    ///
392    /// Returns an error if the provider type doesn't match an enabled feature
393    /// or if required environment variables are not set.
394    #[allow(unused_variables)]
395    pub fn from_config(
396        provider_config: &ProviderConfig,
397        model_override: Option<&str>,
398    ) -> Result<Self> {
399        Self::from_config_with_params(provider_config, model_override, ModelParams::default())
400    }
401
402    /// Create a provider from TOML configuration with model parameters
403    #[allow(unused_variables)]
404    pub fn from_config_with_params(
405        provider_config: &ProviderConfig,
406        model_override: Option<&str>,
407        params: ModelParams,
408    ) -> Result<Self> {
409        match provider_config {
410            #[cfg(feature = "ollama")]
411            ProviderConfig::Ollama {
412                base_url,
413                default_model,
414            } => Ok(Provider::Ollama {
415                base_url: base_url.clone(),
416                model: model_override
417                    .map(String::from)
418                    .unwrap_or_else(|| default_model.clone()),
419                params,
420            }),
421
422            #[cfg(not(feature = "ollama"))]
423            ProviderConfig::Ollama { .. } => Err(AppError::Configuration(
424                "Ollama provider configured but 'ollama' feature is not enabled".into(),
425            )),
426
427            #[cfg(feature = "openai")]
428            ProviderConfig::OpenAI {
429                api_key_env,
430                api_base,
431                default_model,
432            } => {
433                let api_key = std::env::var(api_key_env).map_err(|_| {
434                    AppError::Configuration(format!(
435                        "OpenAI API key environment variable '{}' is not set",
436                        api_key_env
437                    ))
438                })?;
439                Ok(Provider::OpenAI {
440                    api_key,
441                    api_base: api_base.clone(),
442                    model: model_override
443                        .map(String::from)
444                        .unwrap_or_else(|| default_model.clone()),
445                    params,
446                })
447            }
448
449            #[cfg(not(feature = "openai"))]
450            ProviderConfig::OpenAI { .. } => Err(AppError::Configuration(
451                "OpenAI provider configured but 'openai' feature is not enabled".into(),
452            )),
453
454            #[cfg(feature = "llamacpp")]
455            ProviderConfig::LlamaCpp { model_path, .. } => Ok(Provider::LlamaCpp {
456                model_path: model_path.clone(),
457                params,
458            }),
459
460            #[cfg(not(feature = "llamacpp"))]
461            ProviderConfig::LlamaCpp { .. } => Err(AppError::Configuration(
462                "LlamaCpp provider configured but 'llamacpp' feature is not enabled".into(),
463            )),
464
465            #[cfg(feature = "anthropic")]
466            ProviderConfig::Anthropic {
467                api_key_env,
468                default_model,
469            } => {
470                let api_key = std::env::var(api_key_env).map_err(|_| {
471                    AppError::Configuration(format!(
472                        "Anthropic API key environment variable '{}' is not set",
473                        api_key_env
474                    ))
475                })?;
476                Ok(Provider::Anthropic {
477                    api_key,
478                    model: model_override
479                        .map(String::from)
480                        .unwrap_or_else(|| default_model.clone()),
481                    params,
482                })
483            }
484
485            #[cfg(not(feature = "anthropic"))]
486            ProviderConfig::Anthropic { .. } => Err(AppError::Configuration(
487                "Anthropic provider configured but 'anthropic' feature is not enabled".into(),
488            )),
489        }
490    }
491
492    /// Create a provider from a model configuration and its associated provider config
493    ///
494    /// This is the primary way to create providers from TOML config, as it resolves
495    /// the model -> provider reference chain.
496    pub fn from_model_config(
497        model_config: &ModelConfig,
498        provider_config: &ProviderConfig,
499    ) -> Result<Self> {
500        let params = ModelParams::from_model_config(model_config);
501        Self::from_config_with_params(provider_config, Some(&model_config.model), params)
502    }
503}
504
505/// Trait abstraction for LLM client factories (useful for mocking in tests)
506#[async_trait]
507pub trait LLMClientFactoryTrait: Send + Sync {
508    /// Get the default provider configuration
509    fn default_provider(&self) -> &Provider;
510
511    /// Create an LLM client using the default provider
512    async fn create_default(&self) -> Result<Box<dyn LLMClient>>;
513
514    /// Create an LLM client using a specific provider
515    async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>>;
516}
517
518/// Configuration-based LLM client factory
519///
520/// Provides a convenient way to create LLM clients with a default provider
521/// while allowing runtime provider switching.
522pub struct LLMClientFactory {
523    default_provider: Provider,
524}
525
526impl LLMClientFactory {
527    /// Create a new factory with a specific default provider
528    pub fn new(default_provider: Provider) -> Self {
529        Self { default_provider }
530    }
531
532    /// Create a factory from environment variables
533    ///
534    /// Uses `Provider::from_env()` to determine the default provider.
535    pub fn from_env() -> Result<Self> {
536        Ok(Self {
537            default_provider: Provider::from_env()?,
538        })
539    }
540
541    /// Get the default provider configuration
542    pub fn default_provider(&self) -> &Provider {
543        &self.default_provider
544    }
545
546    /// Create an LLM client using the default provider
547    pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
548        self.default_provider.create_client().await
549    }
550
551    /// Create an LLM client using a specific provider
552    pub async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
553        provider.create_client().await
554    }
555}
556
557#[async_trait]
558impl LLMClientFactoryTrait for LLMClientFactory {
559    fn default_provider(&self) -> &Provider {
560        &self.default_provider
561    }
562
563    async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
564        self.default_provider.create_client().await
565    }
566
567    async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
568        provider.create_client().await
569    }
570}
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575
576    #[test]
577    fn test_llm_response_creation() {
578        let response = LLMResponse {
579            content: "Hello".to_string(),
580            tool_calls: vec![],
581            finish_reason: "stop".to_string(),
582            usage: None,
583        };
584
585        assert_eq!(response.content, "Hello");
586        assert!(response.tool_calls.is_empty());
587        assert_eq!(response.finish_reason, "stop");
588        assert!(response.usage.is_none());
589    }
590
591    #[test]
592    fn test_llm_response_with_usage() {
593        let usage = TokenUsage::new(100, 50);
594        let response = LLMResponse {
595            content: "Hello".to_string(),
596            tool_calls: vec![],
597            finish_reason: "stop".to_string(),
598            usage: Some(usage),
599        };
600
601        assert!(response.usage.is_some());
602        let usage = response.usage.unwrap();
603        assert_eq!(usage.prompt_tokens, 100);
604        assert_eq!(usage.completion_tokens, 50);
605        assert_eq!(usage.total_tokens, 150);
606    }
607
608    #[test]
609    fn test_llm_response_with_tool_calls() {
610        let tool_calls = vec![
611            ToolCall {
612                id: "1".to_string(),
613                name: "calculator".to_string(),
614                arguments: serde_json::json!({"a": 1, "b": 2}),
615            },
616            ToolCall {
617                id: "2".to_string(),
618                name: "search".to_string(),
619                arguments: serde_json::json!({"query": "test"}),
620            },
621        ];
622
623        let response = LLMResponse {
624            content: "".to_string(),
625            tool_calls,
626            finish_reason: "tool_calls".to_string(),
627            usage: Some(TokenUsage::new(50, 25)),
628        };
629
630        assert_eq!(response.tool_calls.len(), 2);
631        assert_eq!(response.tool_calls[0].name, "calculator");
632        assert_eq!(response.finish_reason, "tool_calls");
633        assert_eq!(response.usage.as_ref().unwrap().total_tokens, 75);
634    }
635
636    #[test]
637    fn test_factory_creation() {
638        // This test just verifies the factory can be created
639        // Actual provider tests require feature flags
640        #[cfg(feature = "ollama")]
641        {
642            let factory = LLMClientFactory::new(Provider::Ollama {
643                base_url: "http://localhost:11434".to_string(),
644                model: "test".to_string(),
645                params: ModelParams::default(),
646            });
647            assert_eq!(factory.default_provider().name(), "ollama");
648        }
649    }
650
651    #[cfg(feature = "ollama")]
652    #[test]
653    fn test_ollama_provider_properties() {
654        let provider = Provider::Ollama {
655            base_url: "http://localhost:11434".to_string(),
656            model: "ministral-3:3b".to_string(),
657            params: ModelParams::default(),
658        };
659
660        assert_eq!(provider.name(), "ollama");
661        assert!(!provider.requires_api_key());
662        assert!(provider.is_local());
663    }
664
665    #[cfg(feature = "openai")]
666    #[test]
667    fn test_openai_provider_properties() {
668        let provider = Provider::OpenAI {
669            api_key: "sk-test".to_string(),
670            api_base: "https://api.openai.com/v1".to_string(),
671            model: "gpt-4".to_string(),
672            params: ModelParams::default(),
673        };
674
675        assert_eq!(provider.name(), "openai");
676        assert!(provider.requires_api_key());
677        assert!(!provider.is_local());
678    }
679
680    #[cfg(feature = "openai")]
681    #[test]
682    fn test_openai_local_provider() {
683        let provider = Provider::OpenAI {
684            api_key: "test".to_string(),
685            api_base: "http://localhost:8000/v1".to_string(),
686            model: "local-model".to_string(),
687            params: ModelParams::default(),
688        };
689
690        assert!(provider.is_local());
691    }
692
693    #[cfg(feature = "llamacpp")]
694    #[test]
695    fn test_llamacpp_provider_properties() {
696        let provider = Provider::LlamaCpp {
697            model_path: "/path/to/model.gguf".to_string(),
698            params: ModelParams::default(),
699        };
700
701        assert_eq!(provider.name(), "llamacpp");
702        assert!(!provider.requires_api_key());
703        assert!(provider.is_local());
704    }
705
706    #[cfg(feature = "anthropic")]
707    #[test]
708    fn test_anthropic_provider_properties() {
709        let provider = Provider::Anthropic {
710            api_key: "sk-ant-test".to_string(),
711            model: "claude-3-5-sonnet-20241022".to_string(),
712            params: ModelParams::default(),
713        };
714
715        assert_eq!(provider.name(), "anthropic");
716        assert!(provider.requires_api_key());
717        assert!(!provider.is_local());
718    }
719}