use axum::{
Json,
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::Serialize;
use std::collections::HashMap;
#[derive(Debug, thiserror::Error)]
pub enum TidewayError {
#[error("Not found: {0}")]
NotFound(String),
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Unauthorized: {0}")]
Unauthorized(String),
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Internal server error: {0}")]
Internal(String),
#[error("Service unavailable: {0}")]
ServiceUnavailable(String),
#[error("Request timeout")]
RequestTimeout,
#[error("Too many requests: {0}")]
TooManyRequests(String),
#[error(transparent)]
Anyhow(#[from] anyhow::Error),
#[cfg(feature = "database")]
#[error("Database error: {0}")]
Database(String),
}
#[derive(Debug, Clone, Default)]
pub struct ErrorContext {
pub error_id: Option<String>,
pub details: Option<String>,
pub context: HashMap<String, String>,
pub field_errors: HashMap<String, Vec<String>>,
}
impl ErrorContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_error_id(mut self, id: impl Into<String>) -> Self {
self.error_id = Some(id.into());
self
}
pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
self.details = Some(detail.into());
self
}
pub fn with_context(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.context.insert(key.into(), value.into());
self
}
pub fn with_field_error(mut self, field: impl Into<String>, error: impl Into<String>) -> Self {
self.field_errors
.entry(field.into())
.or_insert_with(Vec::new)
.push(error.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ErrorInfo {
pub context: ErrorContext,
pub stack_trace: Option<String>,
}
impl ErrorInfo {
pub fn new() -> Self {
Self {
context: ErrorContext::new(),
stack_trace: None,
}
}
pub fn with_context(mut self, context: ErrorContext) -> Self {
self.context = context;
self
}
pub fn with_stack_trace(mut self, stack_trace: impl Into<String>) -> Self {
self.stack_trace = Some(stack_trace.into());
self
}
}
#[derive(Debug)]
pub struct ErrorWithContext {
error: TidewayError,
context: ErrorContext,
}
impl ErrorWithContext {
pub fn new(error: TidewayError, context: ErrorContext) -> Self {
Self { error, context }
}
pub fn into_error_info(self) -> ErrorInfo {
ErrorInfo::new().with_context(self.context)
}
pub fn error(&self) -> &TidewayError {
&self.error
}
pub fn context(&self) -> &ErrorContext {
&self.context
}
}
impl From<ErrorWithContext> for TidewayError {
fn from(err: ErrorWithContext) -> Self {
err.error
}
}
impl IntoResponse for ErrorWithContext {
fn into_response(self) -> Response {
let error = self.error;
let error_info = ErrorInfo::new().with_context(self.context);
error.into_response_with_info(Some(error_info), false)
}
}
#[derive(Serialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
struct ErrorResponse {
error: String,
#[serde(skip_serializing_if = "Option::is_none")]
error_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
details: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
context: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
field_errors: Option<HashMap<String, Vec<String>>>,
#[serde(skip_serializing_if = "Option::is_none")]
stack_trace: Option<String>,
}
impl TidewayError {
pub fn not_found(msg: impl Into<String>) -> Self {
Self::NotFound(msg.into())
}
pub fn bad_request(msg: impl Into<String>) -> Self {
Self::BadRequest(msg.into())
}
pub fn unauthorized(msg: impl Into<String>) -> Self {
Self::Unauthorized(msg.into())
}
pub fn forbidden(msg: impl Into<String>) -> Self {
Self::Forbidden(msg.into())
}
pub fn internal(msg: impl Into<String>) -> Self {
Self::Internal(msg.into())
}
pub fn service_unavailable(msg: impl Into<String>) -> Self {
Self::ServiceUnavailable(msg.into())
}
pub fn request_timeout() -> Self {
Self::RequestTimeout
}
pub fn too_many_requests(msg: impl Into<String>) -> Self {
Self::TooManyRequests(msg.into())
}
pub fn with_context(self, context: ErrorContext) -> ErrorWithContext {
ErrorWithContext::new(self, context)
}
pub fn into_response_with_info(self, info: Option<ErrorInfo>, dev_mode: bool) -> Response {
let status = self.status_code();
let error_msg = self.to_string();
let mut response = ErrorResponse {
error: error_msg,
error_id: None,
details: None,
context: None,
field_errors: None,
stack_trace: None,
};
if let Some(info) = info {
response.error_id = info.context.error_id;
response.details = info.context.details;
if !info.context.context.is_empty() {
response.context = Some(info.context.context);
}
if !info.context.field_errors.is_empty() {
response.field_errors = Some(info.context.field_errors);
}
if dev_mode {
response.stack_trace = info.stack_trace;
}
}
let error_id = response.error_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
response.error_id = Some(error_id.clone());
let body = Json(response);
tracing::error!(
status = status.as_u16(),
error_id = %error_id,
error = ?self,
"Request failed"
);
(status, body).into_response()
}
fn status_code(&self) -> StatusCode {
match self {
Self::NotFound(_) => StatusCode::NOT_FOUND,
Self::BadRequest(_) => StatusCode::BAD_REQUEST,
Self::Unauthorized(_) => StatusCode::UNAUTHORIZED,
Self::Forbidden(_) => StatusCode::FORBIDDEN,
Self::Internal(_) | Self::Anyhow(_) => StatusCode::INTERNAL_SERVER_ERROR,
#[cfg(feature = "database")]
Self::Database(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::ServiceUnavailable(_) => StatusCode::SERVICE_UNAVAILABLE,
Self::RequestTimeout => StatusCode::REQUEST_TIMEOUT,
Self::TooManyRequests(_) => StatusCode::TOO_MANY_REQUESTS,
}
}
}
impl IntoResponse for TidewayError {
fn into_response(self) -> Response {
self.into_response_with_info(None, false)
}
}
pub type Result<T> = std::result::Result<T, TidewayError>;
#[cfg(feature = "database")]
impl From<sea_orm::DbErr> for TidewayError {
fn from(err: sea_orm::DbErr) -> Self {
match &err {
sea_orm::DbErr::RecordNotFound(msg) => TidewayError::NotFound(if msg.is_empty() {
"Record not found".to_string()
} else {
msg.clone()
}),
sea_orm::DbErr::Query(inner) => {
TidewayError::Database(format!("Query error: {}", inner))
}
sea_orm::DbErr::Exec(inner) => {
TidewayError::Database(format!("Execution error: {}", inner))
}
sea_orm::DbErr::Conn(inner) => {
TidewayError::Database(format!("Connection error: {}", inner))
}
sea_orm::DbErr::Type(inner) => TidewayError::Database(format!("Type error: {}", inner)),
sea_orm::DbErr::Json(inner) => TidewayError::Database(format!("JSON error: {}", inner)),
sea_orm::DbErr::Migration(inner) => {
TidewayError::Database(format!("Migration error: {}", inner))
}
_ => TidewayError::Database(format!("Database error: {}", err)),
}
}
}