use axum::http::{HeaderValue, StatusCode, header};
use axum::response::{IntoResponse, Response};
use serde::Serialize;
#[derive(Debug)]
struct StringError(String);
impl std::fmt::Display for StringError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for StringError {}
#[derive(Clone, Debug, Serialize)]
pub struct ProblemDetails {
#[serde(rename = "type")]
pub type_uri: String,
pub title: String,
pub status: u16,
pub detail: String,
pub instance: Option<String>,
pub code: String,
pub request_id: Option<String>,
pub errors: Vec<ProblemFieldError>,
}
#[derive(Clone, Debug, Serialize, PartialEq, Eq)]
pub struct ProblemFieldError {
pub field: String,
pub messages: Vec<String>,
}
pub struct AutumnError {
inner: Box<dyn std::error::Error + Send + Sync>,
status: StatusCode,
details: Option<std::collections::HashMap<String, Vec<String>>>,
problem_type: Option<&'static str>,
}
pub type AutumnResult<T> = Result<T, AutumnError>;
impl<E> From<E> for AutumnError
where
E: std::error::Error + Send + Sync + 'static,
{
fn from(err: E) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::INTERNAL_SERVER_ERROR,
details: None,
problem_type: None,
}
}
}
impl AutumnError {
#[must_use]
pub const fn with_status(mut self, status: StatusCode) -> Self {
self.status = status;
self
}
pub fn internal_server_error(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::INTERNAL_SERVER_ERROR,
details: None,
problem_type: None,
}
}
pub fn not_found(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::NOT_FOUND,
details: None,
problem_type: None,
}
}
pub fn bad_request(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::BAD_REQUEST,
details: None,
problem_type: None,
}
}
pub fn unprocessable(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::UNPROCESSABLE_ENTITY,
details: None,
problem_type: None,
}
}
pub fn service_unavailable(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::SERVICE_UNAVAILABLE,
details: None,
problem_type: None,
}
}
pub fn unauthorized(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::UNAUTHORIZED,
details: None,
problem_type: None,
}
}
pub fn forbidden(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::FORBIDDEN,
details: None,
problem_type: None,
}
}
#[must_use]
pub fn validation(details: std::collections::HashMap<String, Vec<String>>) -> Self {
Self {
inner: Box::new(StringError("Validation failed".into())),
status: StatusCode::UNPROCESSABLE_ENTITY,
details: Some(details),
problem_type: None,
}
}
pub fn internal_server_error_msg(msg: impl Into<String>) -> Self {
Self::internal_server_error(StringError(msg.into()))
}
pub fn not_found_msg(msg: impl Into<String>) -> Self {
Self::not_found(StringError(msg.into()))
}
pub fn bad_request_msg(msg: impl Into<String>) -> Self {
Self::bad_request(StringError(msg.into()))
}
pub fn unprocessable_msg(msg: impl Into<String>) -> Self {
Self::unprocessable(StringError(msg.into()))
}
pub fn unauthorized_msg(msg: impl Into<String>) -> Self {
Self::unauthorized(StringError(msg.into()))
}
pub fn forbidden_msg(msg: impl Into<String>) -> Self {
Self::forbidden(StringError(msg.into()))
}
pub fn service_unavailable_msg(msg: impl Into<String>) -> Self {
Self::service_unavailable(StringError(msg.into()))
}
pub fn conflict(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self {
inner: Box::new(err),
status: StatusCode::CONFLICT,
details: None,
problem_type: Some("https://autumn.dev/problems/conflict"),
}
}
pub fn conflict_msg(msg: impl Into<String>) -> Self {
Self::conflict(StringError(msg.into()))
}
#[must_use]
pub const fn status(&self) -> StatusCode {
self.status
}
#[must_use]
pub fn source_chain(&self) -> Vec<String> {
let mut chain = Vec::new();
let mut source = self.inner.source();
while let Some(error) = source {
chain.push(error.to_string());
source = error.source();
}
chain
}
}
impl std::fmt::Display for AutumnError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.inner)
}
}
impl std::fmt::Debug for AutumnError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AutumnError")
.field("status", &self.status)
.field("inner", &self.inner)
.field("details", &self.details)
.field("problem_type", &self.problem_type)
.finish()
}
}
impl ProblemDetails {
#[must_use]
pub fn new(
status: StatusCode,
detail: impl Into<String>,
details: Option<&std::collections::HashMap<String, Vec<String>>>,
) -> Self {
problem_details(status, detail.into(), details, None, None, None, true)
}
}
#[must_use]
pub(crate) fn problem_details(
status: StatusCode,
detail: String,
details: Option<&std::collections::HashMap<String, Vec<String>>>,
explicit_type: Option<&'static str>,
request_id: Option<String>,
instance: Option<String>,
expose_internal_detail: bool,
) -> ProblemDetails {
let has_validation_errors = details.is_some_and(|map| !map.is_empty());
let safe_detail = if status.is_server_error() && !expose_internal_detail {
server_error_detail(status)
} else {
detail
};
ProblemDetails {
type_uri: explicit_type
.unwrap_or_else(|| problem_type_for(status, has_validation_errors))
.to_owned(),
title: problem_title_for(status, has_validation_errors).to_owned(),
status: status.as_u16(),
detail: safe_detail,
instance,
code: problem_code_for(status, has_validation_errors).to_owned(),
request_id,
errors: validation_errors(details),
}
}
#[must_use]
pub(crate) fn problem_details_json_string(
status: StatusCode,
detail: impl Into<String>,
details: Option<&std::collections::HashMap<String, Vec<String>>>,
explicit_type: Option<&'static str>,
request_id: Option<String>,
instance: Option<String>,
expose_internal_detail: bool,
) -> String {
let problem = problem_details(
status,
detail.into(),
details,
explicit_type,
request_id,
instance,
expose_internal_detail,
);
problem_details_to_json_string(&problem)
}
#[must_use]
pub(crate) fn problem_details_to_json_string(problem: &ProblemDetails) -> String {
serde_json::to_string(&problem).unwrap_or_else(|_| {
r#"{"type":"https://autumn.dev/problems/internal-server-error","title":"Internal Server Error","status":500,"detail":"Internal server error","instance":null,"code":"autumn.internal_server_error","request_id":null,"errors":[]}"#.to_owned()
})
}
fn validation_errors(
details: Option<&std::collections::HashMap<String, Vec<String>>>,
) -> Vec<ProblemFieldError> {
let mut errors: Vec<_> = details
.into_iter()
.flat_map(std::collections::HashMap::iter)
.map(|(field, messages)| ProblemFieldError {
field: field.clone(),
messages: messages.clone(),
})
.collect();
errors.sort_by(|left, right| left.field.cmp(&right.field));
errors
}
const fn problem_type_for(status: StatusCode, has_validation_errors: bool) -> &'static str {
if has_validation_errors {
return "https://autumn.dev/problems/validation-failed";
}
match status {
StatusCode::BAD_REQUEST => "https://autumn.dev/problems/bad-request",
StatusCode::UNAUTHORIZED => "https://autumn.dev/problems/unauthorized",
StatusCode::FORBIDDEN => "https://autumn.dev/problems/forbidden",
StatusCode::NOT_FOUND => "https://autumn.dev/problems/not-found",
StatusCode::CONFLICT => "https://autumn.dev/problems/conflict",
StatusCode::PAYLOAD_TOO_LARGE => "https://autumn.dev/problems/payload-too-large",
StatusCode::UNPROCESSABLE_ENTITY => "https://autumn.dev/problems/unprocessable-entity",
StatusCode::INTERNAL_SERVER_ERROR => "https://autumn.dev/problems/internal-server-error",
StatusCode::NOT_IMPLEMENTED => "https://autumn.dev/problems/not-implemented",
StatusCode::SERVICE_UNAVAILABLE => "https://autumn.dev/problems/service-unavailable",
_ => "about:blank",
}
}
fn problem_title_for(status: StatusCode, has_validation_errors: bool) -> &'static str {
if has_validation_errors {
return "Validation Failed";
}
match status {
StatusCode::BAD_REQUEST => "Bad Request",
StatusCode::UNAUTHORIZED => "Unauthorized",
StatusCode::FORBIDDEN => "Forbidden",
StatusCode::NOT_FOUND => "Not Found",
StatusCode::CONFLICT => "Conflict",
StatusCode::PAYLOAD_TOO_LARGE => "Payload Too Large",
StatusCode::UNPROCESSABLE_ENTITY => "Unprocessable Entity",
StatusCode::INTERNAL_SERVER_ERROR => "Internal Server Error",
StatusCode::NOT_IMPLEMENTED => "Not Implemented",
StatusCode::SERVICE_UNAVAILABLE => "Service Unavailable",
_ => status.canonical_reason().unwrap_or("Error"),
}
}
fn problem_code_for(status: StatusCode, has_validation_errors: bool) -> &'static str {
if has_validation_errors {
return "autumn.validation_failed";
}
match status {
StatusCode::BAD_REQUEST => "autumn.bad_request",
StatusCode::UNAUTHORIZED => "autumn.unauthorized",
StatusCode::FORBIDDEN => "autumn.forbidden",
StatusCode::NOT_FOUND => "autumn.not_found",
StatusCode::CONFLICT => "autumn.conflict",
StatusCode::PAYLOAD_TOO_LARGE => "autumn.payload_too_large",
StatusCode::UNPROCESSABLE_ENTITY => "autumn.unprocessable_entity",
StatusCode::INTERNAL_SERVER_ERROR => "autumn.internal_server_error",
StatusCode::NOT_IMPLEMENTED => "autumn.not_implemented",
StatusCode::SERVICE_UNAVAILABLE => "autumn.service_unavailable",
_ if status.is_client_error() => "autumn.client_error",
_ if status.is_server_error() => "autumn.server_error",
_ => "autumn.error",
}
}
fn server_error_detail(status: StatusCode) -> String {
match status {
StatusCode::SERVICE_UNAVAILABLE => "Service unavailable".to_owned(),
StatusCode::NOT_IMPLEMENTED => "Not implemented".to_owned(),
_ => "Internal server error".to_owned(),
}
}
impl IntoResponse for AutumnError {
fn into_response(self) -> Response {
let status = self.status;
let message = self.inner.to_string();
let details = self.details.clone();
let problem_type = self.problem_type;
let error_info = crate::middleware::AutumnErrorInfo {
status,
message: message.clone(),
details: details.clone(),
problem_type,
};
let body = problem_details(
status,
message,
details.as_ref(),
problem_type,
None,
None,
true,
);
let mut response = (status, axum::Json(body)).into_response();
response.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/problem+json"),
);
if status == StatusCode::CONFLICT {
response.headers_mut().insert(
"HX-Trigger",
HeaderValue::from_static(r#"{"autumn:conflict":true}"#),
);
}
response.extensions_mut().insert(error_info);
response
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::StatusCode;
#[derive(Debug)]
struct TestError(String);
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for TestError {}
#[derive(Debug)]
struct WrappedError {
message: String,
source: TestError,
}
impl std::fmt::Display for WrappedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for WrappedError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.source)
}
}
#[test]
fn blanket_from_defaults_to_500() {
let err: AutumnError = TestError("boom".into()).into();
assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn internal_server_error_is_500() {
let err = AutumnError::internal_server_error(TestError("boom".into()));
assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_not_found_error() {
let err = AutumnError::not_found(std::io::Error::other("no such user"));
assert_eq!(err.status(), StatusCode::NOT_FOUND);
}
#[test]
fn not_found_is_404() {
let err = AutumnError::not_found(TestError("missing".into()));
assert_eq!(err.status(), StatusCode::NOT_FOUND);
}
#[test]
fn bad_request_is_400() {
let err = AutumnError::bad_request(TestError("invalid input".into()));
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn unprocessable_is_422() {
let err = AutumnError::unprocessable(TestError("bad entity".into()));
assert_eq!(err.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[test]
fn unauthorized_is_401() {
let err = AutumnError::unauthorized(TestError("unauthorized".into()));
assert_eq!(err.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn forbidden_is_403() {
let err = AutumnError::forbidden(TestError("forbidden".into()));
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[test]
fn validation_is_422() {
let mut details = std::collections::HashMap::new();
details.insert("field".to_string(), vec!["error".to_string()]);
let err = AutumnError::validation(details);
assert_eq!(err.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[test]
fn service_unavailable_is_503() {
let err = AutumnError::service_unavailable(TestError("pool exhausted".into()));
assert_eq!(err.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[test]
fn internal_server_error_msg_is_500() {
let err = AutumnError::internal_server_error_msg("db failure");
assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(err.to_string(), "db failure");
}
#[test]
fn not_found_msg_is_404() {
let err = AutumnError::not_found_msg("no such user");
assert_eq!(err.status(), StatusCode::NOT_FOUND);
assert_eq!(err.to_string(), "no such user");
}
#[test]
fn bad_request_msg_is_400() {
let err = AutumnError::bad_request_msg("invalid input");
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn unprocessable_msg_is_422() {
let err = AutumnError::unprocessable_msg("title required");
assert_eq!(err.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[test]
fn unauthorized_msg_is_401() {
let err = AutumnError::unauthorized_msg("login required");
assert_eq!(err.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn forbidden_msg_is_403() {
let err = AutumnError::forbidden_msg("no access");
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[test]
fn service_unavailable_msg_is_503() {
let err = AutumnError::service_unavailable_msg("db down");
assert_eq!(err.status(), StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(err.to_string(), "db down");
}
#[test]
fn with_status_overrides() {
let err: AutumnError = TestError("forbidden".into()).into();
let err = err.with_status(StatusCode::FORBIDDEN);
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[test]
fn display_uses_inner_message() {
let err: AutumnError = TestError("something broke".into()).into();
assert_eq!(err.to_string(), "something broke");
}
#[test]
fn source_chain_lists_inner_sources() {
let err = AutumnError::internal_server_error(WrappedError {
message: "failed to backfill".to_string(),
source: TestError("database connection dropped".to_string()),
});
assert_eq!(
err.source_chain(),
vec!["database connection dropped".to_string()]
);
}
#[test]
fn into_response_has_correct_status() {
let err = AutumnError::not_found(TestError("not found".into()));
let response = err.into_response();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn into_response_has_json_body() -> Result<(), axum::Error> {
let err = AutumnError::not_found(TestError("not found".into()));
let response = err.into_response();
let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?;
let json: serde_json::Value = serde_json::from_slice(&body).expect("valid json");
assert_eq!(json["status"], 404);
assert_eq!(json["detail"], "not found");
assert_eq!(json["code"], "autumn.not_found");
Ok(())
}
#[test]
fn debug_shows_status_and_inner() {
let err = AutumnError::bad_request(TestError("oops".into()));
let debug = format!("{err:?}");
assert!(debug.contains("AutumnError"));
assert!(debug.contains("400"));
}
#[tokio::test]
async fn msg_constructor_produces_valid_json_response() -> Result<(), axum::Error> {
let err = AutumnError::unprocessable_msg("title required");
let response = err.into_response();
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?;
let json: serde_json::Value = serde_json::from_slice(&body).expect("valid json");
assert_eq!(json["status"], 422);
assert_eq!(json["detail"], "title required");
assert_eq!(json["code"], "autumn.unprocessable_entity");
Ok(())
}
#[tokio::test]
async fn service_unavailable_response_is_503() -> Result<(), axum::Error> {
let err = AutumnError::service_unavailable_msg("db down");
let response = err.into_response();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?;
let json: serde_json::Value = serde_json::from_slice(&body).expect("valid json");
assert_eq!(json["status"], 503);
assert_eq!(json["detail"], "db down");
assert_eq!(json["code"], "autumn.service_unavailable");
Ok(())
}
#[test]
fn conflict_is_409() {
let err = AutumnError::conflict(TestError("stale version".into()));
assert_eq!(err.status(), StatusCode::CONFLICT);
}
#[test]
fn conflict_msg_is_409() {
let err = AutumnError::conflict_msg("please reload and retry");
assert_eq!(err.status(), StatusCode::CONFLICT);
assert_eq!(err.to_string(), "please reload and retry");
}
#[tokio::test]
async fn conflict_response_is_409_json() -> Result<(), axum::Error> {
let err = AutumnError::conflict_msg("version mismatch");
let response = err.into_response();
assert_eq!(response.status(), StatusCode::CONFLICT);
let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?;
let json: serde_json::Value = serde_json::from_slice(&body).expect("valid json");
assert_eq!(json["status"], 409);
assert_eq!(json["detail"], "version mismatch");
assert_eq!(json["type"], "https://autumn.dev/problems/conflict");
assert_eq!(json["title"], "Conflict");
Ok(())
}
#[tokio::test]
async fn conflict_response_has_hx_trigger_header() -> Result<(), axum::Error> {
let err = AutumnError::conflict_msg("version mismatch");
let response = err.into_response();
assert_eq!(response.status(), StatusCode::CONFLICT);
let hx_trigger = response
.headers()
.get("HX-Trigger")
.expect("HX-Trigger header present");
assert_eq!(hx_trigger, r#"{"autumn:conflict":true}"#);
Ok(())
}
}