Skip to main content

limit_llm/
provider_factory.rs

1use crate::client::AnthropicClient;
2use crate::config::Config;
3use crate::error::LlmError;
4use crate::openai_provider::OpenAiProvider;
5use crate::providers::LlmProvider;
6use std::boxed::Box;
7
8pub struct ProviderFactory;
9
10impl ProviderFactory {
11    pub fn create_provider(config: &Config) -> Result<Box<dyn LlmProvider>, LlmError> {
12        let provider_config = config.providers.get(&config.provider).ok_or_else(|| {
13            LlmError::ConfigError(format!(
14                "Provider '{}' not found in config",
15                config.provider
16            ))
17        })?;
18
19        let api_key = provider_config.api_key_or_env(&config.provider)
20            .ok_or_else(|| LlmError::ConfigError(format!("No API key found for provider '{}'. Set api_key in config or {}_API_KEY env var",
21                config.provider, config.provider.to_uppercase())))?;
22
23        match config.provider.as_str() {
24            "anthropic" => Ok(Box::new(AnthropicClient::new(
25                api_key,
26                provider_config.base_url.as_deref(),
27                provider_config.timeout,
28                &provider_config.model,
29                provider_config.max_tokens,
30            ))),
31            "openai" => Ok(Box::new(OpenAiProvider::new(
32                api_key,
33                provider_config.base_url.as_deref(),
34                &provider_config.model,
35                provider_config.max_tokens,
36                provider_config.timeout,
37            ))),
38            "zai" => Ok(Box::new(OpenAiProvider::new(
39                api_key,
40                provider_config
41                    .base_url
42                    .as_deref()
43                    .or(Some("https://api.z.ai/api/coding/paas/v4/chat/completions")),
44                &provider_config.model,
45                provider_config.max_tokens,
46                provider_config.timeout,
47            ))),
48            _ => Err(LlmError::ConfigError(format!(
49                "Unknown provider: {}",
50                config.provider
51            ))),
52        }
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59
60    #[test]
61    fn test_zai_provider_factory() {
62        let config_content = r#"
63provider = "zai"
64
65[providers.zai]
66api_key = "test-zai-key"
67model = "glm-4.7"
68"#;
69        let config: Config = toml::from_str(config_content).unwrap();
70        let provider = ProviderFactory::create_provider(&config).unwrap();
71        assert_eq!(provider.provider_name(), "openai"); // Uses OpenAiProvider for now
72        assert_eq!(provider.model_name(), "glm-4.7");
73    }
74}