use std::env;
use std::time::Duration;
use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::Client;
use secrecy::SecretString;
use super::circuit_breaker::CircuitBreaker;
use super::provider::AiProvider;
use super::registry::{ProviderConfig, get_provider};
use crate::config::AiConfig;
#[derive(Debug)]
pub struct AiClient {
provider: &'static ProviderConfig,
http: Client,
api_key: SecretString,
model: String,
max_tokens: u32,
temperature: f32,
max_attempts: u32,
circuit_breaker: CircuitBreaker,
custom_guidance: Option<String>,
}
impl AiClient {
pub fn new(provider_name: &str, config: &AiConfig) -> Result<Self> {
let provider = get_provider(provider_name)
.with_context(|| format!("Unknown AI provider: {provider_name}"))?;
if provider_name == "openrouter"
&& !config.allow_paid_models
&& !super::is_free_model(&config.model)
{
anyhow::bail!(
"Model '{}' is not in the free tier.\n\
To use paid models, set `allow_paid_models = true` in your config file:\n\
{}\n\n\
Or use a free model like: google/gemma-3-12b-it:free",
config.model,
crate::config::config_file_path().display()
);
}
let api_key = env::var(provider.api_key_env).with_context(|| {
format!(
"Missing {} environment variable.\n\
Set it with: export {}=your_api_key",
provider.api_key_env, provider.api_key_env
)
})?;
let http = Client::builder()
.timeout(Duration::from_secs(config.timeout_seconds))
.build()
.context("Failed to create HTTP client")?;
Ok(Self {
provider,
http,
api_key: SecretString::new(api_key.into()),
model: config.model.clone(),
max_tokens: config.max_tokens,
temperature: config.temperature,
max_attempts: config.retry_max_attempts,
circuit_breaker: CircuitBreaker::new(
config.circuit_breaker_threshold,
config.circuit_breaker_reset_seconds,
),
custom_guidance: config.custom_guidance.clone(),
})
}
pub fn with_api_key(
provider_name: &str,
api_key: SecretString,
model_name: &str,
config: &AiConfig,
) -> Result<Self> {
let provider = get_provider(provider_name)
.with_context(|| format!("Unknown AI provider: {provider_name}"))?;
if provider_name == "openrouter"
&& !config.allow_paid_models
&& !super::is_free_model(model_name)
{
anyhow::bail!(
"Model '{}' is not in the free tier.\n\
To use paid models, set `allow_paid_models = true` in your config file:\n\
{}\n\n\
Or use a free model like: google/gemma-3-12b-it:free",
model_name,
crate::config::config_file_path().display()
);
}
let http = Client::builder()
.timeout(Duration::from_secs(config.timeout_seconds))
.build()
.context("Failed to create HTTP client")?;
Ok(Self {
provider,
http,
api_key,
model: model_name.to_string(),
max_tokens: config.max_tokens,
temperature: config.temperature,
max_attempts: config.retry_max_attempts,
circuit_breaker: CircuitBreaker::new(
config.circuit_breaker_threshold,
config.circuit_breaker_reset_seconds,
),
custom_guidance: config.custom_guidance.clone(),
})
}
#[must_use]
pub fn circuit_breaker(&self) -> &CircuitBreaker {
&self.circuit_breaker
}
}
#[async_trait]
impl AiProvider for AiClient {
fn name(&self) -> &str {
self.provider.name
}
fn api_url(&self) -> &str {
self.provider.api_url
}
fn api_key_env(&self) -> &str {
self.provider.api_key_env
}
fn http_client(&self) -> &Client {
&self.http
}
fn api_key(&self) -> &SecretString {
&self.api_key
}
fn model(&self) -> &str {
&self.model
}
fn max_tokens(&self) -> u32 {
self.max_tokens
}
fn temperature(&self) -> f32 {
self.temperature
}
fn max_attempts(&self) -> u32 {
self.max_attempts
}
fn circuit_breaker(&self) -> Option<&super::CircuitBreaker> {
Some(&self.circuit_breaker)
}
fn custom_guidance(&self) -> Option<&str> {
self.custom_guidance.as_deref()
}
fn build_headers(&self) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
if let Ok(val) = "application/json".parse() {
headers.insert("Content-Type", val);
}
if self.provider.name == "openrouter" {
if let Ok(val) = "https://github.com/clouatre-labs/aptu".parse() {
headers.insert("HTTP-Referer", val);
}
if let Ok(val) = "Aptu CLI".parse() {
headers.insert("X-Title", val);
}
}
headers
}
}
#[cfg(test)]
mod tests {
use super::super::registry::all_providers;
use super::*;
fn test_config() -> AiConfig {
AiConfig {
provider: "openrouter".to_string(),
model: "test-model:free".to_string(),
max_tokens: 2048,
temperature: 0.3,
timeout_seconds: 30,
allow_paid_models: false,
circuit_breaker_threshold: 3,
circuit_breaker_reset_seconds: 60,
retry_max_attempts: 3,
tasks: None,
fallback: None,
custom_guidance: None,
validation_enabled: true,
}
}
#[test]
fn test_with_api_key_all_providers() {
let config = test_config();
for provider_config in all_providers() {
let result = AiClient::with_api_key(
provider_config.name,
SecretString::from("test_key"),
"test-model:free",
&config,
);
assert!(
result.is_ok(),
"Failed for provider: {}",
provider_config.name
);
}
}
#[test]
fn test_unknown_provider_error() {
let config = test_config();
let result = AiClient::with_api_key(
"nonexistent",
SecretString::from("key"),
"test-model",
&config,
);
assert!(result.is_err());
}
#[test]
fn test_openrouter_rejects_paid_model() {
let mut config = test_config();
config.model = "anthropic/claude-3".to_string();
config.allow_paid_models = false;
let result = AiClient::with_api_key(
"openrouter",
SecretString::from("key"),
"anthropic/claude-3",
&config,
);
assert!(result.is_err());
}
#[test]
fn test_max_attempts_from_config() {
let mut config = test_config();
config.retry_max_attempts = 5;
let client = AiClient::with_api_key(
"openrouter",
SecretString::from("key"),
"test-model:free",
&config,
)
.expect("should create client");
assert_eq!(client.max_attempts(), 5);
}
}