1use std::env;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use secrecy::SecretString;
15
16use super::circuit_breaker::CircuitBreaker;
17use super::provider::AiProvider;
18use super::registry::{ProviderConfig, get_provider};
19use crate::config::AiConfig;
20
21#[derive(Debug)]
26pub struct AiClient {
27 provider: &'static ProviderConfig,
29 http: Client,
31 api_key: SecretString,
33 model: String,
35 max_tokens: u32,
37 temperature: f32,
39 max_attempts: u32,
41 circuit_breaker: CircuitBreaker,
43 custom_guidance: Option<String>,
45}
46
47impl Drop for AiClient {
48 fn drop(&mut self) {
49 use zeroize::Zeroize;
50 self.api_key.zeroize();
53 }
54}
55
56impl AiClient {
57 pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
75 let provider = get_provider(provider_name)
77 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
78
79 if provider_name == "openrouter"
81 && !config.allow_paid_models
82 && !super::is_free_model(&config.model)
83 {
84 anyhow::bail!(
85 "Model '{}' is not in the free tier.\n\
86 To use paid models, set `allow_paid_models = true` in your config file:\n\
87 {}\n\n\
88 Or use a free model like: google/gemma-3-12b-it:free",
89 config.model,
90 crate::config::config_file_path().display()
91 );
92 }
93
94 let api_key = env::var(provider.api_key_env).with_context(|| {
96 format!(
97 "Missing {} environment variable.\n\
98 Set it with: export {}=your_api_key",
99 provider.api_key_env, provider.api_key_env
100 )
101 })?;
102
103 let http = Client::builder()
105 .timeout(Duration::from_secs(config.timeout_seconds))
106 .build()
107 .context("Failed to create HTTP client")?;
108
109 Ok(Self {
110 provider,
111 http,
112 api_key: SecretString::new(api_key.into()),
113 model: config.model.clone(),
114 max_tokens: config.max_tokens,
115 temperature: config.temperature,
116 max_attempts: config.retry_max_attempts,
117 circuit_breaker: CircuitBreaker::new(
118 config.circuit_breaker_threshold,
119 config.circuit_breaker_reset_seconds,
120 ),
121 custom_guidance: config.custom_guidance.clone(),
122 })
123 }
124
125 pub fn with_api_key(
145 provider_name: &str,
146 api_key: SecretString,
147 model_name: &str,
148 config: &AiConfig,
149 ) -> Result<Self> {
150 let provider = get_provider(provider_name)
152 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
153
154 if provider_name == "openrouter"
156 && !config.allow_paid_models
157 && !super::is_free_model(model_name)
158 {
159 anyhow::bail!(
160 "Model '{}' is not in the free tier.\n\
161 To use paid models, set `allow_paid_models = true` in your config file:\n\
162 {}\n\n\
163 Or use a free model like: google/gemma-3-12b-it:free",
164 model_name,
165 crate::config::config_file_path().display()
166 );
167 }
168
169 let http = Client::builder()
171 .timeout(Duration::from_secs(config.timeout_seconds))
172 .build()
173 .context("Failed to create HTTP client")?;
174
175 Ok(Self {
176 provider,
177 http,
178 api_key,
179 model: model_name.to_string(),
180 max_tokens: config.max_tokens,
181 temperature: config.temperature,
182 max_attempts: config.retry_max_attempts,
183 circuit_breaker: CircuitBreaker::new(
184 config.circuit_breaker_threshold,
185 config.circuit_breaker_reset_seconds,
186 ),
187 custom_guidance: config.custom_guidance.clone(),
188 })
189 }
190
191 #[must_use]
193 pub fn circuit_breaker(&self) -> &CircuitBreaker {
194 &self.circuit_breaker
195 }
196}
197
198#[async_trait]
199impl AiProvider for AiClient {
200 fn name(&self) -> &str {
201 self.provider.name
202 }
203
204 fn api_url(&self) -> &str {
205 self.provider.api_url
206 }
207
208 fn api_key_env(&self) -> &str {
209 self.provider.api_key_env
210 }
211
212 fn http_client(&self) -> &Client {
213 &self.http
214 }
215
216 fn api_key(&self) -> &SecretString {
217 &self.api_key
218 }
219
220 fn model(&self) -> &str {
221 &self.model
222 }
223
224 fn max_tokens(&self) -> u32 {
225 self.max_tokens
226 }
227
228 fn temperature(&self) -> f32 {
229 self.temperature
230 }
231
232 fn max_attempts(&self) -> u32 {
233 self.max_attempts
234 }
235
236 fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
237 Some(&self.circuit_breaker)
238 }
239
240 fn custom_guidance(&self) -> Option<&str> {
241 self.custom_guidance.as_deref()
242 }
243
244 fn build_headers(&self) -> reqwest::header::HeaderMap {
245 let mut headers = reqwest::header::HeaderMap::new();
246 if let Ok(val) = "application/json".parse() {
247 headers.insert("Content-Type", val);
248 }
249
250 if self.provider.name == "openrouter" {
252 if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
253 headers.insert("HTTP-Referer", val);
254 }
255 if let Ok(val) = "Aptu CLI".parse() {
256 headers.insert("X-Title", val);
257 }
258 }
259
260 headers
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::super::registry::all_providers;
267 use super::*;
268
269 fn test_config() -> AiConfig {
270 AiConfig {
271 provider: "openrouter".to_string(),
272 model: "test-model:free".to_string(),
273 max_tokens: 2048,
274 temperature: 0.3,
275 timeout_seconds: 30,
276 allow_paid_models: false,
277 circuit_breaker_threshold: 3,
278 circuit_breaker_reset_seconds: 60,
279 retry_max_attempts: 3,
280 tasks: None,
281 fallback: None,
282 custom_guidance: None,
283 validation_enabled: true,
284 }
285 }
286
287 #[test]
288 fn test_with_api_key_all_providers() {
289 let config = test_config();
290 for provider_config in all_providers() {
291 let result = AiClient::with_api_key(
292 provider_config.name,
293 SecretString::from("test_key"),
294 "test-model:free",
295 &config,
296 );
297 assert!(
298 result.is_ok(),
299 "Failed for provider: {}",
300 provider_config.name
301 );
302 }
303 }
304
305 #[test]
306 fn test_unknown_provider_error() {
307 let config = test_config();
308 let result = AiClient::with_api_key(
309 "nonexistent",
310 SecretString::from("key"),
311 "test-model",
312 &config,
313 );
314 assert!(result.is_err());
315 }
316
317 #[test]
318 fn test_openrouter_rejects_paid_model() {
319 let mut config = test_config();
320 config.model = "anthropic/claude-3".to_string();
321 config.allow_paid_models = false;
322 let result = AiClient::with_api_key(
323 "openrouter",
324 SecretString::from("key"),
325 "anthropic/claude-3",
326 &config,
327 );
328 assert!(result.is_err());
329 }
330
331 #[test]
332 fn test_max_attempts_from_config() {
333 let mut config = test_config();
334 config.retry_max_attempts = 5;
335 let client = AiClient::with_api_key(
336 "openrouter",
337 SecretString::from("key"),
338 "test-model:free",
339 &config,
340 )
341 .expect("should create client");
342 assert_eq!(client.max_attempts(), 5);
343 }
344}