Skip to main content

aptu_core/ai/
client.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Generic AI client for all registered providers.
4//!
5//! Provides a single `AiClient` struct that works with any AI provider
6//! registered in the provider registry. See [`super::registry`] for available providers.
7
8use std::env;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use secrecy::{ExposeSecret, SecretString};
15use serde::{Deserialize, Serialize};
16
17use super::circuit_breaker::CircuitBreaker;
18use super::provider::AiProvider;
19use super::registry::{PROVIDER_ANTHROPIC, ProviderConfig, get_provider};
20use crate::config::AiConfig;
21
22/// Authentication method used by the AI client.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum AuthMethod {
26    /// API key from environment variable.
27    ApiKey,
28    /// OAuth token from Claude credentials file.
29    OAuth,
30}
31
32impl std::fmt::Display for AuthMethod {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            AuthMethod::ApiKey => write!(f, "api-key"),
36            AuthMethod::OAuth => write!(f, "oauth"),
37        }
38    }
39}
40
41/// Claude credentials from ~/.claude/credentials.json.
42#[derive(Debug, Deserialize)]
43pub struct ClaudeCredentials {
44    /// OAuth access token.
45    pub access_token: String,
46}
47
48/// Generic AI client for all providers.
49///
50/// Holds HTTP client, API key, and model configuration for reuse across multiple requests.
51/// Uses the provider registry to get provider-specific configuration.
52#[derive(Debug)]
53pub struct AiClient {
54    /// Provider configuration from registry.
55    provider: &'static ProviderConfig,
56    /// HTTP client with configured timeout.
57    http: Client,
58    /// API key for provider authentication.
59    api_key: SecretString,
60    /// Model name (e.g., "mistralai/mistral-small-2603").
61    model: String,
62    /// Maximum tokens for API responses.
63    max_tokens: u32,
64    /// Temperature for API requests.
65    temperature: f32,
66    /// Maximum retry attempts for rate-limited requests.
67    max_attempts: u32,
68    /// Circuit breaker for resilience.
69    circuit_breaker: CircuitBreaker,
70    /// Optional custom guidance from config to inject into system prompts.
71    custom_guidance: Option<String>,
72    /// Authentication method used.
73    auth_method: AuthMethod,
74}
75
76impl Drop for AiClient {
77    fn drop(&mut self) {
78        use zeroize::Zeroize;
79        // Safety: SecretString wraps String, which implements Zeroize.
80        // Calling zeroize() overwrites the backing buffer before deallocation.
81        self.api_key.zeroize();
82    }
83}
84
85impl AiClient {
86    /// Creates a new AI client from configuration.
87    ///
88    /// Validates the model against cost control settings and fetches the API key
89    /// from the environment.
90    ///
91    /// # Arguments
92    ///
93    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
94    /// * `config` - AI configuration with model, timeout, and cost control settings
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if:
99    /// - Provider is not found in registry
100    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
101    /// - API key environment variable is not set
102    /// - HTTP client creation fails
103    pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
104        // Look up provider in registry
105        let provider = get_provider(provider_name)
106            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
107
108        // Validate model against cost control (OpenRouter-specific)
109        if provider_name == "openrouter"
110            && !config.allow_paid_models
111            && !super::is_free_model(&config.model)
112        {
113            anyhow::bail!(
114                "Model '{}' is not in the free tier.\n\
115                 To use paid models, set `allow_paid_models = true` in your config file:\n\
116                 {}\n\n\
117                 Or use a free model like: google/gemma-3-12b-it:free",
118                config.model,
119                crate::config::config_file_path().display()
120            );
121        }
122
123        // Get API key from environment
124        let api_key = env::var(provider.api_key_env).with_context(|| {
125            format!(
126                "Missing {} environment variable.\n\
127                 Set it with: export {}=your_api_key",
128                provider.api_key_env, provider.api_key_env
129            )
130        })?;
131
132        // Create HTTP client with timeout
133        let http = Client::builder()
134            .timeout(Duration::from_secs(config.timeout_seconds))
135            .build()
136            .context("Failed to create HTTP client")?;
137
138        Ok(Self {
139            provider,
140            http,
141            api_key: SecretString::new(api_key.into()),
142            model: config.model.clone(),
143            max_tokens: config.max_tokens,
144            temperature: config.temperature,
145            max_attempts: config.retry_max_attempts,
146            circuit_breaker: CircuitBreaker::new(
147                config.circuit_breaker_threshold,
148                config.circuit_breaker_reset_seconds,
149            ),
150            custom_guidance: config.custom_guidance.clone(),
151            auth_method: AuthMethod::ApiKey,
152        })
153    }
154
155    /// Creates a new AI client with a provided API key and validates the model exists.
156    ///
157    /// This constructor validates that the model exists via the runtime model registry
158    /// before creating the client. It allows callers to provide an API key directly,
159    /// enabling multi-platform credential resolution (e.g., from iOS keychain via FFI).
160    ///
161    /// # Arguments
162    ///
163    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
164    /// * `api_key` - API key as a `SecretString`
165    /// * `model_name` - Model name to use (e.g., "gemini-3.1-flash-lite-preview")
166    /// * `config` - AI configuration with timeout and cost control settings
167    ///
168    /// # Errors
169    ///
170    /// Returns an error if:
171    /// - Provider is not found in registry
172    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
173    /// - HTTP client creation fails
174    pub fn with_api_key(
175        provider_name: &str,
176        api_key: SecretString,
177        model_name: &str,
178        config: &AiConfig,
179    ) -> Result<Self> {
180        // Look up provider in registry
181        let provider = get_provider(provider_name)
182            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
183
184        // Validate model against cost control (OpenRouter-specific)
185        if provider_name == "openrouter"
186            && !config.allow_paid_models
187            && !super::is_free_model(model_name)
188        {
189            anyhow::bail!(
190                "Model '{}' is not in the free tier.\n\
191                 To use paid models, set `allow_paid_models = true` in your config file:\n\
192                 {}\n\n\
193                 Or use a free model like: google/gemma-3-12b-it:free",
194                model_name,
195                crate::config::config_file_path().display()
196            );
197        }
198
199        // Create HTTP client with timeout
200        let http = Client::builder()
201            .timeout(Duration::from_secs(config.timeout_seconds))
202            .build()
203            .context("Failed to create HTTP client")?;
204
205        Ok(Self {
206            provider,
207            http,
208            api_key,
209            model: model_name.to_string(),
210            max_tokens: config.max_tokens,
211            temperature: config.temperature,
212            max_attempts: config.retry_max_attempts,
213            circuit_breaker: CircuitBreaker::new(
214                config.circuit_breaker_threshold,
215                config.circuit_breaker_reset_seconds,
216            ),
217            custom_guidance: config.custom_guidance.clone(),
218            auth_method: AuthMethod::ApiKey,
219        })
220    }
221
222    /// Creates a new AI client from Claude credentials file (~/.claude/credentials.json).
223    ///
224    /// Reads the credentials file, extracts the access token, stores it in the OS keyring,
225    /// and returns an `AiClient` configured for the Anthropic provider.
226    ///
227    /// # Arguments
228    ///
229    /// * `config` - AI configuration with timeout and cost control settings
230    ///
231    /// # Returns
232    ///
233    /// Returns `Ok(Some(AiClient))` if credentials are found and valid,
234    /// `Ok(None)` if the credentials file is missing or invalid,
235    /// or an error if keyring operations fail.
236    pub fn from_claude_credentials(config: &AiConfig) -> Result<Option<Self>> {
237        // Resolve credentials file path
238        let Some(home) = dirs::home_dir() else {
239            return Ok(None);
240        };
241
242        let creds_path = home.join(".claude").join("credentials.json");
243
244        // Check if file exists
245        if !creds_path.exists() {
246            return Ok(None);
247        }
248
249        // Read and parse credentials file
250        let creds_content =
251            std::fs::read_to_string(&creds_path).context("Failed to read credentials file")?;
252
253        let creds: ClaudeCredentials =
254            serde_json::from_str(&creds_content).context("Failed to parse credentials JSON")?;
255
256        // Validate token is not empty
257        if creds.access_token.is_empty() {
258            return Ok(None);
259        }
260
261        // Store token in keyring
262        #[cfg(feature = "keyring")]
263        {
264            use keyring_core::Entry;
265            let entry = Entry::new("aptu", "anthropic_oauth_token")
266                .context("Failed to create keyring entry")?;
267            entry
268                .set_password(&creds.access_token)
269                .context("Failed to store token in keyring")?;
270        }
271
272        // Create client with the token
273        let client = Self::with_api_key(
274            PROVIDER_ANTHROPIC,
275            SecretString::from(creds.access_token),
276            &config.model,
277            config,
278        )?;
279
280        // Mark as OAuth
281        let mut client = client;
282        client.auth_method = AuthMethod::OAuth;
283        Ok(Some(client))
284    }
285
286    /// Returns the path to the Claude credentials file if it exists.
287    ///
288    /// This helper centralizes the path resolution logic for ~/.claude/credentials.json,
289    /// keeping the CLI command layer thin and avoiding duplicate path construction.
290    ///
291    /// Returns `Some(path)` if the file exists, `None` otherwise.
292    #[must_use]
293    pub fn claude_credentials_path() -> Option<std::path::PathBuf> {
294        let home = dirs::home_dir()?;
295        let creds_path = home.join(".claude").join("credentials.json");
296        if creds_path.exists() {
297            Some(creds_path)
298        } else {
299            None
300        }
301    }
302
303    /// Attempts to retrieve a Claude OAuth token from the OS keyring.
304    ///
305    /// Returns `Ok(Some(AiClient))` if a token is found in the keyring,
306    /// `Ok(None)` if no token is stored, or an error if keyring operations fail.
307    pub fn from_keyring_oauth(config: &AiConfig) -> Result<Option<Self>> {
308        #[cfg(feature = "keyring")]
309        {
310            use keyring_core::Entry;
311            let entry = Entry::new("aptu", "anthropic_oauth_token")
312                .context("Failed to create keyring entry")?;
313
314            match entry.get_password() {
315                Ok(token) => {
316                    let client = Self::with_api_key(
317                        PROVIDER_ANTHROPIC,
318                        SecretString::from(token),
319                        &config.model,
320                        config,
321                    )?;
322
323                    let mut client = client;
324                    client.auth_method = AuthMethod::OAuth;
325                    Ok(Some(client))
326                }
327                Err(_) => Ok(None),
328            }
329        }
330
331        #[cfg(not(feature = "keyring"))]
332        {
333            let _ = config;
334            Ok(None)
335        }
336    }
337
338    /// Returns the authentication method used by this client.
339    #[must_use]
340    pub fn auth_method(&self) -> AuthMethod {
341        self.auth_method
342    }
343
344    /// Get the circuit breaker for this client.
345    #[must_use]
346    pub fn circuit_breaker(&self) -> &CircuitBreaker {
347        &self.circuit_breaker
348    }
349}
350
351#[async_trait]
352impl AiProvider for AiClient {
353    fn name(&self) -> &str {
354        self.provider.name
355    }
356
357    fn api_url(&self) -> &str {
358        self.provider.api_url
359    }
360
361    fn api_key_env(&self) -> &str {
362        self.provider.api_key_env
363    }
364
365    fn http_client(&self) -> &Client {
366        &self.http
367    }
368
369    fn api_key(&self) -> &SecretString {
370        &self.api_key
371    }
372
373    fn model(&self) -> &str {
374        &self.model
375    }
376
377    fn max_tokens(&self) -> u32 {
378        self.max_tokens
379    }
380
381    fn temperature(&self) -> f32 {
382        self.temperature
383    }
384
385    fn max_attempts(&self) -> u32 {
386        self.max_attempts
387    }
388
389    fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
390        Some(&self.circuit_breaker)
391    }
392
393    fn custom_guidance(&self) -> Option<&str> {
394        self.custom_guidance.as_deref()
395    }
396
397    fn build_headers(&self) -> reqwest::header::HeaderMap {
398        let mut headers = reqwest::header::HeaderMap::new();
399        if let Ok(val) = "application/json".parse() {
400            headers.insert("Content-Type", val);
401        }
402
403        // Anthropic-specific headers
404        if self.provider.name == super::registry::PROVIDER_ANTHROPIC {
405            if let Ok(val) = self.api_key().expose_secret().parse() {
406                headers.insert("x-api-key", val);
407            }
408            if let Ok(val) = "2023-06-01".parse() {
409                headers.insert("anthropic-version", val);
410            }
411            return headers;
412        }
413
414        // OpenRouter-specific headers
415        if self.provider.name == "openrouter" {
416            if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
417                headers.insert("HTTP-Referer", val);
418            }
419            if let Ok(val) = "Aptu CLI".parse() {
420                headers.insert("X-Title", val);
421            }
422        }
423
424        headers
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::super::registry::all_providers;
431    use super::*;
432
433    fn test_config() -> AiConfig {
434        AiConfig {
435            provider: "openrouter".to_string(),
436            model: "test-model:free".to_string(),
437            max_tokens: 2048,
438            temperature: 0.3,
439            timeout_seconds: 30,
440            allow_paid_models: false,
441            circuit_breaker_threshold: 3,
442            circuit_breaker_reset_seconds: 60,
443            retry_max_attempts: 3,
444            tasks: None,
445            fallback: None,
446            custom_guidance: None,
447            validation_enabled: true,
448        }
449    }
450
451    #[test]
452    fn test_with_api_key_all_providers() {
453        let config = test_config();
454        for provider_config in all_providers() {
455            let result = AiClient::with_api_key(
456                provider_config.name,
457                SecretString::from("test_key"),
458                "test-model:free",
459                &config,
460            );
461            assert!(
462                result.is_ok(),
463                "Failed for provider: {}",
464                provider_config.name
465            );
466        }
467    }
468
469    #[test]
470    fn test_unknown_provider_error() {
471        let config = test_config();
472        let result = AiClient::with_api_key(
473            "nonexistent",
474            SecretString::from("key"),
475            "test-model",
476            &config,
477        );
478        assert!(result.is_err());
479    }
480
481    #[test]
482    fn test_openrouter_rejects_paid_model() {
483        let mut config = test_config();
484        config.model = "anthropic/claude-sonnet-4-6".to_string();
485        config.allow_paid_models = false;
486        let result = AiClient::with_api_key(
487            "openrouter",
488            SecretString::from("key"),
489            "anthropic/claude-sonnet-4-6",
490            &config,
491        );
492        assert!(result.is_err());
493    }
494
495    #[test]
496    fn test_max_attempts_from_config() {
497        let mut config = test_config();
498        config.retry_max_attempts = 5;
499        let client = AiClient::with_api_key(
500            "openrouter",
501            SecretString::from("key"),
502            "test-model:free",
503            &config,
504        )
505        .expect("should create client");
506        assert_eq!(client.max_attempts(), 5);
507    }
508
509    #[test]
510    fn test_build_headers_anthropic_has_api_key_and_version() {
511        let config = test_config();
512        let client = AiClient::with_api_key(
513            PROVIDER_ANTHROPIC,
514            SecretString::from("test_api_key"),
515            "test-model",
516            &config,
517        )
518        .expect("should create anthropic client");
519
520        let headers = client.build_headers();
521
522        let header_str = |k| headers.get(k).and_then(|v| v.to_str().ok());
523        assert_eq!(header_str("x-api-key"), Some("test_api_key"));
524        assert_eq!(header_str("anthropic-version"), Some("2023-06-01"));
525    }
526
527    #[test]
528    fn test_build_headers_non_anthropic_unaffected() {
529        let config = test_config();
530        let client = AiClient::with_api_key(
531            "openrouter",
532            SecretString::from("test_key"),
533            "test-model:free",
534            &config,
535        )
536        .expect("should create openrouter client");
537
538        let headers = client.build_headers();
539
540        assert!(!headers.contains_key("anthropic-version"));
541        assert!(headers.contains_key("http-referer"));
542        assert!(headers.contains_key("x-title"));
543    }
544
545    #[test]
546    fn test_from_claude_credentials_missing_file() {
547        let config = test_config();
548        let result = AiClient::from_claude_credentials(&config);
549        assert!(result.is_ok());
550        assert!(result.unwrap().is_none());
551    }
552
553    #[test]
554    fn test_from_claude_credentials_malformed_json() {
555        use std::fs;
556        use std::io::Write;
557
558        let temp_dir = tempfile::tempdir().expect("should create temp dir");
559        let claude_dir = temp_dir.path().join(".claude");
560        fs::create_dir_all(&claude_dir).expect("should create .claude dir");
561
562        let creds_path = claude_dir.join("credentials.json");
563        let mut file = fs::File::create(&creds_path).expect("should create file");
564        file.write_all(b"{ invalid json }")
565            .expect("should write file");
566
567        // Temporarily override home_dir for this test
568        // Since we can't easily mock dirs::home_dir, we'll test the parsing logic directly
569        let malformed = "{ invalid json }";
570        let result: Result<ClaudeCredentials, _> = serde_json::from_str(malformed);
571        assert!(result.is_err());
572    }
573
574    #[test]
575    fn test_from_claude_credentials_missing_access_token() {
576        let malformed = r#"{"other_field": "value"}"#;
577        let result: Result<ClaudeCredentials, _> = serde_json::from_str(malformed);
578        assert!(result.is_err());
579    }
580
581    #[test]
582    fn test_from_claude_credentials_empty_token() {
583        let empty_token = r#"{"access_token": ""}"#;
584        let creds: ClaudeCredentials = serde_json::from_str(empty_token).expect("should parse");
585        assert!(creds.access_token.is_empty());
586    }
587
588    #[test]
589    fn test_auth_method_api_key() {
590        let config = test_config();
591        let client = AiClient::with_api_key(
592            PROVIDER_ANTHROPIC,
593            SecretString::from("test_key"),
594            "test-model",
595            &config,
596        )
597        .expect("should create client");
598        assert_eq!(client.auth_method(), AuthMethod::ApiKey);
599    }
600}