Skip to main content

aptu_core/ai/
client.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Generic AI client for all registered providers.
4//!
5//! Provides a single `AiClient` struct that works with any AI provider
6//! registered in the provider registry. See [`super::registry`] for available providers.
7
8use std::env;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use secrecy::SecretString;
15
16use super::circuit_breaker::CircuitBreaker;
17use super::provider::AiProvider;
18use super::registry::{ProviderConfig, get_provider};
19use crate::config::AiConfig;
20
21/// Generic AI client for all providers.
22///
23/// Holds HTTP client, API key, and model configuration for reuse across multiple requests.
24/// Uses the provider registry to get provider-specific configuration.
25#[derive(Debug)]
26pub struct AiClient {
27    /// Provider configuration from registry.
28    provider: &'static ProviderConfig,
29    /// HTTP client with configured timeout.
30    http: Client,
31    /// API key for provider authentication.
32    api_key: SecretString,
33    /// Model name (e.g., "mistralai/mistral-small-2603").
34    model: String,
35    /// Maximum tokens for API responses.
36    max_tokens: u32,
37    /// Temperature for API requests.
38    temperature: f32,
39    /// Maximum retry attempts for rate-limited requests.
40    max_attempts: u32,
41    /// Circuit breaker for resilience.
42    circuit_breaker: CircuitBreaker,
43    /// Optional custom guidance from config to inject into system prompts.
44    custom_guidance: Option<String>,
45}
46
47impl Drop for AiClient {
48    fn drop(&mut self) {
49        use zeroize::Zeroize;
50        // Safety: SecretString wraps String, which implements Zeroize.
51        // Calling zeroize() overwrites the backing buffer before deallocation.
52        self.api_key.zeroize();
53    }
54}
55
56impl AiClient {
57    /// Creates a new AI client from configuration.
58    ///
59    /// Validates the model against cost control settings and fetches the API key
60    /// from the environment.
61    ///
62    /// # Arguments
63    ///
64    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
65    /// * `config` - AI configuration with model, timeout, and cost control settings
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if:
70    /// - Provider is not found in registry
71    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
72    /// - API key environment variable is not set
73    /// - HTTP client creation fails
74    pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
75        // Look up provider in registry
76        let provider = get_provider(provider_name)
77            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
78
79        // Validate model against cost control (OpenRouter-specific)
80        if provider_name == "openrouter"
81            && !config.allow_paid_models
82            && !super::is_free_model(&config.model)
83        {
84            anyhow::bail!(
85                "Model '{}' is not in the free tier.\n\
86                 To use paid models, set `allow_paid_models = true` in your config file:\n\
87                 {}\n\n\
88                 Or use a free model like: google/gemma-3-12b-it:free",
89                config.model,
90                crate::config::config_file_path().display()
91            );
92        }
93
94        // Get API key from environment
95        let api_key = env::var(provider.api_key_env).with_context(|| {
96            format!(
97                "Missing {} environment variable.\n\
98                 Set it with: export {}=your_api_key",
99                provider.api_key_env, provider.api_key_env
100            )
101        })?;
102
103        // Create HTTP client with timeout
104        let http = Client::builder()
105            .timeout(Duration::from_secs(config.timeout_seconds))
106            .build()
107            .context("Failed to create HTTP client")?;
108
109        Ok(Self {
110            provider,
111            http,
112            api_key: SecretString::new(api_key.into()),
113            model: config.model.clone(),
114            max_tokens: config.max_tokens,
115            temperature: config.temperature,
116            max_attempts: config.retry_max_attempts,
117            circuit_breaker: CircuitBreaker::new(
118                config.circuit_breaker_threshold,
119                config.circuit_breaker_reset_seconds,
120            ),
121            custom_guidance: config.custom_guidance.clone(),
122        })
123    }
124
125    /// Creates a new AI client with a provided API key and validates the model exists.
126    ///
127    /// This constructor validates that the model exists via the runtime model registry
128    /// before creating the client. It allows callers to provide an API key directly,
129    /// enabling multi-platform credential resolution (e.g., from iOS keychain via FFI).
130    ///
131    /// # Arguments
132    ///
133    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
134    /// * `api_key` - API key as a `SecretString`
135    /// * `model_name` - Model name to use (e.g., "gemini-3.1-flash-lite-preview")
136    /// * `config` - AI configuration with timeout and cost control settings
137    ///
138    /// # Errors
139    ///
140    /// Returns an error if:
141    /// - Provider is not found in registry
142    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
143    /// - HTTP client creation fails
144    pub fn with_api_key(
145        provider_name: &str,
146        api_key: SecretString,
147        model_name: &str,
148        config: &AiConfig,
149    ) -> Result<Self> {
150        // Look up provider in registry
151        let provider = get_provider(provider_name)
152            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
153
154        // Validate model against cost control (OpenRouter-specific)
155        if provider_name == "openrouter"
156            && !config.allow_paid_models
157            && !super::is_free_model(model_name)
158        {
159            anyhow::bail!(
160                "Model '{}' is not in the free tier.\n\
161                 To use paid models, set `allow_paid_models = true` in your config file:\n\
162                 {}\n\n\
163                 Or use a free model like: google/gemma-3-12b-it:free",
164                model_name,
165                crate::config::config_file_path().display()
166            );
167        }
168
169        // Create HTTP client with timeout
170        let http = Client::builder()
171            .timeout(Duration::from_secs(config.timeout_seconds))
172            .build()
173            .context("Failed to create HTTP client")?;
174
175        Ok(Self {
176            provider,
177            http,
178            api_key,
179            model: model_name.to_string(),
180            max_tokens: config.max_tokens,
181            temperature: config.temperature,
182            max_attempts: config.retry_max_attempts,
183            circuit_breaker: CircuitBreaker::new(
184                config.circuit_breaker_threshold,
185                config.circuit_breaker_reset_seconds,
186            ),
187            custom_guidance: config.custom_guidance.clone(),
188        })
189    }
190
191    /// Get the circuit breaker for this client.
192    #[must_use]
193    pub fn circuit_breaker(&self) -> &CircuitBreaker {
194        &self.circuit_breaker
195    }
196}
197
198#[async_trait]
199impl AiProvider for AiClient {
200    fn name(&self) -> &str {
201        self.provider.name
202    }
203
204    fn api_url(&self) -> &str {
205        self.provider.api_url
206    }
207
208    fn api_key_env(&self) -> &str {
209        self.provider.api_key_env
210    }
211
212    fn http_client(&self) -> &Client {
213        &self.http
214    }
215
216    fn api_key(&self) -> &SecretString {
217        &self.api_key
218    }
219
220    fn model(&self) -> &str {
221        &self.model
222    }
223
224    fn max_tokens(&self) -> u32 {
225        self.max_tokens
226    }
227
228    fn temperature(&self) -> f32 {
229        self.temperature
230    }
231
232    fn max_attempts(&self) -> u32 {
233        self.max_attempts
234    }
235
236    fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
237        Some(&self.circuit_breaker)
238    }
239
240    fn custom_guidance(&self) -> Option<&str> {
241        self.custom_guidance.as_deref()
242    }
243
244    fn build_headers(&self) -> reqwest::header::HeaderMap {
245        let mut headers = reqwest::header::HeaderMap::new();
246        if let Ok(val) = "application/json".parse() {
247            headers.insert("Content-Type", val);
248        }
249
250        // OpenRouter-specific headers
251        if self.provider.name == "openrouter" {
252            if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
253                headers.insert("HTTP-Referer", val);
254            }
255            if let Ok(val) = "Aptu CLI".parse() {
256                headers.insert("X-Title", val);
257            }
258        }
259
260        headers
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::super::registry::all_providers;
267    use super::*;
268
269    fn test_config() -> AiConfig {
270        AiConfig {
271            provider: "openrouter".to_string(),
272            model: "test-model:free".to_string(),
273            max_tokens: 2048,
274            temperature: 0.3,
275            timeout_seconds: 30,
276            allow_paid_models: false,
277            circuit_breaker_threshold: 3,
278            circuit_breaker_reset_seconds: 60,
279            retry_max_attempts: 3,
280            tasks: None,
281            fallback: None,
282            custom_guidance: None,
283            validation_enabled: true,
284        }
285    }
286
287    #[test]
288    fn test_with_api_key_all_providers() {
289        let config = test_config();
290        for provider_config in all_providers() {
291            let result = AiClient::with_api_key(
292                provider_config.name,
293                SecretString::from("test_key"),
294                "test-model:free",
295                &config,
296            );
297            assert!(
298                result.is_ok(),
299                "Failed for provider: {}",
300                provider_config.name
301            );
302        }
303    }
304
305    #[test]
306    fn test_unknown_provider_error() {
307        let config = test_config();
308        let result = AiClient::with_api_key(
309            "nonexistent",
310            SecretString::from("key"),
311            "test-model",
312            &config,
313        );
314        assert!(result.is_err());
315    }
316
317    #[test]
318    fn test_openrouter_rejects_paid_model() {
319        let mut config = test_config();
320        config.model = "anthropic/claude-3".to_string();
321        config.allow_paid_models = false;
322        let result = AiClient::with_api_key(
323            "openrouter",
324            SecretString::from("key"),
325            "anthropic/claude-3",
326            &config,
327        );
328        assert!(result.is_err());
329    }
330
331    #[test]
332    fn test_max_attempts_from_config() {
333        let mut config = test_config();
334        config.retry_max_attempts = 5;
335        let client = AiClient::with_api_key(
336            "openrouter",
337            SecretString::from("key"),
338            "test-model:free",
339            &config,
340        )
341        .expect("should create client");
342        assert_eq!(client.max_attempts(), 5);
343    }
344}