use crate::config::RetryPolicy;
use crate::models::{MessageRequest, MessageResponse, StreamEvent};
use anyhow::Result;
use std::future::Future;
use std::pin::Pin;
use std::time::{Duration, Instant};
use uuid::Uuid;
#[cfg(test)]
pub mod mock;
pub type StreamEventBox =
Pin<Box<dyn futures_util::Stream<Item = Result<StreamEvent>> + Send + 'static>>;
#[allow(async_fn_in_trait, dead_code)] pub trait LlmClient: Send + Sync {
fn provider_name(&self) -> &'static str;
fn model(&self) -> &str;
fn create_message(
&self,
request: MessageRequest,
) -> impl Future<Output = Result<MessageResponse>> + Send;
async fn create_message_stream(&self, request: MessageRequest) -> Result<StreamEventBox>;
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
}
#[allow(dead_code)] pub trait RetryConfigurable {
fn retry_config(&self) -> &RetryConfig;
fn set_retry_config(&mut self, config: RetryConfig);
}
#[derive(Debug)]
pub enum LlmError {
RateLimited {
message: String,
retry_after: Option<Duration>,
},
ServerError { status: u16, message: String },
NetworkError(String),
Timeout(Duration),
AuthenticationError(String),
InvalidRequest { status: u16, message: String },
ModelError(String),
ContentPolicyError(String),
ParseError(String),
ContextLengthError(String),
Other(String),
}
impl std::fmt::Display for LlmError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LlmError::RateLimited { message, .. } => write!(f, "Rate limit exceeded: {message}"),
LlmError::ServerError { status, message } => {
write!(f, "Server error ({status}): {message}")
}
LlmError::NetworkError(msg) => write!(f, "Network error: {msg}"),
LlmError::Timeout(d) => write!(f, "Request timed out after {d:?}"),
LlmError::AuthenticationError(msg) => write!(f, "Authentication failed: {msg}"),
LlmError::InvalidRequest { status, message } => {
write!(f, "Invalid request ({status}): {message}")
}
LlmError::ModelError(msg) => write!(f, "Model error: {msg}"),
LlmError::ContentPolicyError(msg) => write!(f, "Content policy violation: {msg}"),
LlmError::ParseError(msg) => write!(f, "Response parsing error: {msg}"),
LlmError::ContextLengthError(msg) => write!(f, "Context length exceeded: {msg}"),
LlmError::Other(msg) => write!(f, "LLM error: {msg}"),
}
}
}
impl std::error::Error for LlmError {}
impl LlmError {
pub fn is_retryable(&self) -> bool {
matches!(
self,
LlmError::RateLimited { .. }
| LlmError::ServerError { .. }
| LlmError::NetworkError(_)
| LlmError::Timeout(_)
)
}
pub fn suggested_retry_delay(&self) -> Option<Duration> {
match self {
LlmError::RateLimited { retry_after, .. } => *retry_after,
_ => None,
}
}
pub fn from_http_response(status: u16, body: &str) -> Self {
match status {
429 => LlmError::RateLimited {
message: body.to_string(),
retry_after: None,
},
401 | 403 => LlmError::AuthenticationError(body.to_string()),
400 => {
let body_lower = body.to_lowercase();
if body_lower.contains("insufficientquota")
|| body_lower.contains("insufficient_quota")
|| body_lower.contains("exceeded your current quota")
|| body_lower.contains("quota exceeded")
{
LlmError::RateLimited {
message: body.to_string(),
retry_after: None,
}
} else if body_lower.contains("context_length")
|| body_lower.contains("token")
|| body_lower.contains("too long")
|| body_lower.contains("maximum")
{
LlmError::ContextLengthError(body.to_string())
} else if body_lower.contains("content_policy")
|| body_lower.contains("safety")
|| body_lower.contains("harmful")
|| body_lower.contains("inappropriate")
{
LlmError::ContentPolicyError(body.to_string())
} else if body_lower.contains("model") && body_lower.contains("not found") {
LlmError::ModelError(body.to_string())
} else {
LlmError::InvalidRequest {
status,
message: body.to_string(),
}
}
}
404 => {
if body.to_lowercase().contains("model") {
LlmError::ModelError(body.to_string())
} else {
LlmError::InvalidRequest {
status,
message: body.to_string(),
}
}
}
500..=599 => LlmError::ServerError {
status,
message: body.to_string(),
},
_ => LlmError::Other(format!("HTTP {status}: {body}")),
}
}
pub fn from_http_response_with_retry_after(
status: u16,
body: &str,
retry_after: Option<Duration>,
) -> Self {
let mut error = Self::from_http_response(status, body);
if let LlmError::RateLimited {
retry_after: ref mut ra,
..
} = error
{
*ra = retry_after;
}
error
}
pub fn from_reqwest(err: &reqwest::Error) -> Self {
if err.is_timeout() {
LlmError::Timeout(Duration::from_secs(0))
} else if err.is_connect() {
LlmError::NetworkError(format!("Connection failed: {err}"))
} else if err.is_request() {
LlmError::NetworkError(format!("Request failed: {err}"))
} else {
LlmError::Other(err.to_string())
}
}
}
impl From<reqwest::Error> for LlmError {
fn from(err: reqwest::Error) -> Self {
LlmError::from_reqwest(&err)
}
}
impl From<serde_json::Error> for LlmError {
fn from(err: serde_json::Error) -> Self {
LlmError::ParseError(err.to_string())
}
}
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub enabled: bool,
pub max_retries: u32,
pub initial_delay: f64,
pub max_delay: f64,
pub exponential_base: f64,
pub jitter: bool,
pub jitter_factor: f64,
pub respect_retry_after: bool,
#[allow(dead_code)] pub retryable_status_codes: Vec<u16>,
#[allow(dead_code)] pub request_timeout: f64,
pub total_timeout: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
enabled: true,
max_retries: 3,
initial_delay: 1.0,
max_delay: 60.0,
exponential_base: 2.0,
jitter: true,
jitter_factor: 0.1,
respect_retry_after: true,
retryable_status_codes: vec![429, 500, 502, 503, 504],
request_timeout: 120.0,
total_timeout: 0.0, }
}
}
#[allow(dead_code)] impl RetryConfig {
pub fn new() -> Self {
Self::default()
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_initial_delay(mut self, delay: f64) -> Self {
self.initial_delay = delay;
self
}
pub fn with_max_delay(mut self, delay: f64) -> Self {
self.max_delay = delay;
self
}
pub fn with_jitter(mut self, enabled: bool) -> Self {
self.jitter = enabled;
self
}
pub fn with_request_timeout(mut self, timeout: f64) -> Self {
self.request_timeout = timeout;
self
}
pub fn with_total_timeout(mut self, timeout: f64) -> Self {
self.total_timeout = timeout;
self
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
let exponent = i32::try_from(attempt).unwrap_or(i32::MAX);
let base_delay = self.initial_delay * self.exponential_base.powi(exponent);
let capped_delay = base_delay.min(self.max_delay);
let final_delay = if self.jitter {
let jitter_range = capped_delay * self.jitter_factor;
let bytes = *Uuid::new_v4().as_bytes();
let sample = u16::from_le_bytes([bytes[0], bytes[1]]);
let random_factor = f64::from(sample) / f64::from(u16::MAX); let jitter = jitter_range * (2.0 * random_factor - 1.0);
(capped_delay + jitter).max(0.0)
} else {
capped_delay
};
Duration::from_secs_f64(final_delay)
}
pub fn is_retryable_status(&self, status: u16) -> bool {
self.retryable_status_codes.contains(&status)
}
}
impl From<RetryPolicy> for RetryConfig {
fn from(policy: RetryPolicy) -> Self {
Self {
enabled: policy.enabled,
max_retries: policy.max_retries,
initial_delay: policy.initial_delay,
max_delay: policy.max_delay,
exponential_base: policy.exponential_base,
..Default::default()
}
}
}
impl From<RetryConfig> for RetryPolicy {
fn from(config: RetryConfig) -> Self {
Self {
enabled: config.enabled,
max_retries: config.max_retries,
initial_delay: config.initial_delay,
max_delay: config.max_delay,
exponential_base: config.exponential_base,
}
}
}
#[derive(Debug)]
pub struct RetryError {
pub last_error: LlmError,
pub attempts: u32,
pub total_time: Duration,
}
impl std::fmt::Display for RetryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Retry exhausted after {} attempts ({:?}): {}",
self.attempts, self.total_time, self.last_error
)
}
}
impl std::error::Error for RetryError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.last_error)
}
}
pub type RetryResult<T> = Result<T, RetryError>;
pub type RetryCallback = Box<dyn Fn(&LlmError, u32, Duration) + Send + Sync>;
pub async fn with_retry<F, Fut, T>(
config: &RetryConfig,
mut operation: F,
callback: Option<RetryCallback>,
) -> RetryResult<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, LlmError>>,
{
if !config.enabled {
return operation().await.map_err(|e| RetryError {
last_error: e,
attempts: 1,
total_time: Duration::ZERO,
});
}
let start_time = Instant::now();
let total_timeout = if config.total_timeout > 0.0 {
Some(Duration::from_secs_f64(config.total_timeout))
} else {
None
};
let mut last_error: Option<LlmError> = None;
for attempt in 0..=config.max_retries {
if let Some(timeout) = total_timeout
&& start_time.elapsed() >= timeout
{
return Err(RetryError {
last_error: last_error.unwrap_or(LlmError::Timeout(timeout)),
attempts: attempt,
total_time: start_time.elapsed(),
});
}
match operation().await {
Ok(result) => return Ok(result),
Err(err) => {
if !err.is_retryable() {
return Err(RetryError {
last_error: err,
attempts: attempt + 1,
total_time: start_time.elapsed(),
});
}
if attempt >= config.max_retries {
return Err(RetryError {
last_error: err,
attempts: attempt + 1,
total_time: start_time.elapsed(),
});
}
let base_delay = config.delay_for_attempt(attempt);
let delay = if config.respect_retry_after {
err.suggested_retry_delay().unwrap_or(base_delay)
} else {
base_delay
};
if let Some(ref cb) = callback {
cb(&err, attempt, delay);
}
last_error = Some(err);
tokio::time::sleep(delay).await;
}
}
}
Err(RetryError {
last_error: last_error.unwrap_or(LlmError::Other("Unknown retry error".to_string())),
attempts: config.max_retries + 1,
total_time: start_time.elapsed(),
})
}
#[allow(dead_code)] pub async fn with_retry_simple<F, Fut, T>(config: &RetryConfig, operation: F) -> RetryResult<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, LlmError>>,
{
with_retry(config, operation, None).await
}
pub fn parse_retry_after(value: &str) -> Option<Duration> {
if let Ok(seconds) = value.parse::<u64>() {
return Some(Duration::from_secs(seconds));
}
if let Ok(seconds) = value.parse::<f64>() {
return Some(Duration::from_secs_f64(seconds));
}
None
}
pub fn extract_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
headers
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(parse_retry_after)
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_f64_eq(actual: f64, expected: f64) {
assert!(
(actual - expected).abs() < f64::EPSILON,
"expected {expected}, got {actual}"
);
}
#[test]
fn test_retry_config_defaults() {
let config = RetryConfig::default();
assert!(config.enabled);
assert_eq!(config.max_retries, 3);
assert_f64_eq(config.initial_delay, 1.0);
assert_f64_eq(config.max_delay, 60.0);
assert_f64_eq(config.exponential_base, 2.0);
assert!(config.jitter);
}
#[test]
fn test_retry_config_disabled() {
let config = RetryConfig::disabled();
assert!(!config.enabled);
}
#[test]
fn test_retry_config_builder() {
let config = RetryConfig::new()
.with_max_retries(5)
.with_initial_delay(2.0)
.with_max_delay(120.0)
.with_jitter(false);
assert_eq!(config.max_retries, 5);
assert_f64_eq(config.initial_delay, 2.0);
assert_f64_eq(config.max_delay, 120.0);
assert!(!config.jitter);
}
#[test]
fn test_delay_for_attempt_exponential() {
let config = RetryConfig::new().with_jitter(false);
let d0 = config.delay_for_attempt(0);
assert_eq!(d0, Duration::from_secs_f64(1.0));
let d1 = config.delay_for_attempt(1);
assert_eq!(d1, Duration::from_secs_f64(2.0));
let d2 = config.delay_for_attempt(2);
assert_eq!(d2, Duration::from_secs_f64(4.0));
let d3 = config.delay_for_attempt(3);
assert_eq!(d3, Duration::from_secs_f64(8.0));
}
#[test]
fn test_delay_for_attempt_capped() {
let config = RetryConfig::new().with_jitter(false).with_max_delay(5.0);
let d3 = config.delay_for_attempt(3);
assert_eq!(d3, Duration::from_secs_f64(5.0));
}
#[test]
fn test_delay_for_attempt_with_jitter() {
let config = RetryConfig::new().with_jitter(true);
let d1 = config.delay_for_attempt(1);
let d2 = config.delay_for_attempt(1);
let base = 2.0;
let range = base * 0.1;
assert!(d1.as_secs_f64() >= base - range);
assert!(d1.as_secs_f64() <= base + range);
assert!(d2.as_secs_f64() >= base - range);
assert!(d2.as_secs_f64() <= base + range);
}
#[test]
fn test_is_retryable_status() {
let config = RetryConfig::default();
assert!(config.is_retryable_status(429)); assert!(config.is_retryable_status(500)); assert!(config.is_retryable_status(502)); assert!(config.is_retryable_status(503)); assert!(config.is_retryable_status(504));
assert!(!config.is_retryable_status(400)); assert!(!config.is_retryable_status(401)); assert!(!config.is_retryable_status(403)); assert!(!config.is_retryable_status(404)); }
#[test]
fn test_llm_error_retryable() {
assert!(
LlmError::RateLimited {
message: "too many requests".to_string(),
retry_after: None
}
.is_retryable()
);
assert!(
LlmError::ServerError {
status: 500,
message: "internal error".to_string()
}
.is_retryable()
);
assert!(LlmError::NetworkError("connection refused".to_string()).is_retryable());
assert!(LlmError::Timeout(Duration::from_secs(30)).is_retryable());
assert!(!LlmError::AuthenticationError("invalid key".to_string()).is_retryable());
assert!(
!LlmError::InvalidRequest {
status: 400,
message: "bad json".to_string()
}
.is_retryable()
);
assert!(!LlmError::ContentPolicyError("unsafe content".to_string()).is_retryable());
assert!(!LlmError::ContextLengthError("too long".to_string()).is_retryable());
}
#[test]
fn test_llm_error_from_http_response() {
let err = LlmError::from_http_response(429, "rate limit exceeded");
assert!(matches!(err, LlmError::RateLimited { .. }));
let err = LlmError::from_http_response(401, "invalid api key");
assert!(matches!(err, LlmError::AuthenticationError(_)));
let err = LlmError::from_http_response(403, "forbidden");
assert!(matches!(err, LlmError::AuthenticationError(_)));
let err = LlmError::from_http_response(500, "internal server error");
assert!(matches!(err, LlmError::ServerError { status: 500, .. }));
let err = LlmError::from_http_response(503, "service unavailable");
assert!(matches!(err, LlmError::ServerError { status: 503, .. }));
let err = LlmError::from_http_response(400, "context_length_exceeded");
assert!(matches!(err, LlmError::ContextLengthError(_)));
let err = LlmError::from_http_response(
400,
r#"{"error":{"code":"insufficientquota","message":"You exceeded your current quota"}}"#,
);
assert!(matches!(err, LlmError::RateLimited { .. }));
assert!(err.is_retryable());
let err = LlmError::from_http_response(400, "content_policy_violation");
assert!(matches!(err, LlmError::ContentPolicyError(_)));
let err = LlmError::from_http_response(400, "invalid json");
assert!(matches!(err, LlmError::InvalidRequest { status: 400, .. }));
}
#[test]
fn test_llm_error_suggested_retry_delay() {
let err = LlmError::RateLimited {
message: "slow down".to_string(),
retry_after: Some(Duration::from_secs(60)),
};
assert_eq!(err.suggested_retry_delay(), Some(Duration::from_secs(60)));
let err = LlmError::ServerError {
status: 500,
message: "error".to_string(),
};
assert_eq!(err.suggested_retry_delay(), None);
}
#[test]
fn test_parse_retry_after() {
assert_eq!(parse_retry_after("120"), Some(Duration::from_secs(120)));
assert_eq!(parse_retry_after("0"), Some(Duration::from_secs(0)));
assert_eq!(parse_retry_after("1.5"), Some(Duration::from_secs_f64(1.5)));
assert_eq!(parse_retry_after("invalid"), None);
assert_eq!(parse_retry_after(""), None);
}
#[test]
fn test_retry_policy_conversion() {
let policy = RetryPolicy {
enabled: true,
max_retries: 5,
initial_delay: 2.0,
max_delay: 30.0,
exponential_base: 3.0,
};
let config: RetryConfig = policy.clone().into();
assert_eq!(config.enabled, policy.enabled);
assert_eq!(config.max_retries, policy.max_retries);
assert_f64_eq(config.initial_delay, policy.initial_delay);
assert_f64_eq(config.max_delay, policy.max_delay);
assert_f64_eq(config.exponential_base, policy.exponential_base);
let policy2: RetryPolicy = config.into();
assert_eq!(policy2.enabled, policy.enabled);
assert_eq!(policy2.max_retries, policy.max_retries);
}
#[tokio::test]
async fn test_with_retry_success_first_attempt() {
let config = RetryConfig::default();
let mut call_count = 0;
let result = with_retry(
&config,
|| {
call_count += 1;
async { Ok::<_, LlmError>(42) }
},
None,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count, 1);
}
#[tokio::test]
async fn test_with_retry_disabled() {
let config = RetryConfig::disabled();
let mut call_count = 0;
let result: RetryResult<i32> = with_retry(
&config,
|| {
call_count += 1;
async {
Err(LlmError::ServerError {
status: 500,
message: "error".to_string(),
})
}
},
None,
)
.await;
assert!(result.is_err());
assert_eq!(call_count, 1); }
#[tokio::test]
async fn test_with_retry_non_retryable_error() {
let config = RetryConfig::default();
let mut call_count = 0;
let result: RetryResult<i32> = with_retry(
&config,
|| {
call_count += 1;
async { Err(LlmError::AuthenticationError("bad key".to_string())) }
},
None,
)
.await;
assert!(result.is_err());
assert_eq!(call_count, 1); }
#[tokio::test]
async fn test_with_retry_eventual_success() {
let config = RetryConfig::new()
.with_max_retries(3)
.with_initial_delay(0.01);
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = call_count.clone();
let result = with_retry(
&config,
|| {
let count = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
async move {
if count < 2 {
Err(LlmError::ServerError {
status: 500,
message: "temporary error".to_string(),
})
} else {
Ok::<_, LlmError>(42)
}
}
},
None,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_with_retry_exhausted() {
let config = RetryConfig::new()
.with_max_retries(2)
.with_initial_delay(0.01);
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = call_count.clone();
let result: RetryResult<i32> = with_retry(
&config,
|| {
cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
async {
Err(LlmError::ServerError {
status: 500,
message: "persistent error".to_string(),
})
}
},
None,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.attempts, 3); assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_with_retry_callback() {
let config = RetryConfig::new()
.with_max_retries(2)
.with_initial_delay(0.01);
let callback_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = callback_count.clone();
let _: RetryResult<i32> = with_retry(
&config,
|| async {
Err(LlmError::ServerError {
status: 500,
message: "error".to_string(),
})
},
Some(Box::new(move |_err, _attempt, _delay| {
cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
})),
)
.await;
assert_eq!(callback_count.load(std::sync::atomic::Ordering::SeqCst), 2);
}
#[test]
fn test_retry_error_display() {
let err = RetryError {
last_error: LlmError::ServerError {
status: 500,
message: "internal error".to_string(),
},
attempts: 4,
total_time: Duration::from_secs(10),
};
let display = format!("{err}");
assert!(display.contains("4 attempts"));
assert!(display.contains("10"));
assert!(display.contains("Server error"));
}
}