use crate::config::RetryPolicy;
use crate::models::{MessageRequest, MessageResponse, StreamEvent};
use anyhow::Result;
use serde_json::Value;
use std::future::Future;
use std::pin::Pin;
use std::time::{Duration, Instant};
use uuid::Uuid;
#[cfg(test)]
pub mod mock;
pub type StreamEventBox =
Pin<Box<dyn futures_util::Stream<Item = Result<StreamEvent>> + Send + 'static>>;
#[allow(async_fn_in_trait, dead_code)] pub trait LlmClient: Send + Sync {
fn provider_name(&self) -> &'static str;
fn model(&self) -> &str;
fn create_message(
&self,
request: MessageRequest,
) -> impl Future<Output = Result<MessageResponse>> + Send;
async fn create_message_stream(&self, request: MessageRequest) -> Result<StreamEventBox>;
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
}
#[allow(dead_code)] pub trait RetryConfigurable {
fn retry_config(&self) -> &RetryConfig;
fn set_retry_config(&mut self, config: RetryConfig);
}
#[derive(Debug)]
pub enum LlmError {
RateLimited {
message: String,
retry_after: Option<Duration>,
},
ServerError { status: u16, message: String },
NetworkError(String),
Timeout(Duration),
AuthenticationError(String),
AuthorizationError(String),
InvalidRequest { status: u16, message: String },
ModelError(String),
ContentPolicyError(String),
ParseError(String),
ContextLengthError(String),
Other(String),
}
impl std::fmt::Display for LlmError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LlmError::RateLimited { message, .. } => write!(f, "Rate limit exceeded: {message}"),
LlmError::ServerError { status, message } => {
write!(f, "Server error ({status}): {message}")
}
LlmError::NetworkError(msg) => write!(f, "Network error: {msg}"),
LlmError::Timeout(d) => write!(f, "Request timed out after {d:?}"),
LlmError::AuthenticationError(msg) => write!(f, "Authentication failed: {msg}"),
LlmError::AuthorizationError(msg) => write!(f, "Authorization failed: {msg}"),
LlmError::InvalidRequest { status, message } => {
write!(f, "Invalid request ({status}): {message}")
}
LlmError::ModelError(msg) => write!(f, "Model error: {msg}"),
LlmError::ContentPolicyError(msg) => write!(f, "Content policy violation: {msg}"),
LlmError::ParseError(msg) => write!(f, "Response parsing error: {msg}"),
LlmError::ContextLengthError(msg) => write!(f, "Context length exceeded: {msg}"),
LlmError::Other(msg) => write!(f, "LLM error: {msg}"),
}
}
}
impl std::error::Error for LlmError {}
impl LlmError {
pub fn is_retryable(&self) -> bool {
matches!(
self,
LlmError::RateLimited { .. }
| LlmError::ServerError { .. }
| LlmError::NetworkError(_)
| LlmError::Timeout(_)
)
}
pub fn suggested_retry_delay(&self) -> Option<Duration> {
match self {
LlmError::RateLimited { retry_after, .. } => *retry_after,
_ => None,
}
}
pub fn from_http_response(status: u16, body: &str) -> Self {
match status {
429 => LlmError::RateLimited {
message: body.to_string(),
retry_after: None,
},
401 => LlmError::AuthenticationError(body.to_string()),
403 => {
if looks_like_authentication_failure(body) {
LlmError::AuthenticationError(body.to_string())
} else {
LlmError::AuthorizationError(body.to_string())
}
}
400 => {
let body_lower = body.to_lowercase();
if body_lower.contains("insufficientquota")
|| body_lower.contains("insufficient_quota")
|| body_lower.contains("exceeded your current quota")
|| body_lower.contains("quota exceeded")
{
LlmError::RateLimited {
message: body.to_string(),
retry_after: None,
}
} else if body_lower.contains("context_length")
|| body_lower.contains("token")
|| body_lower.contains("too long")
|| body_lower.contains("maximum")
{
LlmError::ContextLengthError(body.to_string())
} else if body_lower.contains("content_policy")
|| body_lower.contains("safety")
|| body_lower.contains("harmful")
|| body_lower.contains("inappropriate")
{
LlmError::ContentPolicyError(body.to_string())
} else if body_lower.contains("model") && body_lower.contains("not found") {
LlmError::ModelError(body.to_string())
} else {
LlmError::InvalidRequest {
status,
message: body.to_string(),
}
}
}
404 => {
if body.to_lowercase().contains("model") {
LlmError::ModelError(body.to_string())
} else {
LlmError::InvalidRequest {
status,
message: body.to_string(),
}
}
}
500..=599 => LlmError::ServerError {
status,
message: body.to_string(),
},
_ => LlmError::Other(format!("HTTP {status}: {body}")),
}
}
pub fn from_http_response_with_retry_after(
status: u16,
body: &str,
retry_after: Option<Duration>,
) -> Self {
let mut error = Self::from_http_response(status, body);
if let LlmError::RateLimited {
retry_after: ref mut ra,
..
} = error
{
*ra = retry_after;
}
error
}
pub fn from_reqwest(err: &reqwest::Error) -> Self {
if err.is_timeout() {
LlmError::Timeout(Duration::from_secs(0))
} else if err.is_connect() {
LlmError::NetworkError(format!("Connection failed: {err}"))
} else if err.is_request() {
LlmError::NetworkError(format!("Request failed: {err}"))
} else {
LlmError::Other(err.to_string())
}
}
}
#[must_use]
pub(crate) fn sanitize_http_error_body(
provider_label: Option<&str>,
status: u16,
body: &str,
) -> String {
if let Some(message) = extract_json_error_message(body) {
return truncate_for_error(&collapse_whitespace(&message), 2_000);
}
if is_probably_html(body) {
let text = html_to_text(body);
let lower = text.to_ascii_lowercase();
let provider = provider_label.unwrap_or("Provider");
let error_id = extract_cloudflare_error_id(&text);
let is_cloudflare = lower.contains("cloudflare");
let looks_like_access_denied = lower.contains("access denied")
&& (is_cloudflare
|| lower.contains("security alert")
|| lower.contains("contact support")
|| lower.contains("contact us")
|| error_id.is_some());
if looks_like_access_denied {
let label = if is_cloudflare {
"Cloudflare Access Denied"
} else {
"Access Denied"
};
let mut message = format!(
"{provider} API returned {label} (HTTP {status}). \
The request was blocked before it reached the model; retry with a \
smaller request or fewer tools, or contact provider support"
);
if let Some(id) = error_id {
message.push_str(&format!(" with ID {id}"));
}
message.push('.');
return message;
}
let text = truncate_for_error(&collapse_whitespace(&text), 900);
return format!("{provider} API returned an HTML error page (HTTP {status}): {text}");
}
truncate_for_error(&collapse_whitespace(body), 2_000)
}
fn looks_like_authentication_failure(body: &str) -> bool {
let lower = body.to_ascii_lowercase();
lower.contains("authentication")
|| lower.contains("unauthorized")
|| lower.contains("api key")
|| lower.contains("invalid key")
|| lower.contains("invalid token")
|| lower.contains("bearer token")
|| lower.contains("missing token")
}
fn extract_json_error_message(body: &str) -> Option<String> {
let value: Value = serde_json::from_str(body).ok()?;
for pointer in [
"/error/message",
"/error",
"/message",
"/detail",
"/error_description",
] {
let Some(value) = value.pointer(pointer) else {
continue;
};
if let Some(message) = value.as_str() {
if !message.trim().is_empty() {
return Some(message.to_string());
}
} else if value.is_object() || value.is_array() {
return Some(value.to_string());
}
}
None
}
fn is_probably_html(body: &str) -> bool {
let prefix = body
.chars()
.take(512)
.collect::<String>()
.to_ascii_lowercase();
prefix.contains("<!doctype html") || prefix.contains("<html") || prefix.contains("<head")
}
fn html_to_text(html: &str) -> String {
let without_scripts = strip_html_block(html, "script");
let without_styles = strip_html_block(&without_scripts, "style");
let mut text = String::with_capacity(without_styles.len().min(4096));
let mut in_tag = false;
for ch in without_styles.chars() {
match ch {
'<' => {
in_tag = true;
text.push(' ');
}
'>' => {
in_tag = false;
text.push(' ');
}
_ if !in_tag => text.push(ch),
_ => {}
}
}
decode_basic_html_entities(&collapse_whitespace(&text))
}
fn strip_html_block(input: &str, tag: &str) -> String {
let mut out = String::with_capacity(input.len());
let mut cursor = 0usize;
let lower = input.to_ascii_lowercase();
let start_marker = format!("<{tag}");
let end_marker = format!("</{tag}>");
while let Some(relative_start) = lower[cursor..].find(&start_marker) {
let start = cursor + relative_start;
out.push_str(&input[cursor..start]);
let after_start = start + start_marker.len();
let Some(relative_end) = lower[after_start..].find(&end_marker) else {
cursor = input.len();
break;
};
cursor = after_start + relative_end + end_marker.len();
out.push(' ');
}
out.push_str(&input[cursor..]);
out
}
fn decode_basic_html_entities(input: &str) -> String {
input
.replace(" ", " ")
.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace(""", "\"")
.replace("'", "'")
.replace("'", "'")
}
fn collapse_whitespace(input: &str) -> String {
input.split_whitespace().collect::<Vec<_>>().join(" ")
}
fn truncate_for_error(input: &str, max_chars: usize) -> String {
let mut out = String::with_capacity(input.len().min(max_chars + 32));
for (count, ch) in input.chars().enumerate() {
if count >= max_chars {
out.push_str("...");
return out;
}
out.push(ch);
}
out
}
fn extract_cloudflare_error_id(text: &str) -> Option<String> {
let mut last = None;
for token in text.split(|ch: char| !ch.is_ascii_hexdigit()) {
if (16..=64).contains(&token.len()) && token.bytes().any(|b| b.is_ascii_alphabetic()) {
last = Some(token.to_string());
}
}
last
}
impl From<reqwest::Error> for LlmError {
fn from(err: reqwest::Error) -> Self {
LlmError::from_reqwest(&err)
}
}
impl From<serde_json::Error> for LlmError {
fn from(err: serde_json::Error) -> Self {
LlmError::ParseError(err.to_string())
}
}
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub enabled: bool,
pub max_retries: u32,
pub initial_delay: f64,
pub max_delay: f64,
pub exponential_base: f64,
pub jitter: bool,
pub jitter_factor: f64,
pub respect_retry_after: bool,
#[allow(dead_code)] pub retryable_status_codes: Vec<u16>,
#[allow(dead_code)] pub request_timeout: f64,
pub total_timeout: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
enabled: true,
max_retries: 3,
initial_delay: 1.0,
max_delay: 60.0,
exponential_base: 2.0,
jitter: true,
jitter_factor: 0.1,
respect_retry_after: true,
retryable_status_codes: vec![429, 500, 502, 503, 504],
request_timeout: 120.0,
total_timeout: 0.0, }
}
}
#[allow(dead_code)] impl RetryConfig {
pub fn new() -> Self {
Self::default()
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_initial_delay(mut self, delay: f64) -> Self {
self.initial_delay = delay;
self
}
pub fn with_max_delay(mut self, delay: f64) -> Self {
self.max_delay = delay;
self
}
pub fn with_jitter(mut self, enabled: bool) -> Self {
self.jitter = enabled;
self
}
pub fn with_request_timeout(mut self, timeout: f64) -> Self {
self.request_timeout = timeout;
self
}
pub fn with_total_timeout(mut self, timeout: f64) -> Self {
self.total_timeout = timeout;
self
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
let exponent = i32::try_from(attempt).unwrap_or(i32::MAX);
let base_delay = self.initial_delay * self.exponential_base.powi(exponent);
let capped_delay = base_delay.min(self.max_delay);
let final_delay = if self.jitter {
let jitter_range = capped_delay * self.jitter_factor;
let bytes = *Uuid::new_v4().as_bytes();
let sample = u16::from_le_bytes([bytes[0], bytes[1]]);
let random_factor = f64::from(sample) / f64::from(u16::MAX); let jitter = jitter_range * (2.0 * random_factor - 1.0);
(capped_delay + jitter).max(0.0)
} else {
capped_delay
};
Duration::from_secs_f64(final_delay)
}
pub fn is_retryable_status(&self, status: u16) -> bool {
self.retryable_status_codes.contains(&status)
}
}
impl From<RetryPolicy> for RetryConfig {
fn from(policy: RetryPolicy) -> Self {
Self {
enabled: policy.enabled,
max_retries: policy.max_retries,
initial_delay: policy.initial_delay,
max_delay: policy.max_delay,
exponential_base: policy.exponential_base,
..Default::default()
}
}
}
impl From<RetryConfig> for RetryPolicy {
fn from(config: RetryConfig) -> Self {
Self {
enabled: config.enabled,
max_retries: config.max_retries,
initial_delay: config.initial_delay,
max_delay: config.max_delay,
exponential_base: config.exponential_base,
}
}
}
#[derive(Debug)]
pub struct RetryError {
pub last_error: LlmError,
pub attempts: u32,
pub total_time: Duration,
}
impl std::fmt::Display for RetryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Retry exhausted after {} attempts ({:?}): {}",
self.attempts, self.total_time, self.last_error
)
}
}
impl std::error::Error for RetryError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.last_error)
}
}
pub type RetryResult<T> = Result<T, RetryError>;
pub type RetryCallback = Box<dyn Fn(&LlmError, u32, Duration) + Send + Sync>;
pub async fn with_retry<F, Fut, T>(
config: &RetryConfig,
mut operation: F,
callback: Option<RetryCallback>,
) -> RetryResult<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, LlmError>>,
{
if !config.enabled {
return operation().await.map_err(|e| RetryError {
last_error: e,
attempts: 1,
total_time: Duration::ZERO,
});
}
let start_time = Instant::now();
let total_timeout = if config.total_timeout > 0.0 {
Some(Duration::from_secs_f64(config.total_timeout))
} else {
None
};
let mut last_error: Option<LlmError> = None;
for attempt in 0..=config.max_retries {
if let Some(timeout) = total_timeout
&& start_time.elapsed() >= timeout
{
return Err(RetryError {
last_error: last_error.unwrap_or(LlmError::Timeout(timeout)),
attempts: attempt,
total_time: start_time.elapsed(),
});
}
match operation().await {
Ok(result) => return Ok(result),
Err(err) => {
if !err.is_retryable() {
return Err(RetryError {
last_error: err,
attempts: attempt + 1,
total_time: start_time.elapsed(),
});
}
if attempt >= config.max_retries {
return Err(RetryError {
last_error: err,
attempts: attempt + 1,
total_time: start_time.elapsed(),
});
}
let base_delay = config.delay_for_attempt(attempt);
let delay = if config.respect_retry_after {
err.suggested_retry_delay().unwrap_or(base_delay)
} else {
base_delay
};
if let Some(ref cb) = callback {
cb(&err, attempt, delay);
}
last_error = Some(err);
tokio::time::sleep(delay).await;
}
}
}
Err(RetryError {
last_error: last_error.unwrap_or(LlmError::Other("Unknown retry error".to_string())),
attempts: config.max_retries + 1,
total_time: start_time.elapsed(),
})
}
#[allow(dead_code)] pub async fn with_retry_simple<F, Fut, T>(config: &RetryConfig, operation: F) -> RetryResult<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, LlmError>>,
{
with_retry(config, operation, None).await
}
pub fn parse_retry_after(value: &str) -> Option<Duration> {
if let Ok(seconds) = value.parse::<u64>() {
return Some(Duration::from_secs(seconds));
}
if let Ok(seconds) = value.parse::<f64>() {
return Some(Duration::from_secs_f64(seconds));
}
None
}
pub fn extract_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
headers
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(parse_retry_after)
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_f64_eq(actual: f64, expected: f64) {
assert!(
(actual - expected).abs() < f64::EPSILON,
"expected {expected}, got {actual}"
);
}
#[test]
fn test_retry_config_defaults() {
let config = RetryConfig::default();
assert!(config.enabled);
assert_eq!(config.max_retries, 3);
assert_f64_eq(config.initial_delay, 1.0);
assert_f64_eq(config.max_delay, 60.0);
assert_f64_eq(config.exponential_base, 2.0);
assert!(config.jitter);
}
#[test]
fn test_retry_config_disabled() {
let config = RetryConfig::disabled();
assert!(!config.enabled);
}
#[test]
fn test_retry_config_builder() {
let config = RetryConfig::new()
.with_max_retries(5)
.with_initial_delay(2.0)
.with_max_delay(120.0)
.with_jitter(false);
assert_eq!(config.max_retries, 5);
assert_f64_eq(config.initial_delay, 2.0);
assert_f64_eq(config.max_delay, 120.0);
assert!(!config.jitter);
}
#[test]
fn test_delay_for_attempt_exponential() {
let config = RetryConfig::new().with_jitter(false);
let d0 = config.delay_for_attempt(0);
assert_eq!(d0, Duration::from_secs_f64(1.0));
let d1 = config.delay_for_attempt(1);
assert_eq!(d1, Duration::from_secs_f64(2.0));
let d2 = config.delay_for_attempt(2);
assert_eq!(d2, Duration::from_secs_f64(4.0));
let d3 = config.delay_for_attempt(3);
assert_eq!(d3, Duration::from_secs_f64(8.0));
}
#[test]
fn test_delay_for_attempt_capped() {
let config = RetryConfig::new().with_jitter(false).with_max_delay(5.0);
let d3 = config.delay_for_attempt(3);
assert_eq!(d3, Duration::from_secs_f64(5.0));
}
#[test]
fn test_delay_for_attempt_with_jitter() {
let config = RetryConfig::new().with_jitter(true);
let d1 = config.delay_for_attempt(1);
let d2 = config.delay_for_attempt(1);
let base = 2.0;
let range = base * 0.1;
assert!(d1.as_secs_f64() >= base - range);
assert!(d1.as_secs_f64() <= base + range);
assert!(d2.as_secs_f64() >= base - range);
assert!(d2.as_secs_f64() <= base + range);
}
#[test]
fn test_is_retryable_status() {
let config = RetryConfig::default();
assert!(config.is_retryable_status(429)); assert!(config.is_retryable_status(500)); assert!(config.is_retryable_status(502)); assert!(config.is_retryable_status(503)); assert!(config.is_retryable_status(504));
assert!(!config.is_retryable_status(400)); assert!(!config.is_retryable_status(401)); assert!(!config.is_retryable_status(403)); assert!(!config.is_retryable_status(404)); }
#[test]
fn test_llm_error_retryable() {
assert!(
LlmError::RateLimited {
message: "too many requests".to_string(),
retry_after: None
}
.is_retryable()
);
assert!(
LlmError::ServerError {
status: 500,
message: "internal error".to_string()
}
.is_retryable()
);
assert!(LlmError::NetworkError("connection refused".to_string()).is_retryable());
assert!(LlmError::Timeout(Duration::from_secs(30)).is_retryable());
assert!(!LlmError::AuthenticationError("invalid key".to_string()).is_retryable());
assert!(!LlmError::AuthorizationError("blocked".to_string()).is_retryable());
assert!(
!LlmError::InvalidRequest {
status: 400,
message: "bad json".to_string()
}
.is_retryable()
);
assert!(!LlmError::ContentPolicyError("unsafe content".to_string()).is_retryable());
assert!(!LlmError::ContextLengthError("too long".to_string()).is_retryable());
}
#[test]
fn test_llm_error_from_http_response() {
let err = LlmError::from_http_response(429, "rate limit exceeded");
assert!(matches!(err, LlmError::RateLimited { .. }));
let err = LlmError::from_http_response(401, "invalid api key");
assert!(matches!(err, LlmError::AuthenticationError(_)));
let err = LlmError::from_http_response(403, "forbidden");
assert!(matches!(err, LlmError::AuthorizationError(_)));
let err = LlmError::from_http_response(403, "invalid api key");
assert!(matches!(err, LlmError::AuthenticationError(_)));
let err = LlmError::from_http_response(500, "internal server error");
assert!(matches!(err, LlmError::ServerError { status: 500, .. }));
let err = LlmError::from_http_response(503, "service unavailable");
assert!(matches!(err, LlmError::ServerError { status: 503, .. }));
let err = LlmError::from_http_response(400, "context_length_exceeded");
assert!(matches!(err, LlmError::ContextLengthError(_)));
let err = LlmError::from_http_response(
400,
r#"{"error":{"code":"insufficientquota","message":"You exceeded your current quota"}}"#,
);
assert!(matches!(err, LlmError::RateLimited { .. }));
assert!(err.is_retryable());
let err = LlmError::from_http_response(400, "content_policy_violation");
assert!(matches!(err, LlmError::ContentPolicyError(_)));
let err = LlmError::from_http_response(400, "invalid json");
assert!(matches!(err, LlmError::InvalidRequest { status: 400, .. }));
}
#[test]
fn cloudflare_html_error_is_summarized_without_raw_markup() {
let body = r#"<!DOCTYPE html><html><head><title>Access Denied</title><style>
.hidden { display: none; }
</style></head><body>
<h1>Access Denied</h1>
<p>The action you just performed triggered a security alert.</p>
<script>window.noisy = true;</script>
<span>2600:1700:467:d410:f137:b94f:1dd0:d1e4</span>
<span>a059a2873f3fdf82</span>
<div>Cloudflare Error Pages</div>
</body></html>"#;
let message = sanitize_http_error_body(Some("Arcee AI"), 403, body);
assert!(message.contains("Arcee AI API returned Cloudflare Access Denied"));
assert!(message.contains("ID a059a2873f3fdf82"));
assert!(!message.contains("<!DOCTYPE"));
assert!(!message.contains("tailwindcss"));
assert!(message.len() < 300);
}
#[test]
fn cloudflare_access_denied_403_is_authorization_not_authentication() {
let message = sanitize_http_error_body(
Some("Arcee AI"),
403,
r#"<!doctype html><html><body><h1>Access Denied</h1><p>Cloudflare Error Pages</p></body></html>"#,
);
let err = LlmError::from_http_response(403, &message);
assert!(matches!(err, LlmError::AuthorizationError(_)));
}
#[test]
fn arcee_access_denied_without_literal_cloudflare_is_still_summarized() {
let body = r#"<!DOCTYPE html><html lang="en"><head>
<meta name="description" content="Cloudflare Error Pages">
<title>Access Denied</title>
<style>:root{--accent:cloudflare}</style></head><body>
<h1>Access Denied</h1>
<p>The action you just performed triggered a security alert.</p>
<p>Please contact us if this was a mistake.</p>
<a>Contact Support</a>
<span>2600:1700:467:d410:f137:b94f:1dd0:d1e4</span>
<span>a059c0d4caf1f9cc</span>
</body></html>"#;
let message = sanitize_http_error_body(Some("Arcee AI"), 403, body);
assert!(
message.contains("Arcee AI API returned Access Denied"),
"got: {message}"
);
assert!(message.contains("ID a059c0d4caf1f9cc"), "got: {message}");
assert!(
!message.to_ascii_lowercase().contains("cloudflare"),
"stripped Arcee page has no literal Cloudflare: {message}"
);
assert!(!message.contains('<'), "no raw markup: {message}");
assert!(message.len() < 300, "stays concise: {message}");
let err = LlmError::from_http_response(403, &message);
assert!(matches!(err, LlmError::AuthorizationError(_)));
}
#[test]
fn test_llm_error_suggested_retry_delay() {
let err = LlmError::RateLimited {
message: "slow down".to_string(),
retry_after: Some(Duration::from_secs(60)),
};
assert_eq!(err.suggested_retry_delay(), Some(Duration::from_secs(60)));
let err = LlmError::ServerError {
status: 500,
message: "error".to_string(),
};
assert_eq!(err.suggested_retry_delay(), None);
}
#[test]
fn test_parse_retry_after() {
assert_eq!(parse_retry_after("120"), Some(Duration::from_secs(120)));
assert_eq!(parse_retry_after("0"), Some(Duration::from_secs(0)));
assert_eq!(parse_retry_after("1.5"), Some(Duration::from_secs_f64(1.5)));
assert_eq!(parse_retry_after("invalid"), None);
assert_eq!(parse_retry_after(""), None);
}
#[test]
fn test_retry_policy_conversion() {
let policy = RetryPolicy {
enabled: true,
max_retries: 5,
initial_delay: 2.0,
max_delay: 30.0,
exponential_base: 3.0,
};
let config: RetryConfig = policy.clone().into();
assert_eq!(config.enabled, policy.enabled);
assert_eq!(config.max_retries, policy.max_retries);
assert_f64_eq(config.initial_delay, policy.initial_delay);
assert_f64_eq(config.max_delay, policy.max_delay);
assert_f64_eq(config.exponential_base, policy.exponential_base);
let policy2: RetryPolicy = config.into();
assert_eq!(policy2.enabled, policy.enabled);
assert_eq!(policy2.max_retries, policy.max_retries);
}
#[tokio::test]
async fn test_with_retry_success_first_attempt() {
let config = RetryConfig::default();
let mut call_count = 0;
let result = with_retry(
&config,
|| {
call_count += 1;
async { Ok::<_, LlmError>(42) }
},
None,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count, 1);
}
#[tokio::test]
async fn test_with_retry_disabled() {
let config = RetryConfig::disabled();
let mut call_count = 0;
let result: RetryResult<i32> = with_retry(
&config,
|| {
call_count += 1;
async {
Err(LlmError::ServerError {
status: 500,
message: "error".to_string(),
})
}
},
None,
)
.await;
assert!(result.is_err());
assert_eq!(call_count, 1); }
#[tokio::test]
async fn test_with_retry_non_retryable_error() {
let config = RetryConfig::default();
let mut call_count = 0;
let result: RetryResult<i32> = with_retry(
&config,
|| {
call_count += 1;
async { Err(LlmError::AuthenticationError("bad key".to_string())) }
},
None,
)
.await;
assert!(result.is_err());
assert_eq!(call_count, 1); }
#[tokio::test]
async fn test_with_retry_eventual_success() {
let config = RetryConfig::new()
.with_max_retries(3)
.with_initial_delay(0.01);
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = call_count.clone();
let result = with_retry(
&config,
|| {
let count = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
async move {
if count < 2 {
Err(LlmError::ServerError {
status: 500,
message: "temporary error".to_string(),
})
} else {
Ok::<_, LlmError>(42)
}
}
},
None,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_with_retry_exhausted() {
let config = RetryConfig::new()
.with_max_retries(2)
.with_initial_delay(0.01);
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = call_count.clone();
let result: RetryResult<i32> = with_retry(
&config,
|| {
cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
async {
Err(LlmError::ServerError {
status: 500,
message: "persistent error".to_string(),
})
}
},
None,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.attempts, 3); assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_with_retry_callback() {
let config = RetryConfig::new()
.with_max_retries(2)
.with_initial_delay(0.01);
let callback_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = callback_count.clone();
let _: RetryResult<i32> = with_retry(
&config,
|| async {
Err(LlmError::ServerError {
status: 500,
message: "error".to_string(),
})
},
Some(Box::new(move |_err, _attempt, _delay| {
cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
})),
)
.await;
assert_eq!(callback_count.load(std::sync::atomic::Ordering::SeqCst), 2);
}
#[test]
fn test_retry_error_display() {
let err = RetryError {
last_error: LlmError::ServerError {
status: 500,
message: "internal error".to_string(),
},
attempts: 4,
total_time: Duration::from_secs(10),
};
let display = format!("{err}");
assert!(display.contains("4 attempts"));
assert!(display.contains("10"));
assert!(display.contains("Server error"));
}
}