Skip to main content

aptu_core/ai/
client.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Generic AI client for all registered providers.
4//!
5//! Provides a single `AiClient` struct that works with any AI provider
6//! registered in the provider registry. See [`super::registry`] for available providers.
7
8use std::env;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use secrecy::SecretString;
15
16use super::circuit_breaker::CircuitBreaker;
17use super::provider::AiProvider;
18use super::registry::{ProviderConfig, get_provider};
19use crate::config::AiConfig;
20
21/// Generic AI client for all providers.
22///
23/// Holds HTTP client, API key, and model configuration for reuse across multiple requests.
24/// Uses the provider registry to get provider-specific configuration.
25#[derive(Debug)]
26pub struct AiClient {
27    /// Provider configuration from registry.
28    provider: &'static ProviderConfig,
29    /// HTTP client with configured timeout.
30    http: Client,
31    /// API key for provider authentication.
32    api_key: SecretString,
33    /// Model name (e.g., "mistralai/devstral-2512:free").
34    model: String,
35    /// Maximum tokens for API responses.
36    max_tokens: u32,
37    /// Temperature for API requests.
38    temperature: f32,
39    /// Maximum retry attempts for rate-limited requests.
40    max_attempts: u32,
41    /// Circuit breaker for resilience.
42    circuit_breaker: CircuitBreaker,
43}
44
45impl AiClient {
46    /// Creates a new AI client from configuration.
47    ///
48    /// Validates the model against cost control settings and fetches the API key
49    /// from the environment.
50    ///
51    /// # Arguments
52    ///
53    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
54    /// * `config` - AI configuration with model, timeout, and cost control settings
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if:
59    /// - Provider is not found in registry
60    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
61    /// - API key environment variable is not set
62    /// - HTTP client creation fails
63    pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
64        // Look up provider in registry
65        let provider = get_provider(provider_name)
66            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
67
68        // Validate model against cost control (OpenRouter-specific)
69        if provider_name == "openrouter"
70            && !config.allow_paid_models
71            && !super::is_free_model(&config.model)
72        {
73            anyhow::bail!(
74                "Model '{}' is not in the free tier.\n\
75                 To use paid models, set `allow_paid_models = true` in your config file:\n\
76                 {}\n\n\
77                 Or use a free model like: mistralai/devstral-2512:free",
78                config.model,
79                crate::config::config_file_path().display()
80            );
81        }
82
83        // Get API key from environment
84        let api_key = env::var(provider.api_key_env).with_context(|| {
85            format!(
86                "Missing {} environment variable.\n\
87                 Set it with: export {}=your_api_key",
88                provider.api_key_env, provider.api_key_env
89            )
90        })?;
91
92        // Create HTTP client with timeout
93        let http = Client::builder()
94            .timeout(Duration::from_secs(config.timeout_seconds))
95            .build()
96            .context("Failed to create HTTP client")?;
97
98        Ok(Self {
99            provider,
100            http,
101            api_key: SecretString::new(api_key.into()),
102            model: config.model.clone(),
103            max_tokens: config.max_tokens,
104            temperature: config.temperature,
105            max_attempts: config.retry_max_attempts,
106            circuit_breaker: CircuitBreaker::new(
107                config.circuit_breaker_threshold,
108                config.circuit_breaker_reset_seconds,
109            ),
110        })
111    }
112
113    /// Creates a new AI client with a provided API key and validates the model exists.
114    ///
115    /// This constructor validates that the model exists via the runtime model registry
116    /// before creating the client. It allows callers to provide an API key directly,
117    /// enabling multi-platform credential resolution (e.g., from iOS keychain via FFI).
118    ///
119    /// # Arguments
120    ///
121    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
122    /// * `api_key` - API key as a `SecretString`
123    /// * `model_name` - Model name to use (e.g., "gemini-3-flash-preview")
124    /// * `config` - AI configuration with timeout and cost control settings
125    ///
126    /// # Errors
127    ///
128    /// Returns an error if:
129    /// - Provider is not found in registry
130    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
131    /// - HTTP client creation fails
132    pub fn with_api_key(
133        provider_name: &str,
134        api_key: SecretString,
135        model_name: &str,
136        config: &AiConfig,
137    ) -> Result<Self> {
138        // Look up provider in registry
139        let provider = get_provider(provider_name)
140            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
141
142        // Validate model against cost control (OpenRouter-specific)
143        if provider_name == "openrouter"
144            && !config.allow_paid_models
145            && !super::is_free_model(model_name)
146        {
147            anyhow::bail!(
148                "Model '{}' is not in the free tier.\n\
149                 To use paid models, set `allow_paid_models = true` in your config file:\n\
150                 {}\n\n\
151                 Or use a free model like: mistralai/devstral-2512:free",
152                model_name,
153                crate::config::config_file_path().display()
154            );
155        }
156
157        // Create HTTP client with timeout
158        let http = Client::builder()
159            .timeout(Duration::from_secs(config.timeout_seconds))
160            .build()
161            .context("Failed to create HTTP client")?;
162
163        Ok(Self {
164            provider,
165            http,
166            api_key,
167            model: model_name.to_string(),
168            max_tokens: config.max_tokens,
169            temperature: config.temperature,
170            max_attempts: config.retry_max_attempts,
171            circuit_breaker: CircuitBreaker::new(
172                config.circuit_breaker_threshold,
173                config.circuit_breaker_reset_seconds,
174            ),
175        })
176    }
177
178    /// Get the circuit breaker for this client.
179    #[must_use]
180    pub fn circuit_breaker(&self) -> &CircuitBreaker {
181        &self.circuit_breaker
182    }
183}
184
185#[async_trait]
186impl AiProvider for AiClient {
187    fn name(&self) -> &str {
188        self.provider.name
189    }
190
191    fn api_url(&self) -> &str {
192        self.provider.api_url
193    }
194
195    fn api_key_env(&self) -> &str {
196        self.provider.api_key_env
197    }
198
199    fn http_client(&self) -> &Client {
200        &self.http
201    }
202
203    fn api_key(&self) -> &SecretString {
204        &self.api_key
205    }
206
207    fn model(&self) -> &str {
208        &self.model
209    }
210
211    fn max_tokens(&self) -> u32 {
212        self.max_tokens
213    }
214
215    fn temperature(&self) -> f32 {
216        self.temperature
217    }
218
219    fn max_attempts(&self) -> u32 {
220        self.max_attempts
221    }
222
223    fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
224        Some(&self.circuit_breaker)
225    }
226
227    fn build_headers(&self) -> reqwest::header::HeaderMap {
228        let mut headers = reqwest::header::HeaderMap::new();
229        if let Ok(val) = "application/json".parse() {
230            headers.insert("Content-Type", val);
231        }
232
233        // OpenRouter-specific headers
234        if self.provider.name == "openrouter" {
235            if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
236                headers.insert("HTTP-Referer", val);
237            }
238            if let Ok(val) = "Aptu CLI".parse() {
239                headers.insert("X-Title", val);
240            }
241        }
242
243        headers
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::super::registry::all_providers;
250    use super::*;
251
252    fn test_config() -> AiConfig {
253        AiConfig {
254            provider: "openrouter".to_string(),
255            model: "test-model:free".to_string(),
256            max_tokens: 2048,
257            temperature: 0.3,
258            timeout_seconds: 30,
259            allow_paid_models: false,
260            circuit_breaker_threshold: 3,
261            circuit_breaker_reset_seconds: 60,
262            retry_max_attempts: 3,
263            tasks: None,
264            fallback: None,
265            custom_guidance: None,
266            validation_enabled: true,
267        }
268    }
269
270    #[test]
271    fn test_with_api_key_all_providers() {
272        let config = test_config();
273        for provider_config in all_providers() {
274            let result = AiClient::with_api_key(
275                provider_config.name,
276                SecretString::from("test_key"),
277                "test-model:free",
278                &config,
279            );
280            assert!(
281                result.is_ok(),
282                "Failed for provider: {}",
283                provider_config.name
284            );
285        }
286    }
287
288    #[test]
289    fn test_unknown_provider_error() {
290        let config = test_config();
291        let result = AiClient::with_api_key(
292            "nonexistent",
293            SecretString::from("key"),
294            "test-model",
295            &config,
296        );
297        assert!(result.is_err());
298    }
299
300    #[test]
301    fn test_openrouter_rejects_paid_model() {
302        let mut config = test_config();
303        config.model = "anthropic/claude-3".to_string();
304        config.allow_paid_models = false;
305        let result = AiClient::with_api_key(
306            "openrouter",
307            SecretString::from("key"),
308            "anthropic/claude-3",
309            &config,
310        );
311        assert!(result.is_err());
312    }
313
314    #[test]
315    fn test_max_attempts_from_config() {
316        let mut config = test_config();
317        config.retry_max_attempts = 5;
318        let client = AiClient::with_api_key(
319            "openrouter",
320            SecretString::from("key"),
321            "test-model:free",
322            &config,
323        )
324        .expect("should create client");
325        assert_eq!(client.max_attempts(), 5);
326    }
327}