Skip to main content

aptu_core/ai/
client.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Generic AI client for all registered providers.
4//!
5//! Provides a single `AiClient` struct that works with any AI provider
6//! registered in the provider registry. See [`super::registry`] for available providers.
7
8use std::env;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use secrecy::SecretString;
15
16use super::circuit_breaker::CircuitBreaker;
17use super::provider::AiProvider;
18use super::registry::{ProviderConfig, get_provider};
19use crate::config::AiConfig;
20
21/// Generic AI client for all providers.
22///
23/// Holds HTTP client, API key, and model configuration for reuse across multiple requests.
24/// Uses the provider registry to get provider-specific configuration.
25#[derive(Debug)]
26pub struct AiClient {
27    /// Provider configuration from registry.
28    provider: &'static ProviderConfig,
29    /// HTTP client with configured timeout.
30    http: Client,
31    /// API key for provider authentication.
32    api_key: SecretString,
33    /// Model name (e.g., "mistralai/mistral-small-2603").
34    model: String,
35    /// Maximum tokens for API responses.
36    max_tokens: u32,
37    /// Temperature for API requests.
38    temperature: f32,
39    /// Maximum retry attempts for rate-limited requests.
40    max_attempts: u32,
41    /// Circuit breaker for resilience.
42    circuit_breaker: CircuitBreaker,
43    /// Optional custom guidance from config to inject into system prompts.
44    custom_guidance: Option<String>,
45}
46
47impl AiClient {
48    /// Creates a new AI client from configuration.
49    ///
50    /// Validates the model against cost control settings and fetches the API key
51    /// from the environment.
52    ///
53    /// # Arguments
54    ///
55    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
56    /// * `config` - AI configuration with model, timeout, and cost control settings
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if:
61    /// - Provider is not found in registry
62    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
63    /// - API key environment variable is not set
64    /// - HTTP client creation fails
65    pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
66        // Look up provider in registry
67        let provider = get_provider(provider_name)
68            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
69
70        // Validate model against cost control (OpenRouter-specific)
71        if provider_name == "openrouter"
72            && !config.allow_paid_models
73            && !super::is_free_model(&config.model)
74        {
75            anyhow::bail!(
76                "Model '{}' is not in the free tier.\n\
77                 To use paid models, set `allow_paid_models = true` in your config file:\n\
78                 {}\n\n\
79                 Or use a free model like: google/gemma-3-12b-it:free",
80                config.model,
81                crate::config::config_file_path().display()
82            );
83        }
84
85        // Get API key from environment
86        let api_key = env::var(provider.api_key_env).with_context(|| {
87            format!(
88                "Missing {} environment variable.\n\
89                 Set it with: export {}=your_api_key",
90                provider.api_key_env, provider.api_key_env
91            )
92        })?;
93
94        // Create HTTP client with timeout
95        let http = Client::builder()
96            .timeout(Duration::from_secs(config.timeout_seconds))
97            .build()
98            .context("Failed to create HTTP client")?;
99
100        Ok(Self {
101            provider,
102            http,
103            api_key: SecretString::new(api_key.into()),
104            model: config.model.clone(),
105            max_tokens: config.max_tokens,
106            temperature: config.temperature,
107            max_attempts: config.retry_max_attempts,
108            circuit_breaker: CircuitBreaker::new(
109                config.circuit_breaker_threshold,
110                config.circuit_breaker_reset_seconds,
111            ),
112            custom_guidance: config.custom_guidance.clone(),
113        })
114    }
115
116    /// Creates a new AI client with a provided API key and validates the model exists.
117    ///
118    /// This constructor validates that the model exists via the runtime model registry
119    /// before creating the client. It allows callers to provide an API key directly,
120    /// enabling multi-platform credential resolution (e.g., from iOS keychain via FFI).
121    ///
122    /// # Arguments
123    ///
124    /// * `provider_name` - Name of the provider (e.g., "openrouter", "gemini")
125    /// * `api_key` - API key as a `SecretString`
126    /// * `model_name` - Model name to use (e.g., "gemini-3.1-flash-lite-preview")
127    /// * `config` - AI configuration with timeout and cost control settings
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if:
132    /// - Provider is not found in registry
133    /// - Model is not in free tier and `allow_paid_models` is false (for `OpenRouter`)
134    /// - HTTP client creation fails
135    pub fn with_api_key(
136        provider_name: &str,
137        api_key: SecretString,
138        model_name: &str,
139        config: &AiConfig,
140    ) -> Result<Self> {
141        // Look up provider in registry
142        let provider = get_provider(provider_name)
143            .with_context(|| format!("Unknown AI provider: {provider_name}"))?;
144
145        // Validate model against cost control (OpenRouter-specific)
146        if provider_name == "openrouter"
147            && !config.allow_paid_models
148            && !super::is_free_model(model_name)
149        {
150            anyhow::bail!(
151                "Model '{}' is not in the free tier.\n\
152                 To use paid models, set `allow_paid_models = true` in your config file:\n\
153                 {}\n\n\
154                 Or use a free model like: google/gemma-3-12b-it:free",
155                model_name,
156                crate::config::config_file_path().display()
157            );
158        }
159
160        // Create HTTP client with timeout
161        let http = Client::builder()
162            .timeout(Duration::from_secs(config.timeout_seconds))
163            .build()
164            .context("Failed to create HTTP client")?;
165
166        Ok(Self {
167            provider,
168            http,
169            api_key,
170            model: model_name.to_string(),
171            max_tokens: config.max_tokens,
172            temperature: config.temperature,
173            max_attempts: config.retry_max_attempts,
174            circuit_breaker: CircuitBreaker::new(
175                config.circuit_breaker_threshold,
176                config.circuit_breaker_reset_seconds,
177            ),
178            custom_guidance: config.custom_guidance.clone(),
179        })
180    }
181
182    /// Get the circuit breaker for this client.
183    #[must_use]
184    pub fn circuit_breaker(&self) -> &CircuitBreaker {
185        &self.circuit_breaker
186    }
187}
188
189#[async_trait]
190impl AiProvider for AiClient {
191    fn name(&self) -> &str {
192        self.provider.name
193    }
194
195    fn api_url(&self) -> &str {
196        self.provider.api_url
197    }
198
199    fn api_key_env(&self) -> &str {
200        self.provider.api_key_env
201    }
202
203    fn http_client(&self) -> &Client {
204        &self.http
205    }
206
207    fn api_key(&self) -> &SecretString {
208        &self.api_key
209    }
210
211    fn model(&self) -> &str {
212        &self.model
213    }
214
215    fn max_tokens(&self) -> u32 {
216        self.max_tokens
217    }
218
219    fn temperature(&self) -> f32 {
220        self.temperature
221    }
222
223    fn max_attempts(&self) -> u32 {
224        self.max_attempts
225    }
226
227    fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
228        Some(&self.circuit_breaker)
229    }
230
231    fn custom_guidance(&self) -> Option<&str> {
232        self.custom_guidance.as_deref()
233    }
234
235    fn build_headers(&self) -> reqwest::header::HeaderMap {
236        let mut headers = reqwest::header::HeaderMap::new();
237        if let Ok(val) = "application/json".parse() {
238            headers.insert("Content-Type", val);
239        }
240
241        // OpenRouter-specific headers
242        if self.provider.name == "openrouter" {
243            if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
244                headers.insert("HTTP-Referer", val);
245            }
246            if let Ok(val) = "Aptu CLI".parse() {
247                headers.insert("X-Title", val);
248            }
249        }
250
251        headers
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::super::registry::all_providers;
258    use super::*;
259
260    fn test_config() -> AiConfig {
261        AiConfig {
262            provider: "openrouter".to_string(),
263            model: "test-model:free".to_string(),
264            max_tokens: 2048,
265            temperature: 0.3,
266            timeout_seconds: 30,
267            allow_paid_models: false,
268            circuit_breaker_threshold: 3,
269            circuit_breaker_reset_seconds: 60,
270            retry_max_attempts: 3,
271            tasks: None,
272            fallback: None,
273            custom_guidance: None,
274            validation_enabled: true,
275        }
276    }
277
278    #[test]
279    fn test_with_api_key_all_providers() {
280        let config = test_config();
281        for provider_config in all_providers() {
282            let result = AiClient::with_api_key(
283                provider_config.name,
284                SecretString::from("test_key"),
285                "test-model:free",
286                &config,
287            );
288            assert!(
289                result.is_ok(),
290                "Failed for provider: {}",
291                provider_config.name
292            );
293        }
294    }
295
296    #[test]
297    fn test_unknown_provider_error() {
298        let config = test_config();
299        let result = AiClient::with_api_key(
300            "nonexistent",
301            SecretString::from("key"),
302            "test-model",
303            &config,
304        );
305        assert!(result.is_err());
306    }
307
308    #[test]
309    fn test_openrouter_rejects_paid_model() {
310        let mut config = test_config();
311        config.model = "anthropic/claude-3".to_string();
312        config.allow_paid_models = false;
313        let result = AiClient::with_api_key(
314            "openrouter",
315            SecretString::from("key"),
316            "anthropic/claude-3",
317            &config,
318        );
319        assert!(result.is_err());
320    }
321
322    #[test]
323    fn test_max_attempts_from_config() {
324        let mut config = test_config();
325        config.retry_max_attempts = 5;
326        let client = AiClient::with_api_key(
327            "openrouter",
328            SecretString::from("key"),
329            "test-model:free",
330            &config,
331        )
332        .expect("should create client");
333        assert_eq!(client.max_attempts(), 5);
334    }
335}