llm_stack_ollama/
factory.rs1use llm_stack::registry::{ProviderConfig, ProviderFactory};
4use llm_stack::{DynProvider, LlmError};
5
6use crate::{OllamaConfig, OllamaProvider};
7
8#[derive(Debug, Clone, Copy, Default)]
30pub struct OllamaFactory;
31
32impl ProviderFactory for OllamaFactory {
33 fn name(&self) -> &'static str {
34 "ollama"
35 }
36
37 fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
38 if config.model.is_empty() {
39 return Err(LlmError::InvalidRequest(
40 "ollama provider requires model".into(),
41 ));
42 }
43
44 let mut ollama_config = OllamaConfig {
45 model: config.model.clone(),
46 client: config.client.clone(),
47 ..Default::default()
48 };
49
50 if let Some(base_url) = &config.base_url {
51 ollama_config.base_url.clone_from(base_url);
52 }
53
54 if let Some(timeout) = config.timeout {
55 ollama_config.timeout = Some(timeout);
56 }
57
58 Ok(Box::new(OllamaProvider::new(ollama_config)))
59 }
60}
61
62pub fn register_global() {
67 llm_stack::ProviderRegistry::global().register(Box::new(OllamaFactory));
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use std::time::Duration;
74
75 #[test]
76 fn test_factory_name() {
77 let factory = OllamaFactory;
78 assert_eq!(factory.name(), "ollama");
79 }
80
81 #[test]
82 fn test_factory_build_success() {
83 let factory = OllamaFactory;
84 let config = ProviderConfig::new("ollama", "llama3.2")
85 .base_url("http://remote:11434")
86 .timeout(Duration::from_secs(60));
87
88 let provider = factory.build(&config).unwrap();
89 assert_eq!(provider.metadata().name, "ollama");
90 assert_eq!(provider.metadata().model, "llama3.2");
91 }
92
93 #[test]
94 fn test_factory_no_api_key_required() {
95 let factory = OllamaFactory;
96 let config = ProviderConfig::new("ollama", "mistral");
98
99 let provider = factory.build(&config).unwrap();
100 assert_eq!(provider.metadata().model, "mistral");
101 }
102
103 #[test]
104 fn test_factory_empty_model() {
105 let factory = OllamaFactory;
106 let config = ProviderConfig::new("ollama", "");
107
108 let err = factory.build(&config).err().unwrap();
109 assert!(matches!(err, LlmError::InvalidRequest(_)));
110 }
111}