Skip to main content

limit_llm/
provider_factory.rs

1use crate::client::AnthropicClient;
2use crate::config::Config;
3use crate::error::LlmError;
4use crate::local_provider::LocalProvider;
5use crate::openai_provider::OpenAiProvider;
6use crate::providers::LlmProvider;
7use crate::zai_provider::{ThinkingConfig, ZaiProvider};
8use std::boxed::Box;
9
10pub struct ProviderFactory;
11
12impl ProviderFactory {
13    pub fn create_provider(config: &Config) -> Result<Box<dyn LlmProvider>, LlmError> {
14        let provider_config = config.providers.get(&config.provider).ok_or_else(|| {
15            LlmError::ConfigError(format!(
16                "Provider '{}' not found in config",
17                config.provider
18            ))
19        })?;
20
21        let api_key = provider_config.api_key_or_env(&config.provider)
22            .ok_or_else(|| LlmError::ConfigError(format!("No API key found for provider '{}'. Set api_key in config or {}_API_KEY env var",
23                config.provider, config.provider.to_uppercase())))?;
24
25        match config.provider.as_str() {
26            "anthropic" => Ok(Box::new(AnthropicClient::new(
27                api_key,
28                provider_config.base_url.as_deref(),
29                provider_config.timeout,
30                &provider_config.model,
31                provider_config.max_tokens,
32            ))),
33            "openai" => Ok(Box::new(OpenAiProvider::new(
34                api_key,
35                provider_config.base_url.as_deref(),
36                &provider_config.model,
37                provider_config.max_tokens,
38                provider_config.timeout,
39            ))),
40            "zai" => {
41                let thinking_config = ThinkingConfig {
42                    thinking_enabled: provider_config.thinking_enabled,
43                    clear_thinking: provider_config.clear_thinking,
44                };
45                Ok(Box::new(ZaiProvider::new(
46                    api_key,
47                    provider_config.base_url.as_deref(),
48                    &provider_config.model,
49                    provider_config.max_tokens,
50                    provider_config.timeout,
51                    thinking_config,
52                )))
53            }
54            // Local LLM providers (all use LocalProvider with different defaults)
55            "local" | "ollama" | "lmstudio" | "vllm" => {
56                // Local providers don't require API key, use placeholder if empty
57                let _ = api_key; // Suppress unused warning
58                Ok(Box::new(LocalProvider::new(
59                    provider_config.base_url.as_deref(),
60                    &provider_config.model,
61                    provider_config.max_tokens,
62                    provider_config.timeout,
63                )))
64            }
65            _ => Err(LlmError::ConfigError(format!(
66                "Unknown provider: {}",
67                config.provider
68            ))),
69        }
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    #[test]
78    fn test_zai_provider_factory() {
79        let config_content = r#"
80provider = "zai"
81
82[providers.zai]
83api_key = "test-zai-key"
84model = "glm-4.7"
85"#;
86        let config: Config = toml::from_str(config_content).unwrap();
87        let provider = ProviderFactory::create_provider(&config).unwrap();
88        assert_eq!(provider.provider_name(), "zai");
89        // Now uses ZaiProvider
90    }
91}