use serde_json::Value;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum LlmError {
#[error("HTTP error (status={status:?}): {message}")]
Http {
status: Option<http::StatusCode>,
message: String,
retryable: bool,
},
#[error("Authentication error: {0}")]
Auth(String),
#[error("Invalid request: {0}")]
InvalidRequest(String),
#[error("Provider error ({code}): {message}")]
Provider {
code: String,
message: String,
retryable: bool,
},
#[error("Response format error: {message}")]
ResponseFormat {
message: String,
raw: String,
},
#[error("Schema validation error: {message}")]
SchemaValidation {
message: String,
schema: Value,
actual: Value,
},
#[error("Tool execution error ({tool_name}): {source}")]
ToolExecution {
tool_name: String,
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("Retry exhausted after {attempts} attempts: {last_error}")]
RetryExhausted {
attempts: u32,
#[source]
last_error: Box<LlmError>,
},
#[error("Operation timed out after {elapsed_ms}ms")]
Timeout {
elapsed_ms: u64,
},
#[error("max nesting depth exceeded (current: {current}, limit: {limit})")]
MaxDepthExceeded {
current: u32,
limit: u32,
},
}
impl LlmError {
pub fn is_retryable(&self) -> bool {
match self {
Self::Http { retryable, .. } | Self::Provider { retryable, .. } => *retryable,
Self::Timeout { .. } => true,
_ => false,
}
}
}
impl From<serde_json::Error> for LlmError {
fn from(err: serde_json::Error) -> Self {
Self::ResponseFormat {
message: err.to_string(),
raw: String::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display_http() {
let err = LlmError::Http {
status: Some(http::StatusCode::TOO_MANY_REQUESTS),
message: "rate limited".into(),
retryable: true,
};
let display = format!("{err}");
assert!(display.contains("429"));
assert!(display.contains("rate limited"));
}
#[test]
fn test_error_display_auth() {
let err = LlmError::Auth("bad key".into());
assert!(format!("{err}").contains("bad key"));
}
#[test]
fn test_error_display_invalid_request() {
let err = LlmError::InvalidRequest("missing model".into());
assert!(format!("{err}").contains("missing model"));
}
#[test]
fn test_error_display_provider() {
let err = LlmError::Provider {
code: "overloaded".into(),
message: "server busy".into(),
retryable: true,
};
let display = format!("{err}");
assert!(display.contains("overloaded"));
assert!(display.contains("server busy"));
}
#[test]
fn test_error_display_response_format() {
let err = LlmError::ResponseFormat {
message: "not json".into(),
raw: "hello".into(),
};
assert!(format!("{err}").contains("not json"));
}
#[test]
fn test_error_display_schema_validation() {
let err = LlmError::SchemaValidation {
message: "missing field".into(),
schema: serde_json::json!({"type": "object"}),
actual: serde_json::json!({}),
};
assert!(format!("{err}").contains("missing field"));
}
#[test]
fn test_error_display_tool_execution() {
let err = LlmError::ToolExecution {
tool_name: "calculator".into(),
source: Box::new(std::io::Error::other("boom")),
};
let display = format!("{err}");
assert!(display.contains("calculator"));
assert!(display.contains("boom"));
}
#[test]
fn test_error_display_retry_exhausted() {
let inner = LlmError::Http {
status: Some(http::StatusCode::INTERNAL_SERVER_ERROR),
message: "server error".into(),
retryable: true,
};
let err = LlmError::RetryExhausted {
attempts: 3,
last_error: Box::new(inner),
};
let display = format!("{err}");
assert!(display.contains('3'));
assert!(display.contains("server error"));
}
#[test]
fn test_error_display_timeout() {
let err = LlmError::Timeout { elapsed_ms: 5000 };
assert!(format!("{err}").contains("5000"));
}
#[test]
fn test_error_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<LlmError>();
}
#[test]
fn test_error_retryable_http() {
let err = LlmError::Http {
status: Some(http::StatusCode::TOO_MANY_REQUESTS),
message: "rate limited".into(),
retryable: true,
};
assert!(matches!(
err,
LlmError::Http {
retryable: true,
..
}
));
}
#[test]
fn test_error_retryable_provider() {
let err = LlmError::Provider {
code: "bad_request".into(),
message: "invalid".into(),
retryable: false,
};
assert!(matches!(
err,
LlmError::Provider {
retryable: false,
..
}
));
}
#[test]
fn test_error_retry_exhausted_nests() {
let inner = LlmError::Auth("expired".into());
let err = LlmError::RetryExhausted {
attempts: 2,
last_error: Box::new(inner),
};
assert!(matches!(
&err,
LlmError::RetryExhausted { last_error, .. }
if matches!(last_error.as_ref(), LlmError::Auth(_))
));
}
#[test]
fn test_error_retry_exhausted_source_chain() {
use std::error::Error;
let inner = LlmError::Auth("expired".into());
let err = LlmError::RetryExhausted {
attempts: 3,
last_error: Box::new(inner),
};
let source = err.source().expect("RetryExhausted should have a source");
assert!(format!("{source}").contains("expired"));
}
#[test]
fn test_error_source_trait() {
use std::error::Error;
let err = LlmError::ToolExecution {
tool_name: "test".into(),
source: Box::new(std::io::Error::new(std::io::ErrorKind::NotFound, "gone")),
};
assert!(err.source().is_some());
}
#[test]
fn test_from_serde_json_error() {
let json_err = serde_json::from_str::<serde_json::Value>("not valid json").unwrap_err();
let llm_err: LlmError = json_err.into();
assert!(matches!(llm_err, LlmError::ResponseFormat { .. }));
}
#[test]
fn test_error_display_max_depth_exceeded() {
let err = LlmError::MaxDepthExceeded {
current: 3,
limit: 3,
};
let display = format!("{err}");
assert!(display.contains("max nesting depth exceeded"));
assert!(display.contains("current: 3"));
assert!(display.contains("limit: 3"));
}
#[test]
fn test_error_max_depth_not_retryable() {
let err = LlmError::MaxDepthExceeded {
current: 2,
limit: 2,
};
assert!(!err.is_retryable());
}
}