use std::any::TypeId;
use bytes::Bytes;
use http::StatusCode;
use serde::Serialize;
use crate::constants::{APPLICATION_JSON, INTERNAL_ERROR_MESSAGE};
use crate::response::{with_body, IntoResponse, Response};
pub type Result<T, E = Error> = core::result::Result<T, E>;
const VALIDATION_ERROR_CODE: &str = "VALIDATION_ERROR";
const VALIDATION_ERROR_MESSAGE: &str = "The submitted data failed validation.";
const GENERIC_ISSUE: &str = "INVALID";
const TRACE_ID_PREFIX: &str = "req-";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorKind {
BadRequest,
Unauthorized,
Forbidden,
NotFound,
MethodNotAllowed,
Conflict,
PayloadTooLarge,
Unprocessable,
TooManyRequests,
Internal,
ServiceUnavailable,
GatewayTimeout,
}
impl ErrorKind {
pub fn status(self) -> StatusCode {
match self {
ErrorKind::BadRequest => StatusCode::BAD_REQUEST,
ErrorKind::Unauthorized => StatusCode::UNAUTHORIZED,
ErrorKind::Forbidden => StatusCode::FORBIDDEN,
ErrorKind::NotFound => StatusCode::NOT_FOUND,
ErrorKind::MethodNotAllowed => StatusCode::METHOD_NOT_ALLOWED,
ErrorKind::Conflict => StatusCode::CONFLICT,
ErrorKind::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
ErrorKind::Unprocessable => StatusCode::UNPROCESSABLE_ENTITY,
ErrorKind::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
ErrorKind::Internal => StatusCode::INTERNAL_SERVER_ERROR,
ErrorKind::ServiceUnavailable => StatusCode::SERVICE_UNAVAILABLE,
ErrorKind::GatewayTimeout => StatusCode::GATEWAY_TIMEOUT,
}
}
pub fn code(self) -> &'static str {
match self {
ErrorKind::BadRequest => "BAD_REQUEST",
ErrorKind::Unauthorized => "UNAUTHORIZED",
ErrorKind::Forbidden => "FORBIDDEN",
ErrorKind::NotFound => "NOT_FOUND",
ErrorKind::MethodNotAllowed => "METHOD_NOT_ALLOWED",
ErrorKind::Conflict => "CONFLICT",
ErrorKind::PayloadTooLarge => "PAYLOAD_TOO_LARGE",
ErrorKind::Unprocessable => "UNPROCESSABLE_ENTITY",
ErrorKind::TooManyRequests => "TOO_MANY_REQUESTS",
ErrorKind::Internal => "INTERNAL_SERVER_ERROR",
ErrorKind::ServiceUnavailable => "SERVICE_UNAVAILABLE",
ErrorKind::GatewayTimeout => "GATEWAY_TIMEOUT",
}
}
}
#[derive(Debug)]
pub struct Error {
kind: ErrorKind,
code: Option<&'static str>,
message: String,
source: Option<Box<dyn std::error::Error + Send + Sync>>,
source_type: Option<TypeId>,
details: Vec<ErrorDetail>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ErrorDetail {
pub field: String,
pub issue: String,
pub message: String,
}
impl ErrorDetail {
pub fn new(
field: impl Into<String>,
issue: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self {
field: field.into(),
issue: issue.into(),
message: message.into(),
}
}
}
impl Error {
pub fn new(kind: ErrorKind, message: impl Into<String>) -> Self {
Self {
kind,
code: None,
message: message.into(),
source: None,
source_type: None,
details: Vec::new(),
}
}
pub fn bad_request(message: impl Into<String>) -> Self {
Self::new(ErrorKind::BadRequest, message)
}
pub fn unauthorized(message: impl Into<String>) -> Self {
Self::new(ErrorKind::Unauthorized, message)
}
pub fn forbidden(message: impl Into<String>) -> Self {
Self::new(ErrorKind::Forbidden, message)
}
pub fn not_found(message: impl Into<String>) -> Self {
Self::new(ErrorKind::NotFound, message)
}
pub fn method_not_allowed(message: impl Into<String>) -> Self {
Self::new(ErrorKind::MethodNotAllowed, message)
}
pub fn conflict(message: impl Into<String>) -> Self {
Self::new(ErrorKind::Conflict, message)
}
pub fn unprocessable(message: impl Into<String>) -> Self {
Self::new(ErrorKind::Unprocessable, message)
}
pub fn payload_too_large(message: impl Into<String>) -> Self {
Self::new(ErrorKind::PayloadTooLarge, message)
}
pub fn too_many_requests(message: impl Into<String>) -> Self {
Self::new(ErrorKind::TooManyRequests, message)
}
pub fn internal(message: impl Into<String>) -> Self {
Self::new(ErrorKind::Internal, message)
}
pub fn service_unavailable(message: impl Into<String>) -> Self {
Self::new(ErrorKind::ServiceUnavailable, message)
}
pub fn gateway_timeout(message: impl Into<String>) -> Self {
Self::new(ErrorKind::GatewayTimeout, message)
}
pub fn with_code(mut self, code: &'static str) -> Self {
self.code = Some(code);
self
}
pub fn with_source<E>(mut self, source: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
self.source = Some(Box::new(source));
self.source_type = Some(TypeId::of::<E>());
self
}
pub fn with_details(mut self, details: Vec<ErrorDetail>) -> Self {
self.details = details;
self
}
pub fn from_garde_report(report: garde::error::Report) -> Self {
let details = report
.iter()
.map(|(path, error)| {
let message = error.to_string();
ErrorDetail::new(path.to_string(), classify_issue(&message), message)
})
.collect();
Self::unprocessable(VALIDATION_ERROR_MESSAGE)
.with_code(VALIDATION_ERROR_CODE)
.with_details(details)
}
pub fn kind(&self) -> ErrorKind {
self.kind
}
pub fn code(&self) -> &str {
self.code.unwrap_or_else(|| self.kind.code())
}
pub(crate) fn static_code(&self) -> &'static str {
self.code.unwrap_or_else(|| self.kind.code())
}
pub fn details(&self) -> &[ErrorDetail] {
&self.details
}
pub fn message(&self) -> &str {
&self.message
}
pub(crate) fn source_type(&self) -> Option<TypeId> {
self.source_type
}
pub(crate) fn is_validation(&self) -> bool {
self.code() == VALIDATION_ERROR_CODE
}
pub fn take_source<E>(&mut self) -> Option<E>
where
E: std::error::Error + Send + Sync + 'static,
{
if self.source_type != Some(TypeId::of::<E>()) {
return None;
}
let source = self.source.take()?;
self.source_type = None;
match source.downcast::<E>() {
Ok(typed) => Some(*typed),
Err(restored) => {
self.source = Some(restored);
self.source_type = Some(TypeId::of::<E>());
None
}
}
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.code(), self.message)
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.source
.as_ref()
.map(|boxed| boxed.as_ref() as &(dyn std::error::Error + 'static))
}
}
#[derive(Serialize)]
struct ErrorBody<'a> {
status: u16,
code: &'a str,
title: &'a str,
message: &'a str,
#[serde(skip_serializing_if = "slice_is_empty")]
details: &'a [ErrorDetail],
#[serde(rename = "traceId")]
trace_id: &'a str,
timestamp: String,
}
fn slice_is_empty(details: &&[ErrorDetail]) -> bool {
details.is_empty()
}
const FALLBACK_ERROR_BODY: &[u8] = br#"{"status":500,"code":"INTERNAL_SERVER_ERROR","title":"Internal Server Error","message":"Internal server error"}"#;
impl IntoResponse for Error {
fn into_response(self) -> Response {
let status = self.kind.status();
let trace_id = generate_trace_id();
let message: &str = if status.is_server_error() {
log_server_error(&self, &trace_id);
INTERNAL_ERROR_MESSAGE
} else {
&self.message
};
let details: &[ErrorDetail] = if status.is_server_error() {
&[]
} else {
&self.details
};
let body = ErrorBody {
status: status.as_u16(),
code: self.code(),
title: status.canonical_reason().unwrap_or("Error"),
message,
details,
trace_id: &trace_id,
timestamp: now_rfc3339(),
};
let mut response = match serde_json::to_vec(&body) {
Ok(buffer) => with_body(status, APPLICATION_JSON, Bytes::from(buffer)),
Err(_) => with_body(
status,
APPLICATION_JSON,
Bytes::from_static(FALLBACK_ERROR_BODY),
),
};
response.headers_mut().insert(
http::header::CACHE_CONTROL,
http::HeaderValue::from_static("no-store"),
);
response
}
}
fn generate_trace_id() -> String {
format!("{TRACE_ID_PREFIX}{}", uuid::Uuid::new_v4())
}
fn now_rfc3339() -> String {
use time::format_description::well_known::Rfc3339;
time::OffsetDateTime::now_utc()
.replace_nanosecond(0)
.ok()
.and_then(|stamp| stamp.format(&Rfc3339).ok())
.unwrap_or_default()
}
fn classify_issue(message: &str) -> &'static str {
let lower = message.to_ascii_lowercase();
if lower.contains("email") {
"INVALID_FORMAT"
} else if lower.contains("length is lower") {
"TOO_SHORT"
} else if lower.contains("length is greater") {
"TOO_LONG"
} else if lower.contains("must be greater than") {
"TOO_SMALL"
} else if lower.contains("must be less than") {
"TOO_LARGE"
} else if lower.contains("lower than") {
"TOO_SMALL"
} else if lower.contains("greater than") {
"TOO_LARGE"
} else {
GENERIC_ISSUE
}
}
fn log_server_error(error: &Error, trace_id: &str) {
match &error.source {
Some(source) => eprintln!(
"tork: server error [{trace_id}]: {}: {} (cause: {source})",
error.code(),
error.message,
),
None => eprintln!(
"tork: server error [{trace_id}]: {}: {}",
error.code(),
error.message,
),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::response::Response;
use http_body_util::BodyExt;
use serde_json::Value;
async fn body_json(response: Response) -> Value {
let bytes = response.into_body().collect().await.unwrap().to_bytes();
serde_json::from_slice(&bytes).unwrap()
}
#[test]
fn status_mapping_matches_kind() {
assert_eq!(ErrorKind::Forbidden.status(), StatusCode::FORBIDDEN);
assert_eq!(ErrorKind::NotFound.status(), StatusCode::NOT_FOUND);
assert_eq!(
ErrorKind::Internal.status(),
StatusCode::INTERNAL_SERVER_ERROR
);
assert_eq!(
ErrorKind::PayloadTooLarge.status(),
StatusCode::PAYLOAD_TOO_LARGE
);
assert_eq!(
ErrorKind::GatewayTimeout.status(),
StatusCode::GATEWAY_TIMEOUT
);
}
#[tokio::test]
async fn client_error_uses_problem_format() {
let response = Error::forbidden("Access denied").into_response();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
let body = body_json(response).await;
assert_eq!(body["status"], 403);
assert_eq!(body["code"], "FORBIDDEN");
assert_eq!(body["title"], "Forbidden");
assert_eq!(body["message"], "Access denied");
assert!(body.get("details").is_none(), "no details expected: {body}");
assert!(
body["traceId"].as_str().unwrap().starts_with("req-"),
"traceId expected: {body}"
);
assert!(
body["timestamp"].as_str().unwrap().ends_with('Z'),
"timestamp: {body}"
);
}
#[tokio::test]
async fn server_error_is_redacted() {
let response = Error::internal("database password is hunter2").into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = body_json(response).await;
assert_eq!(body["code"], "INTERNAL_SERVER_ERROR");
assert_eq!(body["message"], INTERNAL_ERROR_MESSAGE);
assert!(
!serde_json::to_string(&body).unwrap().contains("hunter2"),
"internal detail must not leak"
);
assert!(body["traceId"].as_str().unwrap().starts_with("req-"));
}
#[tokio::test]
async fn validation_details_are_serialized() {
let response = Error::unprocessable(VALIDATION_ERROR_MESSAGE)
.with_code(VALIDATION_ERROR_CODE)
.with_details(vec![ErrorDetail::new(
"price",
"TOO_SMALL",
"must be greater than 0",
)])
.into_response();
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
let body = body_json(response).await;
assert_eq!(body["code"], "VALIDATION_ERROR");
assert_eq!(body["details"][0]["field"], "price");
assert_eq!(body["details"][0]["issue"], "TOO_SMALL");
assert_eq!(body["details"][0]["message"], "must be greater than 0");
}
#[derive(Debug, PartialEq)]
struct SampleCause(&'static str);
impl std::fmt::Display for SampleCause {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.0)
}
}
impl std::error::Error for SampleCause {}
#[derive(Debug)]
struct OtherCause;
impl std::fmt::Display for OtherCause {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("other")
}
}
impl std::error::Error for OtherCause {}
#[test]
fn with_source_records_the_type() {
let error = Error::internal("boom").with_source(SampleCause("cause"));
assert_eq!(error.source_type, Some(TypeId::of::<SampleCause>()));
}
#[test]
fn take_source_round_trips_the_typed_cause() {
let mut error = Error::internal("boom").with_source(SampleCause("cause"));
assert_eq!(
error.take_source::<SampleCause>(),
Some(SampleCause("cause"))
);
assert_eq!(error.take_source::<SampleCause>(), None);
assert_eq!(error.source_type, None);
}
#[test]
fn take_source_rejects_a_mismatched_type() {
let mut error = Error::internal("boom").with_source(SampleCause("cause"));
assert!(error.take_source::<OtherCause>().is_none());
assert_eq!(error.source_type, Some(TypeId::of::<SampleCause>()));
assert_eq!(
error.take_source::<SampleCause>(),
Some(SampleCause("cause"))
);
}
#[test]
fn take_source_is_none_without_a_source() {
let mut error = Error::internal("boom");
assert!(error.take_source::<SampleCause>().is_none());
}
#[test]
fn from_garde_report_classifies_field_errors() {
use garde::Validate;
#[derive(garde::Validate)]
struct Sample {
#[garde(length(min = 3))]
name: String,
}
let report = Sample {
name: String::new(),
}
.validate()
.unwrap_err();
let error = Error::from_garde_report(report);
assert_eq!(error.code(), "VALIDATION_ERROR");
assert_eq!(error.details().len(), 1);
assert_eq!(error.details()[0].field, "name");
assert_eq!(error.details()[0].issue, "TOO_SHORT");
}
#[test]
fn status_mapping_covers_every_kind() {
use ErrorKind::*;
assert_eq!(BadRequest.status(), StatusCode::BAD_REQUEST);
assert_eq!(Unauthorized.status(), StatusCode::UNAUTHORIZED);
assert_eq!(Forbidden.status(), StatusCode::FORBIDDEN);
assert_eq!(NotFound.status(), StatusCode::NOT_FOUND);
assert_eq!(MethodNotAllowed.status(), StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(Conflict.status(), StatusCode::CONFLICT);
assert_eq!(Unprocessable.status(), StatusCode::UNPROCESSABLE_ENTITY);
assert_eq!(PayloadTooLarge.status(), StatusCode::PAYLOAD_TOO_LARGE);
assert_eq!(TooManyRequests.status(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(Internal.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(ServiceUnavailable.status(), StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(GatewayTimeout.status(), StatusCode::GATEWAY_TIMEOUT);
}
#[test]
fn code_mapping_covers_every_kind() {
use ErrorKind::*;
assert_eq!(BadRequest.code(), "BAD_REQUEST");
assert_eq!(Unauthorized.code(), "UNAUTHORIZED");
assert_eq!(Forbidden.code(), "FORBIDDEN");
assert_eq!(NotFound.code(), "NOT_FOUND");
assert_eq!(MethodNotAllowed.code(), "METHOD_NOT_ALLOWED");
assert_eq!(Conflict.code(), "CONFLICT");
assert_eq!(Unprocessable.code(), "UNPROCESSABLE_ENTITY");
assert_eq!(PayloadTooLarge.code(), "PAYLOAD_TOO_LARGE");
assert_eq!(TooManyRequests.code(), "TOO_MANY_REQUESTS");
assert_eq!(Internal.code(), "INTERNAL_SERVER_ERROR");
assert_eq!(ServiceUnavailable.code(), "SERVICE_UNAVAILABLE");
assert_eq!(GatewayTimeout.code(), "GATEWAY_TIMEOUT");
}
#[test]
fn method_not_allowed_constructor_uses_method_not_allowed_kind() {
let error = Error::method_not_allowed("GET not allowed");
assert_eq!(error.kind(), ErrorKind::MethodNotAllowed);
assert_eq!(error.message(), "GET not allowed");
}
#[test]
fn conflict_constructor_uses_conflict_kind() {
let error = Error::conflict("duplicate key");
assert_eq!(error.kind(), ErrorKind::Conflict);
assert_eq!(error.message(), "duplicate key");
}
#[test]
fn too_many_requests_constructor_uses_too_many_requests_kind() {
let error = Error::too_many_requests("slow down");
assert_eq!(error.kind(), ErrorKind::TooManyRequests);
assert_eq!(error.message(), "slow down");
}
#[test]
fn service_unavailable_constructor_uses_service_unavailable_kind() {
let error = Error::service_unavailable("maintenance");
assert_eq!(error.kind(), ErrorKind::ServiceUnavailable);
assert_eq!(error.message(), "maintenance");
}
#[test]
fn error_trait_source_returns_attached_source() {
use std::error::Error as _;
let error = Error::internal("boom").with_source(SampleCause("inner"));
let source = error.source().expect("source should be present");
assert_eq!(source.to_string(), "inner");
}
#[test]
fn error_trait_source_is_none_when_unset() {
use std::error::Error as _;
let error = Error::internal("boom");
assert!(error.source().is_none());
}
#[test]
fn take_source_restores_state_when_downcast_defensively_fails() {
let mut error = Error::internal("boom");
error.source = Some(Box::new(OtherCause));
error.source_type = Some(TypeId::of::<SampleCause>());
assert!(error.take_source::<SampleCause>().is_none());
assert_eq!(error.source_type, Some(TypeId::of::<SampleCause>()));
}
#[test]
fn sample_cause_display_formats_inner_message() {
assert_eq!(SampleCause("payload").to_string(), "payload");
}
#[test]
fn other_cause_display_formats_inner_message() {
assert_eq!(OtherCause.to_string(), "other");
}
#[test]
fn fallback_body_constant_is_valid_json() {
let parsed: Value = serde_json::from_slice(FALLBACK_ERROR_BODY).unwrap();
assert_eq!(parsed["status"], 500);
assert_eq!(parsed["code"], "INTERNAL_SERVER_ERROR");
}
#[test]
fn classify_issue_recognizes_email_format() {
assert_eq!(classify_issue("email is not valid"), "INVALID_FORMAT");
assert_eq!(classify_issue("Email is invalid"), "INVALID_FORMAT");
}
#[test]
fn classify_issue_recognizes_too_long() {
assert_eq!(classify_issue("length is greater than 10"), "TOO_LONG");
}
#[test]
fn classify_issue_recognizes_strict_numeric_bounds() {
assert_eq!(classify_issue("value must be greater than 0"), "TOO_SMALL");
assert_eq!(classify_issue("value must be less than 100"), "TOO_LARGE");
}
#[test]
fn classify_issue_falls_back_to_generic() {
assert_eq!(classify_issue("something unrelated"), "INVALID");
assert_eq!(classify_issue(""), "INVALID");
}
}