use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt;
use std::time::Duration;
use super::ErrorKind;
#[derive(Debug)]
pub struct Error {
kind: ErrorKind,
message: Cow<'static, str>,
request_id: Option<String>,
retry_after: Option<Duration>,
source: Option<Box<dyn StdError + Send + Sync + 'static>>,
}
impl Error {
pub fn new(kind: ErrorKind, message: impl Into<Cow<'static, str>>) -> Self {
Self {
kind,
message: message.into(),
request_id: None,
retry_after: None,
source: None,
}
}
pub fn from_kind(kind: ErrorKind) -> Self {
let message = match kind {
ErrorKind::Unauthorized => "authentication failed",
ErrorKind::Forbidden => "permission denied",
ErrorKind::NotFound => "resource not found",
ErrorKind::InvalidArgument => "invalid argument",
ErrorKind::SchemaViolation => "schema violation",
ErrorKind::RateLimited => "rate limit exceeded",
ErrorKind::Unavailable => "service unavailable",
ErrorKind::Timeout => "request timed out",
ErrorKind::Internal => "internal server error",
ErrorKind::Cancelled => "request cancelled",
ErrorKind::CircuitOpen => "circuit breaker open",
ErrorKind::Connection => "connection failed",
ErrorKind::Protocol => "protocol error",
ErrorKind::Configuration => "configuration error",
ErrorKind::Unknown => "unknown error",
ErrorKind::Conflict => "resource conflict",
ErrorKind::Transport => "transport error",
ErrorKind::InvalidResponse => "invalid response",
};
Self::new(kind, message)
}
#[inline]
pub fn kind(&self) -> ErrorKind {
self.kind
}
#[inline]
pub fn request_id(&self) -> Option<&str> {
self.request_id.as_deref()
}
#[inline]
pub fn retry_after(&self) -> Option<Duration> {
self.retry_after
}
#[inline]
pub fn is_retriable(&self) -> bool {
self.kind.is_retriable()
}
#[must_use]
pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
self.request_id = Some(request_id.into());
self
}
#[must_use]
pub fn with_retry_after(mut self, duration: Duration) -> Self {
self.retry_after = Some(duration);
self
}
#[must_use]
pub fn with_source<E>(mut self, source: E) -> Self
where
E: StdError + Send + Sync + 'static,
{
self.source = Some(Box::new(source));
self
}
pub fn unauthorized(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(ErrorKind::Unauthorized, message)
}
pub fn forbidden(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(ErrorKind::Forbidden, message)
}
pub fn not_found(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(ErrorKind::NotFound, message)
}
pub fn invalid_argument(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(ErrorKind::InvalidArgument, message)
}
pub fn schema_violation(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(ErrorKind::SchemaViolation, message)
}
pub fn rate_limited(retry_after: Option<Duration>) -> Self {
let mut err = Self::from_kind(ErrorKind::RateLimited);
if let Some(duration) = retry_after {
err.retry_after = Some(duration);
}
err
}
pub fn unavailable(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(ErrorKind::Unavailable, message)
}
pub fn timeout(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(ErrorKind::Timeout, message)
}
pub fn internal(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(ErrorKind::Internal, message)
}
pub fn cancelled() -> Self {
Self::from_kind(ErrorKind::Cancelled)
}
pub fn circuit_open() -> Self {
Self::from_kind(ErrorKind::CircuitOpen)
}
pub fn connection(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(ErrorKind::Connection, message)
}
pub fn protocol(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(ErrorKind::Protocol, message)
}
pub fn configuration(message: impl Into<Cow<'static, str>>) -> Self {
Self::new(ErrorKind::Configuration, message)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.kind, self.message)?;
if let Some(ref request_id) = self.request_id {
write!(f, " (request_id: {})", request_id)?;
}
Ok(())
}
}
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
self.source
.as_ref()
.map(|e| e.as_ref() as &(dyn StdError + 'static))
}
}
impl From<ErrorKind> for Error {
fn from(kind: ErrorKind) -> Self {
Self::from_kind(kind)
}
}
impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Self {
let kind = match err.kind() {
std::io::ErrorKind::NotFound => ErrorKind::NotFound,
std::io::ErrorKind::PermissionDenied => ErrorKind::Forbidden,
std::io::ErrorKind::ConnectionRefused
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::NotConnected => ErrorKind::Connection,
std::io::ErrorKind::TimedOut => ErrorKind::Timeout,
_ => ErrorKind::Internal,
};
Error::new(kind, err.to_string()).with_source(err)
}
}
impl From<url::ParseError> for Error {
fn from(err: url::ParseError) -> Self {
Error::configuration(format!("invalid URL: {}", err)).with_source(err)
}
}
impl From<serde_json::Error> for Error {
fn from(err: serde_json::Error) -> Self {
Error::protocol(format!("JSON error: {}", err)).with_source(err)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_new() {
let err = Error::new(ErrorKind::InvalidArgument, "test message");
assert_eq!(err.kind(), ErrorKind::InvalidArgument);
assert!(err.to_string().contains("test message"));
assert!(err.request_id().is_none());
assert!(err.retry_after().is_none());
}
#[test]
fn test_error_from_kind() {
let err = Error::from_kind(ErrorKind::Unauthorized);
assert_eq!(err.kind(), ErrorKind::Unauthorized);
assert!(err.to_string().contains("authentication failed"));
}
#[test]
fn test_error_with_request_id() {
let err = Error::new(ErrorKind::Internal, "server error").with_request_id("req_abc123");
assert_eq!(err.request_id(), Some("req_abc123"));
assert!(err.to_string().contains("req_abc123"));
}
#[test]
fn test_error_with_retry_after() {
let err = Error::rate_limited(Some(Duration::from_secs(30)));
assert_eq!(err.kind(), ErrorKind::RateLimited);
assert_eq!(err.retry_after(), Some(Duration::from_secs(30)));
}
#[test]
fn test_error_is_retriable() {
assert!(Error::from_kind(ErrorKind::Timeout).is_retriable());
assert!(Error::from_kind(ErrorKind::Unavailable).is_retriable());
assert!(Error::from_kind(ErrorKind::RateLimited).is_retriable());
assert!(!Error::from_kind(ErrorKind::Unauthorized).is_retriable());
assert!(!Error::from_kind(ErrorKind::NotFound).is_retriable());
}
#[test]
fn test_error_with_source() {
let io_err = std::io::Error::other("underlying error");
let err = Error::new(ErrorKind::Connection, "connection failed").with_source(io_err);
assert!(err.source().is_some());
}
#[test]
fn test_convenience_constructors() {
assert_eq!(Error::unauthorized("test").kind(), ErrorKind::Unauthorized);
assert_eq!(Error::forbidden("test").kind(), ErrorKind::Forbidden);
assert_eq!(Error::not_found("test").kind(), ErrorKind::NotFound);
assert_eq!(
Error::invalid_argument("test").kind(),
ErrorKind::InvalidArgument
);
assert_eq!(
Error::schema_violation("test").kind(),
ErrorKind::SchemaViolation
);
assert_eq!(Error::unavailable("test").kind(), ErrorKind::Unavailable);
assert_eq!(Error::timeout("test").kind(), ErrorKind::Timeout);
assert_eq!(Error::internal("test").kind(), ErrorKind::Internal);
assert_eq!(Error::cancelled().kind(), ErrorKind::Cancelled);
assert_eq!(Error::circuit_open().kind(), ErrorKind::CircuitOpen);
assert_eq!(Error::connection("test").kind(), ErrorKind::Connection);
assert_eq!(Error::protocol("test").kind(), ErrorKind::Protocol);
assert_eq!(
Error::configuration("test").kind(),
ErrorKind::Configuration
);
}
#[test]
fn test_from_error_kind() {
let err: Error = ErrorKind::Timeout.into();
assert_eq!(err.kind(), ErrorKind::Timeout);
}
#[test]
fn test_from_io_error() {
let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, "timed out");
let err: Error = io_err.into();
assert_eq!(err.kind(), ErrorKind::Timeout);
}
#[test]
fn test_display_format() {
let err = Error::new(ErrorKind::NotFound, "vault not found").with_request_id("req_xyz789");
let display = err.to_string();
assert!(display.contains("not found"));
assert!(display.contains("vault not found"));
assert!(display.contains("req_xyz789"));
}
#[test]
fn test_from_kind_all_variants() {
let _ = Error::from_kind(ErrorKind::Forbidden);
let _ = Error::from_kind(ErrorKind::NotFound);
let _ = Error::from_kind(ErrorKind::InvalidArgument);
let _ = Error::from_kind(ErrorKind::SchemaViolation);
let _ = Error::from_kind(ErrorKind::RateLimited);
let _ = Error::from_kind(ErrorKind::Unavailable);
let _ = Error::from_kind(ErrorKind::Timeout);
let _ = Error::from_kind(ErrorKind::Internal);
let _ = Error::from_kind(ErrorKind::Cancelled);
let _ = Error::from_kind(ErrorKind::CircuitOpen);
let _ = Error::from_kind(ErrorKind::Connection);
let _ = Error::from_kind(ErrorKind::Protocol);
let _ = Error::from_kind(ErrorKind::Configuration);
let _ = Error::from_kind(ErrorKind::Unknown);
}
#[test]
fn test_from_io_error_not_found() {
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
let err: Error = io_err.into();
assert_eq!(err.kind(), ErrorKind::NotFound);
}
#[test]
fn test_from_io_error_permission_denied() {
let io_err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied");
let err: Error = io_err.into();
assert_eq!(err.kind(), ErrorKind::Forbidden);
}
#[test]
fn test_from_io_error_connection_refused() {
let io_err =
std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "connection refused");
let err: Error = io_err.into();
assert_eq!(err.kind(), ErrorKind::Connection);
}
#[test]
fn test_from_io_error_connection_reset() {
let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "connection reset");
let err: Error = io_err.into();
assert_eq!(err.kind(), ErrorKind::Connection);
}
#[test]
fn test_from_io_error_connection_aborted() {
let io_err =
std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "connection aborted");
let err: Error = io_err.into();
assert_eq!(err.kind(), ErrorKind::Connection);
}
#[test]
fn test_from_io_error_not_connected() {
let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, "not connected");
let err: Error = io_err.into();
assert_eq!(err.kind(), ErrorKind::Connection);
}
#[test]
fn test_from_io_error_other() {
let io_err = std::io::Error::other("other error");
let err: Error = io_err.into();
assert_eq!(err.kind(), ErrorKind::Internal);
}
#[test]
fn test_from_url_parse_error() {
let url_err = url::Url::parse("not a valid url").unwrap_err();
let err: Error = url_err.into();
assert_eq!(err.kind(), ErrorKind::Configuration);
assert!(err.to_string().contains("invalid URL"));
}
#[test]
fn test_from_serde_json_error() {
let json_err: serde_json::Error =
serde_json::from_str::<serde_json::Value>("{invalid}").unwrap_err();
let err: Error = json_err.into();
assert_eq!(err.kind(), ErrorKind::Protocol);
assert!(err.to_string().contains("JSON error"));
}
#[test]
fn test_error_display_without_request_id() {
let err = Error::new(ErrorKind::NotFound, "vault not found");
let display = err.to_string();
assert!(!display.contains("request_id"));
}
#[test]
fn test_error_source_none() {
let err = Error::new(ErrorKind::Internal, "test");
assert!(err.source().is_none());
}
#[test]
fn test_error_debug() {
let err = Error::new(ErrorKind::Internal, "test error");
let debug = format!("{:?}", err);
assert!(debug.contains("Error"));
}
#[test]
fn test_from_kind_remaining_variants() {
let _ = Error::from_kind(ErrorKind::Conflict);
let _ = Error::from_kind(ErrorKind::Transport);
let _ = Error::from_kind(ErrorKind::InvalidResponse);
}
#[test]
fn test_rate_limited_without_retry_after() {
let err = Error::rate_limited(None);
assert_eq!(err.kind(), ErrorKind::RateLimited);
assert!(err.retry_after().is_none());
}
#[test]
fn test_error_with_retry_after_builder() {
let err = Error::new(ErrorKind::RateLimited, "rate limited")
.with_retry_after(Duration::from_secs(60));
assert_eq!(err.retry_after(), Some(Duration::from_secs(60)));
}
#[test]
fn test_error_from_kind_message_content() {
let err = Error::from_kind(ErrorKind::Unauthorized);
assert!(err.to_string().contains("authentication"));
let err = Error::from_kind(ErrorKind::NotFound);
assert!(err.to_string().contains("not found"));
let err = Error::from_kind(ErrorKind::Conflict);
assert!(err.to_string().contains("conflict"));
}
}