use llm_stack::registry::{ProviderConfig, ProviderFactory};
use llm_stack::{DynProvider, LlmError};
use crate::{OllamaConfig, OllamaProvider};
#[derive(Debug, Clone, Copy, Default)]
pub struct OllamaFactory;
impl ProviderFactory for OllamaFactory {
fn name(&self) -> &'static str {
"ollama"
}
fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
if config.model.is_empty() {
return Err(LlmError::InvalidRequest(
"ollama provider requires model".into(),
));
}
let mut ollama_config = OllamaConfig {
model: config.model.clone(),
client: config.client.clone(),
..Default::default()
};
if let Some(base_url) = &config.base_url {
ollama_config.base_url.clone_from(base_url);
}
if let Some(timeout) = config.timeout {
ollama_config.timeout = Some(timeout);
}
Ok(Box::new(OllamaProvider::new(ollama_config)))
}
}
pub fn register_global() {
llm_stack::ProviderRegistry::global().register(Box::new(OllamaFactory));
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_factory_name() {
let factory = OllamaFactory;
assert_eq!(factory.name(), "ollama");
}
#[test]
fn test_factory_build_success() {
let factory = OllamaFactory;
let config = ProviderConfig::new("ollama", "llama3.2")
.base_url("http://remote:11434")
.timeout(Duration::from_secs(60));
let provider = factory.build(&config).unwrap();
assert_eq!(provider.metadata().name, "ollama");
assert_eq!(provider.metadata().model, "llama3.2");
}
#[test]
fn test_factory_no_api_key_required() {
let factory = OllamaFactory;
let config = ProviderConfig::new("ollama", "mistral");
let provider = factory.build(&config).unwrap();
assert_eq!(provider.metadata().model, "mistral");
}
#[test]
fn test_factory_empty_model() {
let factory = OllamaFactory;
let config = ProviderConfig::new("ollama", "");
let err = factory.build(&config).err().unwrap();
assert!(matches!(err, LlmError::InvalidRequest(_)));
}
}