1use std::env;
9
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use secrecy::{ExposeSecret, SecretString};
14use serde::{Deserialize, Serialize};
15
16use super::circuit_breaker::CircuitBreaker;
17use super::provider::AiProvider;
18use super::registry::{PROVIDER_ANTHROPIC, ProviderConfig, get_provider};
19use crate::config::AiConfig;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24pub enum AuthMethod {
25 ApiKey,
27 OAuth,
29}
30
31impl std::fmt::Display for AuthMethod {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 AuthMethod::ApiKey => write!(f, "api-key"),
35 AuthMethod::OAuth => write!(f, "oauth"),
36 }
37 }
38}
39
40#[derive(Debug, Deserialize)]
42pub struct ClaudeCredentials {
43 pub access_token: String,
45}
46
47fn build_http_client(timeout_seconds: u64) -> Result<Client> {
51 #[cfg(not(target_arch = "wasm32"))]
52 let http = Client::builder()
53 .timeout(std::time::Duration::from_secs(timeout_seconds))
54 .build()
55 .context("Failed to create HTTP client")?;
56 #[cfg(target_arch = "wasm32")]
57 let http = Client::builder()
58 .build()
59 .context("Failed to create HTTP client")?;
60 Ok(http)
61}
62
63#[derive(Debug)]
68pub struct AiClient {
69 provider: &'static ProviderConfig,
71 http: Client,
73 api_key: SecretString,
75 model: String,
77 max_tokens: u32,
79 temperature: f32,
81 max_attempts: u32,
83 circuit_breaker: CircuitBreaker,
85 custom_guidance: Option<String>,
87 auth_method: AuthMethod,
89}
90
91impl Drop for AiClient {
92 fn drop(&mut self) {
93 use zeroize::Zeroize;
94 self.api_key.zeroize();
97 }
98}
99
100impl AiClient {
101 pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
119 let provider = get_provider(provider_name)
121 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
122
123 if provider_name == "openrouter"
125 && !config.allow_paid_models
126 && !super::is_free_model(&config.model)
127 {
128 anyhow::bail!(
129 "Model '{}' is not in the free tier.\n\
130 To use paid models, set `allow_paid_models = true` in your config file:\n\
131 {}\n\n\
132 Or use a free model like: google/gemma-3-12b-it:free",
133 config.model,
134 crate::config::config_file_path().display()
135 );
136 }
137
138 let api_key = env::var(provider.api_key_env).with_context(|| {
140 format!(
141 "Missing {} environment variable.\n\
142 Set it with: export {}=your_api_key",
143 provider.api_key_env, provider.api_key_env
144 )
145 })?;
146
147 let http = build_http_client(config.timeout_seconds)?;
149
150 Ok(Self {
151 provider,
152 http,
153 api_key: SecretString::new(api_key.into()),
154 model: config.model.clone(),
155 max_tokens: config.max_tokens,
156 temperature: config.temperature,
157 max_attempts: config.retry_max_attempts,
158 circuit_breaker: CircuitBreaker::new(
159 config.circuit_breaker_threshold,
160 config.circuit_breaker_reset_seconds,
161 ),
162 custom_guidance: config.custom_guidance.clone(),
163 auth_method: AuthMethod::ApiKey,
164 })
165 }
166
167 pub fn with_api_key(
187 provider_name: &str,
188 api_key: SecretString,
189 model_name: &str,
190 config: &AiConfig,
191 ) -> Result<Self> {
192 let provider = get_provider(provider_name)
194 .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
195
196 if provider_name == "openrouter"
198 && !config.allow_paid_models
199 && !super::is_free_model(model_name)
200 {
201 anyhow::bail!(
202 "Model '{}' is not in the free tier.\n\
203 To use paid models, set `allow_paid_models = true` in your config file:\n\
204 {}\n\n\
205 Or use a free model like: google/gemma-3-12b-it:free",
206 model_name,
207 crate::config::config_file_path().display()
208 );
209 }
210
211 let http = build_http_client(config.timeout_seconds)?;
213
214 Ok(Self {
215 provider,
216 http,
217 api_key,
218 model: model_name.to_string(),
219 max_tokens: config.max_tokens,
220 temperature: config.temperature,
221 max_attempts: config.retry_max_attempts,
222 circuit_breaker: CircuitBreaker::new(
223 config.circuit_breaker_threshold,
224 config.circuit_breaker_reset_seconds,
225 ),
226 custom_guidance: config.custom_guidance.clone(),
227 auth_method: AuthMethod::ApiKey,
228 })
229 }
230
231 pub fn from_claude_credentials(config: &AiConfig) -> Result<Option<Self>> {
246 let Some(home) = dirs::home_dir() else {
248 return Ok(None);
249 };
250
251 let creds_path = home.join(".claude").join("credentials.json");
252
253 if !creds_path.exists() {
255 return Ok(None);
256 }
257
258 let creds_content =
260 std::fs::read_to_string(&creds_path).context("Failed to read credentials file")?;
261
262 let creds: ClaudeCredentials =
263 serde_json::from_str(&creds_content).context("Failed to parse credentials JSON")?;
264
265 if creds.access_token.is_empty() {
267 return Ok(None);
268 }
269
270 #[cfg(feature = "keyring")]
272 {
273 use keyring_core::Entry;
274 let entry = Entry::new("aptu", "anthropic_oauth_token")
275 .context("Failed to create keyring entry")?;
276 entry
277 .set_password(&creds.access_token)
278 .context("Failed to store token in keyring")?;
279 }
280
281 let client = Self::with_api_key(
283 PROVIDER_ANTHROPIC,
284 SecretString::from(creds.access_token),
285 &config.model,
286 config,
287 )?;
288
289 let mut client = client;
291 client.auth_method = AuthMethod::OAuth;
292 Ok(Some(client))
293 }
294
295 #[must_use]
302 pub fn claude_credentials_path() -> Option<std::path::PathBuf> {
303 let home = dirs::home_dir()?;
304 let creds_path = home.join(".claude").join("credentials.json");
305 if creds_path.exists() {
306 Some(creds_path)
307 } else {
308 None
309 }
310 }
311
312 pub fn from_keyring_oauth(config: &AiConfig) -> Result<Option<Self>> {
317 #[cfg(feature = "keyring")]
318 {
319 use keyring_core::Entry;
320 let entry = Entry::new("aptu", "anthropic_oauth_token")
321 .context("Failed to create keyring entry")?;
322
323 match entry.get_password() {
324 Ok(token) => {
325 let client = Self::with_api_key(
326 PROVIDER_ANTHROPIC,
327 SecretString::from(token),
328 &config.model,
329 config,
330 )?;
331
332 let mut client = client;
333 client.auth_method = AuthMethod::OAuth;
334 Ok(Some(client))
335 }
336 Err(_) => Ok(None),
337 }
338 }
339
340 #[cfg(not(feature = "keyring"))]
341 {
342 let _ = config;
343 Ok(None)
344 }
345 }
346
347 #[must_use]
349 pub fn auth_method(&self) -> AuthMethod {
350 self.auth_method
351 }
352
353 #[must_use]
355 pub fn circuit_breaker(&self) -> &CircuitBreaker {
356 &self.circuit_breaker
357 }
358}
359
360#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
361#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
362impl AiProvider for AiClient {
363 fn config(&self) -> &ProviderConfig {
364 self.provider
365 }
366
367 fn http_client(&self) -> &Client {
368 &self.http
369 }
370
371 fn api_key(&self) -> &SecretString {
372 &self.api_key
373 }
374
375 fn model(&self) -> &str {
376 &self.model
377 }
378
379 fn max_tokens(&self) -> u32 {
380 self.max_tokens
381 }
382
383 fn temperature(&self) -> f32 {
384 self.temperature
385 }
386
387 fn max_attempts(&self) -> u32 {
388 self.max_attempts
389 }
390
391 fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
392 Some(&self.circuit_breaker)
393 }
394
395 fn custom_guidance(&self) -> Option<&str> {
396 self.custom_guidance.as_deref()
397 }
398
399 fn build_headers(&self) -> reqwest::header::HeaderMap {
400 let mut headers = reqwest::header::HeaderMap::new();
401 if let Ok(val) = "application/json".parse() {
402 headers.insert("Content-Type", val);
403 }
404
405 if self.provider.name == super::registry::PROVIDER_ANTHROPIC {
407 if let Ok(val) = self.api_key().expose_secret().parse() {
408 headers.insert("x-api-key", val);
409 }
410 if let Ok(val) = "2023-06-01".parse() {
411 headers.insert("anthropic-version", val);
412 }
413 return headers;
414 }
415
416 if self.provider.name == "openrouter" {
418 if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
419 headers.insert("HTTP-Referer", val);
420 }
421 if let Ok(val) = "Aptu CLI".parse() {
422 headers.insert("X-Title", val);
423 }
424 }
425
426 headers
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::super::registry::all_providers;
433 use super::*;
434
435 fn test_config() -> AiConfig {
436 AiConfig {
437 provider: "openrouter".to_string(),
438 model: "test-model:free".to_string(),
439 max_tokens: 2048,
440 temperature: 0.3,
441 timeout_seconds: 30,
442 allow_paid_models: false,
443 circuit_breaker_threshold: 3,
444 circuit_breaker_reset_seconds: 60,
445 retry_max_attempts: 3,
446 tasks: None,
447 fallback: None,
448 custom_guidance: None,
449 validation_enabled: true,
450 }
451 }
452
453 #[test]
454 fn test_with_api_key_all_providers() {
455 let config = test_config();
456 for provider_config in all_providers() {
457 let result = AiClient::with_api_key(
458 provider_config.name,
459 SecretString::from("test_key"),
460 "test-model:free",
461 &config,
462 );
463 assert!(
464 result.is_ok(),
465 "Failed for provider: {}",
466 provider_config.name
467 );
468 }
469 }
470
471 #[test]
472 fn test_unknown_provider_error() {
473 let config = test_config();
474 let result = AiClient::with_api_key(
475 "nonexistent",
476 SecretString::from("key"),
477 "test-model",
478 &config,
479 );
480 assert!(result.is_err());
481 }
482
483 #[test]
484 fn test_openrouter_rejects_paid_model() {
485 let mut config = test_config();
486 config.model = "anthropic/claude-sonnet-4-6".to_string();
487 config.allow_paid_models = false;
488 let result = AiClient::with_api_key(
489 "openrouter",
490 SecretString::from("key"),
491 "anthropic/claude-sonnet-4-6",
492 &config,
493 );
494 assert!(result.is_err());
495 }
496
497 #[test]
498 fn test_max_attempts_from_config() {
499 let mut config = test_config();
500 config.retry_max_attempts = 5;
501 let client = AiClient::with_api_key(
502 "openrouter",
503 SecretString::from("key"),
504 "test-model:free",
505 &config,
506 )
507 .expect("should create client");
508 assert_eq!(client.max_attempts(), 5);
509 }
510
511 #[test]
512 fn test_build_headers_anthropic_has_api_key_and_version() {
513 let config = test_config();
514 let client = AiClient::with_api_key(
515 PROVIDER_ANTHROPIC,
516 SecretString::from("test_api_key"),
517 "test-model",
518 &config,
519 )
520 .expect("should create anthropic client");
521
522 let headers = client.build_headers();
523
524 let header_str = |k| headers.get(k).and_then(|v| v.to_str().ok());
525 assert_eq!(header_str("x-api-key"), Some("test_api_key"));
526 assert_eq!(header_str("anthropic-version"), Some("2023-06-01"));
527 }
528
529 #[test]
530 fn test_build_headers_non_anthropic_unaffected() {
531 let config = test_config();
532 let client = AiClient::with_api_key(
533 "openrouter",
534 SecretString::from("test_key"),
535 "test-model:free",
536 &config,
537 )
538 .expect("should create openrouter client");
539
540 let headers = client.build_headers();
541
542 assert!(!headers.contains_key("anthropic-version"));
543 assert!(headers.contains_key("http-referer"));
544 assert!(headers.contains_key("x-title"));
545 }
546
547 #[test]
548 fn test_from_claude_credentials_missing_file() {
549 let config = test_config();
550 let result = AiClient::from_claude_credentials(&config);
551 assert!(result.is_ok());
552 assert!(result.unwrap().is_none());
553 }
554
555 #[test]
556 fn test_from_claude_credentials_malformed_json() {
557 use std::fs;
558 use std::io::Write;
559
560 let temp_dir = tempfile::tempdir().expect("should create temp dir");
561 let claude_dir = temp_dir.path().join(".claude");
562 fs::create_dir_all(&claude_dir).expect("should create .claude dir");
563
564 let creds_path = claude_dir.join("credentials.json");
565 let mut file = fs::File::create(&creds_path).expect("should create file");
566 file.write_all(b"{ invalid json }")
567 .expect("should write file");
568
569 let malformed = "{ invalid json }";
572 let result: Result<ClaudeCredentials, _> = serde_json::from_str(malformed);
573 assert!(result.is_err());
574 }
575
576 #[test]
577 fn test_from_claude_credentials_missing_access_token() {
578 let malformed = r#"{"other_field": "value"}"#;
579 let result: Result<ClaudeCredentials, _> = serde_json::from_str(malformed);
580 assert!(result.is_err());
581 }
582
583 #[test]
584 fn test_from_claude_credentials_empty_token() {
585 let empty_token = r#"{"access_token": ""}"#;
586 let creds: ClaudeCredentials = serde_json::from_str(empty_token).expect("should parse");
587 assert!(creds.access_token.is_empty());
588 }
589
590 #[test]
591 fn test_auth_method_api_key() {
592 let config = test_config();
593 let client = AiClient::with_api_key(
594 PROVIDER_ANTHROPIC,
595 SecretString::from("test_key"),
596 "test-model",
597 &config,
598 )
599 .expect("should create client");
600 assert_eq!(client.auth_method(), AuthMethod::ApiKey);
601 }
602}