1use std::env;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use secrecy::SecretString;
15
16use super::circuit_breaker::CircuitBreaker;
17use super::provider::AiProvider;
18use super::registry::{ProviderConfig, get_provider};
19use crate::config::AiConfig;
20
21#[derive(Debug)]
26pub struct AiClient {
27 provider: &'static ProviderConfig,
29 http: Client,
31 api_key: SecretString,
33 model: String,
35 max_tokens: u32,
37 temperature: f32,
39 max_attempts: u32,
41 circuit_breaker: CircuitBreaker,
43}
44
45impl AiClient {
46 pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
64 let provider = get_provider(provider_name)
66 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
67
68 if provider_name == "openrouter"
70 && !config.allow_paid_models
71 && !super::is_free_model(&config.model)
72 {
73 anyhow::bail!(
74 "Model '{}' is not in the free tier.\n\
75 To use paid models, set `allow_paid_models = true` in your config file:\n\
76 {}\n\n\
77 Or use a free model like: mistralai/devstral-2512:free",
78 config.model,
79 crate::config::config_file_path().display()
80 );
81 }
82
83 let api_key = env::var(provider.api_key_env).with_context(|| {
85 format!(
86 "Missing {} environment variable.\n\
87 Set it with: export {}=your_api_key",
88 provider.api_key_env, provider.api_key_env
89 )
90 })?;
91
92 let http = Client::builder()
94 .timeout(Duration::from_secs(config.timeout_seconds))
95 .build()
96 .context("Failed to create HTTP client")?;
97
98 Ok(Self {
99 provider,
100 http,
101 api_key: SecretString::new(api_key.into()),
102 model: config.model.clone(),
103 max_tokens: config.max_tokens,
104 temperature: config.temperature,
105 max_attempts: config.retry_max_attempts,
106 circuit_breaker: CircuitBreaker::new(
107 config.circuit_breaker_threshold,
108 config.circuit_breaker_reset_seconds,
109 ),
110 })
111 }
112
113 pub fn with_api_key(
133 provider_name: &str,
134 api_key: SecretString,
135 model_name: &str,
136 config: &AiConfig,
137 ) -> Result<Self> {
138 let provider = get_provider(provider_name)
140 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
141
142 if provider_name == "openrouter"
144 && !config.allow_paid_models
145 && !super::is_free_model(model_name)
146 {
147 anyhow::bail!(
148 "Model '{}' is not in the free tier.\n\
149 To use paid models, set `allow_paid_models = true` in your config file:\n\
150 {}\n\n\
151 Or use a free model like: mistralai/devstral-2512:free",
152 model_name,
153 crate::config::config_file_path().display()
154 );
155 }
156
157 let http = Client::builder()
159 .timeout(Duration::from_secs(config.timeout_seconds))
160 .build()
161 .context("Failed to create HTTP client")?;
162
163 Ok(Self {
164 provider,
165 http,
166 api_key,
167 model: model_name.to_string(),
168 max_tokens: config.max_tokens,
169 temperature: config.temperature,
170 max_attempts: config.retry_max_attempts,
171 circuit_breaker: CircuitBreaker::new(
172 config.circuit_breaker_threshold,
173 config.circuit_breaker_reset_seconds,
174 ),
175 })
176 }
177
178 #[must_use]
180 pub fn circuit_breaker(&self) -> &CircuitBreaker {
181 &self.circuit_breaker
182 }
183}
184
185#[async_trait]
186impl AiProvider for AiClient {
187 fn name(&self) -> &str {
188 self.provider.name
189 }
190
191 fn api_url(&self) -> &str {
192 self.provider.api_url
193 }
194
195 fn api_key_env(&self) -> &str {
196 self.provider.api_key_env
197 }
198
199 fn http_client(&self) -> &Client {
200 &self.http
201 }
202
203 fn api_key(&self) -> &SecretString {
204 &self.api_key
205 }
206
207 fn model(&self) -> &str {
208 &self.model
209 }
210
211 fn max_tokens(&self) -> u32 {
212 self.max_tokens
213 }
214
215 fn temperature(&self) -> f32 {
216 self.temperature
217 }
218
219 fn max_attempts(&self) -> u32 {
220 self.max_attempts
221 }
222
223 fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
224 Some(&self.circuit_breaker)
225 }
226
227 fn build_headers(&self) -> reqwest::header::HeaderMap {
228 let mut headers = reqwest::header::HeaderMap::new();
229 if let Ok(val) = "application/json".parse() {
230 headers.insert("Content-Type", val);
231 }
232
233 if self.provider.name == "openrouter" {
235 if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
236 headers.insert("HTTP-Referer", val);
237 }
238 if let Ok(val) = "Aptu CLI".parse() {
239 headers.insert("X-Title", val);
240 }
241 }
242
243 headers
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::super::registry::all_providers;
250 use super::*;
251
252 fn test_config() -> AiConfig {
253 AiConfig {
254 provider: "openrouter".to_string(),
255 model: "test-model:free".to_string(),
256 max_tokens: 2048,
257 temperature: 0.3,
258 timeout_seconds: 30,
259 allow_paid_models: false,
260 circuit_breaker_threshold: 3,
261 circuit_breaker_reset_seconds: 60,
262 retry_max_attempts: 3,
263 tasks: None,
264 fallback: None,
265 custom_guidance: None,
266 validation_enabled: true,
267 }
268 }
269
270 #[test]
271 fn test_with_api_key_all_providers() {
272 let config = test_config();
273 for provider_config in all_providers() {
274 let result = AiClient::with_api_key(
275 provider_config.name,
276 SecretString::from("test_key"),
277 "test-model:free",
278 &config,
279 );
280 assert!(
281 result.is_ok(),
282 "Failed for provider: {}",
283 provider_config.name
284 );
285 }
286 }
287
288 #[test]
289 fn test_unknown_provider_error() {
290 let config = test_config();
291 let result = AiClient::with_api_key(
292 "nonexistent",
293 SecretString::from("key"),
294 "test-model",
295 &config,
296 );
297 assert!(result.is_err());
298 }
299
300 #[test]
301 fn test_openrouter_rejects_paid_model() {
302 let mut config = test_config();
303 config.model = "anthropic/claude-3".to_string();
304 config.allow_paid_models = false;
305 let result = AiClient::with_api_key(
306 "openrouter",
307 SecretString::from("key"),
308 "anthropic/claude-3",
309 &config,
310 );
311 assert!(result.is_err());
312 }
313
314 #[test]
315 fn test_max_attempts_from_config() {
316 let mut config = test_config();
317 config.retry_max_attempts = 5;
318 let client = AiClient::with_api_key(
319 "openrouter",
320 SecretString::from("key"),
321 "test-model:free",
322 &config,
323 )
324 .expect("should create client");
325 assert_eq!(client.max_attempts(), 5);
326 }
327}