llm_stack_ollama/
factory.rs1use llm_stack::registry::{ProviderConfig, ProviderFactory};
4use llm_stack::{DynProvider, LlmError};
5
6use crate::{OllamaConfig, OllamaProvider};
7
8#[derive(Debug, Clone, Copy, Default)]
30pub struct OllamaFactory;
31
32impl ProviderFactory for OllamaFactory {
33 fn name(&self) -> &'static str {
34 "ollama"
35 }
36
37 fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
38 if config.model.is_empty() {
39 return Err(LlmError::InvalidRequest(
40 "ollama provider requires model".into(),
41 ));
42 }
43
44 let mut ollama_config = OllamaConfig {
45 model: config.model.clone(),
46 ..Default::default()
47 };
48
49 if let Some(base_url) = &config.base_url {
50 ollama_config.base_url.clone_from(base_url);
51 }
52
53 if let Some(timeout) = config.timeout {
54 ollama_config.timeout = Some(timeout);
55 }
56
57 Ok(Box::new(OllamaProvider::new(ollama_config)))
58 }
59}
60
61pub fn register_global() {
66 llm_stack::ProviderRegistry::global().register(Box::new(OllamaFactory));
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use std::time::Duration;
73
74 #[test]
75 fn test_factory_name() {
76 let factory = OllamaFactory;
77 assert_eq!(factory.name(), "ollama");
78 }
79
80 #[test]
81 fn test_factory_build_success() {
82 let factory = OllamaFactory;
83 let config = ProviderConfig::new("ollama", "llama3.2")
84 .base_url("http://remote:11434")
85 .timeout(Duration::from_secs(60));
86
87 let provider = factory.build(&config).unwrap();
88 assert_eq!(provider.metadata().name, "ollama");
89 assert_eq!(provider.metadata().model, "llama3.2");
90 }
91
92 #[test]
93 fn test_factory_no_api_key_required() {
94 let factory = OllamaFactory;
95 let config = ProviderConfig::new("ollama", "mistral");
97
98 let provider = factory.build(&config).unwrap();
99 assert_eq!(provider.metadata().model, "mistral");
100 }
101
102 #[test]
103 fn test_factory_empty_model() {
104 let factory = OllamaFactory;
105 let config = ProviderConfig::new("ollama", "");
106
107 let err = factory.build(&config).err().unwrap();
108 assert!(matches!(err, LlmError::InvalidRequest(_)));
109 }
110}