1use std::env;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use secrecy::SecretString;
15
16use super::circuit_breaker::CircuitBreaker;
17use super::provider::AiProvider;
18use super::registry::{ProviderConfig, get_provider};
19use crate::config::AiConfig;
20
21#[derive(Debug)]
26pub struct AiClient {
27 provider: &'static ProviderConfig,
29 http: Client,
31 api_key: SecretString,
33 model: String,
35 max_tokens: u32,
37 temperature: f32,
39 circuit_breaker: CircuitBreaker,
41}
42
43impl AiClient {
44 pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
62 let provider = get_provider(provider_name)
64 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
65
66 if provider_name == "openrouter"
68 && !config.allow_paid_models
69 && !super::is_free_model(&config.model)
70 {
71 anyhow::bail!(
72 "Model '{}' is not in the free tier.\n\
73 To use paid models, set `allow_paid_models = true` in your config file:\n\
74 {}\n\n\
75 Or use a free model like: mistralai/devstral-2512:free",
76 config.model,
77 crate::config::config_file_path().display()
78 );
79 }
80
81 let api_key = env::var(provider.api_key_env).with_context(|| {
83 format!(
84 "Missing {} environment variable.\n\
85 Set it with: export {}=your_api_key",
86 provider.api_key_env, provider.api_key_env
87 )
88 })?;
89
90 let http = Client::builder()
92 .timeout(Duration::from_secs(config.timeout_seconds))
93 .build()
94 .context("Failed to create HTTP client")?;
95
96 Ok(Self {
97 provider,
98 http,
99 api_key: SecretString::new(api_key.into()),
100 model: config.model.clone(),
101 max_tokens: config.max_tokens,
102 temperature: config.temperature,
103 circuit_breaker: CircuitBreaker::new(
104 config.circuit_breaker_threshold,
105 config.circuit_breaker_reset_seconds,
106 ),
107 })
108 }
109
110 pub fn with_api_key(
130 provider_name: &str,
131 api_key: SecretString,
132 model_name: &str,
133 config: &AiConfig,
134 ) -> Result<Self> {
135 let provider = get_provider(provider_name)
137 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
138
139 if provider_name == "openrouter"
141 && !config.allow_paid_models
142 && !super::is_free_model(model_name)
143 {
144 anyhow::bail!(
145 "Model '{}' is not in the free tier.\n\
146 To use paid models, set `allow_paid_models = true` in your config file:\n\
147 {}\n\n\
148 Or use a free model like: mistralai/devstral-2512:free",
149 model_name,
150 crate::config::config_file_path().display()
151 );
152 }
153
154 let http = Client::builder()
156 .timeout(Duration::from_secs(config.timeout_seconds))
157 .build()
158 .context("Failed to create HTTP client")?;
159
160 Ok(Self {
161 provider,
162 http,
163 api_key,
164 model: model_name.to_string(),
165 max_tokens: config.max_tokens,
166 temperature: config.temperature,
167 circuit_breaker: CircuitBreaker::new(
168 config.circuit_breaker_threshold,
169 config.circuit_breaker_reset_seconds,
170 ),
171 })
172 }
173
174 #[must_use]
176 pub fn circuit_breaker(&self) -> &CircuitBreaker {
177 &self.circuit_breaker
178 }
179}
180
181#[async_trait]
182impl AiProvider for AiClient {
183 fn name(&self) -> &str {
184 self.provider.name
185 }
186
187 fn api_url(&self) -> &str {
188 self.provider.api_url
189 }
190
191 fn api_key_env(&self) -> &str {
192 self.provider.api_key_env
193 }
194
195 fn http_client(&self) -> &Client {
196 &self.http
197 }
198
199 fn api_key(&self) -> &SecretString {
200 &self.api_key
201 }
202
203 fn model(&self) -> &str {
204 &self.model
205 }
206
207 fn max_tokens(&self) -> u32 {
208 self.max_tokens
209 }
210
211 fn temperature(&self) -> f32 {
212 self.temperature
213 }
214
215 fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
216 Some(&self.circuit_breaker)
217 }
218
219 fn build_headers(&self) -> reqwest::header::HeaderMap {
220 let mut headers = reqwest::header::HeaderMap::new();
221 if let Ok(val) = "application/json".parse() {
222 headers.insert("Content-Type", val);
223 }
224
225 if self.provider.name == "openrouter" {
227 if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
228 headers.insert("HTTP-Referer", val);
229 }
230 if let Ok(val) = "Aptu CLI".parse() {
231 headers.insert("X-Title", val);
232 }
233 }
234
235 headers
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::super::registry::all_providers;
242 use super::*;
243
244 fn test_config() -> AiConfig {
245 AiConfig {
246 provider: "openrouter".to_string(),
247 model: "test-model:free".to_string(),
248 max_tokens: 2048,
249 temperature: 0.3,
250 timeout_seconds: 30,
251 allow_paid_models: false,
252 circuit_breaker_threshold: 3,
253 circuit_breaker_reset_seconds: 60,
254 tasks: None,
255 fallback: None,
256 custom_guidance: None,
257 validation_enabled: true,
258 }
259 }
260
261 #[test]
262 fn test_with_api_key_all_providers() {
263 let config = test_config();
264 for provider_config in all_providers() {
265 let result = AiClient::with_api_key(
266 provider_config.name,
267 SecretString::from("test_key"),
268 "test-model:free",
269 &config,
270 );
271 assert!(
272 result.is_ok(),
273 "Failed for provider: {}",
274 provider_config.name
275 );
276 }
277 }
278
279 #[test]
280 fn test_unknown_provider_error() {
281 let config = test_config();
282 let result = AiClient::with_api_key(
283 "nonexistent",
284 SecretString::from("key"),
285 "test-model",
286 &config,
287 );
288 assert!(result.is_err());
289 }
290
291 #[test]
292 fn test_openrouter_rejects_paid_model() {
293 let mut config = test_config();
294 config.model = "anthropic/claude-3".to_string();
295 config.allow_paid_models = false;
296 let result = AiClient::with_api_key(
297 "openrouter",
298 SecretString::from("key"),
299 "anthropic/claude-3",
300 &config,
301 );
302 assert!(result.is_err());
303 }
304}