1use std::env;
9
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use secrecy::{ExposeSecret, SecretString};
14use serde::{Deserialize, Serialize};
15
16use super::circuit_breaker::CircuitBreaker;
17use super::provider::AiProvider;
18use super::registry::{PROVIDER_ANTHROPIC, ProviderConfig, get_provider};
19use crate::config::AiConfig;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24pub enum AuthMethod {
25 ApiKey,
27 OAuth,
29}
30
31impl std::fmt::Display for AuthMethod {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 AuthMethod::ApiKey => write!(f, "api-key"),
35 AuthMethod::OAuth => write!(f, "oauth"),
36 }
37 }
38}
39
40#[derive(Debug, Deserialize)]
42pub struct ClaudeCredentials {
43 pub access_token: String,
45}
46
47fn build_http_client(timeout_seconds: u64) -> Result<Client> {
51 #[cfg(not(target_arch = "wasm32"))]
52 let http = Client::builder()
53 .timeout(std::time::Duration::from_secs(timeout_seconds))
54 .build()
55 .context("Failed to create HTTP client")?;
56 #[cfg(target_arch = "wasm32")]
57 let http = Client::builder()
58 .build()
59 .context("Failed to create HTTP client")?;
60 Ok(http)
61}
62
63#[derive(Debug)]
68pub struct AiClient {
69 provider: &'static ProviderConfig,
71 http: Client,
73 api_key: SecretString,
75 model: String,
77 max_tokens: u32,
79 temperature: f32,
81 max_attempts: u32,
83 circuit_breaker: CircuitBreaker,
85 custom_guidance: Option<String>,
87 auth_method: AuthMethod,
89}
90
91impl Drop for AiClient {
92 fn drop(&mut self) {
93 use zeroize::Zeroize;
94 self.api_key.zeroize();
97 }
98}
99
100impl AiClient {
101 pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
119 let provider = get_provider(provider_name)
121 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
122
123 if provider_name == "openrouter"
125 && !config.allow_paid_models
126 && !super::is_free_model(&config.model)
127 {
128 anyhow::bail!(
129 "Model '{}' is not in the free tier.\n\
130 To use paid models, set `allow_paid_models = true` in your config file:\n\
131 {}\n\n\
132 Or use a free model like: google/gemma-3-12b-it:free",
133 config.model,
134 crate::config::config_file_path().display()
135 );
136 }
137
138 let api_key = env::var(provider.api_key_env).with_context(|| {
140 format!(
141 "Missing {} environment variable.\n\
142 Set it with: export {}=your_api_key",
143 provider.api_key_env, provider.api_key_env
144 )
145 })?;
146
147 let http = build_http_client(config.timeout_seconds)?;
149
150 Ok(Self {
151 provider,
152 http,
153 api_key: SecretString::new(api_key.into()),
154 model: config.model.clone(),
155 max_tokens: config.max_tokens,
156 temperature: config.temperature,
157 max_attempts: config.retry_max_attempts,
158 circuit_breaker: CircuitBreaker::new(
159 config.circuit_breaker_threshold,
160 config.circuit_breaker_reset_seconds,
161 ),
162 custom_guidance: config.custom_guidance.clone(),
163 auth_method: AuthMethod::ApiKey,
164 })
165 }
166
167 pub fn with_api_key(
187 provider_name: &str,
188 api_key: SecretString,
189 model_name: &str,
190 config: &AiConfig,
191 ) -> Result<Self> {
192 let provider = get_provider(provider_name)
194 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
195
196 if provider_name == "openrouter"
198 && !config.allow_paid_models
199 && !super::is_free_model(model_name)
200 {
201 anyhow::bail!(
202 "Model '{}' is not in the free tier.\n\
203 To use paid models, set `allow_paid_models = true` in your config file:\n\
204 {}\n\n\
205 Or use a free model like: google/gemma-3-12b-it:free",
206 model_name,
207 crate::config::config_file_path().display()
208 );
209 }
210
211 let http = build_http_client(config.timeout_seconds)?;
213
214 Ok(Self {
215 provider,
216 http,
217 api_key,
218 model: model_name.to_string(),
219 max_tokens: config.max_tokens,
220 temperature: config.temperature,
221 max_attempts: config.retry_max_attempts,
222 circuit_breaker: CircuitBreaker::new(
223 config.circuit_breaker_threshold,
224 config.circuit_breaker_reset_seconds,
225 ),
226 custom_guidance: config.custom_guidance.clone(),
227 auth_method: AuthMethod::ApiKey,
228 })
229 }
230
231 pub fn from_claude_credentials(config: &AiConfig) -> Result<Option<Self>> {
246 let Some(home) = dirs::home_dir() else {
248 return Ok(None);
249 };
250
251 let creds_path = home.join(".claude").join("credentials.json");
252
253 if !creds_path.exists() {
255 return Ok(None);
256 }
257
258 let creds_content =
260 std::fs::read_to_string(&creds_path).context("Failed to read credentials file")?;
261
262 let creds: ClaudeCredentials =
263 serde_json::from_str(&creds_content).context("Failed to parse credentials JSON")?;
264
265 if creds.access_token.is_empty() {
267 return Ok(None);
268 }
269
270 #[cfg(feature = "keyring")]
272 {
273 use keyring_core::Entry;
274 let entry = Entry::new("aptu", "anthropic_oauth_token")
275 .context("Failed to create keyring entry")?;
276 entry
277 .set_password(&creds.access_token)
278 .context("Failed to store token in keyring")?;
279 }
280
281 let client = Self::with_api_key(
283 PROVIDER_ANTHROPIC,
284 SecretString::from(creds.access_token),
285 &config.model,
286 config,
287 )?;
288
289 let mut client = client;
291 client.auth_method = AuthMethod::OAuth;
292 Ok(Some(client))
293 }
294
295 #[must_use]
302 pub fn claude_credentials_path() -> Option<std::path::PathBuf> {
303 let home = dirs::home_dir()?;
304 let creds_path = home.join(".claude").join("credentials.json");
305 if creds_path.exists() {
306 Some(creds_path)
307 } else {
308 None
309 }
310 }
311
312 pub fn from_keyring_oauth(config: &AiConfig) -> Result<Option<Self>> {
317 #[cfg(feature = "keyring")]
318 {
319 use keyring_core::Entry;
320 let entry = Entry::new("aptu", "anthropic_oauth_token")
321 .context("Failed to create keyring entry")?;
322
323 match entry.get_password() {
324 Ok(token) => {
325 let client = Self::with_api_key(
326 PROVIDER_ANTHROPIC,
327 SecretString::from(token),
328 &config.model,
329 config,
330 )?;
331
332 let mut client = client;
333 client.auth_method = AuthMethod::OAuth;
334 Ok(Some(client))
335 }
336 Err(_) => Ok(None),
337 }
338 }
339
340 #[cfg(not(feature = "keyring"))]
341 {
342 let _ = config;
343 Ok(None)
344 }
345 }
346
347 #[must_use]
349 pub fn auth_method(&self) -> AuthMethod {
350 self.auth_method
351 }
352
353 #[must_use]
355 pub fn circuit_breaker(&self) -> &CircuitBreaker {
356 &self.circuit_breaker
357 }
358}
359
360#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
361#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
362impl AiProvider for AiClient {
363 fn name(&self) -> &str {
364 self.provider.name
365 }
366
367 fn api_url(&self) -> &str {
368 self.provider.api_url
369 }
370
371 fn api_key_env(&self) -> &str {
372 self.provider.api_key_env
373 }
374
375 fn http_client(&self) -> &Client {
376 &self.http
377 }
378
379 fn api_key(&self) -> &SecretString {
380 &self.api_key
381 }
382
383 fn model(&self) -> &str {
384 &self.model
385 }
386
387 fn max_tokens(&self) -> u32 {
388 self.max_tokens
389 }
390
391 fn temperature(&self) -> f32 {
392 self.temperature
393 }
394
395 fn max_attempts(&self) -> u32 {
396 self.max_attempts
397 }
398
399 fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
400 Some(&self.circuit_breaker)
401 }
402
403 fn custom_guidance(&self) -> Option<&str> {
404 self.custom_guidance.as_deref()
405 }
406
407 fn build_headers(&self) -> reqwest::header::HeaderMap {
408 let mut headers = reqwest::header::HeaderMap::new();
409 if let Ok(val) = "application/json".parse() {
410 headers.insert("Content-Type", val);
411 }
412
413 if self.provider.name == super::registry::PROVIDER_ANTHROPIC {
415 if let Ok(val) = self.api_key().expose_secret().parse() {
416 headers.insert("x-api-key", val);
417 }
418 if let Ok(val) = "2023-06-01".parse() {
419 headers.insert("anthropic-version", val);
420 }
421 return headers;
422 }
423
424 if self.provider.name == "openrouter" {
426 if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
427 headers.insert("HTTP-Referer", val);
428 }
429 if let Ok(val) = "Aptu CLI".parse() {
430 headers.insert("X-Title", val);
431 }
432 }
433
434 headers
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::super::registry::all_providers;
441 use super::*;
442
443 fn test_config() -> AiConfig {
444 AiConfig {
445 provider: "openrouter".to_string(),
446 model: "test-model:free".to_string(),
447 max_tokens: 2048,
448 temperature: 0.3,
449 timeout_seconds: 30,
450 allow_paid_models: false,
451 circuit_breaker_threshold: 3,
452 circuit_breaker_reset_seconds: 60,
453 retry_max_attempts: 3,
454 tasks: None,
455 fallback: None,
456 custom_guidance: None,
457 validation_enabled: true,
458 }
459 }
460
461 #[test]
462 fn test_with_api_key_all_providers() {
463 let config = test_config();
464 for provider_config in all_providers() {
465 let result = AiClient::with_api_key(
466 provider_config.name,
467 SecretString::from("test_key"),
468 "test-model:free",
469 &config,
470 );
471 assert!(
472 result.is_ok(),
473 "Failed for provider: {}",
474 provider_config.name
475 );
476 }
477 }
478
479 #[test]
480 fn test_unknown_provider_error() {
481 let config = test_config();
482 let result = AiClient::with_api_key(
483 "nonexistent",
484 SecretString::from("key"),
485 "test-model",
486 &config,
487 );
488 assert!(result.is_err());
489 }
490
491 #[test]
492 fn test_openrouter_rejects_paid_model() {
493 let mut config = test_config();
494 config.model = "anthropic/claude-sonnet-4-6".to_string();
495 config.allow_paid_models = false;
496 let result = AiClient::with_api_key(
497 "openrouter",
498 SecretString::from("key"),
499 "anthropic/claude-sonnet-4-6",
500 &config,
501 );
502 assert!(result.is_err());
503 }
504
505 #[test]
506 fn test_max_attempts_from_config() {
507 let mut config = test_config();
508 config.retry_max_attempts = 5;
509 let client = AiClient::with_api_key(
510 "openrouter",
511 SecretString::from("key"),
512 "test-model:free",
513 &config,
514 )
515 .expect("should create client");
516 assert_eq!(client.max_attempts(), 5);
517 }
518
519 #[test]
520 fn test_build_headers_anthropic_has_api_key_and_version() {
521 let config = test_config();
522 let client = AiClient::with_api_key(
523 PROVIDER_ANTHROPIC,
524 SecretString::from("test_api_key"),
525 "test-model",
526 &config,
527 )
528 .expect("should create anthropic client");
529
530 let headers = client.build_headers();
531
532 let header_str = |k| headers.get(k).and_then(|v| v.to_str().ok());
533 assert_eq!(header_str("x-api-key"), Some("test_api_key"));
534 assert_eq!(header_str("anthropic-version"), Some("2023-06-01"));
535 }
536
537 #[test]
538 fn test_build_headers_non_anthropic_unaffected() {
539 let config = test_config();
540 let client = AiClient::with_api_key(
541 "openrouter",
542 SecretString::from("test_key"),
543 "test-model:free",
544 &config,
545 )
546 .expect("should create openrouter client");
547
548 let headers = client.build_headers();
549
550 assert!(!headers.contains_key("anthropic-version"));
551 assert!(headers.contains_key("http-referer"));
552 assert!(headers.contains_key("x-title"));
553 }
554
555 #[test]
556 fn test_from_claude_credentials_missing_file() {
557 let config = test_config();
558 let result = AiClient::from_claude_credentials(&config);
559 assert!(result.is_ok());
560 assert!(result.unwrap().is_none());
561 }
562
563 #[test]
564 fn test_from_claude_credentials_malformed_json() {
565 use std::fs;
566 use std::io::Write;
567
568 let temp_dir = tempfile::tempdir().expect("should create temp dir");
569 let claude_dir = temp_dir.path().join(".claude");
570 fs::create_dir_all(&claude_dir).expect("should create .claude dir");
571
572 let creds_path = claude_dir.join("credentials.json");
573 let mut file = fs::File::create(&creds_path).expect("should create file");
574 file.write_all(b"{ invalid json }")
575 .expect("should write file");
576
577 let malformed = "{ invalid json }";
580 let result: Result<ClaudeCredentials, _> = serde_json::from_str(malformed);
581 assert!(result.is_err());
582 }
583
584 #[test]
585 fn test_from_claude_credentials_missing_access_token() {
586 let malformed = r#"{"other_field": "value"}"#;
587 let result: Result<ClaudeCredentials, _> = serde_json::from_str(malformed);
588 assert!(result.is_err());
589 }
590
591 #[test]
592 fn test_from_claude_credentials_empty_token() {
593 let empty_token = r#"{"access_token": ""}"#;
594 let creds: ClaudeCredentials = serde_json::from_str(empty_token).expect("should parse");
595 assert!(creds.access_token.is_empty());
596 }
597
598 #[test]
599 fn test_auth_method_api_key() {
600 let config = test_config();
601 let client = AiClient::with_api_key(
602 PROVIDER_ANTHROPIC,
603 SecretString::from("test_key"),
604 "test-model",
605 &config,
606 )
607 .expect("should create client");
608 assert_eq!(client.auth_method(), AuthMethod::ApiKey);
609 }
610}