Skip to main content

a3s_code_core/llm/
factory.rs

1//! LLM client factory
2
3use super::anthropic::AnthropicClient;
4use super::openai::OpenAiClient;
5use super::types::SecretString;
6use super::zhipu::ZhipuClient;
7use super::LlmClient;
8use crate::retry::RetryConfig;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// LLM client configuration
13#[derive(Clone, Default)]
14pub struct LlmConfig {
15    pub provider: String,
16    pub model: String,
17    pub api_key: SecretString,
18    pub base_url: Option<String>,
19    pub headers: HashMap<String, String>,
20    pub session_id_header: Option<String>,
21    pub session_id: Option<String>,
22    pub retry_config: Option<RetryConfig>,
23    /// Sampling temperature (0.0–1.0). None uses the provider default.
24    pub temperature: Option<f32>,
25    /// Maximum tokens to generate. None uses the client default.
26    pub max_tokens: Option<usize>,
27    /// Extended thinking budget in tokens (Anthropic only).
28    pub thinking_budget: Option<usize>,
29    /// When true, temperature is never sent to the API (e.g., o1 models).
30    pub disable_temperature: bool,
31}
32
33impl std::fmt::Debug for LlmConfig {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("LlmConfig")
36            .field("provider", &self.provider)
37            .field("model", &self.model)
38            .field("api_key", &"[REDACTED]")
39            .field("base_url", &self.base_url)
40            .field("headers", &self.headers.keys().collect::<Vec<_>>())
41            .field("session_id_header", &self.session_id_header)
42            .field(
43                "session_id",
44                &self.session_id.as_ref().map(|_| "[REDACTED]"),
45            )
46            .field("retry_config", &self.retry_config)
47            .field("temperature", &self.temperature)
48            .field("max_tokens", &self.max_tokens)
49            .field("thinking_budget", &self.thinking_budget)
50            .field("disable_temperature", &self.disable_temperature)
51            .finish()
52    }
53}
54
55impl LlmConfig {
56    pub fn new(
57        provider: impl Into<String>,
58        model: impl Into<String>,
59        api_key: impl Into<String>,
60    ) -> Self {
61        Self {
62            provider: provider.into(),
63            model: model.into(),
64            api_key: SecretString::new(api_key.into()),
65            base_url: None,
66            headers: HashMap::new(),
67            session_id_header: None,
68            session_id: None,
69            retry_config: None,
70            temperature: None,
71            max_tokens: None,
72            thinking_budget: None,
73            disable_temperature: false,
74        }
75    }
76
77    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
78        self.base_url = Some(base_url.into());
79        self
80    }
81
82    pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
83        self.headers = headers;
84        self
85    }
86
87    pub fn with_session_id_header(mut self, header_name: impl Into<String>) -> Self {
88        self.session_id_header = Some(header_name.into());
89        self
90    }
91
92    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
93        self.session_id = Some(session_id.into());
94        self
95    }
96
97    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
98        self.retry_config = Some(retry_config);
99        self
100    }
101
102    pub fn with_temperature(mut self, temperature: f32) -> Self {
103        self.temperature = Some(temperature);
104        self
105    }
106
107    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
108        self.max_tokens = Some(max_tokens);
109        self
110    }
111
112    pub fn with_thinking_budget(mut self, budget: usize) -> Self {
113        self.thinking_budget = Some(budget);
114        self
115    }
116
117    pub(crate) fn resolved_headers(&self) -> HashMap<String, String> {
118        let mut headers = self.headers.clone();
119        if let (Some(header_name), Some(session_id)) = (&self.session_id_header, &self.session_id) {
120            headers.insert(header_name.clone(), session_id.clone());
121        }
122        headers
123    }
124}
125
126/// Create LLM client with full configuration (supports custom base_url)
127pub fn create_client_with_config(config: LlmConfig) -> Arc<dyn LlmClient> {
128    let retry = config.retry_config.clone().unwrap_or_default();
129    let api_key = config.api_key.expose().to_string();
130    let headers = config.resolved_headers();
131
132    match config.provider.as_str() {
133        "anthropic" | "claude" => {
134            let mut client = AnthropicClient::new(api_key, config.model)
135                .with_provider_name(config.provider.clone())
136                .with_retry_config(retry);
137            if let Some(base_url) = config.base_url {
138                client = client.with_base_url(base_url);
139            }
140            if !config.disable_temperature {
141                if let Some(temp) = config.temperature {
142                    client = client.with_temperature(temp);
143                }
144            }
145            if let Some(max) = config.max_tokens {
146                client = client.with_max_tokens(max);
147            }
148            if let Some(budget) = config.thinking_budget {
149                client = client.with_thinking_budget(budget);
150            }
151            Arc::new(client)
152        }
153        "openai" | "gpt" => {
154            let mut client = OpenAiClient::new(api_key, config.model)
155                .with_provider_name(config.provider.clone())
156                .with_retry_config(retry);
157            if let Some(base_url) = config.base_url {
158                client = client.with_base_url(base_url);
159            }
160            if !headers.is_empty() {
161                client = client.with_headers(headers.clone());
162            }
163            if !config.disable_temperature {
164                if let Some(temp) = config.temperature {
165                    client = client.with_temperature(temp);
166                }
167            }
168            if let Some(max) = config.max_tokens {
169                client = client.with_max_tokens(max);
170            }
171            Arc::new(client)
172        }
173        "glm" | "zhipu" | "bigmodel" => {
174            let mut client = ZhipuClient::new(api_key, config.model).with_retry_config(retry);
175            if let Some(base_url) = config.base_url {
176                client = client.with_base_url(base_url);
177            }
178            if !config.disable_temperature {
179                if let Some(temp) = config.temperature {
180                    client = client.with_temperature(temp);
181                }
182            }
183            if let Some(max) = config.max_tokens {
184                client = client.with_max_tokens(max);
185            }
186            Arc::new(client)
187        }
188        // OpenAI-compatible providers (deepseek, groq, together, ollama, etc.)
189        _ => {
190            tracing::info!(
191                "Using OpenAI-compatible client for provider '{}'",
192                config.provider
193            );
194            let mut client = OpenAiClient::new(api_key, config.model)
195                .with_provider_name(config.provider.clone())
196                .with_retry_config(retry);
197            if let Some(base_url) = config.base_url {
198                client = client.with_base_url(base_url);
199            }
200            if !headers.is_empty() {
201                client = client.with_headers(headers.clone());
202            }
203            if !config.disable_temperature {
204                if let Some(temp) = config.temperature {
205                    client = client.with_temperature(temp);
206                }
207            }
208            if let Some(max) = config.max_tokens {
209                client = client.with_max_tokens(max);
210            }
211            Arc::new(client)
212        }
213    }
214}