1use crate::types::{AppError, Result, ToolCall, ToolDefinition};
2use crate::utils::toml_config::{ModelConfig, ProviderConfig};
3use async_trait::async_trait;
4
5#[async_trait]
7pub trait LLMClient: Send + Sync {
8 async fn generate(&self, prompt: &str) -> Result<String>;
10
11 async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String>;
13
14 async fn generate_with_history(
16 &self,
17 messages: &[(String, String)], ) -> Result<String>;
19
20 async fn generate_with_tools(
22 &self,
23 prompt: &str,
24 tools: &[ToolDefinition],
25 ) -> Result<LLMResponse>;
26
27 async fn stream(
29 &self,
30 prompt: &str,
31 ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>>;
32
33 fn model_name(&self) -> &str;
35}
36
37#[derive(Debug, Clone)]
39pub struct LLMResponse {
40 pub content: String,
42 pub tool_calls: Vec<ToolCall>,
44 pub finish_reason: String,
46}
47
48#[derive(Debug, Clone)]
53#[non_exhaustive]
54pub enum Provider {
55 #[cfg(feature = "openai")]
57 OpenAI {
58 api_key: String,
60 api_base: String,
62 model: String,
64 },
65
66 #[cfg(feature = "ollama")]
68 Ollama {
69 base_url: String,
71 model: String,
73 },
74
75 #[cfg(feature = "llamacpp")]
77 LlamaCpp {
78 model_path: String,
80 },
81}
82
83impl Provider {
84 pub async fn create_client(&self) -> Result<Box<dyn LLMClient>> {
93 match self {
94 #[cfg(feature = "openai")]
95 Provider::OpenAI {
96 api_key,
97 api_base,
98 model,
99 } => Ok(Box::new(super::openai::OpenAIClient::new(
100 api_key.clone(),
101 api_base.clone(),
102 model.clone(),
103 ))),
104
105 #[cfg(feature = "ollama")]
106 Provider::Ollama { base_url, model } => Ok(Box::new(
107 super::ollama::OllamaClient::new(base_url.clone(), model.clone()).await?,
108 )),
109
110 #[cfg(feature = "llamacpp")]
111 Provider::LlamaCpp { model_path } => Ok(Box::new(
112 super::llamacpp::LlamaCppClient::new(model_path.clone())?,
113 )),
114 }
115 }
116
117 #[allow(unreachable_code)]
152 pub fn from_env() -> Result<Self> {
153 #[cfg(feature = "llamacpp")]
155 if let Ok(model_path) = std::env::var("LLAMACPP_MODEL_PATH") {
156 if !model_path.is_empty() {
157 return Ok(Provider::LlamaCpp { model_path });
158 }
159 }
160
161 #[cfg(feature = "openai")]
163 if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
164 if !api_key.is_empty() {
165 let api_base = std::env::var("OPENAI_API_BASE")
166 .unwrap_or_else(|_| "https://api.openai.com/v1".into());
167 let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4".into());
168 return Ok(Provider::OpenAI {
169 api_key,
170 api_base,
171 model,
172 });
173 }
174 }
175
176 #[cfg(feature = "ollama")]
178 {
179 let base_url = std::env::var("OLLAMA_URL")
181 .or_else(|_| std::env::var("OLLAMA_BASE_URL"))
182 .unwrap_or_else(|_| "http://localhost:11434".into());
183 let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "ministral-3:3b".into());
184 return Ok(Provider::Ollama { base_url, model });
185 }
186
187 #[allow(unreachable_code)]
189 Err(AppError::Configuration(
190 "No LLM provider configured. Enable a feature (ollama, openai, llamacpp) and set the appropriate environment variables.".into(),
191 ))
192 }
193
194 pub fn name(&self) -> &'static str {
196 match self {
197 #[cfg(feature = "openai")]
198 Provider::OpenAI { .. } => "openai",
199
200 #[cfg(feature = "ollama")]
201 Provider::Ollama { .. } => "ollama",
202
203 #[cfg(feature = "llamacpp")]
204 Provider::LlamaCpp { .. } => "llamacpp",
205 }
206 }
207
208 pub fn requires_api_key(&self) -> bool {
210 match self {
211 #[cfg(feature = "openai")]
212 Provider::OpenAI { .. } => true,
213
214 #[cfg(feature = "ollama")]
215 Provider::Ollama { .. } => false,
216
217 #[cfg(feature = "llamacpp")]
218 Provider::LlamaCpp { .. } => false,
219 }
220 }
221
222 pub fn is_local(&self) -> bool {
224 match self {
225 #[cfg(feature = "openai")]
226 Provider::OpenAI { api_base, .. } => {
227 api_base.contains("localhost") || api_base.contains("127.0.0.1")
228 }
229
230 #[cfg(feature = "ollama")]
231 Provider::Ollama { base_url, .. } => {
232 base_url.contains("localhost") || base_url.contains("127.0.0.1")
233 }
234
235 #[cfg(feature = "llamacpp")]
236 Provider::LlamaCpp { .. } => true,
237 }
238 }
239
240 #[allow(unused_variables)]
252 pub fn from_config(
253 provider_config: &ProviderConfig,
254 model_override: Option<&str>,
255 ) -> Result<Self> {
256 match provider_config {
257 #[cfg(feature = "ollama")]
258 ProviderConfig::Ollama {
259 base_url,
260 default_model,
261 } => Ok(Provider::Ollama {
262 base_url: base_url.clone(),
263 model: model_override
264 .map(String::from)
265 .unwrap_or_else(|| default_model.clone()),
266 }),
267
268 #[cfg(not(feature = "ollama"))]
269 ProviderConfig::Ollama { .. } => Err(AppError::Configuration(
270 "Ollama provider configured but 'ollama' feature is not enabled".into(),
271 )),
272
273 #[cfg(feature = "openai")]
274 ProviderConfig::OpenAI {
275 api_key_env,
276 api_base,
277 default_model,
278 } => {
279 let api_key = std::env::var(api_key_env).map_err(|_| {
280 AppError::Configuration(format!(
281 "OpenAI API key environment variable '{}' is not set",
282 api_key_env
283 ))
284 })?;
285 Ok(Provider::OpenAI {
286 api_key,
287 api_base: api_base.clone(),
288 model: model_override
289 .map(String::from)
290 .unwrap_or_else(|| default_model.clone()),
291 })
292 }
293
294 #[cfg(not(feature = "openai"))]
295 ProviderConfig::OpenAI { .. } => Err(AppError::Configuration(
296 "OpenAI provider configured but 'openai' feature is not enabled".into(),
297 )),
298
299 #[cfg(feature = "llamacpp")]
300 ProviderConfig::LlamaCpp { model_path, .. } => Ok(Provider::LlamaCpp {
301 model_path: model_path.clone(),
302 }),
303
304 #[cfg(not(feature = "llamacpp"))]
305 ProviderConfig::LlamaCpp { .. } => Err(AppError::Configuration(
306 "LlamaCpp provider configured but 'llamacpp' feature is not enabled".into(),
307 )),
308 }
309 }
310
311 pub fn from_model_config(
316 model_config: &ModelConfig,
317 provider_config: &ProviderConfig,
318 ) -> Result<Self> {
319 Self::from_config(provider_config, Some(&model_config.model))
320 }
321}
322
323#[async_trait]
325pub trait LLMClientFactoryTrait: Send + Sync {
326 fn default_provider(&self) -> &Provider;
328
329 async fn create_default(&self) -> Result<Box<dyn LLMClient>>;
331
332 async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>>;
334}
335
336pub struct LLMClientFactory {
341 default_provider: Provider,
342}
343
344impl LLMClientFactory {
345 pub fn new(default_provider: Provider) -> Self {
347 Self { default_provider }
348 }
349
350 pub fn from_env() -> Result<Self> {
354 Ok(Self {
355 default_provider: Provider::from_env()?,
356 })
357 }
358
359 pub fn default_provider(&self) -> &Provider {
361 &self.default_provider
362 }
363
364 pub async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
366 self.default_provider.create_client().await
367 }
368
369 pub async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
371 provider.create_client().await
372 }
373}
374
375#[async_trait]
376impl LLMClientFactoryTrait for LLMClientFactory {
377 fn default_provider(&self) -> &Provider {
378 &self.default_provider
379 }
380
381 async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
382 self.default_provider.create_client().await
383 }
384
385 async fn create_with_provider(&self, provider: Provider) -> Result<Box<dyn LLMClient>> {
386 provider.create_client().await
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn test_llm_response_creation() {
396 let response = LLMResponse {
397 content: "Hello".to_string(),
398 tool_calls: vec![],
399 finish_reason: "stop".to_string(),
400 };
401
402 assert_eq!(response.content, "Hello");
403 assert!(response.tool_calls.is_empty());
404 assert_eq!(response.finish_reason, "stop");
405 }
406
407 #[test]
408 fn test_llm_response_with_tool_calls() {
409 let tool_calls = vec![
410 ToolCall {
411 id: "1".to_string(),
412 name: "calculator".to_string(),
413 arguments: serde_json::json!({"a": 1, "b": 2}),
414 },
415 ToolCall {
416 id: "2".to_string(),
417 name: "search".to_string(),
418 arguments: serde_json::json!({"query": "test"}),
419 },
420 ];
421
422 let response = LLMResponse {
423 content: "".to_string(),
424 tool_calls,
425 finish_reason: "tool_calls".to_string(),
426 };
427
428 assert_eq!(response.tool_calls.len(), 2);
429 assert_eq!(response.tool_calls[0].name, "calculator");
430 assert_eq!(response.finish_reason, "tool_calls");
431 }
432
433 #[test]
434 fn test_factory_creation() {
435 #[cfg(feature = "ollama")]
438 {
439 let factory = LLMClientFactory::new(Provider::Ollama {
440 base_url: "http://localhost:11434".to_string(),
441 model: "test".to_string(),
442 });
443 assert_eq!(factory.default_provider().name(), "ollama");
444 }
445 }
446
447 #[cfg(feature = "ollama")]
448 #[test]
449 fn test_ollama_provider_properties() {
450 let provider = Provider::Ollama {
451 base_url: "http://localhost:11434".to_string(),
452 model: "ministral-3:3b".to_string(),
453 };
454
455 assert_eq!(provider.name(), "ollama");
456 assert!(!provider.requires_api_key());
457 assert!(provider.is_local());
458 }
459
460 #[cfg(feature = "openai")]
461 #[test]
462 fn test_openai_provider_properties() {
463 let provider = Provider::OpenAI {
464 api_key: "sk-test".to_string(),
465 api_base: "https://api.openai.com/v1".to_string(),
466 model: "gpt-4".to_string(),
467 };
468
469 assert_eq!(provider.name(), "openai");
470 assert!(provider.requires_api_key());
471 assert!(!provider.is_local());
472 }
473
474 #[cfg(feature = "openai")]
475 #[test]
476 fn test_openai_local_provider() {
477 let provider = Provider::OpenAI {
478 api_key: "test".to_string(),
479 api_base: "http://localhost:8000/v1".to_string(),
480 model: "local-model".to_string(),
481 };
482
483 assert!(provider.is_local());
484 }
485
486 #[cfg(feature = "llamacpp")]
487 #[test]
488 fn test_llamacpp_provider_properties() {
489 let provider = Provider::LlamaCpp {
490 model_path: "/path/to/model.gguf".to_string(),
491 };
492
493 assert_eq!(provider.name(), "llamacpp");
494 assert!(!provider.requires_api_key());
495 assert!(provider.is_local());
496 }
497}