Skip to main content

ares/llm/
client.rs

1use crate::types::{AppError, Result, ToolCall, ToolDefinition};
2use crate::utils::toml_config::{ModelConfig, ProviderConfig};
3use async_trait::async_trait;
4
5/// Generic LLM client trait for provider abstraction
6#[async_trait]
7pub trait LLMClient: Send + Sync {
8    /// Generate a completion from a prompt
9    async fn generate(&self, prompt: &str) -> Result<String>;
10
11    /// Generate with system prompt
12    async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String>;
13
14    /// Generate with conversation history
15    async fn generate_with_history(
16        &self,
17        messages: &[(String, String)], // (role, content) pairs
18    ) -> Result<String>;
19
20    /// Generate with tool calling support
21    async fn generate_with_tools(
22        &self,
23        prompt: &str,
24        tools: &[ToolDefinition],
25    ) -> Result<LLMResponse>;
26
27    /// Generate with conversation history AND tool definitions.
28    ///
29    /// This is the core method for multi-turn tool calling, combining:
30    /// - `generate_with_history()` - conversation context
31    /// - `generate_with_tools()` - tool calling capability
32    ///
33    /// # Arguments
34    ///
35    /// * `messages` - Conversation history as ConversationMessage structs
36    /// * `tools` - Available tool definitions
37    ///
38    /// # Returns
39    ///
40    /// An LLMResponse containing the model's reply and any tool calls requested.
41    async fn generate_with_tools_and_history(
42        &self,
43        messages: &[crate::llm::coordinator::ConversationMessage],
44        tools: &[ToolDefinition],
45    ) -> Result<LLMResponse>;
46
47    /// Stream a completion
48    async fn stream(
49        &self,
50        prompt: &str,
51    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
52
53    /// Stream a completion with system prompt
54    async fn stream_with_system(
55        &self,
56        system: &str,
57        prompt: &str,
58    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
59
60    /// Stream a completion with conversation history
61    async fn stream_with_history(
62        &self,
63        messages: &[(String, String)], // (role, content) pairs
64    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
65
66    /// Get the model name/identifier
67    fn model_name(&self) -> &str;
68}
69
70/// Token usage statistics from an LLM generation call
71#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
72pub struct TokenUsage {
73    /// Number of tokens in the prompt/input
74    pub prompt_tokens: u32,
75    /// Number of tokens in the completion/output
76    pub completion_tokens: u32,
77    /// Total tokens used (prompt + completion)
78    pub total_tokens: u32,
79}
80
81impl TokenUsage {
82    /// Create a new TokenUsage with the given values
83    pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
84        Self {
85            prompt_tokens,
86            completion_tokens,
87            total_tokens: prompt_tokens + completion_tokens,
88        }
89    }
90}
91
92/// Response from an LLM generation call
93#[derive(Debug, Clone)]
94pub struct LLMResponse {
95    /// The generated text content
96    pub content: String,
97    /// Any tool calls the model wants to make
98    pub tool_calls: Vec<ToolCall>,
99    /// Reason the generation finished (e.g., "stop", "tool_calls", "length")
100    pub finish_reason: String,
101    /// Token usage statistics (if provided by the model)
102    pub usage: Option<TokenUsage>,
103}
104
105/// Model inference parameters
106#[derive(Debug, Clone, Default)]
107pub struct ModelParams {
108    /// Sampling temperature (0.0 = deterministic, 1.0+ = creative)
109    pub temperature: Option<f32>,
110    /// Maximum tokens to generate
111    pub max_tokens: Option<u32>,
112    /// Nucleus sampling parameter
113    pub top_p: Option<f32>,
114    /// Frequency penalty (-2.0 to 2.0)
115    pub frequency_penalty: Option<f32>,
116    /// Presence penalty (-2.0 to 2.0)
117    pub presence_penalty: Option<f32>,
118}
119
120impl ModelParams {
121    /// Create params from a ModelConfig
122    pub fn from_model_config(config: &ModelConfig) -> Self {
123        Self {
124            temperature: Some(config.temperature),
125            max_tokens: Some(config.max_tokens),
126            top_p: config.top_p,
127            frequency_penalty: config.frequency_penalty,
128            presence_penalty: config.presence_penalty,
129        }
130    }
131}
132
133/// LLM Provider configuration
134///
135/// Each variant is feature-gated to ensure only enabled providers are available.
136/// Use `Provider::from_env()` to automatically select based on environment variables.
137#[derive(Debug, Clone)]
138#[non_exhaustive]
139pub enum Provider {
140    /// OpenAI API and compatible endpoints (e.g., Azure OpenAI, local vLLM)
141    #[cfg(feature = "openai")]
142    OpenAI {
143        /// API key for authentication
144        api_key: String,
145        /// Base URL for the API (default: <https://api.openai.com/v1>)
146        api_base: String,
147        /// Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
148        model: String,
149        /// Model inference parameters
150        params: ModelParams,
151    },
152
153    /// Ollama local inference server
154    #[cfg(feature = "ollama")]
155    Ollama {
156        /// Base URL for Ollama API (default: http://localhost:11434)
157        base_url: String,
158        /// Model name (e.g., "ministral-3:3b", "mistral", "qwen3-vl:2b")
159        model: String,
160        /// Model inference parameters
161        params: ModelParams,
162    },
163
164    /// LlamaCpp for direct GGUF model loading
165    #[cfg(feature = "llamacpp")]
166    LlamaCpp {
167        /// Path to the GGUF model file
168        model_path: String,
169        /// Model inference parameters
170        params: ModelParams,
171    },
172
173    /// Anthropic Claude API
174    #[cfg(feature = "anthropic")]
175    Anthropic {
176        /// API key for authentication
177        api_key: String,
178        /// Model identifier (e.g., "claude-3-5-sonnet-20241022")
179        model: String,
180        /// Model inference parameters
181        params: ModelParams,
182    },
183}
184
185impl Provider {
186    /// Create an LLM client from this provider configuration
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if:
191    /// - The provider cannot be initialized
192    /// - Required configuration is missing
193    /// - Network connectivity issues (for remote providers)
194    #[allow(unreachable_patterns)]
195    pub async fn create_client(&self) -> Result<Box<dyn LLMClient>> {
196        match self {
197            #[cfg(feature = "openai")]
198            Provider::OpenAI {
199                api_key,
200                api_base,
201                model,
202                params,
203            } => Ok(Box::new(super::openai::OpenAIClient::with_params(
204                api_key.clone(),
205                api_base.clone(),
206                model.clone(),
207                params.clone(),
208            ))),
209
210            #[cfg(feature = "ollama")]
211            Provider::Ollama {
212                base_url,
213                model,
214                params,
215            } => Ok(Box::new(
216                super::ollama::OllamaClient::with_params(
217                    base_url.clone(),
218                    model.clone(),
219                    params.clone(),
220                )
221                .await?,
222            )),
223
224            #[cfg(feature = "llamacpp")]
225            Provider::LlamaCpp { model_path, params } => Ok(Box::new(
226                super::llamacpp::LlamaCppClient::with_params(model_path.clone(), params.clone())?,
227            )),
228
229            #[cfg(feature = "anthropic")]
230            Provider::Anthropic {
231                api_key,
232                model,
233                params,
234            } => Ok(Box::new(super::anthropic::AnthropicClient::with_params(
235                api_key.clone(),
236                model.clone(),
237                params.clone(),
238            ))),
239            _ => unreachable!("Provider variant not enabled"),
240        }
241    }
242
243    /// Create a provider from environment variables
244    ///
245    /// Provider priority (first match wins):
246    /// 1. **LlamaCpp** - if `LLAMACPP_MODEL_PATH` is set
247    /// 2. **OpenAI** - if `OPENAI_API_KEY` is set
248    /// 3. **Ollama** - default fallback for local inference
249    ///
250    /// # Environment Variables
251    ///
252    /// ## LlamaCpp
253    /// - `LLAMACPP_MODEL_PATH` - Path to GGUF model file (required)
254    ///
255    /// ## OpenAI
256    /// - `OPENAI_API_KEY` - API key (required)
257    /// - `OPENAI_API_BASE` - Base URL (default: <https://api.openai.com/v1>)
258    /// - `OPENAI_MODEL` - Model name (default: gpt-4)
259    ///
260    /// ## Ollama
261    /// - `OLLAMA_BASE_URL` - Server URL (default: http://localhost:11434)
262    /// - `OLLAMA_MODEL` - Model name (default: ministral-3:3b)
263    ///
264    /// # Errors
265    ///
266    /// Returns an error if no LLM provider features are enabled or configured.
267    ///
268    /// # Example
269    ///
270    /// ```rust,ignore
271    /// // Set environment variables
272    /// std::env::set_var("OLLAMA_MODEL", "ministral-3:3b");
273    ///
274    /// let provider = Provider::from_env()?;
275    /// let client = provider.create_client().await?;
276    /// ```
277    #[allow(unreachable_code)]
278    pub fn from_env() -> Result<Self> {
279        // Check for LlamaCpp first (direct GGUF loading - highest priority when configured)
280        #[cfg(feature = "llamacpp")]
281        if let Ok(model_path) = std::env::var("LLAMACPP_MODEL_PATH") {
282            if !model_path.is_empty() {
283                return Ok(Provider::LlamaCpp {
284                    model_path,
285                    params: ModelParams::default(),
286                });
287            }
288        }
289
290        // Check for OpenAI (requires explicit API key configuration)
291        #[cfg(feature = "openai")]
292        if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
293            if !api_key.is_empty() {
294                let api_base = std::env::var("OPENAI_API_BASE")
295                    .unwrap_or_else(|_| "https://api.openai.com/v1".into());
296                let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4".into());
297                return Ok(Provider::OpenAI {
298                    api_key,
299                    api_base,
300                    model,
301                    params: ModelParams::default(),
302                });
303            }
304        }
305
306        // Check for Anthropic (requires explicit API key configuration)
307        #[cfg(feature = "anthropic")]
308        if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
309            if !api_key.is_empty() {
310                let model = std::env::var("ANTHROPIC_MODEL")
311                    .unwrap_or_else(|_| "claude-3-5-sonnet-20241022".into());
312                return Ok(Provider::Anthropic {
313                    api_key,
314                    model,
315                    params: ModelParams::default(),
316                });
317            }
318        }
319
320        // Ollama as default local inference (no API key required)
321        #[cfg(feature = "ollama")]
322        {
323            // Accept both OLLAMA_URL (preferred) and legacy OLLAMA_BASE_URL
324            let base_url = std::env::var("OLLAMA_URL")
325                .or_else(|_| std::env::var("OLLAMA_BASE_URL"))
326                .unwrap_or_else(|_| "http://localhost:11434".into());
327            let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "ministral-3:3b".into());
328            return Ok(Provider::Ollama {
329                base_url,
330                model,
331                params: ModelParams::default(),
332            });
333        }
334
335        // No provider available
336        #[allow(unreachable_code)]
337        Err(AppError::Configuration(
338            "No LLM provider configured. Enable a feature (ollama, openai, llamacpp) and set the appropriate environment variables.".into(),
339        ))
340    }
341
342    /// Get the provider name as a string
343    #[allow(unreachable_patterns)]
344    pub fn name(&self) -> &'static str {
345        match self {
346            #[cfg(feature = "openai")]
347            Provider::OpenAI { .. } => "openai",
348
349            #[cfg(feature = "ollama")]
350            Provider::Ollama { .. } => "ollama",
351
352            #[cfg(feature = "llamacpp")]
353            Provider::LlamaCpp { .. } => "llamacpp",
354
355            #[cfg(feature = "anthropic")]
356            Provider::Anthropic { .. } => "anthropic",
357            _ => unreachable!("Provider variant not enabled"),
358        }
359    }
360
361    /// Check if this provider requires an API key
362    #[allow(unreachable_patterns)]
363    pub fn requires_api_key(&self) -> bool {
364        match self {
365            #[cfg(feature = "openai")]
366            Provider::OpenAI { .. } => true,
367
368            #[cfg(feature = "ollama")]
369            Provider::Ollama { .. } => false,
370
371            #[cfg(feature = "llamacpp")]
372            Provider::LlamaCpp { .. } => false,
373
374            #[cfg(feature = "anthropic")]
375            Provider::Anthropic { .. } => true,
376            _ => unreachable!("Provider variant not enabled"),
377        }
378    }
379
380    /// Check if this provider is local (no network required)
381    #[allow(unreachable_patterns)]
382    pub fn is_local(&self) -> bool {
383        match self {
384            #[cfg(feature = "openai")]
385            Provider::OpenAI { api_base, .. } => {
386                api_base.contains("localhost") || api_base.contains("127.0.0.1")
387            }
388
389            #[cfg(feature = "ollama")]
390            Provider::Ollama { base_url, .. } => {
391                base_url.contains("localhost") || base_url.contains("127.0.0.1")
392            }
393
394            #[cfg(feature = "llamacpp")]
395            Provider::LlamaCpp { .. } => true,
396
397            #[cfg(feature = "anthropic")]
398            Provider::Anthropic { .. } => false,
399            _ => unreachable!("Provider variant not enabled"),
400        }
401    }
402
403    /// Create a provider from TOML configuration
404    ///
405    /// # Arguments
406    ///
407    /// * `provider_config` - The provider configuration from ares.toml
408    /// * `model_override` - Optional model name to override the provider default
409    ///
410    /// # Errors
411    ///
412    /// Returns an error if the provider type doesn't match an enabled feature
413    /// or if required environment variables are not set.
414    #[allow(unused_variables)]
415    pub fn from_config(
416        provider_config: &ProviderConfig,
417        model_override: Option<&str>,
418    ) -> Result<Self> {
419        Self::from_config_with_params(provider_config, model_override, ModelParams::default())
420    }
421
422    /// Create a provider from TOML configuration with model parameters
423    #[allow(unused_variables)]
424    pub fn from_config_with_params(
425        provider_config: &ProviderConfig,
426        model_override: Option<&str>,
427        params: ModelParams,
428    ) -> Result<Self> {
429        match provider_config {
430            #[cfg(feature = "ollama")]
431            ProviderConfig::Ollama {
432                base_url,
433                default_model,
434            } => Ok(Provider::Ollama {
435                base_url: base_url.clone(),
436                model: model_override
437                    .map(String::from)
438                    .unwrap_or_else(|| default_model.clone()),
439                params,
440            }),
441
442            #[cfg(not(feature = "ollama"))]
443            ProviderConfig::Ollama { .. } => Err(AppError::Configuration(
444                "Ollama provider configured but 'ollama' feature is not enabled".into(),
445            )),
446
447            #[cfg(feature = "openai")]
448            ProviderConfig::OpenAI {
449                api_key_env,
450                api_base,
451                default_model,
452            } => {
453                let api_key = std::env::var(api_key_env).map_err(|_| {
454                    AppError::Configuration(format!(
455                        "OpenAI API key environment variable '{}' is not set",
456                        api_key_env
457                    ))
458                })?;
459                Ok(Provider::OpenAI {
460                    api_key,
461                    api_base: api_base.clone(),
462                    model: model_override
463                        .map(String::from)
464                        .unwrap_or_else(|| default_model.clone()),
465                    params,
466                })
467            }
468
469            #[cfg(not(feature = "openai"))]
470            ProviderConfig::OpenAI { .. } => Err(AppError::Configuration(
471                "OpenAI provider configured but 'openai' feature is not enabled".into(),
472            )),
473
474            #[cfg(feature = "llamacpp")]
475            ProviderConfig::LlamaCpp { model_path, .. } => Ok(Provider::LlamaCpp {
476                model_path: model_path.clone(),
477                params,
478            }),
479
480            #[cfg(not(feature = "llamacpp"))]
481            ProviderConfig::LlamaCpp { .. } => Err(AppError::Configuration(
482                "LlamaCpp provider configured but 'llamacpp' feature is not enabled".into(),
483            )),
484
485            #[cfg(feature = "anthropic")]
486            ProviderConfig::Anthropic {
487                api_key_env,
488                default_model,
489            } => {
490                let api_key = std::env::var(api_key_env).map_err(|_| {
491                    AppError::Configuration(format!(
492                        "Anthropic API key environment variable '{}' is not set",
493                        api_key_env
494                    ))
495                })?;
496                Ok(Provider::Anthropic {
497                    api_key,
498                    model: model_override
499                        .map(String::from)
500                        .unwrap_or_else(|| default_model.clone()),
501                    params,
502                })
503            }
504
505            #[cfg(not(feature = "anthropic"))]
506            ProviderConfig::Anthropic { .. } => Err(AppError::Configuration(
507                "Anthropic provider configured but 'anthropic' feature is not enabled".into(),
508            )),
509        }
510    }
511
512    /// Create a provider from a model configuration and its associated provider config
513    ///
514    /// This is the primary way to create providers from TOML config, as it resolves
515    /// the model -> provider reference chain.
516    pub fn from_model_config(
517        model_config: &ModelConfig,
518        provider_config: &ProviderConfig,
519    ) -> Result<Self> {
520        let params = ModelParams::from_model_config(model_config);
521        Self::from_config_with_params(provider_config, Some(&model_config.model), params)
522    }
523}
524
525/// Trait abstraction for LLM client factories (useful for mocking in tests)
526#[async_trait]
527pub trait LLMClientFactoryTrait: Send + Sync {
528    /// Get the default provider configuration
529    fn default_provider(&self) -> &Provider;
530
531    /// Create an LLM client using the default provider
532    async fn create_default(&self) -> Result<Box<dyn LLMClient>>;
533
534    /// Create an LLM client using a specific provider
535    async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>>;
536}
537
538/// Configuration-based LLM client factory
539///
540/// Provides a convenient way to create LLM clients with a default provider
541/// while allowing runtime provider switching.
542pub struct LLMClientFactory {
543    default_provider: Provider,
544}
545
546impl LLMClientFactory {
547    /// Create a new factory with a specific default provider
548    pub fn new(default_provider: Provider) -> Self {
549        Self { default_provider }
550    }
551
552    /// Create a factory from environment variables
553    ///
554    /// Uses `Provider::from_env()` to determine the default provider.
555    pub fn from_env() -> Result<Self> {
556        Ok(Self {
557            default_provider: Provider::from_env()?,
558        })
559    }
560
561    /// Get the default provider configuration
562    pub fn default_provider(&self) -> &Provider {
563        &self.default_provider
564    }
565
566    /// Create an LLM client using the default provider
567    pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
568        self.default_provider.create_client().await
569    }
570
571    /// Create an LLM client using a specific provider
572    pub async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
573        provider.create_client().await
574    }
575}
576
577#[async_trait]
578impl LLMClientFactoryTrait for LLMClientFactory {
579    fn default_provider(&self) -> &Provider {
580        &self.default_provider
581    }
582
583    async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
584        self.default_provider.create_client().await
585    }
586
587    async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
588        provider.create_client().await
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn test_llm_response_creation() {
598        let response = LLMResponse {
599            content: "Hello".to_string(),
600            tool_calls: vec![],
601            finish_reason: "stop".to_string(),
602            usage: None,
603        };
604
605        assert_eq!(response.content, "Hello");
606        assert!(response.tool_calls.is_empty());
607        assert_eq!(response.finish_reason, "stop");
608        assert!(response.usage.is_none());
609    }
610
611    #[test]
612    fn test_llm_response_with_usage() {
613        let usage = TokenUsage::new(100, 50);
614        let response = LLMResponse {
615            content: "Hello".to_string(),
616            tool_calls: vec![],
617            finish_reason: "stop".to_string(),
618            usage: Some(usage),
619        };
620
621        assert!(response.usage.is_some());
622        let usage = response.usage.unwrap();
623        assert_eq!(usage.prompt_tokens, 100);
624        assert_eq!(usage.completion_tokens, 50);
625        assert_eq!(usage.total_tokens, 150);
626    }
627
628    #[test]
629    fn test_llm_response_with_tool_calls() {
630        let tool_calls = vec![
631            ToolCall {
632                id: "1".to_string(),
633                name: "calculator".to_string(),
634                arguments: serde_json::json!({"a": 1, "b": 2}),
635            },
636            ToolCall {
637                id: "2".to_string(),
638                name: "search".to_string(),
639                arguments: serde_json::json!({"query": "test"}),
640            },
641        ];
642
643        let response = LLMResponse {
644            content: "".to_string(),
645            tool_calls,
646            finish_reason: "tool_calls".to_string(),
647            usage: Some(TokenUsage::new(50, 25)),
648        };
649
650        assert_eq!(response.tool_calls.len(), 2);
651        assert_eq!(response.tool_calls[0].name, "calculator");
652        assert_eq!(response.finish_reason, "tool_calls");
653        assert_eq!(response.usage.as_ref().unwrap().total_tokens, 75);
654    }
655
656    #[test]
657    fn test_factory_creation() {
658        // This test just verifies the factory can be created
659        // Actual provider tests require feature flags
660        #[cfg(feature = "ollama")]
661        {
662            let factory = LLMClientFactory::new(Provider::Ollama {
663                base_url: "http://localhost:11434".to_string(),
664                model: "test".to_string(),
665                params: ModelParams::default(),
666            });
667            assert_eq!(factory.default_provider().name(), "ollama");
668        }
669    }
670
671    #[cfg(feature = "ollama")]
672    #[test]
673    fn test_ollama_provider_properties() {
674        let provider = Provider::Ollama {
675            base_url: "http://localhost:11434".to_string(),
676            model: "ministral-3:3b".to_string(),
677            params: ModelParams::default(),
678        };
679
680        assert_eq!(provider.name(), "ollama");
681        assert!(!provider.requires_api_key());
682        assert!(provider.is_local());
683    }
684
685    #[cfg(feature = "openai")]
686    #[test]
687    fn test_openai_provider_properties() {
688        let provider = Provider::OpenAI {
689            api_key: "sk-test".to_string(),
690            api_base: "https://api.openai.com/v1".to_string(),
691            model: "gpt-4".to_string(),
692            params: ModelParams::default(),
693        };
694
695        assert_eq!(provider.name(), "openai");
696        assert!(provider.requires_api_key());
697        assert!(!provider.is_local());
698    }
699
700    #[cfg(feature = "openai")]
701    #[test]
702    fn test_openai_local_provider() {
703        let provider = Provider::OpenAI {
704            api_key: "test".to_string(),
705            api_base: "http://localhost:8000/v1".to_string(),
706            model: "local-model".to_string(),
707            params: ModelParams::default(),
708        };
709
710        assert!(provider.is_local());
711    }
712
713    #[cfg(feature = "llamacpp")]
714    #[test]
715    fn test_llamacpp_provider_properties() {
716        let provider = Provider::LlamaCpp {
717            model_path: "/path/to/model.gguf".to_string(),
718            params: ModelParams::default(),
719        };
720
721        assert_eq!(provider.name(), "llamacpp");
722        assert!(!provider.requires_api_key());
723        assert!(provider.is_local());
724    }
725
726    #[cfg(feature = "anthropic")]
727    #[test]
728    fn test_anthropic_provider_properties() {
729        let provider = Provider::Anthropic {
730            api_key: "sk-ant-test".to_string(),
731            model: "claude-3-5-sonnet-20241022".to_string(),
732            params: ModelParams::default(),
733        };
734
735        assert_eq!(provider.name(), "anthropic");
736        assert!(provider.requires_api_key());
737        assert!(!provider.is_local());
738    }
739}