1use crate::types::{AppError, Result, ToolCall, ToolDefinition};
2use crate::utils::toml_config::{ModelConfig, ProviderConfig};
3use async_trait::async_trait;
4
5#[async_trait]
7pub trait LLMClient: Send + Sync {
8 async fn generate(&self, prompt: &str) -> Result<String>;
10
11 async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String>;
13
14 async fn generate_with_history(
16 &self,
17 messages: &[(String, String)], ) -> Result<String>;
19
20 async fn generate_with_tools(
22 &self,
23 prompt: &str,
24 tools: &[ToolDefinition],
25 ) -> Result<LLMResponse>;
26
27 async fn stream(
29 &self,
30 prompt: &str,
31 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
32
33 async fn stream_with_system(
35 &self,
36 system: &str,
37 prompt: &str,
38 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
39
40 async fn stream_with_history(
42 &self,
43 messages: &[(String, String)], ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
45
46 fn model_name(&self) -> &str;
48}
49
50#[derive(Debug, Clone, Default, PartialEq, Eq)]
52pub struct TokenUsage {
53 pub prompt_tokens: u32,
55 pub completion_tokens: u32,
57 pub total_tokens: u32,
59}
60
61impl TokenUsage {
62 pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
64 Self {
65 prompt_tokens,
66 completion_tokens,
67 total_tokens: prompt_tokens + completion_tokens,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct LLMResponse {
75 pub content: String,
77 pub tool_calls: Vec<ToolCall>,
79 pub finish_reason: String,
81 pub usage: Option<TokenUsage>,
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct ModelParams {
88 pub temperature: Option<f32>,
90 pub max_tokens: Option<u32>,
92 pub top_p: Option<f32>,
94 pub frequency_penalty: Option<f32>,
96 pub presence_penalty: Option<f32>,
98}
99
100impl ModelParams {
101 pub fn from_model_config(config: &ModelConfig) -> Self {
103 Self {
104 temperature: Some(config.temperature),
105 max_tokens: Some(config.max_tokens),
106 top_p: config.top_p,
107 frequency_penalty: config.frequency_penalty,
108 presence_penalty: config.presence_penalty,
109 }
110 }
111}
112
113#[derive(Debug, Clone)]
118#[non_exhaustive]
119pub enum Provider {
120 #[cfg(feature = "openai")]
122 OpenAI {
123 api_key: String,
125 api_base: String,
127 model: String,
129 params: ModelParams,
131 },
132
133 #[cfg(feature = "ollama")]
135 Ollama {
136 base_url: String,
138 model: String,
140 params: ModelParams,
142 },
143
144 #[cfg(feature = "llamacpp")]
146 LlamaCpp {
147 model_path: String,
149 params: ModelParams,
151 },
152
153 #[cfg(feature = "anthropic")]
155 Anthropic {
156 api_key: String,
158 model: String,
160 params: ModelParams,
162 },
163}
164
165impl Provider {
166 #[allow(unreachable_patterns)]
175 pub async fn create_client(&self) -> Result<Box<dyn LLMClient>> {
176 match self {
177 #[cfg(feature = "openai")]
178 Provider::OpenAI {
179 api_key,
180 api_base,
181 model,
182 params,
183 } => Ok(Box::new(super::openai::OpenAIClient::with_params(
184 api_key.clone(),
185 api_base.clone(),
186 model.clone(),
187 params.clone(),
188 ))),
189
190 #[cfg(feature = "ollama")]
191 Provider::Ollama {
192 base_url,
193 model,
194 params,
195 } => Ok(Box::new(
196 super::ollama::OllamaClient::with_params(
197 base_url.clone(),
198 model.clone(),
199 params.clone(),
200 )
201 .await?,
202 )),
203
204 #[cfg(feature = "llamacpp")]
205 Provider::LlamaCpp { model_path, params } => Ok(Box::new(
206 super::llamacpp::LlamaCppClient::with_params(model_path.clone(), params.clone())?,
207 )),
208
209 #[cfg(feature = "anthropic")]
210 Provider::Anthropic {
211 api_key,
212 model,
213 params,
214 } => Ok(Box::new(super::anthropic::AnthropicClient::with_params(
215 api_key.clone(),
216 model.clone(),
217 params.clone(),
218 ))),
219 _ => unreachable!("Provider variant not enabled"),
220 }
221 }
222
223 #[allow(unreachable_code)]
258 pub fn from_env() -> Result<Self> {
259 #[cfg(feature = "llamacpp")]
261 if let Ok(model_path) = std::env::var("LLAMACPP_MODEL_PATH") {
262 if !model_path.is_empty() {
263 return Ok(Provider::LlamaCpp {
264 model_path,
265 params: ModelParams::default(),
266 });
267 }
268 }
269
270 #[cfg(feature = "openai")]
272 if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
273 if !api_key.is_empty() {
274 let api_base = std::env::var("OPENAI_API_BASE")
275 .unwrap_or_else(|_| "https://api.openai.com/v1".into());
276 let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4".into());
277 return Ok(Provider::OpenAI {
278 api_key,
279 api_base,
280 model,
281 params: ModelParams::default(),
282 });
283 }
284 }
285
286 #[cfg(feature = "anthropic")]
288 if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
289 if !api_key.is_empty() {
290 let model = std::env::var("ANTHROPIC_MODEL")
291 .unwrap_or_else(|_| "claude-3-5-sonnet-20241022".into());
292 return Ok(Provider::Anthropic {
293 api_key,
294 model,
295 params: ModelParams::default(),
296 });
297 }
298 }
299
300 #[cfg(feature = "ollama")]
302 {
303 let base_url = std::env::var("OLLAMA_URL")
305 .or_else(|_| std::env::var("OLLAMA_BASE_URL"))
306 .unwrap_or_else(|_| "http://localhost:11434".into());
307 let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "ministral-3:3b".into());
308 return Ok(Provider::Ollama {
309 base_url,
310 model,
311 params: ModelParams::default(),
312 });
313 }
314
315 #[allow(unreachable_code)]
317 Err(AppError::Configuration(
318 "No LLM provider configured. Enable a feature (ollama, openai, llamacpp) and set the appropriate environment variables.".into(),
319 ))
320 }
321
322 #[allow(unreachable_patterns)]
324 pub fn name(&self) -> &'static str {
325 match self {
326 #[cfg(feature = "openai")]
327 Provider::OpenAI { .. } => "openai",
328
329 #[cfg(feature = "ollama")]
330 Provider::Ollama { .. } => "ollama",
331
332 #[cfg(feature = "llamacpp")]
333 Provider::LlamaCpp { .. } => "llamacpp",
334
335 #[cfg(feature = "anthropic")]
336 Provider::Anthropic { .. } => "anthropic",
337 _ => unreachable!("Provider variant not enabled"),
338 }
339 }
340
341 #[allow(unreachable_patterns)]
343 pub fn requires_api_key(&self) -> bool {
344 match self {
345 #[cfg(feature = "openai")]
346 Provider::OpenAI { .. } => true,
347
348 #[cfg(feature = "ollama")]
349 Provider::Ollama { .. } => false,
350
351 #[cfg(feature = "llamacpp")]
352 Provider::LlamaCpp { .. } => false,
353
354 #[cfg(feature = "anthropic")]
355 Provider::Anthropic { .. } => true,
356 _ => unreachable!("Provider variant not enabled"),
357 }
358 }
359
360 #[allow(unreachable_patterns)]
362 pub fn is_local(&self) -> bool {
363 match self {
364 #[cfg(feature = "openai")]
365 Provider::OpenAI { api_base, .. } => {
366 api_base.contains("localhost") || api_base.contains("127.0.0.1")
367 }
368
369 #[cfg(feature = "ollama")]
370 Provider::Ollama { base_url, .. } => {
371 base_url.contains("localhost") || base_url.contains("127.0.0.1")
372 }
373
374 #[cfg(feature = "llamacpp")]
375 Provider::LlamaCpp { .. } => true,
376
377 #[cfg(feature = "anthropic")]
378 Provider::Anthropic { .. } => false,
379 _ => unreachable!("Provider variant not enabled"),
380 }
381 }
382
383 #[allow(unused_variables)]
395 pub fn from_config(
396 provider_config: &ProviderConfig,
397 model_override: Option<&str>,
398 ) -> Result<Self> {
399 Self::from_config_with_params(provider_config, model_override, ModelParams::default())
400 }
401
402 #[allow(unused_variables)]
404 pub fn from_config_with_params(
405 provider_config: &ProviderConfig,
406 model_override: Option<&str>,
407 params: ModelParams,
408 ) -> Result<Self> {
409 match provider_config {
410 #[cfg(feature = "ollama")]
411 ProviderConfig::Ollama {
412 base_url,
413 default_model,
414 } => Ok(Provider::Ollama {
415 base_url: base_url.clone(),
416 model: model_override
417 .map(String::from)
418 .unwrap_or_else(|| default_model.clone()),
419 params,
420 }),
421
422 #[cfg(not(feature = "ollama"))]
423 ProviderConfig::Ollama { .. } => Err(AppError::Configuration(
424 "Ollama provider configured but 'ollama' feature is not enabled".into(),
425 )),
426
427 #[cfg(feature = "openai")]
428 ProviderConfig::OpenAI {
429 api_key_env,
430 api_base,
431 default_model,
432 } => {
433 let api_key = std::env::var(api_key_env).map_err(|_| {
434 AppError::Configuration(format!(
435 "OpenAI API key environment variable '{}' is not set",
436 api_key_env
437 ))
438 })?;
439 Ok(Provider::OpenAI {
440 api_key,
441 api_base: api_base.clone(),
442 model: model_override
443 .map(String::from)
444 .unwrap_or_else(|| default_model.clone()),
445 params,
446 })
447 }
448
449 #[cfg(not(feature = "openai"))]
450 ProviderConfig::OpenAI { .. } => Err(AppError::Configuration(
451 "OpenAI provider configured but 'openai' feature is not enabled".into(),
452 )),
453
454 #[cfg(feature = "llamacpp")]
455 ProviderConfig::LlamaCpp { model_path, .. } => Ok(Provider::LlamaCpp {
456 model_path: model_path.clone(),
457 params,
458 }),
459
460 #[cfg(not(feature = "llamacpp"))]
461 ProviderConfig::LlamaCpp { .. } => Err(AppError::Configuration(
462 "LlamaCpp provider configured but 'llamacpp' feature is not enabled".into(),
463 )),
464
465 #[cfg(feature = "anthropic")]
466 ProviderConfig::Anthropic {
467 api_key_env,
468 default_model,
469 } => {
470 let api_key = std::env::var(api_key_env).map_err(|_| {
471 AppError::Configuration(format!(
472 "Anthropic API key environment variable '{}' is not set",
473 api_key_env
474 ))
475 })?;
476 Ok(Provider::Anthropic {
477 api_key,
478 model: model_override
479 .map(String::from)
480 .unwrap_or_else(|| default_model.clone()),
481 params,
482 })
483 }
484
485 #[cfg(not(feature = "anthropic"))]
486 ProviderConfig::Anthropic { .. } => Err(AppError::Configuration(
487 "Anthropic provider configured but 'anthropic' feature is not enabled".into(),
488 )),
489 }
490 }
491
492 pub fn from_model_config(
497 model_config: &ModelConfig,
498 provider_config: &ProviderConfig,
499 ) -> Result<Self> {
500 let params = ModelParams::from_model_config(model_config);
501 Self::from_config_with_params(provider_config, Some(&model_config.model), params)
502 }
503}
504
505#[async_trait]
507pub trait LLMClientFactoryTrait: Send + Sync {
508 fn default_provider(&self) -> &Provider;
510
511 async fn create_default(&self) -> Result<Box<dyn LLMClient>>;
513
514 async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>>;
516}
517
518pub struct LLMClientFactory {
523 default_provider: Provider,
524}
525
526impl LLMClientFactory {
527 pub fn new(default_provider: Provider) -> Self {
529 Self { default_provider }
530 }
531
532 pub fn from_env() -> Result<Self> {
536 Ok(Self {
537 default_provider: Provider::from_env()?,
538 })
539 }
540
541 pub fn default_provider(&self) -> &Provider {
543 &self.default_provider
544 }
545
546 pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
548 self.default_provider.create_client().await
549 }
550
551 pub async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
553 provider.create_client().await
554 }
555}
556
557#[async_trait]
558impl LLMClientFactoryTrait for LLMClientFactory {
559 fn default_provider(&self) -> &Provider {
560 &self.default_provider
561 }
562
563 async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
564 self.default_provider.create_client().await
565 }
566
567 async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
568 provider.create_client().await
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575
576 #[test]
577 fn test_llm_response_creation() {
578 let response = LLMResponse {
579 content: "Hello".to_string(),
580 tool_calls: vec![],
581 finish_reason: "stop".to_string(),
582 usage: None,
583 };
584
585 assert_eq!(response.content, "Hello");
586 assert!(response.tool_calls.is_empty());
587 assert_eq!(response.finish_reason, "stop");
588 assert!(response.usage.is_none());
589 }
590
591 #[test]
592 fn test_llm_response_with_usage() {
593 let usage = TokenUsage::new(100, 50);
594 let response = LLMResponse {
595 content: "Hello".to_string(),
596 tool_calls: vec![],
597 finish_reason: "stop".to_string(),
598 usage: Some(usage),
599 };
600
601 assert!(response.usage.is_some());
602 let usage = response.usage.unwrap();
603 assert_eq!(usage.prompt_tokens, 100);
604 assert_eq!(usage.completion_tokens, 50);
605 assert_eq!(usage.total_tokens, 150);
606 }
607
608 #[test]
609 fn test_llm_response_with_tool_calls() {
610 let tool_calls = vec![
611 ToolCall {
612 id: "1".to_string(),
613 name: "calculator".to_string(),
614 arguments: serde_json::json!({"a": 1, "b": 2}),
615 },
616 ToolCall {
617 id: "2".to_string(),
618 name: "search".to_string(),
619 arguments: serde_json::json!({"query": "test"}),
620 },
621 ];
622
623 let response = LLMResponse {
624 content: "".to_string(),
625 tool_calls,
626 finish_reason: "tool_calls".to_string(),
627 usage: Some(TokenUsage::new(50, 25)),
628 };
629
630 assert_eq!(response.tool_calls.len(), 2);
631 assert_eq!(response.tool_calls[0].name, "calculator");
632 assert_eq!(response.finish_reason, "tool_calls");
633 assert_eq!(response.usage.as_ref().unwrap().total_tokens, 75);
634 }
635
636 #[test]
637 fn test_factory_creation() {
638 #[cfg(feature = "ollama")]
641 {
642 let factory = LLMClientFactory::new(Provider::Ollama {
643 base_url: "http://localhost:11434".to_string(),
644 model: "test".to_string(),
645 params: ModelParams::default(),
646 });
647 assert_eq!(factory.default_provider().name(), "ollama");
648 }
649 }
650
651 #[cfg(feature = "ollama")]
652 #[test]
653 fn test_ollama_provider_properties() {
654 let provider = Provider::Ollama {
655 base_url: "http://localhost:11434".to_string(),
656 model: "ministral-3:3b".to_string(),
657 params: ModelParams::default(),
658 };
659
660 assert_eq!(provider.name(), "ollama");
661 assert!(!provider.requires_api_key());
662 assert!(provider.is_local());
663 }
664
665 #[cfg(feature = "openai")]
666 #[test]
667 fn test_openai_provider_properties() {
668 let provider = Provider::OpenAI {
669 api_key: "sk-test".to_string(),
670 api_base: "https://api.openai.com/v1".to_string(),
671 model: "gpt-4".to_string(),
672 params: ModelParams::default(),
673 };
674
675 assert_eq!(provider.name(), "openai");
676 assert!(provider.requires_api_key());
677 assert!(!provider.is_local());
678 }
679
680 #[cfg(feature = "openai")]
681 #[test]
682 fn test_openai_local_provider() {
683 let provider = Provider::OpenAI {
684 api_key: "test".to_string(),
685 api_base: "http://localhost:8000/v1".to_string(),
686 model: "local-model".to_string(),
687 params: ModelParams::default(),
688 };
689
690 assert!(provider.is_local());
691 }
692
693 #[cfg(feature = "llamacpp")]
694 #[test]
695 fn test_llamacpp_provider_properties() {
696 let provider = Provider::LlamaCpp {
697 model_path: "/path/to/model.gguf".to_string(),
698 params: ModelParams::default(),
699 };
700
701 assert_eq!(provider.name(), "llamacpp");
702 assert!(!provider.requires_api_key());
703 assert!(provider.is_local());
704 }
705
706 #[cfg(feature = "anthropic")]
707 #[test]
708 fn test_anthropic_provider_properties() {
709 let provider = Provider::Anthropic {
710 api_key: "sk-ant-test".to_string(),
711 model: "claude-3-5-sonnet-20241022".to_string(),
712 params: ModelParams::default(),
713 };
714
715 assert_eq!(provider.name(), "anthropic");
716 assert!(provider.requires_api_key());
717 assert!(!provider.is_local());
718 }
719}