1use std::env;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use secrecy::SecretString;
15
16use super::circuit_breaker::CircuitBreaker;
17use super::provider::AiProvider;
18use super::registry::{ProviderConfig, get_provider};
19use crate::config::AiConfig;
20
21#[derive(Debug)]
26pub struct AiClient {
27 provider: &'static ProviderConfig,
29 http: Client,
31 api_key: SecretString,
33 model: String,
35 max_tokens: u32,
37 temperature: f32,
39 max_attempts: u32,
41 circuit_breaker: CircuitBreaker,
43 custom_guidance: Option<String>,
45}
46
47impl AiClient {
48 pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
66 let provider = get_provider(provider_name)
68 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
69
70 if provider_name == "openrouter"
72 && !config.allow_paid_models
73 && !super::is_free_model(&config.model)
74 {
75 anyhow::bail!(
76 "Model '{}' is not in the free tier.\n\
77 To use paid models, set `allow_paid_models = true` in your config file:\n\
78 {}\n\n\
79 Or use a free model like: google/gemma-3-12b-it:free",
80 config.model,
81 crate::config::config_file_path().display()
82 );
83 }
84
85 let api_key = env::var(provider.api_key_env).with_context(|| {
87 format!(
88 "Missing {} environment variable.\n\
89 Set it with: export {}=your_api_key",
90 provider.api_key_env, provider.api_key_env
91 )
92 })?;
93
94 let http = Client::builder()
96 .timeout(Duration::from_secs(config.timeout_seconds))
97 .build()
98 .context("Failed to create HTTP client")?;
99
100 Ok(Self {
101 provider,
102 http,
103 api_key: SecretString::new(api_key.into()),
104 model: config.model.clone(),
105 max_tokens: config.max_tokens,
106 temperature: config.temperature,
107 max_attempts: config.retry_max_attempts,
108 circuit_breaker: CircuitBreaker::new(
109 config.circuit_breaker_threshold,
110 config.circuit_breaker_reset_seconds,
111 ),
112 custom_guidance: config.custom_guidance.clone(),
113 })
114 }
115
116 pub fn with_api_key(
136 provider_name: &str,
137 api_key: SecretString,
138 model_name: &str,
139 config: &AiConfig,
140 ) -> Result<Self> {
141 let provider = get_provider(provider_name)
143 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
144
145 if provider_name == "openrouter"
147 && !config.allow_paid_models
148 && !super::is_free_model(model_name)
149 {
150 anyhow::bail!(
151 "Model '{}' is not in the free tier.\n\
152 To use paid models, set `allow_paid_models = true` in your config file:\n\
153 {}\n\n\
154 Or use a free model like: google/gemma-3-12b-it:free",
155 model_name,
156 crate::config::config_file_path().display()
157 );
158 }
159
160 let http = Client::builder()
162 .timeout(Duration::from_secs(config.timeout_seconds))
163 .build()
164 .context("Failed to create HTTP client")?;
165
166 Ok(Self {
167 provider,
168 http,
169 api_key,
170 model: model_name.to_string(),
171 max_tokens: config.max_tokens,
172 temperature: config.temperature,
173 max_attempts: config.retry_max_attempts,
174 circuit_breaker: CircuitBreaker::new(
175 config.circuit_breaker_threshold,
176 config.circuit_breaker_reset_seconds,
177 ),
178 custom_guidance: config.custom_guidance.clone(),
179 })
180 }
181
182 #[must_use]
184 pub fn circuit_breaker(&self) -> &CircuitBreaker {
185 &self.circuit_breaker
186 }
187}
188
189#[async_trait]
190impl AiProvider for AiClient {
191 fn name(&self) -> &str {
192 self.provider.name
193 }
194
195 fn api_url(&self) -> &str {
196 self.provider.api_url
197 }
198
199 fn api_key_env(&self) -> &str {
200 self.provider.api_key_env
201 }
202
203 fn http_client(&self) -> &Client {
204 &self.http
205 }
206
207 fn api_key(&self) -> &SecretString {
208 &self.api_key
209 }
210
211 fn model(&self) -> &str {
212 &self.model
213 }
214
215 fn max_tokens(&self) -> u32 {
216 self.max_tokens
217 }
218
219 fn temperature(&self) -> f32 {
220 self.temperature
221 }
222
223 fn max_attempts(&self) -> u32 {
224 self.max_attempts
225 }
226
227 fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
228 Some(&self.circuit_breaker)
229 }
230
231 fn custom_guidance(&self) -> Option<&str> {
232 self.custom_guidance.as_deref()
233 }
234
235 fn build_headers(&self) -> reqwest::header::HeaderMap {
236 let mut headers = reqwest::header::HeaderMap::new();
237 if let Ok(val) = "application/json".parse() {
238 headers.insert("Content-Type", val);
239 }
240
241 if self.provider.name == "openrouter" {
243 if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
244 headers.insert("HTTP-Referer", val);
245 }
246 if let Ok(val) = "Aptu CLI".parse() {
247 headers.insert("X-Title", val);
248 }
249 }
250
251 headers
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::super::registry::all_providers;
258 use super::*;
259
260 fn test_config() -> AiConfig {
261 AiConfig {
262 provider: "openrouter".to_string(),
263 model: "test-model:free".to_string(),
264 max_tokens: 2048,
265 temperature: 0.3,
266 timeout_seconds: 30,
267 allow_paid_models: false,
268 circuit_breaker_threshold: 3,
269 circuit_breaker_reset_seconds: 60,
270 retry_max_attempts: 3,
271 tasks: None,
272 fallback: None,
273 custom_guidance: None,
274 validation_enabled: true,
275 }
276 }
277
278 #[test]
279 fn test_with_api_key_all_providers() {
280 let config = test_config();
281 for provider_config in all_providers() {
282 let result = AiClient::with_api_key(
283 provider_config.name,
284 SecretString::from("test_key"),
285 "test-model:free",
286 &config,
287 );
288 assert!(
289 result.is_ok(),
290 "Failed for provider: {}",
291 provider_config.name
292 );
293 }
294 }
295
296 #[test]
297 fn test_unknown_provider_error() {
298 let config = test_config();
299 let result = AiClient::with_api_key(
300 "nonexistent",
301 SecretString::from("key"),
302 "test-model",
303 &config,
304 );
305 assert!(result.is_err());
306 }
307
308 #[test]
309 fn test_openrouter_rejects_paid_model() {
310 let mut config = test_config();
311 config.model = "anthropic/claude-3".to_string();
312 config.allow_paid_models = false;
313 let result = AiClient::with_api_key(
314 "openrouter",
315 SecretString::from("key"),
316 "anthropic/claude-3",
317 &config,
318 );
319 assert!(result.is_err());
320 }
321
322 #[test]
323 fn test_max_attempts_from_config() {
324 let mut config = test_config();
325 config.retry_max_attempts = 5;
326 let client = AiClient::with_api_key(
327 "openrouter",
328 SecretString::from("key"),
329 "test-model:free",
330 &config,
331 )
332 .expect("should create client");
333 assert_eq!(client.max_attempts(), 5);
334 }
335}