use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
#[derive(Debug)]
pub struct ServerError {
pub status: StatusCode,
pub message: String,
pub source: Option<anyhow::Error>,
pub context: Vec<(&'static str, String)>,
}
impl ServerError {
pub fn new(status: StatusCode, message: impl Into<String>) -> Self {
Self {
status,
message: message.into(),
source: None,
context: Vec::new(),
}
}
pub fn from_anyhow(
source: anyhow::Error,
status: StatusCode,
message: impl Into<String>,
) -> Self {
Self {
status,
message: message.into(),
source: Some(source),
context: Vec::new(),
}
}
pub fn with_context(mut self, key: &'static str, value: impl Into<String>) -> Self {
self.context.push((key, value.into()));
self
}
#[allow(dead_code)]
pub fn internal(message: impl Into<String>) -> Self {
Self::new(StatusCode::INTERNAL_SERVER_ERROR, message)
}
pub fn internal_anyhow(source: anyhow::Error, message: impl Into<String>) -> Self {
Self::from_anyhow(source, StatusCode::INTERNAL_SERVER_ERROR, message)
}
pub fn bad_request(message: impl Into<String>) -> Self {
Self::new(StatusCode::BAD_REQUEST, message)
}
pub fn forbidden(message: impl Into<String>) -> Self {
Self::new(StatusCode::FORBIDDEN, message)
}
pub fn not_found(message: impl Into<String>) -> Self {
Self::new(StatusCode::NOT_FOUND, message)
}
}
impl IntoResponse for ServerError {
fn into_response(self) -> Response {
if self.status.is_server_error() {
if let Some(source) = &self.source {
tracing::error!(
status = self.status.as_u16(),
message = %self.message,
context = ?self.context,
error = ?source,
"Server error"
);
} else {
tracing::error!(
status = self.status.as_u16(),
message = %self.message,
context = ?self.context,
"Server error"
);
}
}
let body = Json(json!({
"error": self.message,
}));
(self.status, body).into_response()
}
}
impl From<sqlx::Error> for ServerError {
fn from(err: sqlx::Error) -> Self {
Self::internal_anyhow(err.into(), "Database operation failed")
}
}
impl From<anyhow::Error> for ServerError {
fn from(err: anyhow::Error) -> Self {
Self::internal_anyhow(err, "Internal server error")
}
}
pub trait ServerErrorExt<T> {
fn server_err(self, status: StatusCode, message: impl Into<String>) -> Result<T, ServerError>;
fn internal_err(self, message: impl Into<String>) -> Result<T, ServerError>;
}
impl<T, E> ServerErrorExt<T> for Result<T, E>
where
E: Into<anyhow::Error>,
{
fn server_err(self, status: StatusCode, message: impl Into<String>) -> Result<T, ServerError> {
self.map_err(|e| ServerError::from_anyhow(e.into(), status, message))
}
fn internal_err(self, message: impl Into<String>) -> Result<T, ServerError> {
self.map_err(|e| ServerError::internal_anyhow(e.into(), message))
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::to_bytes;
use axum::http::StatusCode;
#[tokio::test]
async fn test_server_error_response_shape() {
let error = ServerError::bad_request("Invalid input");
let response = error.into_response();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body_json["error"], "Invalid input");
}
#[tokio::test]
async fn test_server_error_with_context() {
let error = ServerError::internal("Database error")
.with_context("user_id", "123")
.with_context("operation", "fetch_user");
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body_json["error"], "Database error");
}
#[tokio::test]
async fn test_server_error_from_anyhow() {
let anyhow_err = anyhow::anyhow!("Something went wrong");
let error = ServerError::from_anyhow(
anyhow_err,
StatusCode::INTERNAL_SERVER_ERROR,
"Operation failed",
);
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body_json["error"], "Operation failed");
}
#[tokio::test]
async fn test_server_error_status_codes() {
let bad_request = ServerError::bad_request("Bad input");
assert_eq!(bad_request.status, StatusCode::BAD_REQUEST);
let forbidden = ServerError::forbidden("Access denied");
assert_eq!(forbidden.status, StatusCode::FORBIDDEN);
let not_found = ServerError::not_found("Resource missing");
assert_eq!(not_found.status, StatusCode::NOT_FOUND);
let internal = ServerError::internal("Server error");
assert_eq!(internal.status, StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn test_server_error_ext_trait() {
let result: Result<(), std::io::Error> = Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"file not found",
));
let error = result
.server_err(StatusCode::NOT_FOUND, "File operation failed")
.unwrap_err();
assert_eq!(error.status, StatusCode::NOT_FOUND);
assert_eq!(error.message, "File operation failed");
let result2: Result<(), std::io::Error> = Err(std::io::Error::other("io error"));
let error2 = result2
.internal_err("Internal operation failed")
.unwrap_err();
assert_eq!(error2.status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(error2.message, "Internal operation failed");
}
}