Skip to main content

aptu_core/ai/
client.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Generic AI client for all registered providers.
4//!
5//! Provides a single `AiClient` struct that works with any AI provider
6//! registered in the provider registry. See [`super::registry`] for available providers.
7
8use std::env;
9
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use secrecy::{ExposeSecret, SecretString};
14use serde::{Deserialize, Serialize};
15
16use super::circuit_breaker::CircuitBreaker;
17use super::provider::AiProvider;
18use super::registry::{PROVIDER_ANTHROPIC, ProviderConfig, get_provider};
19use crate::config::AiConfig;
20
21/// Authentication method used by the AI client.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24pub enum AuthMethod {
25    /// API key from environment variable.
26    ApiKey,
27    /// OAuth token from Claude credentials file.
28    OAuth,
29}
30
31impl std::fmt::Display for AuthMethod {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            AuthMethod::ApiKey => write!(f, "api-key"),
35            AuthMethod::OAuth => write!(f, "oauth"),
36        }
37    }
38}
39
40/// Claude credentials from ~/.claude/credentials.json.
41#[derive(Debug, Deserialize)]
42pub struct ClaudeCredentials {
43    /// OAuth access token.
44    pub access_token: String,
45}
46
47/// Creates an HTTP client with timeout. On native targets, the request timeout
48/// is set on the client; on wasm32, the browser's fetch API manages timeouts
49/// independently and reqwest's `timeout()` is unavailable.
50fn build_http_client(timeout_seconds: u64) -> Result<Client> {
51    #[cfg(not(target_arch = "wasm32"))]
52    let http = Client::builder()
53        .timeout(std::time::Duration::from_secs(timeout_seconds))
54        .build()
55        .context("Failed to create HTTP client")?;
56    #[cfg(target_arch = "wasm32")]
57    let http = Client::builder()
58        .build()
59        .context("Failed to create HTTP client")?;
60    Ok(http)
61}
62
63/// Generic AI client for all providers.
64///
65/// Holds HTTP client, API key, and model configuration for reuse across multiple requests.
66/// Uses the provider registry to get provider-specific configuration.
67#[derive(Debug)]
68pub struct AiClient {
69    /// Provider configuration from registry.
70    provider: &'static ProviderConfig,
71    /// HTTP client with configured timeout.
72    http: Client,
73    /// API key for provider authentication.
74    api_key: SecretString,
75    /// Model name (e.g., "mistralai/mistral-small-2603").
76    model: String,
77    /// Maximum tokens for API responses.
78    max_tokens: u32,
79    /// Temperature for API requests.
80    temperature: f32,
81    /// Maximum retry attempts for rate-limited requests.
82    max_attempts: u32,
83    /// Circuit breaker for resilience.
84    circuit_breaker: CircuitBreaker,
85    /// Optional custom guidance from config to inject into system prompts.
86    custom_guidance: Option<String>,
87    /// Authentication method used.
88    auth_method: AuthMethod,
89}
90
91impl Drop for AiClient {
92    fn drop(&mut self) {
93        use zeroize::Zeroize;
94        // Safety: SecretString wraps String, which implements Zeroize.
95        // Calling zeroize() overwrites the backing buffer before deallocation.
96        self.api_key.zeroize();
97    }
98}
99
100impl AiClient {
101    /// Creates a new AI client from configuration.
102    ///
103    /// Validates the model against cost control settings and fetches the API key
104    /// from the environment.
105    ///
106    /// # Arguments
107    ///
108    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
109    /// * `config` - AI configuration with model, timeout, and cost control settings
110    ///
111    /// # Errors
112    ///
113    /// Returns an error if:
114    /// - Provider is not found in registry
115    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
116    /// - API key environment variable is not set
117    /// - HTTP client creation fails
118    pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
119        // Look up provider in registry
120        let provider = get_provider(provider_name)
121            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
122
123        // Validate model against cost control (OpenRouter-specific)
124        if provider_name == "openrouter"
125            && !config.allow_paid_models
126            && !super::is_free_model(&config.model)
127        {
128            anyhow::bail!(
129                "Model '{}' is not in the free tier.\n\
130                 To use paid models, set `allow_paid_models = true` in your config file:\n\
131                 {}\n\n\
132                 Or use a free model like: google/gemma-3-12b-it:free",
133                config.model,
134                crate::config::config_file_path().display()
135            );
136        }
137
138        // Get API key from environment
139        let api_key = env::var(provider.api_key_env).with_context(|| {
140            format!(
141                "Missing {} environment variable.\n\
142                 Set it with: export {}=your_api_key",
143                provider.api_key_env, provider.api_key_env
144            )
145        })?;
146
147        // Create HTTP client with timeout (timeout() is native-only; wasm32 uses fetch API)
148        let http = build_http_client(config.timeout_seconds)?;
149
150        Ok(Self {
151            provider,
152            http,
153            api_key: SecretString::new(api_key.into()),
154            model: config.model.clone(),
155            max_tokens: config.max_tokens,
156            temperature: config.temperature,
157            max_attempts: config.retry_max_attempts,
158            circuit_breaker: CircuitBreaker::new(
159                config.circuit_breaker_threshold,
160                config.circuit_breaker_reset_seconds,
161            ),
162            custom_guidance: config.custom_guidance.clone(),
163            auth_method: AuthMethod::ApiKey,
164        })
165    }
166
167    /// Creates a new AI client with a provided API key and validates the model exists.
168    ///
169    /// This constructor validates that the model exists via the runtime model registry
170    /// before creating the client. It allows callers to provide an API key directly,
171    /// enabling multi-platform credential resolution (e.g., from iOS keychain via FFI).
172    ///
173    /// # Arguments
174    ///
175    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
176    /// * `api_key` - API key as a `SecretString`
177    /// * `model_name` - Model name to use (e.g., "gemini-3.1-flash-lite-preview")
178    /// * `config` - AI configuration with timeout and cost control settings
179    ///
180    /// # Errors
181    ///
182    /// Returns an error if:
183    /// - Provider is not found in registry
184    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
185    /// - HTTP client creation fails
186    pub fn with_api_key(
187        provider_name: &str,
188        api_key: SecretString,
189        model_name: &str,
190        config: &AiConfig,
191    ) -> Result<Self> {
192        // Look up provider in registry
193        let provider = get_provider(provider_name)
194            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
195
196        // Validate model against cost control (OpenRouter-specific)
197        if provider_name == "openrouter"
198            && !config.allow_paid_models
199            && !super::is_free_model(model_name)
200        {
201            anyhow::bail!(
202                "Model '{}' is not in the free tier.\n\
203                 To use paid models, set `allow_paid_models = true` in your config file:\n\
204                 {}\n\n\
205                 Or use a free model like: google/gemma-3-12b-it:free",
206                model_name,
207                crate::config::config_file_path().display()
208            );
209        }
210
211        // Create HTTP client with timeout (timeout() is native-only; wasm32 uses fetch API)
212        let http = build_http_client(config.timeout_seconds)?;
213
214        Ok(Self {
215            provider,
216            http,
217            api_key,
218            model: model_name.to_string(),
219            max_tokens: config.max_tokens,
220            temperature: config.temperature,
221            max_attempts: config.retry_max_attempts,
222            circuit_breaker: CircuitBreaker::new(
223                config.circuit_breaker_threshold,
224                config.circuit_breaker_reset_seconds,
225            ),
226            custom_guidance: config.custom_guidance.clone(),
227            auth_method: AuthMethod::ApiKey,
228        })
229    }
230
231    /// Creates a new AI client from Claude credentials file (~/.claude/credentials.json).
232    ///
233    /// Reads the credentials file, extracts the access token, stores it in the OS keyring,
234    /// and returns an `AiClient` configured for the Anthropic provider.
235    ///
236    /// # Arguments
237    ///
238    /// * `config` - AI configuration with timeout and cost control settings
239    ///
240    /// # Returns
241    ///
242    /// Returns `Ok(Some(AiClient))` if credentials are found and valid,
243    /// `Ok(None)` if the credentials file is missing or invalid,
244    /// or an error if keyring operations fail.
245    pub fn from_claude_credentials(config: &AiConfig) -> Result<Option<Self>> {
246        // Resolve credentials file path
247        let Some(home) = dirs::home_dir() else {
248            return Ok(None);
249        };
250
251        let creds_path = home.join(".claude").join("credentials.json");
252
253        // Check if file exists
254        if !creds_path.exists() {
255            return Ok(None);
256        }
257
258        // Read and parse credentials file
259        let creds_content =
260            std::fs::read_to_string(&creds_path).context("Failed to read credentials file")?;
261
262        let creds: ClaudeCredentials =
263            serde_json::from_str(&creds_content).context("Failed to parse credentials JSON")?;
264
265        // Validate token is not empty
266        if creds.access_token.is_empty() {
267            return Ok(None);
268        }
269
270        // Store token in keyring
271        #[cfg(feature = "keyring")]
272        {
273            use keyring_core::Entry;
274            let entry = Entry::new("aptu", "anthropic_oauth_token")
275                .context("Failed to create keyring entry")?;
276            entry
277                .set_password(&creds.access_token)
278                .context("Failed to store token in keyring")?;
279        }
280
281        // Create client with the token
282        let client = Self::with_api_key(
283            PROVIDER_ANTHROPIC,
284            SecretString::from(creds.access_token),
285            &config.model,
286            config,
287        )?;
288
289        // Mark as OAuth
290        let mut client = client;
291        client.auth_method = AuthMethod::OAuth;
292        Ok(Some(client))
293    }
294
295    /// Returns the path to the Claude credentials file if it exists.
296    ///
297    /// This helper centralizes the path resolution logic for ~/.claude/credentials.json,
298    /// keeping the CLI command layer thin and avoiding duplicate path construction.
299    ///
300    /// Returns `Some(path)` if the file exists, `None` otherwise.
301    #[must_use]
302    pub fn claude_credentials_path() -> Option<std::path::PathBuf> {
303        let home = dirs::home_dir()?;
304        let creds_path = home.join(".claude").join("credentials.json");
305        if creds_path.exists() {
306            Some(creds_path)
307        } else {
308            None
309        }
310    }
311
312    /// Attempts to retrieve a Claude OAuth token from the OS keyring.
313    ///
314    /// Returns `Ok(Some(AiClient))` if a token is found in the keyring,
315    /// `Ok(None)` if no token is stored, or an error if keyring operations fail.
316    pub fn from_keyring_oauth(config: &AiConfig) -> Result<Option<Self>> {
317        #[cfg(feature = "keyring")]
318        {
319            use keyring_core::Entry;
320            let entry = Entry::new("aptu", "anthropic_oauth_token")
321                .context("Failed to create keyring entry")?;
322
323            match entry.get_password() {
324                Ok(token) => {
325                    let client = Self::with_api_key(
326                        PROVIDER_ANTHROPIC,
327                        SecretString::from(token),
328                        &config.model,
329                        config,
330                    )?;
331
332                    let mut client = client;
333                    client.auth_method = AuthMethod::OAuth;
334                    Ok(Some(client))
335                }
336                Err(_) => Ok(None),
337            }
338        }
339
340        #[cfg(not(feature = "keyring"))]
341        {
342            let _ = config;
343            Ok(None)
344        }
345    }
346
347    /// Returns the authentication method used by this client.
348    #[must_use]
349    pub fn auth_method(&self) -> AuthMethod {
350        self.auth_method
351    }
352
353    /// Get the circuit breaker for this client.
354    #[must_use]
355    pub fn circuit_breaker(&self) -> &CircuitBreaker {
356        &self.circuit_breaker
357    }
358}
359
360#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
361#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
362impl AiProvider for AiClient {
363    fn config(&self) -> &ProviderConfig {
364        self.provider
365    }
366
367    fn http_client(&self) -> &Client {
368        &self.http
369    }
370
371    fn api_key(&self) -> &SecretString {
372        &self.api_key
373    }
374
375    fn model(&self) -> &str {
376        &self.model
377    }
378
379    fn max_tokens(&self) -> u32 {
380        self.max_tokens
381    }
382
383    fn temperature(&self) -> f32 {
384        self.temperature
385    }
386
387    fn max_attempts(&self) -> u32 {
388        self.max_attempts
389    }
390
391    fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
392        Some(&self.circuit_breaker)
393    }
394
395    fn custom_guidance(&self) -> Option<&str> {
396        self.custom_guidance.as_deref()
397    }
398
399    fn build_headers(&self) -> reqwest::header::HeaderMap {
400        let mut headers = reqwest::header::HeaderMap::new();
401        if let Ok(val) = "application/json".parse() {
402            headers.insert("Content-Type", val);
403        }
404
405        // Anthropic-specific headers
406        if self.provider.name == super::registry::PROVIDER_ANTHROPIC {
407            if let Ok(val) = self.api_key().expose_secret().parse() {
408                headers.insert("x-api-key", val);
409            }
410            if let Ok(val) = "2023-06-01".parse() {
411                headers.insert("anthropic-version", val);
412            }
413            return headers;
414        }
415
416        // OpenRouter-specific headers
417        if self.provider.name == "openrouter" {
418            if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
419                headers.insert("HTTP-Referer", val);
420            }
421            if let Ok(val) = "Aptu CLI".parse() {
422                headers.insert("X-Title", val);
423            }
424        }
425
426        headers
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::super::registry::all_providers;
433    use super::*;
434
435    fn test_config() -> AiConfig {
436        AiConfig {
437            provider: "openrouter".to_string(),
438            model: "test-model:free".to_string(),
439            max_tokens: 2048,
440            temperature: 0.3,
441            timeout_seconds: 30,
442            allow_paid_models: false,
443            circuit_breaker_threshold: 3,
444            circuit_breaker_reset_seconds: 60,
445            retry_max_attempts: 3,
446            tasks: None,
447            fallback: None,
448            custom_guidance: None,
449            validation_enabled: true,
450        }
451    }
452
453    #[test]
454    fn test_with_api_key_all_providers() {
455        let config = test_config();
456        for provider_config in all_providers() {
457            let result = AiClient::with_api_key(
458                provider_config.name,
459                SecretString::from("test_key"),
460                "test-model:free",
461                &config,
462            );
463            assert!(
464                result.is_ok(),
465                "Failed for provider: {}",
466                provider_config.name
467            );
468        }
469    }
470
471    #[test]
472    fn test_unknown_provider_error() {
473        let config = test_config();
474        let result = AiClient::with_api_key(
475            "nonexistent",
476            SecretString::from("key"),
477            "test-model",
478            &config,
479        );
480        assert!(result.is_err());
481    }
482
483    #[test]
484    fn test_openrouter_rejects_paid_model() {
485        let mut config = test_config();
486        config.model = "anthropic/claude-sonnet-4-6".to_string();
487        config.allow_paid_models = false;
488        let result = AiClient::with_api_key(
489            "openrouter",
490            SecretString::from("key"),
491            "anthropic/claude-sonnet-4-6",
492            &config,
493        );
494        assert!(result.is_err());
495    }
496
497    #[test]
498    fn test_max_attempts_from_config() {
499        let mut config = test_config();
500        config.retry_max_attempts = 5;
501        let client = AiClient::with_api_key(
502            "openrouter",
503            SecretString::from("key"),
504            "test-model:free",
505            &config,
506        )
507        .expect("should create client");
508        assert_eq!(client.max_attempts(), 5);
509    }
510
511    #[test]
512    fn test_build_headers_anthropic_has_api_key_and_version() {
513        let config = test_config();
514        let client = AiClient::with_api_key(
515            PROVIDER_ANTHROPIC,
516            SecretString::from("test_api_key"),
517            "test-model",
518            &config,
519        )
520        .expect("should create anthropic client");
521
522        let headers = client.build_headers();
523
524        let header_str = |k| headers.get(k).and_then(|v| v.to_str().ok());
525        assert_eq!(header_str("x-api-key"), Some("test_api_key"));
526        assert_eq!(header_str("anthropic-version"), Some("2023-06-01"));
527    }
528
529    #[test]
530    fn test_build_headers_non_anthropic_unaffected() {
531        let config = test_config();
532        let client = AiClient::with_api_key(
533            "openrouter",
534            SecretString::from("test_key"),
535            "test-model:free",
536            &config,
537        )
538        .expect("should create openrouter client");
539
540        let headers = client.build_headers();
541
542        assert!(!headers.contains_key("anthropic-version"));
543        assert!(headers.contains_key("http-referer"));
544        assert!(headers.contains_key("x-title"));
545    }
546
547    #[test]
548    fn test_from_claude_credentials_missing_file() {
549        let config = test_config();
550        let result = AiClient::from_claude_credentials(&config);
551        assert!(result.is_ok());
552        assert!(result.unwrap().is_none());
553    }
554
555    #[test]
556    fn test_from_claude_credentials_malformed_json() {
557        use std::fs;
558        use std::io::Write;
559
560        let temp_dir = tempfile::tempdir().expect("should create temp dir");
561        let claude_dir = temp_dir.path().join(".claude");
562        fs::create_dir_all(&claude_dir).expect("should create .claude dir");
563
564        let creds_path = claude_dir.join("credentials.json");
565        let mut file = fs::File::create(&creds_path).expect("should create file");
566        file.write_all(b"{ invalid json }")
567            .expect("should write file");
568
569        // Temporarily override home_dir for this test
570        // Since we can't easily mock dirs::home_dir, we'll test the parsing logic directly
571        let malformed = "{ invalid json }";
572        let result: Result<ClaudeCredentials, _> = serde_json::from_str(malformed);
573        assert!(result.is_err());
574    }
575
576    #[test]
577    fn test_from_claude_credentials_missing_access_token() {
578        let malformed = r#"{"other_field": "value"}"#;
579        let result: Result<ClaudeCredentials, _> = serde_json::from_str(malformed);
580        assert!(result.is_err());
581    }
582
583    #[test]
584    fn test_from_claude_credentials_empty_token() {
585        let empty_token = r#"{"access_token": ""}"#;
586        let creds: ClaudeCredentials = serde_json::from_str(empty_token).expect("should parse");
587        assert!(creds.access_token.is_empty());
588    }
589
590    #[test]
591    fn test_auth_method_api_key() {
592        let config = test_config();
593        let client = AiClient::with_api_key(
594            PROVIDER_ANTHROPIC,
595            SecretString::from("test_key"),
596            "test-model",
597            &config,
598        )
599        .expect("should create client");
600        assert_eq!(client.auth_method(), AuthMethod::ApiKey);
601    }
602}