use serde::Serialize;
use std::fmt;
use crate::response::{APPLICATION_JSON, APPLICATION_PROBLEM_JSON, BoxBody, IntoResponse};
use bytes::Bytes;
use http::header::CONTENT_TYPE;
use http_body_util::Full;
#[derive(Debug, Clone)]
pub struct ErrorConfig {
pub use_rfc7807: bool,
pub base_uri: String,
}
impl Default for ErrorConfig {
fn default() -> Self {
Self {
use_rfc7807: false,
base_uri: "about:blank".to_string(),
}
}
}
tokio::task_local! {
pub(crate) static ERROR_CONFIG: ErrorConfig;
}
pub mod standard {
use serde::Serialize;
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
pub trace_id: String,
}
#[derive(Debug, Serialize)]
pub struct ErrorDetail {
pub code: String,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<serde_json::Value>,
}
}
pub mod rfc7807 {
use serde::Serialize;
#[derive(Debug, Serialize, PartialEq)]
pub struct ProblemDetails {
#[serde(rename = "type")]
pub type_uri: String,
pub title: String,
pub status: u16,
pub detail: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub instance: Option<String>,
pub trace_id: String,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub extensions: Option<serde_json::Value>,
}
pub(crate) fn get_error_type_uri(base_uri: &str, code: &str) -> String {
if base_uri == "about:blank" {
return base_uri.to_string();
}
format!(
"{}/{}",
base_uri.trim_end_matches('/'),
code.to_lowercase().replace('_', "-")
)
}
pub(crate) fn get_error_title(code: &str) -> String {
code.split('_')
.map(|s| {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + &c.as_str().to_lowercase(),
}
})
.collect::<Vec<_>>()
.join(" ")
}
}
#[derive(Debug)]
struct ErrorInner {
status: u16,
code: String,
message: String,
details: Option<serde_json::Value>,
trace_id: Option<String>,
instance: Option<String>,
}
#[derive(Debug)]
pub struct Error(Box<ErrorInner>);
impl Error {
pub fn new(status: u16, code: impl Into<String>, message: impl Into<String>) -> Self {
Self(Box::new(ErrorInner {
status,
code: code.into(),
message: message.into(),
details: None,
trace_id: None,
instance: None,
}))
}
pub fn status(&self) -> u16 {
self.0.status
}
pub fn code(&self) -> &str {
&self.0.code
}
pub fn message(&self) -> &str {
&self.0.message
}
pub fn details(&self) -> Option<&serde_json::Value> {
self.0.details.as_ref()
}
pub fn trace_id(&self) -> Option<&str> {
self.0.trace_id.as_deref()
}
pub fn instance(&self) -> Option<&str> {
self.0.instance.as_deref()
}
pub fn with_details(mut self, details: serde_json::Value) -> Self {
self.0.details = Some(details);
self
}
pub fn with_trace_id(mut self, trace_id: impl Into<String>) -> Self {
self.0.trace_id = Some(trace_id.into());
self
}
pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
self.0.instance = Some(instance.into());
self
}
pub fn bad_request(message: impl Into<String>) -> Self {
Self::new(400, "BAD_REQUEST", message)
}
pub fn unauthorized(message: impl Into<String>) -> Self {
Self::new(401, "UNAUTHORIZED", message)
}
pub fn forbidden(message: impl Into<String>) -> Self {
Self::new(403, "FORBIDDEN", message)
}
pub fn not_found(message: impl Into<String>) -> Self {
Self::new(404, "NOT_FOUND", message)
}
pub fn conflict(message: impl Into<String>) -> Self {
Self::new(409, "CONFLICT", message)
}
pub fn validation(message: impl Into<String>) -> Self {
Self::new(422, "VALIDATION_ERROR", message)
}
pub fn rate_limited(message: impl Into<String>) -> Self {
Self::new(429, "RATE_LIMITED", message)
}
pub fn internal(message: impl Into<String>) -> Self {
Self::new(500, "INTERNAL_ERROR", message)
}
pub fn to_rfc7807_response(&self, trace_id: String, base_uri: &str) -> rfc7807::ProblemDetails {
rfc7807::ProblemDetails {
type_uri: rfc7807::get_error_type_uri(base_uri, &self.0.code),
title: rfc7807::get_error_title(&self.0.code),
status: self.0.status,
detail: self.0.message.clone(),
instance: self.0.instance.clone(),
trace_id,
extensions: self.0.details.clone(),
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.0.code, self.0.message)
}
}
impl std::error::Error for Error {}
pub trait IntoApiError {
fn into_api_error(self) -> Error;
}
impl<T: IntoApiError> From<T> for Error {
fn from(err: T) -> Self {
err.into_api_error()
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct ErrorVariant {
pub status: u16,
pub code: &'static str,
pub description: &'static str,
}
pub trait DocumentedError: IntoApiError {
fn error_variants() -> Vec<ErrorVariant>;
}
impl IntoResponse for Error {
fn into_response(self) -> http::Response<BoxBody> {
let trace_id = self
.0
.trace_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let config = ERROR_CONFIG.try_with(|c| c.clone()).unwrap_or_default();
if config.use_rfc7807 {
let response = self.to_rfc7807_response(trace_id, &config.base_uri);
let body = serde_json::to_vec(&response).unwrap_or_default();
http::Response::builder()
.status(self.0.status)
.header(CONTENT_TYPE, APPLICATION_PROBLEM_JSON)
.body(Full::new(Bytes::from(body)))
.unwrap()
} else {
let response = standard::ErrorResponse {
error: standard::ErrorDetail {
code: self.0.code.clone(),
message: self.0.message.clone(),
details: self.0.details.clone(),
},
trace_id,
};
let body = serde_json::to_vec(&response).unwrap_or_default();
http::Response::builder()
.status(self.0.status)
.header(CONTENT_TYPE, APPLICATION_JSON)
.body(Full::new(Bytes::from(body)))
.unwrap()
}
}
}
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
mod tests {
use super::*;
use http_body_util::BodyExt;
#[derive(Debug)]
enum TestUserError {
NotFound(u64),
EmailTaken(String),
}
impl IntoApiError for TestUserError {
fn into_api_error(self) -> Error {
match self {
TestUserError::NotFound(id) => Error::not_found(format!("user {} not found", id)),
TestUserError::EmailTaken(email) => {
Error::conflict(format!("email {} already taken", email))
}
}
}
}
impl DocumentedError for TestUserError {
fn error_variants() -> Vec<ErrorVariant> {
vec![
ErrorVariant {
status: 404,
code: "NOT_FOUND",
description: "User not found",
},
ErrorVariant {
status: 409,
code: "CONFLICT",
description: "Email already taken",
},
]
}
}
#[test]
fn test_into_api_error_not_found() {
let domain_err = TestUserError::NotFound(42);
let api_err: Error = domain_err.into_api_error();
assert_eq!(api_err.status(), 404);
assert_eq!(api_err.code(), "NOT_FOUND");
assert_eq!(api_err.message(), "user 42 not found");
}
#[test]
fn test_into_api_error_conflict() {
let domain_err = TestUserError::EmailTaken("test@example.com".to_string());
let api_err: Error = domain_err.into_api_error();
assert_eq!(api_err.status(), 409);
assert_eq!(api_err.code(), "CONFLICT");
assert_eq!(api_err.message(), "email test@example.com already taken");
}
#[test]
fn test_domain_error_from_conversion() {
let domain_err = TestUserError::NotFound(123);
let api_err = Error::from(domain_err);
assert_eq!(api_err.status(), 404);
assert_eq!(api_err.code(), "NOT_FOUND");
}
#[test]
fn test_documented_error_variants() {
let variants = TestUserError::error_variants();
assert_eq!(variants.len(), 2);
assert_eq!(variants[0].status, 404);
assert_eq!(variants[0].code, "NOT_FOUND");
assert_eq!(variants[1].status, 409);
assert_eq!(variants[1].code, "CONFLICT");
}
#[test]
fn test_error_new() {
let err = Error::new(500, "TEST_ERROR", "test message");
assert_eq!(err.status(), 500);
assert_eq!(err.code(), "TEST_ERROR");
assert_eq!(err.message(), "test message");
assert!(err.details().is_none());
assert!(err.trace_id().is_none());
}
#[test]
fn test_error_bad_request() {
let err = Error::bad_request("invalid input");
assert_eq!(err.status(), 400);
assert_eq!(err.code(), "BAD_REQUEST");
assert_eq!(err.message(), "invalid input");
}
#[test]
fn test_error_unauthorized() {
let err = Error::unauthorized("not authenticated");
assert_eq!(err.status(), 401);
assert_eq!(err.code(), "UNAUTHORIZED");
}
#[test]
fn test_error_forbidden() {
let err = Error::forbidden("access denied");
assert_eq!(err.status(), 403);
assert_eq!(err.code(), "FORBIDDEN");
}
#[test]
fn test_error_not_found() {
let err = Error::not_found("resource not found");
assert_eq!(err.status(), 404);
assert_eq!(err.code(), "NOT_FOUND");
}
#[test]
fn test_error_conflict() {
let err = Error::conflict("already exists");
assert_eq!(err.status(), 409);
assert_eq!(err.code(), "CONFLICT");
}
#[test]
fn test_error_validation() {
let err = Error::validation("invalid data");
assert_eq!(err.status(), 422);
assert_eq!(err.code(), "VALIDATION_ERROR");
}
#[test]
fn test_error_rate_limited() {
let err = Error::rate_limited("too many requests");
assert_eq!(err.status(), 429);
assert_eq!(err.code(), "RATE_LIMITED");
}
#[test]
fn test_error_internal() {
let err = Error::internal("server error");
assert_eq!(err.status(), 500);
assert_eq!(err.code(), "INTERNAL_ERROR");
}
#[test]
fn test_error_with_details() {
let details = serde_json::json!({"field": "email", "error": "invalid format"});
let err = Error::bad_request("validation failed").with_details(details.clone());
assert_eq!(err.details(), Some(&details));
}
#[test]
fn test_error_with_trace_id() {
let err = Error::bad_request("test").with_trace_id("trace-123");
assert_eq!(err.trace_id(), Some("trace-123"));
}
#[test]
fn test_error_display() {
let err = Error::bad_request("invalid input");
let display = format!("{}", err);
assert_eq!(display, "BAD_REQUEST: invalid input");
}
#[test]
fn test_error_to_rfc7807_response() {
let err = Error::not_found("user not found");
let response = err.to_rfc7807_response("trace-abc".to_string(), "https://myapp.com/errors");
assert_eq!(response.trace_id, "trace-abc");
assert_eq!(response.type_uri, "https://myapp.com/errors/not-found");
assert_eq!(response.title, "Not Found");
assert_eq!(response.detail, "user not found");
}
#[tokio::test]
async fn test_error_into_rfc7807_response() {
let config = ErrorConfig {
use_rfc7807: true,
base_uri: "https://myapp.com/errors".to_string(),
};
ERROR_CONFIG
.scope(config, async {
let err = Error::bad_request("test error").with_trace_id("my-trace");
let response = err.into_response();
assert_eq!(response.status(), 400);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/problem+json"
);
let body = response.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["type"], "https://myapp.com/errors/bad-request");
assert_eq!(json["title"], "Bad Request");
assert_eq!(json["status"], 400);
assert_eq!(json["detail"], "test error");
assert_eq!(json["trace_id"], "my-trace");
})
.await;
}
#[tokio::test]
async fn test_error_into_standard_response() {
ERROR_CONFIG
.scope(ErrorConfig::default(), async {
let err = Error::bad_request("test error").with_trace_id("my-trace");
let response = err.into_response();
assert_eq!(response.status(), 400);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/json"
);
let body = response.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"]["code"], "BAD_REQUEST");
assert_eq!(json["error"]["message"], "test error");
assert_eq!(json["trace_id"], "my-trace");
})
.await;
}
#[tokio::test]
async fn test_error_into_response_generates_trace_id() {
let config = ErrorConfig {
use_rfc7807: true,
base_uri: "https://myapp.com/errors".to_string(),
};
ERROR_CONFIG
.scope(config, async {
let err = Error::internal("error");
let response = err.into_response();
let body = response.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let trace_id = json["trace_id"].as_str().unwrap();
assert_eq!(trace_id.len(), 36); })
.await;
}
#[test]
fn test_error_response_skips_none_details() {
let err = Error::bad_request("test");
let response = err.to_rfc7807_response("trace".to_string(), "https://myapp.com/errors");
let json = serde_json::to_string(&response).unwrap();
assert!(!json.contains("extensions"));
assert!(!json.contains("\"error\":"));
assert!(!json.contains("\"details\":"));
}
#[test]
fn test_error_response_includes_details() {
let details = serde_json::json!({"field": "email"});
let err = Error::bad_request("test").with_details(details);
let response = err.to_rfc7807_response("trace".to_string(), "https://myapp.com/errors");
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("\"field\":\"email\""));
}
#[test]
fn test_error_is_std_error() {
let err = Error::internal("test");
let _: &dyn std::error::Error = &err;
}
#[test]
fn test_error_builder_chain() {
let details = serde_json::json!({"field": "name"});
let err = Error::validation("invalid")
.with_details(details.clone())
.with_trace_id("trace-123")
.with_instance("/users/1");
assert_eq!(err.status(), 422);
assert_eq!(err.code(), "VALIDATION_ERROR");
assert_eq!(err.details(), Some(&details));
assert_eq!(err.trace_id(), Some("trace-123"));
assert_eq!(err.instance(), Some("/users/1"));
}
#[test]
fn test_get_error_type_uri() {
assert_eq!(
rfc7807::get_error_type_uri("https://myapp.com/errors", "NOT_FOUND"),
"https://myapp.com/errors/not-found"
);
assert_eq!(
rfc7807::get_error_type_uri("https://myapp.com/errors", "BAD_REQUEST"),
"https://myapp.com/errors/bad-request"
);
}
#[test]
fn test_get_error_type_uri_about_blank() {
assert_eq!(
rfc7807::get_error_type_uri("about:blank", "NOT_FOUND"),
"about:blank"
);
}
#[test]
fn test_get_error_type_uri_trailing_slash() {
assert_eq!(
rfc7807::get_error_type_uri("https://myapp.com/errors/", "NOT_FOUND"),
"https://myapp.com/errors/not-found"
);
}
#[test]
fn test_get_error_title() {
assert_eq!(rfc7807::get_error_title("NOT_FOUND"), "Not Found");
assert_eq!(rfc7807::get_error_title("BAD_REQUEST"), "Bad Request");
assert_eq!(
rfc7807::get_error_title("INTERNAL_SERVER_ERROR"),
"Internal Server Error"
);
}
}