1use crate::types::{AppError, Result, ToolCall, ToolDefinition};
2use crate::utils::toml_config::{ModelConfig, ProviderConfig};
3use async_trait::async_trait;
4
5#[async_trait]
7pub trait LLMClient: Send + Sync {
8 async fn generate(&self, prompt: &str) -> Result<String>;
10
11 async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String>;
13
14 async fn generate_with_history(
16 &self,
17 messages: &[(String, String)], ) -> Result<String>;
19
20 async fn generate_with_tools(
22 &self,
23 prompt: &str,
24 tools: &[ToolDefinition],
25 ) -> Result<LLMResponse>;
26
27 async fn stream(
29 &self,
30 prompt: &str,
31 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
32
33 async fn stream_with_system(
35 &self,
36 system: &str,
37 prompt: &str,
38 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
39
40 async fn stream_with_history(
42 &self,
43 messages: &[(String, String)], ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
45
46 fn model_name(&self) -> &str;
48}
49
50#[derive(Debug, Clone)]
52pub struct LLMResponse {
53 pub content: String,
55 pub tool_calls: Vec<ToolCall>,
57 pub finish_reason: String,
59}
60
61#[derive(Debug, Clone, Default)]
63pub struct ModelParams {
64 pub temperature: Option<f32>,
66 pub max_tokens: Option<u32>,
68 pub top_p: Option<f32>,
70 pub frequency_penalty: Option<f32>,
72 pub presence_penalty: Option<f32>,
74}
75
76impl ModelParams {
77 pub fn from_model_config(config: &ModelConfig) -> Self {
79 Self {
80 temperature: Some(config.temperature),
81 max_tokens: Some(config.max_tokens),
82 top_p: config.top_p,
83 frequency_penalty: config.frequency_penalty,
84 presence_penalty: config.presence_penalty,
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
94#[non_exhaustive]
95pub enum Provider {
96 #[cfg(feature = "openai")]
98 OpenAI {
99 api_key: String,
101 api_base: String,
103 model: String,
105 params: ModelParams,
107 },
108
109 #[cfg(feature = "ollama")]
111 Ollama {
112 base_url: String,
114 model: String,
116 params: ModelParams,
118 },
119
120 #[cfg(feature = "llamacpp")]
122 LlamaCpp {
123 model_path: String,
125 params: ModelParams,
127 },
128}
129
130impl Provider {
131 #[allow(unreachable_patterns)]
140 pub async fn create_client(&self) -> Result<Box<dyn LLMClient>> {
141 match self {
142 #[cfg(feature = "openai")]
143 Provider::OpenAI {
144 api_key,
145 api_base,
146 model,
147 params,
148 } => Ok(Box::new(super::openai::OpenAIClient::with_params(
149 api_key.clone(),
150 api_base.clone(),
151 model.clone(),
152 params.clone(),
153 ))),
154
155 #[cfg(feature = "ollama")]
156 Provider::Ollama {
157 base_url,
158 model,
159 params,
160 } => Ok(Box::new(
161 super::ollama::OllamaClient::with_params(
162 base_url.clone(),
163 model.clone(),
164 params.clone(),
165 )
166 .await?,
167 )),
168
169 #[cfg(feature = "llamacpp")]
170 Provider::LlamaCpp { model_path, params } => Ok(Box::new(
171 super::llamacpp::LlamaCppClient::with_params(model_path.clone(), params.clone())?,
172 )),
173 _ => unreachable!("Provider variant not enabled"),
174 }
175 }
176
177 #[allow(unreachable_code)]
212 pub fn from_env() -> Result<Self> {
213 #[cfg(feature = "llamacpp")]
215 if let Ok(model_path) = std::env::var("LLAMACPP_MODEL_PATH") {
216 if !model_path.is_empty() {
217 return Ok(Provider::LlamaCpp {
218 model_path,
219 params: ModelParams::default(),
220 });
221 }
222 }
223
224 #[cfg(feature = "openai")]
226 if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
227 if !api_key.is_empty() {
228 let api_base = std::env::var("OPENAI_API_BASE")
229 .unwrap_or_else(|_| "https://api.openai.com/v1".into());
230 let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4".into());
231 return Ok(Provider::OpenAI {
232 api_key,
233 api_base,
234 model,
235 params: ModelParams::default(),
236 });
237 }
238 }
239
240 #[cfg(feature = "ollama")]
242 {
243 let base_url = std::env::var("OLLAMA_URL")
245 .or_else(|_| std::env::var("OLLAMA_BASE_URL"))
246 .unwrap_or_else(|_| "http://localhost:11434".into());
247 let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "ministral-3:3b".into());
248 return Ok(Provider::Ollama {
249 base_url,
250 model,
251 params: ModelParams::default(),
252 });
253 }
254
255 #[allow(unreachable_code)]
257 Err(AppError::Configuration(
258 "No LLM provider configured. Enable a feature (ollama, openai, llamacpp) and set the appropriate environment variables.".into(),
259 ))
260 }
261
262 #[allow(unreachable_patterns)]
264 pub fn name(&self) -> &'static str {
265 match self {
266 #[cfg(feature = "openai")]
267 Provider::OpenAI { .. } => "openai",
268
269 #[cfg(feature = "ollama")]
270 Provider::Ollama { .. } => "ollama",
271
272 #[cfg(feature = "llamacpp")]
273 Provider::LlamaCpp { .. } => "llamacpp",
274 _ => unreachable!("Provider variant not enabled"),
275 }
276 }
277
278 #[allow(unreachable_patterns)]
280 pub fn requires_api_key(&self) -> bool {
281 match self {
282 #[cfg(feature = "openai")]
283 Provider::OpenAI { .. } => true,
284
285 #[cfg(feature = "ollama")]
286 Provider::Ollama { .. } => false,
287
288 #[cfg(feature = "llamacpp")]
289 Provider::LlamaCpp { .. } => false,
290 _ => unreachable!("Provider variant not enabled"),
291 }
292 }
293
294 #[allow(unreachable_patterns)]
296 pub fn is_local(&self) -> bool {
297 match self {
298 #[cfg(feature = "openai")]
299 Provider::OpenAI { api_base, .. } => {
300 api_base.contains("localhost") || api_base.contains("127.0.0.1")
301 }
302
303 #[cfg(feature = "ollama")]
304 Provider::Ollama { base_url, .. } => {
305 base_url.contains("localhost") || base_url.contains("127.0.0.1")
306 }
307
308 #[cfg(feature = "llamacpp")]
309 Provider::LlamaCpp { .. } => true,
310 _ => unreachable!("Provider variant not enabled"),
311 }
312 }
313
314 #[allow(unused_variables)]
326 pub fn from_config(
327 provider_config: &ProviderConfig,
328 model_override: Option<&str>,
329 ) -> Result<Self> {
330 Self::from_config_with_params(provider_config, model_override, ModelParams::default())
331 }
332
333 #[allow(unused_variables)]
335 pub fn from_config_with_params(
336 provider_config: &ProviderConfig,
337 model_override: Option<&str>,
338 params: ModelParams,
339 ) -> Result<Self> {
340 match provider_config {
341 #[cfg(feature = "ollama")]
342 ProviderConfig::Ollama {
343 base_url,
344 default_model,
345 } => Ok(Provider::Ollama {
346 base_url: base_url.clone(),
347 model: model_override
348 .map(String::from)
349 .unwrap_or_else(|| default_model.clone()),
350 params,
351 }),
352
353 #[cfg(not(feature = "ollama"))]
354 ProviderConfig::Ollama { .. } => Err(AppError::Configuration(
355 "Ollama provider configured but 'ollama' feature is not enabled".into(),
356 )),
357
358 #[cfg(feature = "openai")]
359 ProviderConfig::OpenAI {
360 api_key_env,
361 api_base,
362 default_model,
363 } => {
364 let api_key = std::env::var(api_key_env).map_err(|_| {
365 AppError::Configuration(format!(
366 "OpenAI API key environment variable '{}' is not set",
367 api_key_env
368 ))
369 })?;
370 Ok(Provider::OpenAI {
371 api_key,
372 api_base: api_base.clone(),
373 model: model_override
374 .map(String::from)
375 .unwrap_or_else(|| default_model.clone()),
376 params,
377 })
378 }
379
380 #[cfg(not(feature = "openai"))]
381 ProviderConfig::OpenAI { .. } => Err(AppError::Configuration(
382 "OpenAI provider configured but 'openai' feature is not enabled".into(),
383 )),
384
385 #[cfg(feature = "llamacpp")]
386 ProviderConfig::LlamaCpp { model_path, .. } => Ok(Provider::LlamaCpp {
387 model_path: model_path.clone(),
388 params,
389 }),
390
391 #[cfg(not(feature = "llamacpp"))]
392 ProviderConfig::LlamaCpp { .. } => Err(AppError::Configuration(
393 "LlamaCpp provider configured but 'llamacpp' feature is not enabled".into(),
394 )),
395 }
396 }
397
398 pub fn from_model_config(
403 model_config: &ModelConfig,
404 provider_config: &ProviderConfig,
405 ) -> Result<Self> {
406 let params = ModelParams::from_model_config(model_config);
407 Self::from_config_with_params(provider_config, Some(&model_config.model), params)
408 }
409}
410
411#[async_trait]
413pub trait LLMClientFactoryTrait: Send + Sync {
414 fn default_provider(&self) -> &Provider;
416
417 async fn create_default(&self) -> Result<Box<dyn LLMClient>>;
419
420 async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>>;
422}
423
424pub struct LLMClientFactory {
429 default_provider: Provider,
430}
431
432impl LLMClientFactory {
433 pub fn new(default_provider: Provider) -> Self {
435 Self { default_provider }
436 }
437
438 pub fn from_env() -> Result<Self> {
442 Ok(Self {
443 default_provider: Provider::from_env()?,
444 })
445 }
446
447 pub fn default_provider(&self) -> &Provider {
449 &self.default_provider
450 }
451
452 pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
454 self.default_provider.create_client().await
455 }
456
457 pub async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
459 provider.create_client().await
460 }
461}
462
463#[async_trait]
464impl LLMClientFactoryTrait for LLMClientFactory {
465 fn default_provider(&self) -> &Provider {
466 &self.default_provider
467 }
468
469 async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
470 self.default_provider.create_client().await
471 }
472
473 async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
474 provider.create_client().await
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[test]
483 fn test_llm_response_creation() {
484 let response = LLMResponse {
485 content: "Hello".to_string(),
486 tool_calls: vec![],
487 finish_reason: "stop".to_string(),
488 };
489
490 assert_eq!(response.content, "Hello");
491 assert!(response.tool_calls.is_empty());
492 assert_eq!(response.finish_reason, "stop");
493 }
494
495 #[test]
496 fn test_llm_response_with_tool_calls() {
497 let tool_calls = vec![
498 ToolCall {
499 id: "1".to_string(),
500 name: "calculator".to_string(),
501 arguments: serde_json::json!({"a": 1, "b": 2}),
502 },
503 ToolCall {
504 id: "2".to_string(),
505 name: "search".to_string(),
506 arguments: serde_json::json!({"query": "test"}),
507 },
508 ];
509
510 let response = LLMResponse {
511 content: "".to_string(),
512 tool_calls,
513 finish_reason: "tool_calls".to_string(),
514 };
515
516 assert_eq!(response.tool_calls.len(), 2);
517 assert_eq!(response.tool_calls[0].name, "calculator");
518 assert_eq!(response.finish_reason, "tool_calls");
519 }
520
521 #[test]
522 fn test_factory_creation() {
523 #[cfg(feature = "ollama")]
526 {
527 let factory = LLMClientFactory::new(Provider::Ollama {
528 base_url: "http://localhost:11434".to_string(),
529 model: "test".to_string(),
530 params: ModelParams::default(),
531 });
532 assert_eq!(factory.default_provider().name(), "ollama");
533 }
534 }
535
536 #[cfg(feature = "ollama")]
537 #[test]
538 fn test_ollama_provider_properties() {
539 let provider = Provider::Ollama {
540 base_url: "http://localhost:11434".to_string(),
541 model: "ministral-3:3b".to_string(),
542 params: ModelParams::default(),
543 };
544
545 assert_eq!(provider.name(), "ollama");
546 assert!(!provider.requires_api_key());
547 assert!(provider.is_local());
548 }
549
550 #[cfg(feature = "openai")]
551 #[test]
552 fn test_openai_provider_properties() {
553 let provider = Provider::OpenAI {
554 api_key: "sk-test".to_string(),
555 api_base: "https://api.openai.com/v1".to_string(),
556 model: "gpt-4".to_string(),
557 params: ModelParams::default(),
558 };
559
560 assert_eq!(provider.name(), "openai");
561 assert!(provider.requires_api_key());
562 assert!(!provider.is_local());
563 }
564
565 #[cfg(feature = "openai")]
566 #[test]
567 fn test_openai_local_provider() {
568 let provider = Provider::OpenAI {
569 api_key: "test".to_string(),
570 api_base: "http://localhost:8000/v1".to_string(),
571 model: "local-model".to_string(),
572 params: ModelParams::default(),
573 };
574
575 assert!(provider.is_local());
576 }
577
578 #[cfg(feature = "llamacpp")]
579 #[test]
580 fn test_llamacpp_provider_properties() {
581 let provider = Provider::LlamaCpp {
582 model_path: "/path/to/model.gguf".to_string(),
583 params: ModelParams::default(),
584 };
585
586 assert_eq!(provider.name(), "llamacpp");
587 assert!(!provider.requires_api_key());
588 assert!(provider.is_local());
589 }
590}