limit_llm/
provider_factory.rs1use crate::client::AnthropicClient;
2use crate::config::Config;
3use crate::error::LlmError;
4use crate::openai_provider::OpenAiProvider;
5use crate::providers::LlmProvider;
6use std::boxed::Box;
7
8pub struct ProviderFactory;
9
10impl ProviderFactory {
11 pub fn create_provider(config: &Config) -> Result<Box<dyn LlmProvider>, LlmError> {
12 let provider_config = config.providers.get(&config.provider).ok_or_else(|| {
13 LlmError::ConfigError(format!(
14 "Provider '{}' not found in config",
15 config.provider
16 ))
17 })?;
18
19 let api_key = provider_config.api_key_or_env(&config.provider)
20 .ok_or_else(|| LlmError::ConfigError(format!("No API key found for provider '{}'. Set api_key in config or {}_API_KEY env var",
21 config.provider, config.provider.to_uppercase())))?;
22
23 match config.provider.as_str() {
24 "anthropic" => Ok(Box::new(AnthropicClient::new(
25 api_key,
26 provider_config.base_url.as_deref(),
27 provider_config.timeout,
28 &provider_config.model,
29 provider_config.max_tokens,
30 ))),
31 "openai" => Ok(Box::new(OpenAiProvider::new(
32 api_key,
33 provider_config.base_url.as_deref(),
34 &provider_config.model,
35 provider_config.max_tokens,
36 provider_config.timeout,
37 ))),
38 "zai" => Ok(Box::new(OpenAiProvider::new(
39 api_key,
40 provider_config
41 .base_url
42 .as_deref()
43 .or(Some("https://api.z.ai/api/coding/paas/v4/chat/completions")),
44 &provider_config.model,
45 provider_config.max_tokens,
46 provider_config.timeout,
47 ))),
48 _ => Err(LlmError::ConfigError(format!(
49 "Unknown provider: {}",
50 config.provider
51 ))),
52 }
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59
60 #[test]
61 fn test_zai_provider_factory() {
62 let config_content = r#"
63provider = "zai"
64
65[providers.zai]
66api_key = "test-zai-key"
67model = "glm-4.7"
68"#;
69 let config: Config = toml::from_str(config_content).unwrap();
70 let provider = ProviderFactory::create_provider(&config).unwrap();
71 assert_eq!(provider.provider_name(), "openai"); assert_eq!(provider.model_name(), "glm-4.7");
73 }
74}