1use crate::types::{AppError, Result, ToolCall, ToolDefinition};
2use crate::utils::toml_config::{ModelConfig, ProviderConfig};
3use async_trait::async_trait;
4
5#[async_trait]
7pub trait LLMClient: Send + Sync {
8 async fn generate(&self, prompt: &str) -> Result<String>;
10
11 async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String>;
13
14 async fn generate_with_history(
16 &self,
17 messages: &[(String, String)], ) -> Result<String>;
19
20 async fn generate_with_tools(
22 &self,
23 prompt: &str,
24 tools: &[ToolDefinition],
25 ) -> Result<LLMResponse>;
26
27 async fn generate_with_tools_and_history(
42 &self,
43 messages: &[crate::llm::coordinator::ConversationMessage],
44 tools: &[ToolDefinition],
45 ) -> Result<LLMResponse>;
46
47 async fn stream(
49 &self,
50 prompt: &str,
51 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
52
53 async fn stream_with_system(
55 &self,
56 system: &str,
57 prompt: &str,
58 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
59
60 async fn stream_with_history(
62 &self,
63 messages: &[(String, String)], ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
65
66 fn model_name(&self) -> &str;
68}
69
70#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
72pub struct TokenUsage {
73 pub prompt_tokens: u32,
75 pub completion_tokens: u32,
77 pub total_tokens: u32,
79}
80
81impl TokenUsage {
82 pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
84 Self {
85 prompt_tokens,
86 completion_tokens,
87 total_tokens: prompt_tokens + completion_tokens,
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct LLMResponse {
95 pub content: String,
97 pub tool_calls: Vec<ToolCall>,
99 pub finish_reason: String,
101 pub usage: Option<TokenUsage>,
103}
104
105#[derive(Debug, Clone, Default)]
107pub struct ModelParams {
108 pub temperature: Option<f32>,
110 pub max_tokens: Option<u32>,
112 pub top_p: Option<f32>,
114 pub frequency_penalty: Option<f32>,
116 pub presence_penalty: Option<f32>,
118}
119
120impl ModelParams {
121 pub fn from_model_config(config: &ModelConfig) -> Self {
123 Self {
124 temperature: Some(config.temperature),
125 max_tokens: Some(config.max_tokens),
126 top_p: config.top_p,
127 frequency_penalty: config.frequency_penalty,
128 presence_penalty: config.presence_penalty,
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
138#[non_exhaustive]
139pub enum Provider {
140 #[cfg(feature = "openai")]
142 OpenAI {
143 api_key: String,
145 api_base: String,
147 model: String,
149 params: ModelParams,
151 },
152
153 #[cfg(feature = "ollama")]
155 Ollama {
156 base_url: String,
158 model: String,
160 params: ModelParams,
162 },
163
164 #[cfg(feature = "llamacpp")]
166 LlamaCpp {
167 model_path: String,
169 params: ModelParams,
171 },
172
173 #[cfg(feature = "anthropic")]
175 Anthropic {
176 api_key: String,
178 model: String,
180 params: ModelParams,
182 },
183}
184
185impl Provider {
186 #[allow(unreachable_patterns)]
195 pub async fn create_client(&self) -> Result<Box<dyn LLMClient>> {
196 match self {
197 #[cfg(feature = "openai")]
198 Provider::OpenAI {
199 api_key,
200 api_base,
201 model,
202 params,
203 } => Ok(Box::new(super::openai::OpenAIClient::with_params(
204 api_key.clone(),
205 api_base.clone(),
206 model.clone(),
207 params.clone(),
208 ))),
209
210 #[cfg(feature = "ollama")]
211 Provider::Ollama {
212 base_url,
213 model,
214 params,
215 } => Ok(Box::new(
216 super::ollama::OllamaClient::with_params(
217 base_url.clone(),
218 model.clone(),
219 params.clone(),
220 )
221 .await?,
222 )),
223
224 #[cfg(feature = "llamacpp")]
225 Provider::LlamaCpp { model_path, params } => Ok(Box::new(
226 super::llamacpp::LlamaCppClient::with_params(model_path.clone(), params.clone())?,
227 )),
228
229 #[cfg(feature = "anthropic")]
230 Provider::Anthropic {
231 api_key,
232 model,
233 params,
234 } => Ok(Box::new(super::anthropic::AnthropicClient::with_params(
235 api_key.clone(),
236 model.clone(),
237 params.clone(),
238 ))),
239 _ => unreachable!("Provider variant not enabled"),
240 }
241 }
242
243 #[allow(unreachable_code)]
278 pub fn from_env() -> Result<Self> {
279 #[cfg(feature = "llamacpp")]
281 if let Ok(model_path) = std::env::var("LLAMACPP_MODEL_PATH") {
282 if !model_path.is_empty() {
283 return Ok(Provider::LlamaCpp {
284 model_path,
285 params: ModelParams::default(),
286 });
287 }
288 }
289
290 #[cfg(feature = "openai")]
292 if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
293 if !api_key.is_empty() {
294 let api_base = std::env::var("OPENAI_API_BASE")
295 .unwrap_or_else(|_| "https://api.openai.com/v1".into());
296 let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4".into());
297 return Ok(Provider::OpenAI {
298 api_key,
299 api_base,
300 model,
301 params: ModelParams::default(),
302 });
303 }
304 }
305
306 #[cfg(feature = "anthropic")]
308 if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
309 if !api_key.is_empty() {
310 let model = std::env::var("ANTHROPIC_MODEL")
311 .unwrap_or_else(|_| "claude-3-5-sonnet-20241022".into());
312 return Ok(Provider::Anthropic {
313 api_key,
314 model,
315 params: ModelParams::default(),
316 });
317 }
318 }
319
320 #[cfg(feature = "ollama")]
322 {
323 let base_url = std::env::var("OLLAMA_URL")
325 .or_else(|_| std::env::var("OLLAMA_BASE_URL"))
326 .unwrap_or_else(|_| "http://localhost:11434".into());
327 let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "ministral-3:3b".into());
328 return Ok(Provider::Ollama {
329 base_url,
330 model,
331 params: ModelParams::default(),
332 });
333 }
334
335 #[allow(unreachable_code)]
337 Err(AppError::Configuration(
338 "No LLM provider configured. Enable a feature (ollama, openai, llamacpp) and set the appropriate environment variables.".into(),
339 ))
340 }
341
342 #[allow(unreachable_patterns)]
344 pub fn name(&self) -> &'static str {
345 match self {
346 #[cfg(feature = "openai")]
347 Provider::OpenAI { .. } => "openai",
348
349 #[cfg(feature = "ollama")]
350 Provider::Ollama { .. } => "ollama",
351
352 #[cfg(feature = "llamacpp")]
353 Provider::LlamaCpp { .. } => "llamacpp",
354
355 #[cfg(feature = "anthropic")]
356 Provider::Anthropic { .. } => "anthropic",
357 _ => unreachable!("Provider variant not enabled"),
358 }
359 }
360
361 #[allow(unreachable_patterns)]
363 pub fn requires_api_key(&self) -> bool {
364 match self {
365 #[cfg(feature = "openai")]
366 Provider::OpenAI { .. } => true,
367
368 #[cfg(feature = "ollama")]
369 Provider::Ollama { .. } => false,
370
371 #[cfg(feature = "llamacpp")]
372 Provider::LlamaCpp { .. } => false,
373
374 #[cfg(feature = "anthropic")]
375 Provider::Anthropic { .. } => true,
376 _ => unreachable!("Provider variant not enabled"),
377 }
378 }
379
380 #[allow(unreachable_patterns)]
382 pub fn is_local(&self) -> bool {
383 match self {
384 #[cfg(feature = "openai")]
385 Provider::OpenAI { api_base, .. } => {
386 api_base.contains("localhost") || api_base.contains("127.0.0.1")
387 }
388
389 #[cfg(feature = "ollama")]
390 Provider::Ollama { base_url, .. } => {
391 base_url.contains("localhost") || base_url.contains("127.0.0.1")
392 }
393
394 #[cfg(feature = "llamacpp")]
395 Provider::LlamaCpp { .. } => true,
396
397 #[cfg(feature = "anthropic")]
398 Provider::Anthropic { .. } => false,
399 _ => unreachable!("Provider variant not enabled"),
400 }
401 }
402
403 #[allow(unused_variables)]
415 pub fn from_config(
416 provider_config: &ProviderConfig,
417 model_override: Option<&str>,
418 ) -> Result<Self> {
419 Self::from_config_with_params(provider_config, model_override, ModelParams::default())
420 }
421
422 #[allow(unused_variables)]
424 pub fn from_config_with_params(
425 provider_config: &ProviderConfig,
426 model_override: Option<&str>,
427 params: ModelParams,
428 ) -> Result<Self> {
429 match provider_config {
430 #[cfg(feature = "ollama")]
431 ProviderConfig::Ollama {
432 base_url,
433 default_model,
434 } => Ok(Provider::Ollama {
435 base_url: base_url.clone(),
436 model: model_override
437 .map(String::from)
438 .unwrap_or_else(|| default_model.clone()),
439 params,
440 }),
441
442 #[cfg(not(feature = "ollama"))]
443 ProviderConfig::Ollama { .. } => Err(AppError::Configuration(
444 "Ollama provider configured but 'ollama' feature is not enabled".into(),
445 )),
446
447 #[cfg(feature = "openai")]
448 ProviderConfig::OpenAI {
449 api_key_env,
450 api_base,
451 default_model,
452 } => {
453 let api_key = std::env::var(api_key_env).map_err(|_| {
454 AppError::Configuration(format!(
455 "OpenAI API key environment variable '{}' is not set",
456 api_key_env
457 ))
458 })?;
459 Ok(Provider::OpenAI {
460 api_key,
461 api_base: api_base.clone(),
462 model: model_override
463 .map(String::from)
464 .unwrap_or_else(|| default_model.clone()),
465 params,
466 })
467 }
468
469 #[cfg(not(feature = "openai"))]
470 ProviderConfig::OpenAI { .. } => Err(AppError::Configuration(
471 "OpenAI provider configured but 'openai' feature is not enabled".into(),
472 )),
473
474 #[cfg(feature = "llamacpp")]
475 ProviderConfig::LlamaCpp { model_path, .. } => Ok(Provider::LlamaCpp {
476 model_path: model_path.clone(),
477 params,
478 }),
479
480 #[cfg(not(feature = "llamacpp"))]
481 ProviderConfig::LlamaCpp { .. } => Err(AppError::Configuration(
482 "LlamaCpp provider configured but 'llamacpp' feature is not enabled".into(),
483 )),
484
485 #[cfg(feature = "anthropic")]
486 ProviderConfig::Anthropic {
487 api_key_env,
488 default_model,
489 } => {
490 let api_key = std::env::var(api_key_env).map_err(|_| {
491 AppError::Configuration(format!(
492 "Anthropic API key environment variable '{}' is not set",
493 api_key_env
494 ))
495 })?;
496 Ok(Provider::Anthropic {
497 api_key,
498 model: model_override
499 .map(String::from)
500 .unwrap_or_else(|| default_model.clone()),
501 params,
502 })
503 }
504
505 #[cfg(not(feature = "anthropic"))]
506 ProviderConfig::Anthropic { .. } => Err(AppError::Configuration(
507 "Anthropic provider configured but 'anthropic' feature is not enabled".into(),
508 )),
509 }
510 }
511
512 pub fn from_model_config(
517 model_config: &ModelConfig,
518 provider_config: &ProviderConfig,
519 ) -> Result<Self> {
520 let params = ModelParams::from_model_config(model_config);
521 Self::from_config_with_params(provider_config, Some(&model_config.model), params)
522 }
523}
524
525#[async_trait]
527pub trait LLMClientFactoryTrait: Send + Sync {
528 fn default_provider(&self) -> &Provider;
530
531 async fn create_default(&self) -> Result<Box<dyn LLMClient>>;
533
534 async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>>;
536}
537
538pub struct LLMClientFactory {
543 default_provider: Provider,
544}
545
546impl LLMClientFactory {
547 pub fn new(default_provider: Provider) -> Self {
549 Self { default_provider }
550 }
551
552 pub fn from_env() -> Result<Self> {
556 Ok(Self {
557 default_provider: Provider::from_env()?,
558 })
559 }
560
561 pub fn default_provider(&self) -> &Provider {
563 &self.default_provider
564 }
565
566 pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
568 self.default_provider.create_client().await
569 }
570
571 pub async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
573 provider.create_client().await
574 }
575}
576
577#[async_trait]
578impl LLMClientFactoryTrait for LLMClientFactory {
579 fn default_provider(&self) -> &Provider {
580 &self.default_provider
581 }
582
583 async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
584 self.default_provider.create_client().await
585 }
586
587 async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
588 provider.create_client().await
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 fn test_llm_response_creation() {
598 let response = LLMResponse {
599 content: "Hello".to_string(),
600 tool_calls: vec![],
601 finish_reason: "stop".to_string(),
602 usage: None,
603 };
604
605 assert_eq!(response.content, "Hello");
606 assert!(response.tool_calls.is_empty());
607 assert_eq!(response.finish_reason, "stop");
608 assert!(response.usage.is_none());
609 }
610
611 #[test]
612 fn test_llm_response_with_usage() {
613 let usage = TokenUsage::new(100, 50);
614 let response = LLMResponse {
615 content: "Hello".to_string(),
616 tool_calls: vec![],
617 finish_reason: "stop".to_string(),
618 usage: Some(usage),
619 };
620
621 assert!(response.usage.is_some());
622 let usage = response.usage.unwrap();
623 assert_eq!(usage.prompt_tokens, 100);
624 assert_eq!(usage.completion_tokens, 50);
625 assert_eq!(usage.total_tokens, 150);
626 }
627
628 #[test]
629 fn test_llm_response_with_tool_calls() {
630 let tool_calls = vec![
631 ToolCall {
632 id: "1".to_string(),
633 name: "calculator".to_string(),
634 arguments: serde_json::json!({"a": 1, "b": 2}),
635 },
636 ToolCall {
637 id: "2".to_string(),
638 name: "search".to_string(),
639 arguments: serde_json::json!({"query": "test"}),
640 },
641 ];
642
643 let response = LLMResponse {
644 content: "".to_string(),
645 tool_calls,
646 finish_reason: "tool_calls".to_string(),
647 usage: Some(TokenUsage::new(50, 25)),
648 };
649
650 assert_eq!(response.tool_calls.len(), 2);
651 assert_eq!(response.tool_calls[0].name, "calculator");
652 assert_eq!(response.finish_reason, "tool_calls");
653 assert_eq!(response.usage.as_ref().unwrap().total_tokens, 75);
654 }
655
656 #[test]
657 fn test_factory_creation() {
658 #[cfg(feature = "ollama")]
661 {
662 let factory = LLMClientFactory::new(Provider::Ollama {
663 base_url: "http://localhost:11434".to_string(),
664 model: "test".to_string(),
665 params: ModelParams::default(),
666 });
667 assert_eq!(factory.default_provider().name(), "ollama");
668 }
669 }
670
671 #[cfg(feature = "ollama")]
672 #[test]
673 fn test_ollama_provider_properties() {
674 let provider = Provider::Ollama {
675 base_url: "http://localhost:11434".to_string(),
676 model: "ministral-3:3b".to_string(),
677 params: ModelParams::default(),
678 };
679
680 assert_eq!(provider.name(), "ollama");
681 assert!(!provider.requires_api_key());
682 assert!(provider.is_local());
683 }
684
685 #[cfg(feature = "openai")]
686 #[test]
687 fn test_openai_provider_properties() {
688 let provider = Provider::OpenAI {
689 api_key: "sk-test".to_string(),
690 api_base: "https://api.openai.com/v1".to_string(),
691 model: "gpt-4".to_string(),
692 params: ModelParams::default(),
693 };
694
695 assert_eq!(provider.name(), "openai");
696 assert!(provider.requires_api_key());
697 assert!(!provider.is_local());
698 }
699
700 #[cfg(feature = "openai")]
701 #[test]
702 fn test_openai_local_provider() {
703 let provider = Provider::OpenAI {
704 api_key: "test".to_string(),
705 api_base: "http://localhost:8000/v1".to_string(),
706 model: "local-model".to_string(),
707 params: ModelParams::default(),
708 };
709
710 assert!(provider.is_local());
711 }
712
713 #[cfg(feature = "llamacpp")]
714 #[test]
715 fn test_llamacpp_provider_properties() {
716 let provider = Provider::LlamaCpp {
717 model_path: "/path/to/model.gguf".to_string(),
718 params: ModelParams::default(),
719 };
720
721 assert_eq!(provider.name(), "llamacpp");
722 assert!(!provider.requires_api_key());
723 assert!(provider.is_local());
724 }
725
726 #[cfg(feature = "anthropic")]
727 #[test]
728 fn test_anthropic_provider_properties() {
729 let provider = Provider::Anthropic {
730 api_key: "sk-ant-test".to_string(),
731 model: "claude-3-5-sonnet-20241022".to_string(),
732 params: ModelParams::default(),
733 };
734
735 assert_eq!(provider.name(), "anthropic");
736 assert!(provider.requires_api_key());
737 assert!(!provider.is_local());
738 }
739}