pub type LlmResult<T> = std::result::Result<T, LlmError>;
#[derive(Debug, thiserror::Error)]
pub enum LlmError {
#[error("configuration error: {0}")]
Config(String),
#[error("authentication error: {0}")]
Auth(String),
#[error("rate limited: {message}")]
RateLimited {
message: String,
retry_after_secs: Option<u64>,
},
#[error("timeout: {0}")]
Timeout(String),
#[error("model not found: {0}")]
ModelNotFound(String),
#[error("content filtered: {0}")]
ContentFiltered(String),
#[error("token limit exceeded: {0}")]
TokenLimitExceeded(String),
#[error("provider error ({status}): {message}")]
Provider { status: u16, message: String },
#[error("connection error: {0}")]
Connection(String),
#[error("serialization error: {0}")]
Serialization(String),
#[error("transient error: {0}")]
Transient(String),
#[error("internal error: {0}")]
Internal(String),
}
impl LlmError {
pub fn timeout(duration: std::time::Duration) -> Self {
let secs = duration.as_secs();
let ms = duration.subsec_millis();
if secs >= 60 {
Self::Timeout(format!("{}m{}s", secs / 60, secs % 60))
} else {
Self::Timeout(format!("{secs}.{ms:03}s"))
}
}
pub fn is_retryable(&self) -> bool {
matches!(
self,
LlmError::RateLimited { .. }
| LlmError::Timeout(_)
| LlmError::Connection(_)
| LlmError::Transient(_)
)
}
pub fn is_auth(&self) -> bool {
matches!(self, LlmError::Auth(_))
}
pub fn is_rate_limited(&self) -> bool {
matches!(self, LlmError::RateLimited { .. })
}
pub fn retry_after_secs(&self) -> Option<u64> {
match self {
LlmError::RateLimited {
retry_after_secs, ..
} => *retry_after_secs,
_ => None,
}
}
pub fn truncated_message(&self, max_len: usize) -> String {
let msg = self.to_string();
if msg.len() <= max_len {
msg
} else {
let boundary = msg
.char_indices()
.take_while(|(i, _)| *i <= max_len)
.last()
.map(|(i, _)| i)
.unwrap_or(0);
format!("{}...[truncated]", &msg[..boundary])
}
}
}
#[cfg(feature = "openai")]
impl From<reqwest::Error> for LlmError {
fn from(err: reqwest::Error) -> Self {
if err.is_timeout() {
LlmError::Timeout("request timed out".to_string())
} else if err.is_connect() {
LlmError::Connection(err.to_string())
} else if let Some(status) = err.status() {
match status.as_u16() {
401 | 403 => LlmError::Auth(err.to_string()),
429 => LlmError::RateLimited {
message: err.to_string(),
retry_after_secs: None,
},
_ => LlmError::Provider {
status: status.as_u16(),
message: err.to_string(),
},
}
} else {
LlmError::Connection(err.to_string())
}
}
}
impl From<serde_json::Error> for LlmError {
fn from(err: serde_json::Error) -> Self {
LlmError::Serialization(err.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_retryable() {
assert!(LlmError::RateLimited {
message: "slow down".into(),
retry_after_secs: None,
}
.is_retryable());
assert!(LlmError::Timeout("t".into()).is_retryable());
assert!(LlmError::Connection("c".into()).is_retryable());
assert!(LlmError::Transient("t".into()).is_retryable());
assert!(!LlmError::Auth("a".into()).is_retryable());
assert!(!LlmError::Config("c".into()).is_retryable());
}
#[test]
fn test_is_auth() {
assert!(LlmError::Auth("bad key".into()).is_auth());
assert!(!LlmError::Config("x".into()).is_auth());
}
#[test]
fn test_is_rate_limited() {
assert!(LlmError::RateLimited {
message: "x".into(),
retry_after_secs: Some(5),
}
.is_rate_limited());
assert!(!LlmError::Timeout("x".into()).is_rate_limited());
}
#[test]
fn test_retry_after_secs() {
let err = LlmError::RateLimited {
message: "x".into(),
retry_after_secs: Some(42),
};
assert_eq!(err.retry_after_secs(), Some(42));
let err2 = LlmError::Config("x".into());
assert_eq!(err2.retry_after_secs(), None);
}
#[test]
fn test_truncated_message_short() {
let err = LlmError::Config("short".into());
let msg = err.truncated_message(1000);
assert_eq!(msg, "configuration error: short");
}
#[test]
fn test_truncated_message_long() {
let long = "a".repeat(500);
let err = LlmError::Config(long);
let msg = err.truncated_message(50);
assert!(msg.len() < 80); assert!(msg.ends_with("...[truncated]"));
}
#[test]
fn test_truncated_message_utf8_boundary() {
let emoji_msg = "🔒".repeat(20); let err = LlmError::Config(emoji_msg);
let msg = err.truncated_message(26);
assert!(msg.ends_with("...[truncated]"));
assert!(msg.is_char_boundary(0));
}
#[test]
fn test_truncated_message_zero() {
let err = LlmError::Config("hello".into());
let msg = err.truncated_message(0);
assert!(msg.ends_with("...[truncated]"));
}
#[test]
fn test_timeout_display() {
let err = LlmError::timeout(std::time::Duration::from_millis(1500));
assert_eq!(err.to_string(), "timeout: 1.500s");
}
#[test]
fn test_timeout_display_minutes() {
let err = LlmError::timeout(std::time::Duration::from_secs(125));
assert_eq!(err.to_string(), "timeout: 2m5s");
}
}