Skip to main content

limit_llm/
provider_factory.rs

1use crate::client::AnthropicClient;
2use crate::config::Config;
3use crate::error::LlmError;
4use crate::openai_provider::OpenAiProvider;
5use crate::providers::LlmProvider;
6use crate::zai_provider::{ThinkingConfig, ZaiProvider};
7use std::boxed::Box;
8
9pub struct ProviderFactory;
10
11impl ProviderFactory {
12    pub fn create_provider(config: &Config) -> Result<Box<dyn LlmProvider>, LlmError> {
13        let provider_config = config.providers.get(&config.provider).ok_or_else(|| {
14            LlmError::ConfigError(format!(
15                "Provider '{}' not found in config",
16                config.provider
17            ))
18        })?;
19
20        let api_key = provider_config.api_key_or_env(&config.provider)
21            .ok_or_else(|| LlmError::ConfigError(format!("No API key found for provider '{}'. Set api_key in config or {}_API_KEY env var",
22                config.provider, config.provider.to_uppercase())))?;
23
24        match config.provider.as_str() {
25            "anthropic" => Ok(Box::new(AnthropicClient::new(
26                api_key,
27                provider_config.base_url.as_deref(),
28                provider_config.timeout,
29                &provider_config.model,
30                provider_config.max_tokens,
31            ))),
32            "openai" => Ok(Box::new(OpenAiProvider::new(
33                api_key,
34                provider_config.base_url.as_deref(),
35                &provider_config.model,
36                provider_config.max_tokens,
37                provider_config.timeout,
38            ))),
39            "zai" => {
40                let thinking_config = ThinkingConfig {
41                    thinking_enabled: provider_config.thinking_enabled,
42                    clear_thinking: provider_config.clear_thinking,
43                };
44                Ok(Box::new(ZaiProvider::new(
45                    api_key,
46                    provider_config.base_url.as_deref(),
47                    &provider_config.model,
48                    provider_config.max_tokens,
49                    provider_config.timeout,
50                    thinking_config,
51                )))
52            }
53            _ => Err(LlmError::ConfigError(format!(
54                "Unknown provider: {}",
55                config.provider
56            ))),
57        }
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[test]
66    fn test_zai_provider_factory() {
67        let config_content = r#"
68provider = "zai"
69
70[providers.zai]
71api_key = "test-zai-key"
72model = "glm-4.7"
73"#;
74        let config: Config = toml::from_str(config_content).unwrap();
75        let provider = ProviderFactory::create_provider(&config).unwrap();
76        assert_eq!(provider.provider_name(), "zai");
77        // Now uses ZaiProvider
78    }
79}