use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use serde::Serialize;
#[derive(Debug)]
struct StringError(String);
impl std::fmt::Display for StringError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for StringError {}
#[derive(Serialize)]
struct ErrorBody {
error: ErrorInner,
}
#[derive(Serialize)]
struct ErrorInner {
status: u16,
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
details: Option<std::collections::HashMap<String, Vec<String>>>,
}
pub struct AutumnError {
inner: Box<dyn std::error::Error + Send + Sync>,
status: StatusCode,
details: Option<std::collections::HashMap<String, Vec<String>>>,
}
pub type AutumnResult<T> = Result<T, AutumnError>;
impl<E> From<E> for AutumnError
where
E: std::error::Error + Send + Sync + 'static,
{
fn from(err: E) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::INTERNAL_SERVER_ERROR,
details: None,
}
}
}
impl AutumnError {
#[must_use]
pub const fn with_status(mut self, status: StatusCode) -> Self {
self.status = status;
self
}
pub fn internal_server_error(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::INTERNAL_SERVER_ERROR,
details: None,
}
}
pub fn not_found(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::NOT_FOUND,
details: None,
}
}
pub fn bad_request(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::BAD_REQUEST,
details: None,
}
}
pub fn unprocessable(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::UNPROCESSABLE_ENTITY,
details: None,
}
}
pub fn service_unavailable(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::SERVICE_UNAVAILABLE,
details: None,
}
}
pub fn unauthorized(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::UNAUTHORIZED,
details: None,
}
}
pub fn forbidden(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::FORBIDDEN,
details: None,
}
}
#[must_use]
pub fn validation(details: std::collections::HashMap<String, Vec<String>>) -> Self {
Self {
inner: Box::new(StringError("Validation failed".into())),
status: StatusCode::UNPROCESSABLE_ENTITY,
details: Some(details),
}
}
pub fn internal_server_error_msg(msg: impl Into<String>) -> Self {
Self::internal_server_error(StringError(msg.into()))
}
pub fn not_found_msg(msg: impl Into<String>) -> Self {
Self::not_found(StringError(msg.into()))
}
pub fn bad_request_msg(msg: impl Into<String>) -> Self {
Self::bad_request(StringError(msg.into()))
}
pub fn unprocessable_msg(msg: impl Into<String>) -> Self {
Self::unprocessable(StringError(msg.into()))
}
pub fn unauthorized_msg(msg: impl Into<String>) -> Self {
Self::unauthorized(StringError(msg.into()))
}
pub fn forbidden_msg(msg: impl Into<String>) -> Self {
Self::forbidden(StringError(msg.into()))
}
pub fn service_unavailable_msg(msg: impl Into<String>) -> Self {
Self::service_unavailable(StringError(msg.into()))
}
#[must_use]
pub const fn status(&self) -> StatusCode {
self.status
}
}
impl std::fmt::Display for AutumnError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.inner)
}
}
impl std::fmt::Debug for AutumnError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AutumnError")
.field("status", &self.status)
.field("inner", &self.inner)
.field("details", &self.details)
.finish()
}
}
impl IntoResponse for AutumnError {
fn into_response(self) -> Response {
let status = self.status;
let message = self.inner.to_string();
let error_info = crate::middleware::AutumnErrorInfo {
status,
message: message.clone(),
details: self.details.clone(),
};
let body = ErrorBody {
error: ErrorInner {
status: status.as_u16(),
message,
details: self.details,
},
};
let mut response = (status, axum::Json(body)).into_response();
response.extensions_mut().insert(error_info);
response
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::StatusCode;
#[derive(Debug)]
struct TestError(String);
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for TestError {}
#[test]
fn blanket_from_defaults_to_500() {
let err: AutumnError = TestError("boom".into()).into();
assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn internal_server_error_is_500() {
let err = AutumnError::internal_server_error(TestError("boom".into()));
assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_not_found_error() {
let err = AutumnError::not_found(std::io::Error::other("no such user"));
assert_eq!(err.status(), StatusCode::NOT_FOUND);
}
#[test]
fn not_found_is_404() {
let err = AutumnError::not_found(TestError("missing".into()));
assert_eq!(err.status(), StatusCode::NOT_FOUND);
}
#[test]
fn bad_request_is_400() {
let err = AutumnError::bad_request(TestError("invalid input".into()));
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn unprocessable_is_422() {
let err = AutumnError::unprocessable(TestError("bad entity".into()));
assert_eq!(err.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[test]
fn unauthorized_is_401() {
let err = AutumnError::unauthorized(TestError("unauthorized".into()));
assert_eq!(err.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn forbidden_is_403() {
let err = AutumnError::forbidden(TestError("forbidden".into()));
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[test]
fn validation_is_422() {
let mut details = std::collections::HashMap::new();
details.insert("field".to_string(), vec!["error".to_string()]);
let err = AutumnError::validation(details);
assert_eq!(err.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[test]
fn service_unavailable_is_503() {
let err = AutumnError::service_unavailable(TestError("pool exhausted".into()));
assert_eq!(err.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[test]
fn internal_server_error_msg_is_500() {
let err = AutumnError::internal_server_error_msg("db failure");
assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(err.to_string(), "db failure");
}
#[test]
fn not_found_msg_is_404() {
let err = AutumnError::not_found_msg("no such user");
assert_eq!(err.status(), StatusCode::NOT_FOUND);
assert_eq!(err.to_string(), "no such user");
}
#[test]
fn bad_request_msg_is_400() {
let err = AutumnError::bad_request_msg("invalid input");
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn unprocessable_msg_is_422() {
let err = AutumnError::unprocessable_msg("title required");
assert_eq!(err.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[test]
fn unauthorized_msg_is_401() {
let err = AutumnError::unauthorized_msg("login required");
assert_eq!(err.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn forbidden_msg_is_403() {
let err = AutumnError::forbidden_msg("no access");
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[test]
fn service_unavailable_msg_is_503() {
let err = AutumnError::service_unavailable_msg("db down");
assert_eq!(err.status(), StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(err.to_string(), "db down");
}
#[test]
fn with_status_overrides() {
let err: AutumnError = TestError("forbidden".into()).into();
let err = err.with_status(StatusCode::FORBIDDEN);
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[test]
fn display_uses_inner_message() {
let err: AutumnError = TestError("something broke".into()).into();
assert_eq!(err.to_string(), "something broke");
}
#[test]
fn into_response_has_correct_status() {
let err = AutumnError::not_found(TestError("not found".into()));
let response = err.into_response();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn into_response_has_json_body() -> Result<(), axum::Error> {
let err = AutumnError::not_found(TestError("not found".into()));
let response = err.into_response();
let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?;
let json: serde_json::Value = serde_json::from_slice(&body).expect("valid json");
assert_eq!(json["error"]["status"], 404);
assert_eq!(json["error"]["message"], "not found");
Ok(())
}
#[test]
fn debug_shows_status_and_inner() {
let err = AutumnError::bad_request(TestError("oops".into()));
let debug = format!("{err:?}");
assert!(debug.contains("AutumnError"));
assert!(debug.contains("400"));
}
#[tokio::test]
async fn msg_constructor_produces_valid_json_response() -> Result<(), axum::Error> {
let err = AutumnError::unprocessable_msg("title required");
let response = err.into_response();
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?;
let json: serde_json::Value = serde_json::from_slice(&body).expect("valid json");
assert_eq!(json["error"]["status"], 422);
assert_eq!(json["error"]["message"], "title required");
Ok(())
}
#[tokio::test]
async fn service_unavailable_response_is_503() -> Result<(), axum::Error> {
let err = AutumnError::service_unavailable_msg("db down");
let response = err.into_response();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?;
let json: serde_json::Value = serde_json::from_slice(&body).expect("valid json");
assert_eq!(json["error"]["status"], 503);
assert_eq!(json["error"]["message"], "db down");
Ok(())
}
}