use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)]
pub enum Error {
#[error("Configuration error: {0}")]
Configuration(String),
#[error("API error: {0}")]
Api(#[from] ApiError),
#[error("Database error: {0}")]
Database(#[from] safebrowsing_db::DatabaseError),
#[error("Invalid URL: {0}")]
InvalidUrl(String),
#[error("Hash error: {0}")]
Hash(String),
#[error("Timeout: {0}")]
Timeout(String),
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Protobuf error: {0}")]
Protobuf(#[from] prost::DecodeError),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("URL parse error: {0}")]
UrlParse(#[from] url::ParseError),
#[error("Internal error: {0}")]
Internal(String),
#[error("Cache error: {0}")]
Cache(String),
#[error("Encoding error: {0}")]
Encoding(String),
#[error("Validation error: {0}")]
Validation(String),
}
#[derive(Error, Debug)]
pub enum ApiError {
#[error("HTTP {status}: {message}")]
HttpStatus { status: u16, message: String },
#[error("Invalid response format: {0}")]
InvalidResponse(String),
#[error("Rate limited, retry after: {retry_after:?}")]
RateLimit {
retry_after: Option<std::time::Duration>,
},
#[error("API quota exceeded")]
QuotaExceeded,
#[error("Authentication failed: {0}")]
Authentication(String),
#[error("Server unavailable: {0}")]
ServerUnavailable(String),
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Network error: {0}")]
Network(String),
}
#[derive(Error, Debug)]
pub enum DatabaseError {
#[error("Database not initialized")]
NotInitialized,
#[error("Database is stale (last update: {last_update:?})")]
Stale {
last_update: Option<std::time::Instant>,
},
#[error("Database corruption detected: {0}")]
Corruption(String),
#[error("Incompatible database version: expected {expected}, found {found}")]
VersionMismatch { expected: String, found: String },
#[error("File operation failed: {0}")]
FileOperation(String),
#[error("Checksum mismatch: expected {expected:x}, found {found:x}")]
ChecksumMismatch { expected: u64, found: u64 },
#[error("Missing required data: {0}")]
MissingData(String),
#[error("Update failed: {0}")]
UpdateFailed(String),
#[error("Concurrent access error: {0}")]
ConcurrentAccess(String),
}
impl From<safebrowsing_api::Error> for Error {
fn from(err: safebrowsing_api::Error) -> Self {
match err {
safebrowsing_api::Error::Api(api_err) => Error::Api(ApiError::from(api_err)),
safebrowsing_api::Error::Http(http_err) => Error::Http(http_err),
safebrowsing_api::Error::Protobuf(msg) => Error::Protobuf(prost::DecodeError::new(msg)),
safebrowsing_api::Error::Configuration(msg) => Error::Configuration(msg),
}
}
}
impl From<safebrowsing_url::UrlError> for Error {
fn from(err: safebrowsing_url::UrlError) -> Self {
match err {
safebrowsing_url::UrlError::Parse(parse_err) => Error::UrlParse(parse_err),
safebrowsing_url::UrlError::InvalidHost(msg) => Error::InvalidUrl(msg),
safebrowsing_url::UrlError::Idna(msg) => {
Error::InvalidUrl(format!("IDNA error: {msg}"))
}
safebrowsing_url::UrlError::InvalidFormat(msg) => Error::InvalidUrl(msg),
}
}
}
impl From<safebrowsing_api::ApiError> for ApiError {
fn from(err: safebrowsing_api::ApiError) -> Self {
match err {
safebrowsing_api::ApiError::BadRequest(msg) => ApiError::BadRequest(msg),
safebrowsing_api::ApiError::Authentication(msg) => ApiError::Authentication(msg),
safebrowsing_api::ApiError::QuotaExceeded => ApiError::QuotaExceeded,
safebrowsing_api::ApiError::RateLimit { retry_after } => {
ApiError::RateLimit { retry_after }
}
safebrowsing_api::ApiError::ServerUnavailable(msg) => ApiError::ServerUnavailable(msg),
safebrowsing_api::ApiError::HttpStatus { status, message } => {
ApiError::HttpStatus { status, message }
}
}
}
}
impl From<&str> for Error {
fn from(msg: &str) -> Self {
Error::Internal(msg.to_string())
}
}
impl From<String> for Error {
fn from(msg: String) -> Self {
Error::Internal(msg)
}
}
pub trait ErrorContext<T> {
fn with_context<F>(self, f: F) -> Result<T>
where
F: FnOnce() -> String;
fn context(self, msg: &'static str) -> Result<T>;
}
impl<T, E> ErrorContext<T> for std::result::Result<T, E>
where
E: Into<Error>,
{
fn with_context<F>(self, f: F) -> Result<T>
where
F: FnOnce() -> String,
{
self.map_err(|e| {
let base_error = e.into();
Error::Internal(format!("{}: {}", f(), base_error))
})
}
fn context(self, msg: &'static str) -> Result<T> {
self.with_context(|| msg.to_string())
}
}
impl Error {
pub fn is_retryable(&self) -> bool {
match self {
Error::Api(api_error) => api_error.is_retryable(),
Error::Database(safebrowsing_db::DatabaseError::Stale(_)) => true,
Error::Http(req_error) => {
req_error.is_timeout() || req_error.is_connect()
}
Error::Timeout(_) => true,
_ => false,
}
}
pub fn is_permanent(&self) -> bool {
matches!(
self,
Error::Configuration(_)
| Error::InvalidUrl(_)
| Error::Api(ApiError::Authentication(_))
| Error::Api(ApiError::BadRequest(_))
| Error::Database(safebrowsing_db::DatabaseError::DecodeError(_))
| Error::Database(safebrowsing_db::DatabaseError::InvalidChecksum { .. })
| Error::Validation(_)
)
}
pub fn user_message(&self) -> String {
match self {
Error::Configuration(_) => "Configuration issue detected".to_string(),
Error::Api(ApiError::Authentication(_)) => {
"Invalid API key or authentication failed".to_string()
}
Error::Api(ApiError::QuotaExceeded) => {
"API quota exceeded, please try again later".to_string()
}
Error::Api(ApiError::RateLimit { .. }) => {
"Rate limited by API, please wait before retrying".to_string()
}
Error::InvalidUrl(url) => format!("Invalid URL format: {url}"),
Error::Database(safebrowsing_db::DatabaseError::Stale(_)) => {
"Database needs updating".to_string()
}
Error::Database(safebrowsing_db::DatabaseError::DecodeError(_)) => {
"Database corruption detected, please reset".to_string()
}
Error::Timeout(_) => "Operation timed out, please try again".to_string(),
Error::Http(_) => "Network connection failed".to_string(),
_ => "An unexpected error occurred".to_string(),
}
}
}
impl ApiError {
pub fn is_retryable(&self) -> bool {
match self {
ApiError::HttpStatus { status, .. } => {
*status >= 500 || *status == 429
}
ApiError::RateLimit { .. } => true,
ApiError::ServerUnavailable(_) => true,
ApiError::Network(_) => true,
ApiError::Authentication(_) => false,
ApiError::BadRequest(_) => false,
ApiError::QuotaExceeded => false,
ApiError::InvalidResponse(_) => false,
}
}
pub fn from_status(status: u16, body: &str) -> Self {
match status {
401 => ApiError::Authentication("Invalid API key".to_string()),
403 => ApiError::QuotaExceeded,
429 => ApiError::RateLimit { retry_after: None },
400 => ApiError::BadRequest(body.to_string()),
503 => ApiError::ServerUnavailable("Service temporarily unavailable".to_string()),
_ => ApiError::HttpStatus {
status,
message: body.to_string(),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_context() {
let result: std::result::Result<(), &str> = Err("test error");
let error = result.context("operation failed").unwrap_err();
assert!(matches!(error, Error::Internal(_)));
assert!(error.to_string().contains("operation failed"));
assert!(error.to_string().contains("test error"));
}
#[test]
fn test_retryable_errors() {
assert!(Error::Timeout("test".to_string()).is_retryable());
assert!(Error::Api(ApiError::RateLimit { retry_after: None }).is_retryable());
assert!(!Error::Configuration("test".to_string()).is_retryable());
assert!(!Error::InvalidUrl("test".to_string()).is_retryable());
}
#[test]
fn test_permanent_errors() {
assert!(Error::Configuration("test".to_string()).is_permanent());
assert!(Error::InvalidUrl("test".to_string()).is_permanent());
assert!(Error::Api(ApiError::Authentication("test".to_string())).is_permanent());
assert!(!Error::Timeout("test".to_string()).is_permanent());
}
#[test]
fn test_api_error_from_status() {
assert!(matches!(
ApiError::from_status(401, "unauthorized"),
ApiError::Authentication(_)
));
assert!(matches!(
ApiError::from_status(403, "forbidden"),
ApiError::QuotaExceeded
));
assert!(matches!(
ApiError::from_status(429, "rate limit"),
ApiError::RateLimit { .. }
));
assert!(matches!(
ApiError::from_status(500, "server error"),
ApiError::HttpStatus { status: 500, .. }
));
}
#[test]
fn test_user_messages() {
let config_error = Error::Configuration("test".to_string());
assert_eq!(config_error.user_message(), "Configuration issue detected");
let auth_error = Error::Api(ApiError::Authentication("test".to_string()));
assert_eq!(
auth_error.user_message(),
"Invalid API key or authentication failed"
);
let url_error = Error::InvalidUrl("invalid-url".to_string());
assert_eq!(url_error.user_message(), "Invalid URL format: invalid-url");
}
}