1use std::env;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use secrecy::{ExposeSecret, SecretString};
15use serde::{Deserialize, Serialize};
16
17use super::circuit_breaker::CircuitBreaker;
18use super::provider::AiProvider;
19use super::registry::{PROVIDER_ANTHROPIC, ProviderConfig, get_provider};
20use crate::config::AiConfig;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum AuthMethod {
26 ApiKey,
28 OAuth,
30}
31
32impl std::fmt::Display for AuthMethod {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 AuthMethod::ApiKey => write!(f, "api-key"),
36 AuthMethod::OAuth => write!(f, "oauth"),
37 }
38 }
39}
40
41#[derive(Debug, Deserialize)]
43pub struct ClaudeCredentials {
44 pub access_token: String,
46}
47
48#[derive(Debug)]
53pub struct AiClient {
54 provider: &'static ProviderConfig,
56 http: Client,
58 api_key: SecretString,
60 model: String,
62 max_tokens: u32,
64 temperature: f32,
66 max_attempts: u32,
68 circuit_breaker: CircuitBreaker,
70 custom_guidance: Option<String>,
72 auth_method: AuthMethod,
74}
75
76impl Drop for AiClient {
77 fn drop(&mut self) {
78 use zeroize::Zeroize;
79 self.api_key.zeroize();
82 }
83}
84
85impl AiClient {
86 pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
104 let provider = get_provider(provider_name)
106 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
107
108 if provider_name == "openrouter"
110 && !config.allow_paid_models
111 && !super::is_free_model(&config.model)
112 {
113 anyhow::bail!(
114 "Model '{}' is not in the free tier.\n\
115 To use paid models, set `allow_paid_models = true` in your config file:\n\
116 {}\n\n\
117 Or use a free model like: google/gemma-3-12b-it:free",
118 config.model,
119 crate::config::config_file_path().display()
120 );
121 }
122
123 let api_key = env::var(provider.api_key_env).with_context(|| {
125 format!(
126 "Missing {} environment variable.\n\
127 Set it with: export {}=your_api_key",
128 provider.api_key_env, provider.api_key_env
129 )
130 })?;
131
132 let http = Client::builder()
134 .timeout(Duration::from_secs(config.timeout_seconds))
135 .build()
136 .context("Failed to create HTTP client")?;
137
138 Ok(Self {
139 provider,
140 http,
141 api_key: SecretString::new(api_key.into()),
142 model: config.model.clone(),
143 max_tokens: config.max_tokens,
144 temperature: config.temperature,
145 max_attempts: config.retry_max_attempts,
146 circuit_breaker: CircuitBreaker::new(
147 config.circuit_breaker_threshold,
148 config.circuit_breaker_reset_seconds,
149 ),
150 custom_guidance: config.custom_guidance.clone(),
151 auth_method: AuthMethod::ApiKey,
152 })
153 }
154
155 pub fn with_api_key(
175 provider_name: &str,
176 api_key: SecretString,
177 model_name: &str,
178 config: &AiConfig,
179 ) -> Result<Self> {
180 let provider = get_provider(provider_name)
182 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
183
184 if provider_name == "openrouter"
186 && !config.allow_paid_models
187 && !super::is_free_model(model_name)
188 {
189 anyhow::bail!(
190 "Model '{}' is not in the free tier.\n\
191 To use paid models, set `allow_paid_models = true` in your config file:\n\
192 {}\n\n\
193 Or use a free model like: google/gemma-3-12b-it:free",
194 model_name,
195 crate::config::config_file_path().display()
196 );
197 }
198
199 let http = Client::builder()
201 .timeout(Duration::from_secs(config.timeout_seconds))
202 .build()
203 .context("Failed to create HTTP client")?;
204
205 Ok(Self {
206 provider,
207 http,
208 api_key,
209 model: model_name.to_string(),
210 max_tokens: config.max_tokens,
211 temperature: config.temperature,
212 max_attempts: config.retry_max_attempts,
213 circuit_breaker: CircuitBreaker::new(
214 config.circuit_breaker_threshold,
215 config.circuit_breaker_reset_seconds,
216 ),
217 custom_guidance: config.custom_guidance.clone(),
218 auth_method: AuthMethod::ApiKey,
219 })
220 }
221
222 pub fn from_claude_credentials(config: &AiConfig) -> Result<Option<Self>> {
237 let Some(home) = dirs::home_dir() else {
239 return Ok(None);
240 };
241
242 let creds_path = home.join(".claude").join("credentials.json");
243
244 if !creds_path.exists() {
246 return Ok(None);
247 }
248
249 let creds_content =
251 std::fs::read_to_string(&creds_path).context("Failed to read credentials file")?;
252
253 let creds: ClaudeCredentials =
254 serde_json::from_str(&creds_content).context("Failed to parse credentials JSON")?;
255
256 if creds.access_token.is_empty() {
258 return Ok(None);
259 }
260
261 #[cfg(feature = "keyring")]
263 {
264 use keyring_core::Entry;
265 let entry = Entry::new("aptu", "anthropic_oauth_token")
266 .context("Failed to create keyring entry")?;
267 entry
268 .set_password(&creds.access_token)
269 .context("Failed to store token in keyring")?;
270 }
271
272 let client = Self::with_api_key(
274 PROVIDER_ANTHROPIC,
275 SecretString::from(creds.access_token),
276 &config.model,
277 config,
278 )?;
279
280 let mut client = client;
282 client.auth_method = AuthMethod::OAuth;
283 Ok(Some(client))
284 }
285
286 #[must_use]
293 pub fn claude_credentials_path() -> Option<std::path::PathBuf> {
294 let home = dirs::home_dir()?;
295 let creds_path = home.join(".claude").join("credentials.json");
296 if creds_path.exists() {
297 Some(creds_path)
298 } else {
299 None
300 }
301 }
302
303 pub fn from_keyring_oauth(config: &AiConfig) -> Result<Option<Self>> {
308 #[cfg(feature = "keyring")]
309 {
310 use keyring_core::Entry;
311 let entry = Entry::new("aptu", "anthropic_oauth_token")
312 .context("Failed to create keyring entry")?;
313
314 match entry.get_password() {
315 Ok(token) => {
316 let client = Self::with_api_key(
317 PROVIDER_ANTHROPIC,
318 SecretString::from(token),
319 &config.model,
320 config,
321 )?;
322
323 let mut client = client;
324 client.auth_method = AuthMethod::OAuth;
325 Ok(Some(client))
326 }
327 Err(_) => Ok(None),
328 }
329 }
330
331 #[cfg(not(feature = "keyring"))]
332 {
333 let _ = config;
334 Ok(None)
335 }
336 }
337
338 #[must_use]
340 pub fn auth_method(&self) -> AuthMethod {
341 self.auth_method
342 }
343
344 #[must_use]
346 pub fn circuit_breaker(&self) -> &CircuitBreaker {
347 &self.circuit_breaker
348 }
349}
350
351#[async_trait]
352impl AiProvider for AiClient {
353 fn name(&self) -> &str {
354 self.provider.name
355 }
356
357 fn api_url(&self) -> &str {
358 self.provider.api_url
359 }
360
361 fn api_key_env(&self) -> &str {
362 self.provider.api_key_env
363 }
364
365 fn http_client(&self) -> &Client {
366 &self.http
367 }
368
369 fn api_key(&self) -> &SecretString {
370 &self.api_key
371 }
372
373 fn model(&self) -> &str {
374 &self.model
375 }
376
377 fn max_tokens(&self) -> u32 {
378 self.max_tokens
379 }
380
381 fn temperature(&self) -> f32 {
382 self.temperature
383 }
384
385 fn max_attempts(&self) -> u32 {
386 self.max_attempts
387 }
388
389 fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
390 Some(&self.circuit_breaker)
391 }
392
393 fn custom_guidance(&self) -> Option<&str> {
394 self.custom_guidance.as_deref()
395 }
396
397 fn build_headers(&self) -> reqwest::header::HeaderMap {
398 let mut headers = reqwest::header::HeaderMap::new();
399 if let Ok(val) = "application/json".parse() {
400 headers.insert("Content-Type", val);
401 }
402
403 if self.provider.name == super::registry::PROVIDER_ANTHROPIC {
405 if let Ok(val) = self.api_key().expose_secret().parse() {
406 headers.insert("x-api-key", val);
407 }
408 if let Ok(val) = "2023-06-01".parse() {
409 headers.insert("anthropic-version", val);
410 }
411 return headers;
412 }
413
414 if self.provider.name == "openrouter" {
416 if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
417 headers.insert("HTTP-Referer", val);
418 }
419 if let Ok(val) = "Aptu CLI".parse() {
420 headers.insert("X-Title", val);
421 }
422 }
423
424 headers
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::super::registry::all_providers;
431 use super::*;
432
433 fn test_config() -> AiConfig {
434 AiConfig {
435 provider: "openrouter".to_string(),
436 model: "test-model:free".to_string(),
437 max_tokens: 2048,
438 temperature: 0.3,
439 timeout_seconds: 30,
440 allow_paid_models: false,
441 circuit_breaker_threshold: 3,
442 circuit_breaker_reset_seconds: 60,
443 retry_max_attempts: 3,
444 tasks: None,
445 fallback: None,
446 custom_guidance: None,
447 validation_enabled: true,
448 }
449 }
450
451 #[test]
452 fn test_with_api_key_all_providers() {
453 let config = test_config();
454 for provider_config in all_providers() {
455 let result = AiClient::with_api_key(
456 provider_config.name,
457 SecretString::from("test_key"),
458 "test-model:free",
459 &config,
460 );
461 assert!(
462 result.is_ok(),
463 "Failed for provider: {}",
464 provider_config.name
465 );
466 }
467 }
468
469 #[test]
470 fn test_unknown_provider_error() {
471 let config = test_config();
472 let result = AiClient::with_api_key(
473 "nonexistent",
474 SecretString::from("key"),
475 "test-model",
476 &config,
477 );
478 assert!(result.is_err());
479 }
480
481 #[test]
482 fn test_openrouter_rejects_paid_model() {
483 let mut config = test_config();
484 config.model = "anthropic/claude-sonnet-4-6".to_string();
485 config.allow_paid_models = false;
486 let result = AiClient::with_api_key(
487 "openrouter",
488 SecretString::from("key"),
489 "anthropic/claude-sonnet-4-6",
490 &config,
491 );
492 assert!(result.is_err());
493 }
494
495 #[test]
496 fn test_max_attempts_from_config() {
497 let mut config = test_config();
498 config.retry_max_attempts = 5;
499 let client = AiClient::with_api_key(
500 "openrouter",
501 SecretString::from("key"),
502 "test-model:free",
503 &config,
504 )
505 .expect("should create client");
506 assert_eq!(client.max_attempts(), 5);
507 }
508
509 #[test]
510 fn test_build_headers_anthropic_has_api_key_and_version() {
511 let config = test_config();
512 let client = AiClient::with_api_key(
513 PROVIDER_ANTHROPIC,
514 SecretString::from("test_api_key"),
515 "test-model",
516 &config,
517 )
518 .expect("should create anthropic client");
519
520 let headers = client.build_headers();
521
522 let header_str = |k| headers.get(k).and_then(|v| v.to_str().ok());
523 assert_eq!(header_str("x-api-key"), Some("test_api_key"));
524 assert_eq!(header_str("anthropic-version"), Some("2023-06-01"));
525 }
526
527 #[test]
528 fn test_build_headers_non_anthropic_unaffected() {
529 let config = test_config();
530 let client = AiClient::with_api_key(
531 "openrouter",
532 SecretString::from("test_key"),
533 "test-model:free",
534 &config,
535 )
536 .expect("should create openrouter client");
537
538 let headers = client.build_headers();
539
540 assert!(!headers.contains_key("anthropic-version"));
541 assert!(headers.contains_key("http-referer"));
542 assert!(headers.contains_key("x-title"));
543 }
544
545 #[test]
546 fn test_from_claude_credentials_missing_file() {
547 let config = test_config();
548 let result = AiClient::from_claude_credentials(&config);
549 assert!(result.is_ok());
550 assert!(result.unwrap().is_none());
551 }
552
553 #[test]
554 fn test_from_claude_credentials_malformed_json() {
555 use std::fs;
556 use std::io::Write;
557
558 let temp_dir = tempfile::tempdir().expect("should create temp dir");
559 let claude_dir = temp_dir.path().join(".claude");
560 fs::create_dir_all(&claude_dir).expect("should create .claude dir");
561
562 let creds_path = claude_dir.join("credentials.json");
563 let mut file = fs::File::create(&creds_path).expect("should create file");
564 file.write_all(b"{ invalid json }")
565 .expect("should write file");
566
567 let malformed = "{ invalid json }";
570 let result: Result<ClaudeCredentials, _> = serde_json::from_str(malformed);
571 assert!(result.is_err());
572 }
573
574 #[test]
575 fn test_from_claude_credentials_missing_access_token() {
576 let malformed = r#"{"other_field": "value"}"#;
577 let result: Result<ClaudeCredentials, _> = serde_json::from_str(malformed);
578 assert!(result.is_err());
579 }
580
581 #[test]
582 fn test_from_claude_credentials_empty_token() {
583 let empty_token = r#"{"access_token": ""}"#;
584 let creds: ClaudeCredentials = serde_json::from_str(empty_token).expect("should parse");
585 assert!(creds.access_token.is_empty());
586 }
587
588 #[test]
589 fn test_auth_method_api_key() {
590 let config = test_config();
591 let client = AiClient::with_api_key(
592 PROVIDER_ANTHROPIC,
593 SecretString::from("test_key"),
594 "test-model",
595 &config,
596 )
597 .expect("should create client");
598 assert_eq!(client.auth_method(), AuthMethod::ApiKey);
599 }
600}