use axum::response::{IntoResponse, Response};
use http::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use uuid::Uuid;
pub use rusty_gasket_macros::ApiError;
pub trait ApiError: std::error::Error + Send + Sync + 'static {
fn error_code(&self) -> &str;
fn status_code(&self) -> StatusCode;
fn expose_details(&self) -> bool {
self.status_code().is_client_error()
}
fn details(&self) -> Vec<ErrorDetail> {
Vec::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct ErrorResponse {
pub error: String,
pub message: String,
pub correlation_id: Uuid,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub details: Vec<ErrorDetail>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct ProblemDetails {
#[serde(rename = "type")]
pub problem_type: String,
pub title: String,
pub status: u16,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instance: Option<String>,
#[serde(flatten)]
pub extensions: Map<String, Value>,
}
impl ProblemDetails {
#[must_use]
pub fn new(
status: StatusCode,
problem_type: impl Into<String>,
title: impl Into<String>,
) -> Self {
Self {
problem_type: problem_type.into(),
title: title.into(),
status: status.as_u16(),
detail: None,
instance: None,
extensions: Map::new(),
}
}
#[must_use]
pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
self.detail = Some(detail.into());
self
}
#[must_use]
pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
self.instance = Some(instance.into());
self
}
#[must_use]
pub fn with_extension(mut self, name: impl Into<String>, value: Value) -> Self {
self.extensions.insert(name.into(), value);
self
}
}
impl ErrorResponse {
#[must_use]
pub fn new(error: impl Into<String>, message: impl Into<String>, correlation_id: Uuid) -> Self {
Self {
error: error.into(),
message: message.into(),
correlation_id,
details: Vec::new(),
}
}
#[must_use]
pub fn with_details(
error: impl Into<String>,
message: impl Into<String>,
correlation_id: Uuid,
details: Vec<ErrorDetail>,
) -> Self {
Self {
error: error.into(),
message: message.into(),
correlation_id,
details,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ErrorDetail {
pub issue: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
impl ErrorDetail {
#[must_use]
pub fn new(issue: impl Into<String>) -> Self {
Self {
issue: issue.into(),
description: None,
}
}
#[must_use]
pub fn with_description(issue: impl Into<String>, description: impl Into<String>) -> Self {
Self {
issue: issue.into(),
description: Some(description.into()),
}
}
}
pub fn full_error_chain(error: &dyn std::error::Error) -> String {
let mut msg = error.to_string();
let mut source = error.source();
while let Some(cause) = source {
msg.push_str(" | caused by: ");
msg.push_str(&cause.to_string());
source = cause.source();
}
msg
}
fn correlation_id() -> Uuid {
crate::observability::current_request_id()
.and_then(|s| Uuid::parse_str(&s).ok())
.unwrap_or_else(Uuid::now_v7)
}
fn finalize_error_response(status: StatusCode, body: ErrorResponse) -> Response {
let cid = body.correlation_id;
let mut response = (status, axum::Json(body)).into_response();
if let Ok(val) = http::HeaderValue::from_str(&cid.to_string()) {
response.headers_mut().insert("X-Correlation-ID", val);
}
response
}
pub fn error_into_response(error: &impl ApiError) -> Response {
let cid = correlation_id();
let status = error.status_code();
let (message, details) = if error.expose_details() {
(error.to_string(), error.details())
} else {
("Internal server error".to_string(), Vec::new())
};
if status.is_server_error() {
tracing::error!(
error_code = error.error_code(),
status = status.as_u16(),
correlation_id = %cid,
error_chain = %full_error_chain(error),
"Server error"
);
} else {
tracing::warn!(
error_code = error.error_code(),
status = status.as_u16(),
correlation_id = %cid,
error_chain = %full_error_chain(error),
"Client error"
);
}
let body = ErrorResponse::with_details(error.error_code(), message, cid, details);
finalize_error_response(status, body)
}
#[must_use]
pub fn quick_error_response(status: StatusCode, code: &str, message: &str) -> Response {
let body = ErrorResponse::new(code, message, correlation_id());
finalize_error_response(status, body)
}
#[must_use]
pub fn quick_error_response_with_details(
status: StatusCode,
code: &str,
message: &str,
details: Vec<ErrorDetail>,
) -> Response {
let body = ErrorResponse::with_details(code, message, correlation_id(), details);
finalize_error_response(status, body)
}
#[must_use]
pub fn problem_response(status: StatusCode, problem: ProblemDetails) -> Response {
let mut response = (status, axum::Json(problem)).into_response();
response.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/problem+json"),
);
response
}
#[derive(Debug, thiserror::Error, rusty_gasket_macros::ApiError)]
#[non_exhaustive]
pub enum FrameworkError {
#[error("Not found")]
#[api_error(code = "NOT_FOUND", status = 404)]
NotFound,
#[error("Method not allowed")]
#[api_error(code = "METHOD_NOT_ALLOWED", status = 405)]
MethodNotAllowed,
#[error("Internal server error")]
#[api_error(code = "INTERNAL_ERROR", status = 500, expose = false)]
Internal(#[source] Box<dyn std::error::Error + Send + Sync>),
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[derive(Debug, thiserror::Error, ApiError)]
enum TestError {
#[error("thing not found: {0}")]
#[api_error(code = "NOT_FOUND", status = 404)]
NotFound(String),
#[error("bad input")]
#[api_error(code = "BAD_REQUEST", status = 400)]
BadRequest,
#[error("kaboom")]
#[api_error(code = "INTERNAL", status = 500, expose = false)]
Internal,
#[error("custom exposed 500")]
#[api_error(code = "CUSTOM_500", status = 500, expose = true)]
CustomExposed,
#[error("validation failed")]
#[api_error(code = "VALIDATION", status = 422)]
Validation { _field: String },
}
#[test]
fn derived_error_code_and_status() {
let e = TestError::NotFound("x".into());
assert_eq!(e.error_code(), "NOT_FOUND");
assert_eq!(e.status_code(), StatusCode::NOT_FOUND);
assert!(e.expose_details());
}
#[test]
fn derived_500_hides_by_default() {
let e = TestError::Internal;
assert_eq!(e.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(!e.expose_details());
}
#[test]
fn derived_500_with_explicit_expose() {
let e = TestError::CustomExposed;
assert_eq!(e.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(e.expose_details());
}
#[test]
fn derived_unit_variant() {
let e = TestError::BadRequest;
assert_eq!(e.error_code(), "BAD_REQUEST");
assert_eq!(e.status_code(), StatusCode::BAD_REQUEST);
}
#[test]
fn derived_struct_variant() {
let e = TestError::Validation {
_field: "email".into(),
};
assert_eq!(e.error_code(), "VALIDATION");
assert_eq!(e.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
}
}