use std::fmt;
#[derive(Debug)]
#[non_exhaustive]
pub enum ProviderError {
AuthenticationFailed { provider_message: String },
PermissionDenied { provider_message: String },
ModelNotFound { provider_message: String },
ContextWindowExceeded { provider_message: String },
SafetyFilterTriggered { provider_message: String },
RateLimited {
message: String,
status: u16,
retry_after_ms: Option<u64>,
},
UnexpectedStatus {
status: u16,
message: String,
retryable: bool,
retry_after_ms: Option<u64>,
},
ConnectionFailed { reason: String },
StreamInterrupted { reason: String },
InvalidResponse { reason: String },
}
impl ProviderError {
pub fn is_retryable(&self) -> bool {
matches!(
self,
ProviderError::RateLimited { .. }
| ProviderError::ConnectionFailed { .. }
| ProviderError::StreamInterrupted { .. }
| ProviderError::UnexpectedStatus {
retryable: true,
..
}
)
}
pub fn retry_after_ms(&self) -> Option<u64> {
match self {
ProviderError::RateLimited { retry_after_ms, .. } => *retry_after_ms,
ProviderError::UnexpectedStatus { retry_after_ms, .. } => *retry_after_ms,
_ => None,
}
}
}
impl fmt::Display for ProviderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ProviderError::AuthenticationFailed { provider_message } => {
write!(f, "Authentication failed: {provider_message}")
}
ProviderError::PermissionDenied { provider_message } => {
write!(f, "Permission denied: {provider_message}")
}
ProviderError::ModelNotFound { provider_message } => {
write!(f, "Model not found: {provider_message}")
}
ProviderError::ContextWindowExceeded { provider_message } => {
write!(f, "Context window exceeded: {provider_message}")
}
ProviderError::SafetyFilterTriggered { provider_message } => {
write!(f, "Safety filter triggered: {provider_message}")
}
ProviderError::RateLimited {
message, status, ..
} => {
write!(f, "Rate limited (status {status}): {message}")
}
ProviderError::UnexpectedStatus {
status,
message,
retryable,
..
} => {
write!(
f,
"HTTP error (status {status}): {message} (retryable: {retryable})"
)
}
ProviderError::ConnectionFailed { reason } => {
write!(f, "Connection failed: {reason}")
}
ProviderError::StreamInterrupted { reason } => {
write!(f, "Stream interrupted: {reason}")
}
ProviderError::InvalidResponse { reason } => {
write!(f, "Invalid response: {reason}")
}
}
}
}
impl std::error::Error for ProviderError {}
pub type ProviderResult<T> = std::result::Result<T, ProviderError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rate_limited_is_retryable_and_carries_retry_after() {
let err = ProviderError::RateLimited {
message: "slow down".into(),
status: 429,
retry_after_ms: Some(500),
};
assert!(err.is_retryable());
assert_eq!(err.retry_after_ms(), Some(500));
}
#[test]
fn connection_failed_is_retryable() {
let err = ProviderError::ConnectionFailed {
reason: "dns".into(),
};
assert!(err.is_retryable());
assert_eq!(err.retry_after_ms(), None);
}
#[test]
fn stream_interrupted_is_retryable() {
let err = ProviderError::StreamInterrupted {
reason: "error decoding response body".into(),
};
assert!(err.is_retryable());
assert_eq!(err.retry_after_ms(), None);
assert!(err.to_string().starts_with("Stream interrupted:"));
}
#[test]
fn unexpected_status_honours_retryable_flag() {
let retryable = ProviderError::UnexpectedStatus {
status: 503,
message: "unavailable".into(),
retryable: true,
retry_after_ms: None,
};
let terminal = ProviderError::UnexpectedStatus {
status: 418,
message: "teapot".into(),
retryable: false,
retry_after_ms: None,
};
assert!(retryable.is_retryable());
assert!(!terminal.is_retryable());
}
#[test]
fn classified_variants_are_not_retryable() {
for err in [
ProviderError::AuthenticationFailed {
provider_message: String::new(),
},
ProviderError::PermissionDenied {
provider_message: String::new(),
},
ProviderError::ModelNotFound {
provider_message: String::new(),
},
ProviderError::ContextWindowExceeded {
provider_message: String::new(),
},
ProviderError::SafetyFilterTriggered {
provider_message: String::new(),
},
ProviderError::InvalidResponse {
reason: String::new(),
},
] {
assert!(!err.is_retryable(), "expected terminal: {err:?}");
}
}
#[test]
fn all_variants_display_non_empty() {
let variants = [
ProviderError::AuthenticationFailed {
provider_message: "bad key".into(),
},
ProviderError::PermissionDenied {
provider_message: "nope".into(),
},
ProviderError::ModelNotFound {
provider_message: "no such model".into(),
},
ProviderError::ContextWindowExceeded {
provider_message: "too long".into(),
},
ProviderError::SafetyFilterTriggered {
provider_message: "blocked".into(),
},
ProviderError::RateLimited {
message: "slow".into(),
status: 429,
retry_after_ms: Some(1000),
},
ProviderError::UnexpectedStatus {
status: 500,
message: "boom".into(),
retryable: true,
retry_after_ms: None,
},
ProviderError::ConnectionFailed {
reason: "dns".into(),
},
ProviderError::StreamInterrupted {
reason: "chunk read error".into(),
},
ProviderError::InvalidResponse {
reason: "bad json".into(),
},
];
for v in &variants {
assert!(!format!("{v}").is_empty(), "empty display: {v:?}");
}
}
}