a3s_code_core/llm/
factory.rs1use super::anthropic::AnthropicClient;
4use super::openai::OpenAiClient;
5use super::types::SecretString;
6use super::zhipu::ZhipuClient;
7use super::LlmClient;
8use crate::retry::RetryConfig;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[derive(Clone, Default)]
14pub struct LlmConfig {
15 pub provider: String,
16 pub model: String,
17 pub api_key: SecretString,
18 pub base_url: Option<String>,
19 pub headers: HashMap<String, String>,
20 pub session_id_header: Option<String>,
21 pub session_id: Option<String>,
22 pub retry_config: Option<RetryConfig>,
23 pub temperature: Option<f32>,
25 pub max_tokens: Option<usize>,
27 pub thinking_budget: Option<usize>,
29 pub disable_temperature: bool,
31}
32
33impl std::fmt::Debug for LlmConfig {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("LlmConfig")
36 .field("provider", &self.provider)
37 .field("model", &self.model)
38 .field("api_key", &"[REDACTED]")
39 .field("base_url", &self.base_url)
40 .field("headers", &self.headers.keys().collect::<Vec<_>>())
41 .field("session_id_header", &self.session_id_header)
42 .field(
43 "session_id",
44 &self.session_id.as_ref().map(|_| "[REDACTED]"),
45 )
46 .field("retry_config", &self.retry_config)
47 .field("temperature", &self.temperature)
48 .field("max_tokens", &self.max_tokens)
49 .field("thinking_budget", &self.thinking_budget)
50 .field("disable_temperature", &self.disable_temperature)
51 .finish()
52 }
53}
54
55impl LlmConfig {
56 pub fn new(
57 provider: impl Into<String>,
58 model: impl Into<String>,
59 api_key: impl Into<String>,
60 ) -> Self {
61 Self {
62 provider: provider.into(),
63 model: model.into(),
64 api_key: SecretString::new(api_key.into()),
65 base_url: None,
66 headers: HashMap::new(),
67 session_id_header: None,
68 session_id: None,
69 retry_config: None,
70 temperature: None,
71 max_tokens: None,
72 thinking_budget: None,
73 disable_temperature: false,
74 }
75 }
76
77 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
78 self.base_url = Some(base_url.into());
79 self
80 }
81
82 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
83 self.headers = headers;
84 self
85 }
86
87 pub fn with_session_id_header(mut self, header_name: impl Into<String>) -> Self {
88 self.session_id_header = Some(header_name.into());
89 self
90 }
91
92 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
93 self.session_id = Some(session_id.into());
94 self
95 }
96
97 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
98 self.retry_config = Some(retry_config);
99 self
100 }
101
102 pub fn with_temperature(mut self, temperature: f32) -> Self {
103 self.temperature = Some(temperature);
104 self
105 }
106
107 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
108 self.max_tokens = Some(max_tokens);
109 self
110 }
111
112 pub fn with_thinking_budget(mut self, budget: usize) -> Self {
113 self.thinking_budget = Some(budget);
114 self
115 }
116
117 pub(crate) fn resolved_headers(&self) -> HashMap<String, String> {
118 let mut headers = self.headers.clone();
119 if let (Some(header_name), Some(session_id)) = (&self.session_id_header, &self.session_id) {
120 headers.insert(header_name.clone(), session_id.clone());
121 }
122 headers
123 }
124}
125
126pub fn create_client_with_config(config: LlmConfig) -> Arc<dyn LlmClient> {
128 let retry = config.retry_config.clone().unwrap_or_default();
129 let api_key = config.api_key.expose().to_string();
130 let headers = config.resolved_headers();
131
132 match config.provider.as_str() {
133 "anthropic" | "claude" => {
134 let mut client = AnthropicClient::new(api_key, config.model)
135 .with_provider_name(config.provider.clone())
136 .with_retry_config(retry);
137 if let Some(base_url) = config.base_url {
138 client = client.with_base_url(base_url);
139 }
140 if !config.disable_temperature {
141 if let Some(temp) = config.temperature {
142 client = client.with_temperature(temp);
143 }
144 }
145 if let Some(max) = config.max_tokens {
146 client = client.with_max_tokens(max);
147 }
148 if let Some(budget) = config.thinking_budget {
149 client = client.with_thinking_budget(budget);
150 }
151 Arc::new(client)
152 }
153 "openai" | "gpt" => {
154 let mut client = OpenAiClient::new(api_key, config.model)
155 .with_provider_name(config.provider.clone())
156 .with_retry_config(retry);
157 if let Some(base_url) = config.base_url {
158 client = client.with_base_url(base_url);
159 }
160 if !headers.is_empty() {
161 client = client.with_headers(headers.clone());
162 }
163 if !config.disable_temperature {
164 if let Some(temp) = config.temperature {
165 client = client.with_temperature(temp);
166 }
167 }
168 if let Some(max) = config.max_tokens {
169 client = client.with_max_tokens(max);
170 }
171 Arc::new(client)
172 }
173 "glm" | "zhipu" | "bigmodel" => {
174 let mut client = ZhipuClient::new(api_key, config.model).with_retry_config(retry);
175 if let Some(base_url) = config.base_url {
176 client = client.with_base_url(base_url);
177 }
178 if !config.disable_temperature {
179 if let Some(temp) = config.temperature {
180 client = client.with_temperature(temp);
181 }
182 }
183 if let Some(max) = config.max_tokens {
184 client = client.with_max_tokens(max);
185 }
186 Arc::new(client)
187 }
188 _ => {
190 tracing::info!(
191 "Using OpenAI-compatible client for provider '{}'",
192 config.provider
193 );
194 let mut client = OpenAiClient::new(api_key, config.model)
195 .with_provider_name(config.provider.clone())
196 .with_retry_config(retry);
197 if let Some(base_url) = config.base_url {
198 client = client.with_base_url(base_url);
199 }
200 if !headers.is_empty() {
201 client = client.with_headers(headers.clone());
202 }
203 if !config.disable_temperature {
204 if let Some(temp) = config.temperature {
205 client = client.with_temperature(temp);
206 }
207 }
208 if let Some(max) = config.max_tokens {
209 client = client.with_max_tokens(max);
210 }
211 Arc::new(client)
212 }
213 }
214}