use std::time::Duration;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("authentication failed: {0}")]
Auth(String),
#[error("rate limit exceeded (retry after {retry_after:?})")]
RateLimit {
retry_after: Option<Duration>,
},
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("context too long: requested {requested_tokens} but max is {max_tokens}")]
ContextTooLong {
max_tokens: u32,
requested_tokens: u32,
},
#[error("model unavailable: {0}")]
ModelUnavailable(String),
#[error("server error (HTTP {status}): {body}")]
ServerError {
status: u16,
body: String,
},
#[error("upstream server fault: {0}")]
UpstreamServerFault(String),
#[error("content filter triggered: {0}")]
ContentFilter(String),
#[error("network error: {0}")]
Network(Box<dyn std::error::Error + Send + Sync>),
#[error("stream interrupted mid-response: {0}")]
StreamInterrupted(String),
#[error("operation cancelled")]
Cancelled,
#[error("adapter error: {0}")]
Adapter(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("stream idle for {0:?}")]
StreamIdle(std::time::Duration),
}
pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn network(e: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Network(Box::new(e))
}
pub fn adapter(e: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Adapter(Box::new(e))
}
pub fn stream_interrupted(inner: impl std::fmt::Display) -> Self {
Self::StreamInterrupted(inner.to_string())
}
}
#[must_use]
pub fn is_auth_error(err: &Error) -> bool {
match err {
Error::Auth(_) => true,
Error::ServerError { status, .. } => *status == 401 || *status == 403,
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stream_interrupted_display_uses_clear_prefix() {
let e = Error::stream_interrupted("hyper: connection reset by peer");
assert_eq!(
e.to_string(),
"stream interrupted mid-response: hyper: connection reset by peer"
);
}
#[test]
fn stream_interrupted_constructor_accepts_display() {
let io = std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof");
let e = Error::stream_interrupted(io);
assert!(matches!(e, Error::StreamInterrupted(_)));
assert!(e.to_string().contains("eof"));
}
#[test]
fn is_auth_error_matches_explicit_auth_variant() {
assert!(is_auth_error(&Error::Auth("bad key".into())));
}
#[test]
fn is_auth_error_matches_server_error_401() {
assert!(is_auth_error(&Error::ServerError {
status: 401,
body: "unauthorized".into(),
}));
}
#[test]
fn is_auth_error_matches_server_error_403() {
assert!(is_auth_error(&Error::ServerError {
status: 403,
body: "forbidden".into(),
}));
}
#[test]
fn is_auth_error_rejects_other_server_errors() {
assert!(!is_auth_error(&Error::ServerError {
status: 500,
body: "boom".into(),
}));
assert!(!is_auth_error(&Error::ServerError {
status: 429,
body: "slow down".into(),
}));
}
#[test]
fn is_auth_error_rejects_unrelated_variants() {
assert!(!is_auth_error(&Error::RateLimit { retry_after: None }));
assert!(!is_auth_error(&Error::InvalidRequest("nope".into())));
assert!(!is_auth_error(&Error::Cancelled));
assert!(!is_auth_error(&Error::Network(Box::new(
std::io::Error::new(std::io::ErrorKind::ConnectionReset, "x")
))));
}
}