use std::time::Duration;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum BigRagError {
#[error("bad request: {message}")]
BadRequest {
message: String,
status: u16,
},
#[error("authentication failed: {message}")]
Authentication {
message: String,
},
#[error("not found: {message}")]
NotFound {
message: String,
},
#[error("conflict: {message}")]
Conflict {
message: String,
},
#[error("rate limited")]
RateLimited,
#[error("server error: {message}")]
ServerError {
message: String,
status: u16,
},
#[error("request timed out after {0:?}")]
Timeout(Duration),
#[error("connection failed: {0}")]
Connection(String),
#[error("failed to read file: {0}")]
FileRead(#[from] std::io::Error),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("unexpected status {status}: {message}")]
Api {
status: u16,
message: String,
},
}
impl BigRagError {
pub fn status(&self) -> Option<u16> {
match self {
Self::BadRequest { status, .. } => Some(*status),
Self::Authentication { .. } => Some(401),
Self::NotFound { .. } => Some(404),
Self::Conflict { .. } => Some(409),
Self::RateLimited => Some(429),
Self::ServerError { status, .. } => Some(*status),
Self::Api { status, .. } => Some(*status),
Self::Timeout(_) | Self::Connection(_) | Self::FileRead(_) | Self::Serialization(_) => {
None
}
}
}
pub fn is_retryable(&self) -> bool {
matches!(
self,
Self::RateLimited | Self::ServerError { .. } | Self::Timeout(_) | Self::Connection(_)
)
}
}
pub(crate) async fn parse_error_response(response: reqwest::Response) -> BigRagError {
let status = response.status().as_u16();
let body: serde_json::Value = response.json().await.unwrap_or_default();
let message = body
.get("detail")
.and_then(|v| v.as_str())
.or_else(|| {
body.get("error")
.and_then(|e| e.get("message"))
.and_then(|v| v.as_str())
})
.or_else(|| body.get("message").and_then(|v| v.as_str()))
.unwrap_or("unknown error")
.to_string();
match status {
400 => BigRagError::BadRequest { message, status },
401 | 403 => BigRagError::Authentication { message },
404 => BigRagError::NotFound { message },
409 => BigRagError::Conflict { message },
429 => BigRagError::RateLimited,
500..=599 => BigRagError::ServerError { message, status },
_ => BigRagError::Api { status, message },
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_status_returns_http_code() {
let err = BigRagError::NotFound {
message: "not found".into(),
};
assert_eq!(err.status(), Some(404));
}
#[test]
fn test_status_returns_none_for_non_http_errors() {
let err = BigRagError::Timeout(Duration::from_secs(5));
assert_eq!(err.status(), None);
}
#[test]
fn test_is_retryable_for_server_error() {
let err = BigRagError::ServerError {
message: "bad gateway".into(),
status: 502,
};
assert!(err.is_retryable());
}
#[test]
fn test_is_retryable_for_rate_limit() {
let err = BigRagError::RateLimited;
assert!(err.is_retryable());
}
#[test]
fn test_is_not_retryable_for_not_found() {
let err = BigRagError::NotFound {
message: "gone".into(),
};
assert!(!err.is_retryable());
}
#[test]
fn test_is_retryable_for_timeout() {
let err = BigRagError::Timeout(Duration::from_secs(30));
assert!(err.is_retryable());
}
#[test]
fn test_display_formatting() {
let err = BigRagError::BadRequest {
message: "invalid name".into(),
status: 400,
};
assert_eq!(err.to_string(), "bad request: invalid name");
}
}