use crate::error::ProviderError;
use std::sync::Arc;
use std::time::Duration;
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct CompletionConfig {
pub max_tokens: u32,
pub temperature: f64,
}
impl Default for CompletionConfig {
fn default() -> Self {
Self {
max_tokens: 4096,
temperature: 0.0,
}
}
}
#[async_trait::async_trait]
pub trait LlmProvider: Send + Sync {
async fn complete(
&self,
system_prompt: &str,
user_prompt: &str,
config: &CompletionConfig,
) -> Result<String, ProviderError>;
fn name(&self) -> &str;
fn model(&self) -> &str;
}
pub fn resolve_claude_alias(model: &str) -> Result<String, ProviderError> {
match model {
"sonnet" => Ok("claude-sonnet-4-6".to_string()),
"opus" => Ok("claude-opus-4-7".to_string()),
"haiku" => Ok("claude-haiku-4-5-20251001".to_string()),
m if m.contains("claude-") => Ok(m.to_string()),
_ => Err(ProviderError::Auth {
message: format!("unknown model alias: {model}"),
}),
}
}
pub struct RetryProvider {
inner: Arc<dyn LlmProvider>,
pub max_retries: u32,
pub base_delay: Duration,
}
impl RetryProvider {
pub fn new(inner: Arc<dyn LlmProvider>) -> Self {
Self {
inner,
max_retries: 3,
base_delay: Duration::from_secs(1),
}
}
pub fn with_config(
inner: Arc<dyn LlmProvider>,
max_retries: u32,
base_delay: Duration,
) -> Self {
Self {
inner,
max_retries,
base_delay,
}
}
}
fn is_retryable(error: &ProviderError) -> bool {
match error {
ProviderError::Timeout { .. } | ProviderError::Network { .. } => true,
ProviderError::Http { status, .. } => *status == 500 || *status == 429,
_ => false,
}
}
#[async_trait::async_trait]
impl LlmProvider for RetryProvider {
async fn complete(
&self,
system_prompt: &str,
user_prompt: &str,
config: &CompletionConfig,
) -> Result<String, ProviderError> {
let mut last_error = None;
let mut delay = self.base_delay;
for attempt in 0..=self.max_retries {
match self
.inner
.complete(system_prompt, user_prompt, config)
.await
{
Ok(response) => return Ok(response),
Err(err) => {
if !is_retryable(&err) || attempt == self.max_retries {
return Err(err);
}
last_error = Some(err);
tokio::time::sleep(delay).await;
delay = delay.saturating_mul(2);
}
}
}
Err(last_error.expect("at least one attempt must have been made"))
}
fn name(&self) -> &str {
self.inner.name()
}
fn model(&self) -> &str {
self.inner.model()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
struct MockProvider {
provider_name: String,
provider_model: String,
responses: std::sync::Mutex<Vec<Result<String, ProviderError>>>,
call_count: AtomicU32,
}
impl MockProvider {
fn new(name: &str, model: &str) -> Self {
Self {
provider_name: name.to_string(),
provider_model: model.to_string(),
responses: std::sync::Mutex::new(Vec::new()),
call_count: AtomicU32::new(0),
}
}
fn with_responses(
name: &str,
model: &str,
responses: Vec<Result<String, ProviderError>>,
) -> Self {
let mut reversed = responses;
reversed.reverse();
Self {
provider_name: name.to_string(),
provider_model: model.to_string(),
responses: std::sync::Mutex::new(reversed),
call_count: AtomicU32::new(0),
}
}
fn call_count(&self) -> u32 {
self.call_count.load(Ordering::SeqCst)
}
}
#[async_trait::async_trait]
impl LlmProvider for MockProvider {
async fn complete(
&self,
_system_prompt: &str,
_user_prompt: &str,
_config: &CompletionConfig,
) -> Result<String, ProviderError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
let mut responses = self.responses.lock().unwrap();
if let Some(result) = responses.pop() {
result
} else {
Ok("default response".to_string())
}
}
fn name(&self) -> &str {
&self.provider_name
}
fn model(&self) -> &str {
&self.provider_model
}
}
#[test]
fn test_completion_config_default_values() {
let config = CompletionConfig::default();
assert_eq!(config.max_tokens, 4096);
assert!((config.temperature - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_completion_config_is_non_exhaustive() {
let config = CompletionConfig::default();
assert_eq!(config.max_tokens, 4096);
assert!((config.temperature).abs() < f64::EPSILON);
}
#[tokio::test]
async fn test_retry_provider_delegates_name() {
let mock = Arc::new(MockProvider::new("test-provider", "test-model"));
let retry = RetryProvider::new(mock);
assert_eq!(retry.name(), "test-provider");
}
#[tokio::test]
async fn test_retry_provider_delegates_model() {
let mock = Arc::new(MockProvider::new("test-provider", "test-model"));
let retry = RetryProvider::new(mock);
assert_eq!(retry.model(), "test-model");
}
#[tokio::test]
async fn test_retry_provider_retries_on_timeout() {
let mock = Arc::new(MockProvider::with_responses(
"p",
"m",
vec![
Err(ProviderError::Timeout {
message: "t1".into(),
}),
Err(ProviderError::Timeout {
message: "t2".into(),
}),
Ok("success".into()),
],
));
let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
let config = CompletionConfig::default();
let result = retry.complete("sys", "usr", &config).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(mock.call_count(), 3);
}
#[tokio::test]
async fn test_retry_provider_retries_on_http_500() {
let mock = Arc::new(MockProvider::with_responses(
"p",
"m",
vec![
Err(ProviderError::Http {
status: 500,
body: "err".into(),
}),
Ok("ok".into()),
],
));
let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
let config = CompletionConfig::default();
let result = retry.complete("sys", "usr", &config).await;
assert!(result.is_ok());
assert_eq!(mock.call_count(), 2);
}
#[tokio::test]
async fn test_retry_provider_retries_on_http_429() {
let mock = Arc::new(MockProvider::with_responses(
"p",
"m",
vec![
Err(ProviderError::Http {
status: 429,
body: "rate limit".into(),
}),
Ok("ok".into()),
],
));
let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
let config = CompletionConfig::default();
let result = retry.complete("sys", "usr", &config).await;
assert!(result.is_ok());
assert_eq!(mock.call_count(), 2);
}
#[tokio::test]
async fn test_retry_provider_retries_on_network() {
let mock = Arc::new(MockProvider::with_responses(
"p",
"m",
vec![
Err(ProviderError::Network {
message: "dns".into(),
}),
Ok("ok".into()),
],
));
let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
let config = CompletionConfig::default();
let result = retry.complete("sys", "usr", &config).await;
assert!(result.is_ok());
assert_eq!(mock.call_count(), 2);
}
#[tokio::test]
async fn test_retry_provider_does_not_retry_on_auth() {
let mock = Arc::new(MockProvider::with_responses(
"p",
"m",
vec![Err(ProviderError::Auth {
message: "bad key".into(),
})],
));
let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
let config = CompletionConfig::default();
let result = retry.complete("sys", "usr", &config).await;
assert!(result.is_err());
assert_eq!(mock.call_count(), 1);
}
#[tokio::test]
async fn test_retry_provider_does_not_retry_on_process() {
let mock = Arc::new(MockProvider::with_responses(
"p",
"m",
vec![Err(ProviderError::Process {
exit_code: Some(1),
stderr: "fail".into(),
})],
));
let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
let config = CompletionConfig::default();
let result = retry.complete("sys", "usr", &config).await;
assert!(result.is_err());
assert_eq!(mock.call_count(), 1);
}
#[tokio::test]
async fn test_retry_provider_does_not_retry_on_nested_session() {
let mock = Arc::new(MockProvider::with_responses(
"p",
"m",
vec![Err(ProviderError::NestedSession)],
));
let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
let config = CompletionConfig::default();
let result = retry.complete("sys", "usr", &config).await;
assert!(result.is_err());
assert_eq!(mock.call_count(), 1);
}
#[tokio::test]
async fn test_retry_provider_does_not_retry_on_http_4xx() {
let mock = Arc::new(MockProvider::with_responses(
"p",
"m",
vec![Err(ProviderError::Http {
status: 403,
body: "forbidden".into(),
})],
));
let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
let config = CompletionConfig::default();
let result = retry.complete("sys", "usr", &config).await;
assert!(result.is_err());
assert_eq!(mock.call_count(), 1);
}
#[tokio::test]
async fn test_retry_provider_returns_last_error_after_exhausting_retries() {
let mock = Arc::new(MockProvider::with_responses(
"p",
"m",
vec![
Err(ProviderError::Timeout {
message: "t1".into(),
}),
Err(ProviderError::Timeout {
message: "t2".into(),
}),
Err(ProviderError::Timeout {
message: "t3".into(),
}),
],
));
let retry = RetryProvider::with_config(mock.clone(), 2, Duration::from_millis(1));
let config = CompletionConfig::default();
let result = retry.complete("sys", "usr", &config).await;
assert!(result.is_err());
assert_eq!(mock.call_count(), 3);
match result.unwrap_err() {
ProviderError::Timeout { message } => assert_eq!(message, "t3"),
other => panic!("expected Timeout, got: {other}"),
}
}
#[tokio::test]
async fn test_retry_provider_returns_success_on_first_retry() {
let mock = Arc::new(MockProvider::with_responses(
"p",
"m",
vec![
Err(ProviderError::Timeout {
message: "t1".into(),
}),
Ok("recovered".into()),
],
));
let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
let config = CompletionConfig::default();
let result = retry.complete("sys", "usr", &config).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "recovered");
assert_eq!(mock.call_count(), 2);
}
#[test]
fn test_retry_provider_default_config() {
let mock = Arc::new(MockProvider::new("p", "m"));
let retry = RetryProvider::new(mock);
assert_eq!(retry.max_retries, 3);
assert_eq!(retry.base_delay, Duration::from_secs(1));
}
#[test]
fn test_resolve_claude_alias_opus_returns_claude_opus_4_7() {
let result = resolve_claude_alias("opus").unwrap();
assert_eq!(result, "claude-opus-4-7");
}
#[test]
fn test_resolve_claude_alias_sonnet_returns_claude_sonnet_4_6() {
let result = resolve_claude_alias("sonnet").unwrap();
assert_eq!(result, "claude-sonnet-4-6");
}
#[test]
fn test_resolve_claude_alias_haiku_returns_claude_haiku_4_5_20251001() {
let result = resolve_claude_alias("haiku").unwrap();
assert_eq!(result, "claude-haiku-4-5-20251001");
}
#[test]
fn test_resolve_claude_alias_accepts_literal_claude_opus_4_6_passthrough() {
assert_eq!(
resolve_claude_alias("claude-opus-4-6").unwrap(),
"claude-opus-4-6"
);
}
}