aptu_core/ai/
client.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Generic AI client for all registered providers.
4//!
5//! Provides a single `AiClient` struct that works with any AI provider
6//! registered in the provider registry. See [`super::registry`] for available providers.
7
8use std::env;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use secrecy::SecretString;
15
16use super::circuit_breaker::CircuitBreaker;
17use super::provider::AiProvider;
18use super::registry::{ProviderConfig, get_provider};
19use crate::config::AiConfig;
20
21/// Generic AI client for all providers.
22///
23/// Holds HTTP client, API key, and model configuration for reuse across multiple requests.
24/// Uses the provider registry to get provider-specific configuration.
25#[derive(Debug)]
26pub struct AiClient {
27    /// Provider configuration from registry.
28    provider: &'static ProviderConfig,
29    /// HTTP client with configured timeout.
30    http: Client,
31    /// API key for provider authentication.
32    api_key: SecretString,
33    /// Model name (e.g., "mistralai/devstral-2512:free").
34    model: String,
35    /// Maximum tokens for API responses.
36    max_tokens: u32,
37    /// Temperature for API requests.
38    temperature: f32,
39    /// Circuit breaker for resilience.
40    circuit_breaker: CircuitBreaker,
41}
42
43impl AiClient {
44    /// Creates a new AI client from configuration.
45    ///
46    /// Validates the model against cost control settings and fetches the API key
47    /// from the environment.
48    ///
49    /// # Arguments
50    ///
51    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
52    /// * `config` - AI configuration with model, timeout, and cost control settings
53    ///
54    /// # Errors
55    ///
56    /// Returns an error if:
57    /// - Provider is not found in registry
58    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
59    /// - API key environment variable is not set
60    /// - HTTP client creation fails
61    pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
62        // Look up provider in registry
63        let provider = get_provider(provider_name)
64            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
65
66        // Validate model against cost control (OpenRouter-specific)
67        if provider_name == "openrouter"
68            && !config.allow_paid_models
69            && !super::is_free_model(&config.model)
70        {
71            anyhow::bail!(
72                "Model '{}' is not in the free tier.\n\
73                 To use paid models, set `allow_paid_models = true` in your config file:\n\
74                 {}\n\n\
75                 Or use a free model like: mistralai/devstral-2512:free",
76                config.model,
77                crate::config::config_file_path().display()
78            );
79        }
80
81        // Get API key from environment
82        let api_key = env::var(provider.api_key_env).with_context(|| {
83            format!(
84                "Missing {} environment variable.\n\
85                 Set it with: export {}=your_api_key",
86                provider.api_key_env, provider.api_key_env
87            )
88        })?;
89
90        // Create HTTP client with timeout
91        let http = Client::builder()
92            .timeout(Duration::from_secs(config.timeout_seconds))
93            .build()
94            .context("Failed to create HTTP client")?;
95
96        Ok(Self {
97            provider,
98            http,
99            api_key: SecretString::new(api_key.into()),
100            model: config.model.clone(),
101            max_tokens: config.max_tokens,
102            temperature: config.temperature,
103            circuit_breaker: CircuitBreaker::new(
104                config.circuit_breaker_threshold,
105                config.circuit_breaker_reset_seconds,
106            ),
107        })
108    }
109
110    /// Creates a new AI client with a provided API key and validates the model exists.
111    ///
112    /// This constructor validates that the model exists via the runtime model registry
113    /// before creating the client. It allows callers to provide an API key directly,
114    /// enabling multi-platform credential resolution (e.g., from iOS keychain via FFI).
115    ///
116    /// # Arguments
117    ///
118    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
119    /// * `api_key` - API key as a `SecretString`
120    /// * `model_name` - Model name to use (e.g., "gemini-3-flash-preview")
121    /// * `config` - AI configuration with timeout and cost control settings
122    ///
123    /// # Errors
124    ///
125    /// Returns an error if:
126    /// - Provider is not found in registry
127    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
128    /// - HTTP client creation fails
129    pub fn with_api_key(
130        provider_name: &str,
131        api_key: SecretString,
132        model_name: &str,
133        config: &AiConfig,
134    ) -> Result<Self> {
135        // Look up provider in registry
136        let provider = get_provider(provider_name)
137            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
138
139        // Validate model against cost control (OpenRouter-specific)
140        if provider_name == "openrouter"
141            && !config.allow_paid_models
142            && !super::is_free_model(model_name)
143        {
144            anyhow::bail!(
145                "Model '{}' is not in the free tier.\n\
146                 To use paid models, set `allow_paid_models = true` in your config file:\n\
147                 {}\n\n\
148                 Or use a free model like: mistralai/devstral-2512:free",
149                model_name,
150                crate::config::config_file_path().display()
151            );
152        }
153
154        // Create HTTP client with timeout
155        let http = Client::builder()
156            .timeout(Duration::from_secs(config.timeout_seconds))
157            .build()
158            .context("Failed to create HTTP client")?;
159
160        Ok(Self {
161            provider,
162            http,
163            api_key,
164            model: model_name.to_string(),
165            max_tokens: config.max_tokens,
166            temperature: config.temperature,
167            circuit_breaker: CircuitBreaker::new(
168                config.circuit_breaker_threshold,
169                config.circuit_breaker_reset_seconds,
170            ),
171        })
172    }
173
174    /// Get the circuit breaker for this client.
175    #[must_use]
176    pub fn circuit_breaker(&self) -> &CircuitBreaker {
177        &self.circuit_breaker
178    }
179}
180
181#[async_trait]
182impl AiProvider for AiClient {
183    fn name(&self) -> &str {
184        self.provider.name
185    }
186
187    fn api_url(&self) -> &str {
188        self.provider.api_url
189    }
190
191    fn api_key_env(&self) -> &str {
192        self.provider.api_key_env
193    }
194
195    fn http_client(&self) -> &Client {
196        &self.http
197    }
198
199    fn api_key(&self) -> &SecretString {
200        &self.api_key
201    }
202
203    fn model(&self) -> &str {
204        &self.model
205    }
206
207    fn max_tokens(&self) -> u32 {
208        self.max_tokens
209    }
210
211    fn temperature(&self) -> f32 {
212        self.temperature
213    }
214
215    fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
216        Some(&self.circuit_breaker)
217    }
218
219    fn build_headers(&self) -> reqwest::header::HeaderMap {
220        let mut headers = reqwest::header::HeaderMap::new();
221        if let Ok(val) = "application/json".parse() {
222            headers.insert("Content-Type", val);
223        }
224
225        // OpenRouter-specific headers
226        if self.provider.name == "openrouter" {
227            if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
228                headers.insert("HTTP-Referer", val);
229            }
230            if let Ok(val) = "Aptu CLI".parse() {
231                headers.insert("X-Title", val);
232            }
233        }
234
235        headers
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::super::registry::all_providers;
242    use super::*;
243
244    fn test_config() -> AiConfig {
245        AiConfig {
246            provider: "openrouter".to_string(),
247            model: "test-model:free".to_string(),
248            max_tokens: 2048,
249            temperature: 0.3,
250            timeout_seconds: 30,
251            allow_paid_models: false,
252            circuit_breaker_threshold: 3,
253            circuit_breaker_reset_seconds: 60,
254            tasks: None,
255            fallback: None,
256            custom_guidance: None,
257            validation_enabled: true,
258        }
259    }
260
261    #[test]
262    fn test_with_api_key_all_providers() {
263        let config = test_config();
264        for provider_config in all_providers() {
265            let result = AiClient::with_api_key(
266                provider_config.name,
267                SecretString::from("test_key"),
268                "test-model:free",
269                &config,
270            );
271            assert!(
272                result.is_ok(),
273                "Failed for provider: {}",
274                provider_config.name
275            );
276        }
277    }
278
279    #[test]
280    fn test_unknown_provider_error() {
281        let config = test_config();
282        let result = AiClient::with_api_key(
283            "nonexistent",
284            SecretString::from("key"),
285            "test-model",
286            &config,
287        );
288        assert!(result.is_err());
289    }
290
291    #[test]
292    fn test_openrouter_rejects_paid_model() {
293        let mut config = test_config();
294        config.model = "anthropic/claude-3".to_string();
295        config.allow_paid_models = false;
296        let result = AiClient::with_api_key(
297            "openrouter",
298            SecretString::from("key"),
299            "anthropic/claude-3",
300            &config,
301        );
302        assert!(result.is_err());
303    }
304}