use std::time::Duration;
use thiserror::Error;
pub type Result<T> = std::result::Result<T, LlmError>;
#[derive(Debug, Clone, PartialEq)]
pub enum RetryStrategy {
ExponentialBackoff {
base_delay: Duration,
max_delay: Duration,
max_attempts: u32,
},
WaitAndRetry {
wait: Duration,
},
ReduceContext,
NoRetry,
}
impl RetryStrategy {
pub fn network_backoff() -> Self {
Self::ExponentialBackoff {
base_delay: Duration::from_millis(125),
max_delay: Duration::from_secs(30),
max_attempts: 5,
}
}
pub fn server_backoff() -> Self {
Self::ExponentialBackoff {
base_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
max_attempts: 3,
}
}
pub fn should_retry(&self) -> bool {
!matches!(self, Self::NoRetry)
}
}
#[derive(Debug, Error)]
pub enum LlmError {
#[error("API error: {0}")]
ApiError(String),
#[error("Rate limit exceeded: {0}")]
RateLimited(String),
#[error("Invalid request: {0}")]
InvalidRequest(String),
#[error("Authentication error: {0}")]
AuthError(String),
#[error("Token limit exceeded: max {max}, got {got}")]
TokenLimitExceeded { max: usize, got: usize },
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Network error: {0}")]
NetworkError(String),
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("Provider error: {0}")]
ProviderError(String),
#[error("Request timed out")]
Timeout,
#[error("Not supported: {0}")]
NotSupported(String),
#[error("Unknown error: {0}")]
Unknown(String),
}
impl From<reqwest::Error> for LlmError {
fn from(err: reqwest::Error) -> Self {
if err.is_timeout() {
LlmError::Timeout
} else if err.is_connect() {
LlmError::NetworkError(format!("Connection failed: {}", err))
} else {
LlmError::NetworkError(err.to_string())
}
}
}
fn map_openai_api_error(api_err: async_openai::error::ApiError) -> LlmError {
let full_message = api_err.to_string();
match api_err.code.as_deref() {
Some("rate_limit_exceeded") => return LlmError::RateLimited(full_message),
Some("insufficient_quota") => return LlmError::ApiError(full_message),
Some("invalid_api_key")
| Some("no_auth")
| Some("ip_not_authorized")
| Some("no_such_organization") => return LlmError::AuthError(full_message),
Some("context_length_exceeded") | Some("max_tokens_exceeded") => {
return LlmError::TokenLimitExceeded { max: 0, got: 0 };
}
Some("model_not_found") | Some("invalid_model") => {
return LlmError::ModelNotFound(full_message);
}
Some("content_filter") | Some("content_policy_violation") => {
return LlmError::ApiError(full_message);
}
_ => {} }
match api_err.r#type.as_deref() {
Some("tokens") | Some("requests") => LlmError::RateLimited(full_message),
Some("authentication_error") => LlmError::AuthError(full_message),
Some("invalid_request_error") => LlmError::InvalidRequest(full_message),
Some("server_error") => LlmError::ProviderError(full_message),
_ => LlmError::ApiError(full_message),
}
}
impl From<async_openai::error::OpenAIError> for LlmError {
fn from(err: async_openai::error::OpenAIError) -> Self {
use async_openai::error::OpenAIError;
match err {
OpenAIError::ApiError(api_err) => map_openai_api_error(api_err),
OpenAIError::Reqwest(req_err) => LlmError::from(req_err),
OpenAIError::JSONDeserialize(json_err, _content) => {
LlmError::SerializationError(json_err)
}
OpenAIError::StreamError(stream_err) => {
LlmError::ProviderError(format!("Stream error: {stream_err}"))
}
OpenAIError::InvalidArgument(msg) => LlmError::InvalidRequest(msg),
OpenAIError::FileSaveError(msg) | OpenAIError::FileReadError(msg) => {
LlmError::Unknown(format!("File I/O error: {msg}"))
}
}
}
}
fn parse_retry_after_secs(message: &str) -> Option<Duration> {
let lower = message.to_ascii_lowercase();
let marker = "try again in ";
let start = lower.find(marker)? + marker.len();
let tail = &message[start..];
let mut total_ms = 0u64;
let mut num_str = String::with_capacity(8);
let mut chars = tail.chars().peekable();
while let Some(c) = chars.peek().copied() {
match c {
'0'..='9' | '.' => {
num_str.push(c);
chars.next();
}
'm' => {
if let Ok(mins) = num_str.parse::<f64>() {
total_ms += (mins * 60_000.0) as u64;
}
num_str.clear();
chars.next();
}
's' => {
if let Ok(secs) = num_str.parse::<f64>() {
total_ms += (secs * 1_000.0) as u64;
}
break;
}
_ => break,
}
}
if total_ms == 0 {
return None;
}
let buffered = ((total_ms as f64 * 1.1) as u64).min(120_000);
Some(Duration::from_millis(buffered))
}
impl LlmError {
pub fn retry_strategy(&self) -> RetryStrategy {
match self {
Self::NetworkError(_) | Self::Timeout => RetryStrategy::network_backoff(),
Self::RateLimited(msg) => RetryStrategy::WaitAndRetry {
wait: parse_retry_after_secs(msg).unwrap_or(Duration::from_secs(60)),
},
Self::ProviderError(_) => RetryStrategy::server_backoff(),
Self::TokenLimitExceeded { .. } => RetryStrategy::ReduceContext,
Self::AuthError(_)
| Self::InvalidRequest(_)
| Self::ModelNotFound(_)
| Self::ConfigError(_)
| Self::NotSupported(_) => RetryStrategy::NoRetry,
Self::ApiError(_) | Self::SerializationError(_) | Self::Unknown(_) => {
RetryStrategy::ExponentialBackoff {
base_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(30),
max_attempts: 2,
}
}
}
}
pub fn user_description(&self) -> String {
match self {
Self::NetworkError(_) => {
"Unable to connect to the API. Check your internet connection.".to_string()
}
Self::Timeout => "Request timed out. The server may be overloaded.".to_string(),
Self::RateLimited(_) => "Rate limited by the API. Waiting before retry...".to_string(),
Self::TokenLimitExceeded { max, got } => {
format!(
"Context too large ({}/{} tokens). Reducing context and retrying...",
got, max
)
}
Self::AuthError(_) => {
"Authentication failed. Please check your API key is valid and not expired."
.to_string()
}
Self::ModelNotFound(model) => {
format!(
"Model '{}' not found. Use a supported model like 'gpt-4o-mini'.",
model
)
}
Self::InvalidRequest(msg) => {
format!("Invalid request: {}. Check your parameters.", msg)
}
Self::ConfigError(msg) => format!("Configuration error: {}.", msg),
Self::NotSupported(feature) => {
format!("Feature '{}' is not supported by this provider.", feature)
}
Self::ApiError(_) | Self::ProviderError(_) => {
"API server error. Retrying...".to_string()
}
Self::SerializationError(_) => {
"Failed to parse API response. This may be a temporary issue.".to_string()
}
Self::Unknown(msg) => format!("An unexpected error occurred: {}", msg),
}
}
pub fn is_recoverable(&self) -> bool {
self.retry_strategy().should_retry()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_error_display() {
let error = LlmError::ApiError("something went wrong".to_string());
assert_eq!(error.to_string(), "API error: something went wrong");
let error = LlmError::RateLimited("too many requests".to_string());
assert_eq!(error.to_string(), "Rate limit exceeded: too many requests");
let error = LlmError::InvalidRequest("bad params".to_string());
assert_eq!(error.to_string(), "Invalid request: bad params");
}
#[test]
fn test_llm_error_auth() {
let error = LlmError::AuthError("invalid key".to_string());
assert_eq!(error.to_string(), "Authentication error: invalid key");
}
#[test]
fn test_llm_error_token_limit() {
let error = LlmError::TokenLimitExceeded {
max: 4096,
got: 5000,
};
assert_eq!(
error.to_string(),
"Token limit exceeded: max 4096, got 5000"
);
}
#[test]
fn test_llm_error_model_not_found() {
let error = LlmError::ModelNotFound("gpt-5-turbo".to_string());
assert_eq!(error.to_string(), "Model not found: gpt-5-turbo");
}
#[test]
fn test_llm_error_network() {
let error = LlmError::NetworkError("connection refused".to_string());
assert_eq!(error.to_string(), "Network error: connection refused");
}
#[test]
fn test_llm_error_config() {
let error = LlmError::ConfigError("missing api key".to_string());
assert_eq!(error.to_string(), "Configuration error: missing api key");
}
#[test]
fn test_llm_error_provider() {
let error = LlmError::ProviderError("openai specific error".to_string());
assert_eq!(error.to_string(), "Provider error: openai specific error");
}
#[test]
fn test_llm_error_timeout() {
let error = LlmError::Timeout;
assert_eq!(error.to_string(), "Request timed out");
}
#[test]
fn test_llm_error_not_supported() {
let error = LlmError::NotSupported("function calling".to_string());
assert_eq!(error.to_string(), "Not supported: function calling");
}
#[test]
fn test_llm_error_unknown() {
let error = LlmError::Unknown("mystery error".to_string());
assert_eq!(error.to_string(), "Unknown error: mystery error");
}
#[test]
fn test_llm_error_debug() {
let error = LlmError::ApiError("test".to_string());
let debug = format!("{:?}", error);
assert!(debug.contains("ApiError"));
assert!(debug.contains("test"));
}
#[test]
fn test_llm_error_from_serde_json() {
let json_str = "not json at all";
let json_err: serde_json::Error =
serde_json::from_str::<serde_json::Value>(json_str).unwrap_err();
let llm_err: LlmError = json_err.into();
assert!(matches!(llm_err, LlmError::SerializationError(_)));
}
#[test]
fn test_network_error_retry_strategy() {
let error = LlmError::NetworkError("connection failed".to_string());
let strategy = error.retry_strategy();
match strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => {
assert_eq!(max_attempts, 5);
}
_ => panic!("Expected ExponentialBackoff for network error"),
}
assert!(strategy.should_retry());
assert!(error.is_recoverable());
}
#[test]
fn test_timeout_retry_strategy() {
let error = LlmError::Timeout;
let strategy = error.retry_strategy();
assert!(matches!(strategy, RetryStrategy::ExponentialBackoff { .. }));
assert!(strategy.should_retry());
}
#[test]
fn test_rate_limited_retry_strategy() {
let error = LlmError::RateLimited("too many requests".to_string());
let strategy = error.retry_strategy();
match strategy {
RetryStrategy::WaitAndRetry { wait } => {
assert_eq!(wait, Duration::from_secs(60));
}
_ => panic!("Expected WaitAndRetry for rate limit"),
}
assert!(strategy.should_retry());
}
#[test]
fn test_token_limit_reduce_context_strategy() {
let error = LlmError::TokenLimitExceeded {
max: 4096,
got: 5000,
};
let strategy = error.retry_strategy();
assert!(matches!(strategy, RetryStrategy::ReduceContext));
assert!(strategy.should_retry());
}
#[test]
fn test_auth_error_no_retry() {
let error = LlmError::AuthError("invalid key".to_string());
let strategy = error.retry_strategy();
assert!(matches!(strategy, RetryStrategy::NoRetry));
assert!(!strategy.should_retry());
assert!(!error.is_recoverable());
}
#[test]
fn test_invalid_request_no_retry() {
let error = LlmError::InvalidRequest("bad params".to_string());
assert!(matches!(error.retry_strategy(), RetryStrategy::NoRetry));
}
#[test]
fn test_model_not_found_no_retry() {
let error = LlmError::ModelNotFound("gpt-5".to_string());
assert!(matches!(error.retry_strategy(), RetryStrategy::NoRetry));
}
#[test]
fn test_user_description_network() {
let error = LlmError::NetworkError("connection refused".to_string());
let desc = error.user_description();
assert!(desc.contains("internet connection"));
}
#[test]
fn test_user_description_auth() {
let error = LlmError::AuthError("invalid".to_string());
let desc = error.user_description();
assert!(desc.contains("API key"));
}
#[test]
fn test_user_description_token_limit() {
let error = LlmError::TokenLimitExceeded {
max: 4096,
got: 5000,
};
let desc = error.user_description();
assert!(desc.contains("5000/4096"));
assert!(desc.contains("Reducing"));
}
#[test]
fn test_retry_strategy_equality() {
let s1 = RetryStrategy::network_backoff();
let s2 = RetryStrategy::network_backoff();
assert_eq!(s1, s2);
let s3 = RetryStrategy::NoRetry;
assert_ne!(s1, s3);
}
#[test]
fn test_user_description_timeout() {
let error = LlmError::Timeout;
let desc = error.user_description();
assert!(desc.contains("timed out"));
}
#[test]
fn test_user_description_rate_limited() {
let error = LlmError::RateLimited("slow down".to_string());
let desc = error.user_description();
assert!(desc.contains("Rate limited"));
}
#[test]
fn test_user_description_model_not_found() {
let error = LlmError::ModelNotFound("gpt-5".to_string());
let desc = error.user_description();
assert!(desc.contains("gpt-5"));
assert!(desc.contains("not found"));
}
#[test]
fn test_user_description_not_supported() {
let error = LlmError::NotSupported("streaming".to_string());
let desc = error.user_description();
assert!(desc.contains("streaming"));
assert!(desc.contains("not supported"));
}
#[test]
fn test_user_description_unknown() {
let error = LlmError::Unknown("mystery".to_string());
let desc = error.user_description();
assert!(desc.contains("mystery"));
}
#[test]
fn test_user_description_api_error() {
let error = LlmError::ApiError("server crashed".to_string());
let desc = error.user_description();
assert!(desc.contains("Retrying"));
}
#[test]
fn test_user_description_provider_error() {
let error = LlmError::ProviderError("internal failure".to_string());
let desc = error.user_description();
assert!(desc.contains("Retrying"));
}
#[test]
fn test_user_description_serialization() {
let json_err = serde_json::from_str::<serde_json::Value>("bad").unwrap_err();
let error = LlmError::SerializationError(json_err);
let desc = error.user_description();
assert!(desc.contains("parse"));
}
#[test]
fn test_user_description_config() {
let error = LlmError::ConfigError("missing field".to_string());
let desc = error.user_description();
assert!(desc.contains("Configuration"));
}
#[test]
fn test_user_description_invalid_request() {
let error = LlmError::InvalidRequest("empty prompt".to_string());
let desc = error.user_description();
assert!(desc.contains("empty prompt"));
}
#[test]
fn test_api_error_500_server_backoff() {
let error = LlmError::ApiError("HTTP 500 internal server error".to_string());
let strategy = error.retry_strategy();
match strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => {
assert_eq!(max_attempts, 2); }
_ => panic!("Expected ExponentialBackoff for ApiError"),
}
}
#[test]
fn test_api_error_502_server_backoff() {
let error = LlmError::ApiError("502 bad gateway".to_string());
assert!(matches!(
error.retry_strategy(),
RetryStrategy::ExponentialBackoff { .. }
));
}
#[test]
fn test_api_error_503_server_backoff() {
let error = LlmError::ApiError("503 service unavailable".to_string());
assert!(matches!(
error.retry_strategy(),
RetryStrategy::ExponentialBackoff { .. }
));
}
#[test]
fn test_provider_error_server_backoff() {
let error = LlmError::ProviderError("internal issue".to_string());
let strategy = error.retry_strategy();
match strategy {
RetryStrategy::ExponentialBackoff {
base_delay,
max_delay,
max_attempts,
} => {
assert_eq!(base_delay, Duration::from_secs(1));
assert_eq!(max_delay, Duration::from_secs(60));
assert_eq!(max_attempts, 3);
}
_ => panic!("Expected server_backoff for ProviderError"),
}
}
#[test]
fn test_unknown_error_retry_strategy() {
let error = LlmError::Unknown("something".to_string());
let strategy = error.retry_strategy();
match strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => {
assert_eq!(max_attempts, 2);
}
_ => panic!("Expected ExponentialBackoff for Unknown"),
}
}
#[test]
fn test_serialization_error_retry_strategy() {
let json_err = serde_json::from_str::<serde_json::Value>("bad").unwrap_err();
let error = LlmError::SerializationError(json_err);
let strategy = error.retry_strategy();
assert!(matches!(strategy, RetryStrategy::ExponentialBackoff { .. }));
}
#[test]
fn test_api_error_non_5xx_retry_strategy() {
let error = LlmError::ApiError("generic error".to_string());
let strategy = error.retry_strategy();
match strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => {
assert_eq!(max_attempts, 2);
}
_ => panic!("Expected ExponentialBackoff for generic ApiError"),
}
}
#[test]
fn test_config_error_no_retry() {
let error = LlmError::ConfigError("bad config".to_string());
assert!(matches!(error.retry_strategy(), RetryStrategy::NoRetry));
assert!(!error.is_recoverable());
}
#[test]
fn test_not_supported_no_retry() {
let error = LlmError::NotSupported("embeddings".to_string());
assert!(matches!(error.retry_strategy(), RetryStrategy::NoRetry));
assert!(!error.is_recoverable());
}
#[test]
fn test_server_backoff_values() {
let strategy = RetryStrategy::server_backoff();
match strategy {
RetryStrategy::ExponentialBackoff {
base_delay,
max_delay,
max_attempts,
} => {
assert_eq!(base_delay, Duration::from_secs(1));
assert_eq!(max_delay, Duration::from_secs(60));
assert_eq!(max_attempts, 3);
}
_ => panic!("Expected ExponentialBackoff"),
}
}
#[test]
fn test_network_backoff_values() {
let strategy = RetryStrategy::network_backoff();
match strategy {
RetryStrategy::ExponentialBackoff {
base_delay,
max_delay,
max_attempts,
} => {
assert_eq!(base_delay, Duration::from_millis(125));
assert_eq!(max_delay, Duration::from_secs(30));
assert_eq!(max_attempts, 5);
}
_ => panic!("Expected ExponentialBackoff"),
}
}
#[test]
fn test_reduce_context_should_retry() {
let strategy = RetryStrategy::ReduceContext;
assert!(strategy.should_retry());
}
#[test]
fn test_wait_and_retry_should_retry() {
let strategy = RetryStrategy::WaitAndRetry {
wait: Duration::from_secs(1),
};
assert!(strategy.should_retry());
}
#[test]
fn test_is_recoverable_network() {
assert!(LlmError::NetworkError("fail".to_string()).is_recoverable());
}
#[test]
fn test_is_recoverable_timeout() {
assert!(LlmError::Timeout.is_recoverable());
}
#[test]
fn test_is_recoverable_rate_limited() {
assert!(LlmError::RateLimited("wait".to_string()).is_recoverable());
}
#[test]
fn test_is_not_recoverable_invalid_request() {
assert!(!LlmError::InvalidRequest("bad".to_string()).is_recoverable());
}
#[test]
fn test_is_not_recoverable_model_not_found() {
assert!(!LlmError::ModelNotFound("x".to_string()).is_recoverable());
}
}