use axum::{
Json,
http::StatusCode,
response::{IntoResponse, Response},
};
use sea_orm::DbErr;
use serde::{Deserialize, Serialize};
use std::fmt;
use utoipa::ToSchema;
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct BatchFailure {
pub index: usize,
pub error: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct BatchResult<T> {
pub succeeded: Vec<T>,
pub failed: Vec<BatchFailure>,
}
impl<T> BatchResult<T> {
#[must_use]
pub fn new() -> Self {
Self {
succeeded: Vec::new(),
failed: Vec::new(),
}
}
pub fn add_success(&mut self, item: T) {
self.succeeded.push(item);
}
pub fn add_failure(&mut self, index: usize, error: impl Into<String>) {
self.failed.push(BatchFailure {
index,
error: error.into(),
});
}
#[must_use]
pub fn all_failed(&self) -> bool {
self.succeeded.is_empty() && !self.failed.is_empty()
}
#[must_use]
pub fn all_succeeded(&self) -> bool {
!self.succeeded.is_empty() && self.failed.is_empty()
}
#[must_use]
pub fn is_partial(&self) -> bool {
!self.succeeded.is_empty() && !self.failed.is_empty()
}
}
impl<T> Default for BatchResult<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum ApiError {
NotFound {
resource: String,
id: Option<String>,
},
BadRequest {
message: String,
},
Unauthorized {
message: String,
},
Forbidden {
message: String,
},
Conflict {
message: String,
},
ValidationFailed {
errors: Vec<String>,
},
Database {
message: String,
internal: DbErr,
},
Internal {
message: String,
internal: Option<String>,
},
Custom {
status: StatusCode,
message: String,
internal: Option<String>,
},
}
impl ApiError {
pub fn not_found(resource: impl Into<String>, id: Option<String>) -> Self {
Self::NotFound {
resource: resource.into(),
id,
}
}
pub fn bad_request(message: impl Into<String>) -> Self {
Self::BadRequest {
message: message.into(),
}
}
pub fn unauthorized(message: impl Into<String>) -> Self {
Self::Unauthorized {
message: message.into(),
}
}
pub fn forbidden(message: impl Into<String>) -> Self {
Self::Forbidden {
message: message.into(),
}
}
pub fn conflict(message: impl Into<String>) -> Self {
Self::Conflict {
message: message.into(),
}
}
#[must_use]
pub fn validation_failed(errors: Vec<String>) -> Self {
Self::ValidationFailed { errors }
}
#[must_use]
pub fn database(err: DbErr) -> Self {
Self::Database {
message: "A database error occurred".to_string(),
internal: err,
}
}
pub fn internal(message: impl Into<String>, internal: Option<String>) -> Self {
Self::Internal {
message: message.into(),
internal,
}
}
pub fn custom(
status: StatusCode,
message: impl Into<String>,
internal: Option<String>,
) -> Self {
Self::Custom {
status,
message: message.into(),
internal,
}
}
fn status_code(&self) -> StatusCode {
match self {
Self::NotFound { .. } => StatusCode::NOT_FOUND,
Self::BadRequest { .. } => StatusCode::BAD_REQUEST,
Self::Unauthorized { .. } => StatusCode::UNAUTHORIZED,
Self::Forbidden { .. } => StatusCode::FORBIDDEN,
Self::Conflict { .. } => StatusCode::CONFLICT,
Self::ValidationFailed { .. } => StatusCode::UNPROCESSABLE_ENTITY,
Self::Database { .. } => StatusCode::INTERNAL_SERVER_ERROR,
Self::Internal { .. } => StatusCode::INTERNAL_SERVER_ERROR,
Self::Custom { status, .. } => *status,
}
}
fn user_message(&self) -> String {
match self {
Self::NotFound { resource, id } => {
if let Some(id) = id {
format!("{resource} with ID '{id}' not found")
} else {
format!("{resource} not found")
}
}
Self::BadRequest { message } => message.clone(),
Self::Unauthorized { message } => message.clone(),
Self::Forbidden { message } => message.clone(),
Self::Conflict { message } => message.clone(),
Self::ValidationFailed { errors } => {
if errors.len() == 1 {
errors[0].clone()
} else {
format!("Validation failed: {}", errors.join(", "))
}
}
Self::Database { message, .. } => message.clone(),
Self::Internal { message, .. } => message.clone(),
Self::Custom { message, .. } => message.clone(),
}
}
fn log_internal(&self) {
match self {
Self::Database { internal, .. } => {
tracing::error!(
error = ?internal,
"Database error occurred"
);
}
Self::Internal {
internal: Some(details),
..
} => {
tracing::error!(
details = %details,
"Internal error occurred"
);
}
Self::Custom {
internal: Some(details),
status,
..
} => {
tracing::error!(
status = %status,
details = %details,
"Custom error occurred"
);
}
_ => {
tracing::debug!(
error = %self.user_message(),
status = %self.status_code(),
"API error"
);
}
}
}
}
#[derive(Serialize)]
struct ErrorResponse {
error: String,
#[serde(skip_serializing_if = "Option::is_none")]
details: Option<Vec<String>>,
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
self.log_internal();
let status = self.status_code();
let response = match &self {
Self::ValidationFailed { errors } => ErrorResponse {
error: "Validation failed".to_string(),
details: Some(errors.clone()),
},
_ => ErrorResponse {
error: self.user_message(),
details: None,
},
};
(status, Json(response)).into_response()
}
}
impl fmt::Display for ApiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.user_message())
}
}
impl std::error::Error for ApiError {}
impl From<DbErr> for ApiError {
fn from(err: DbErr) -> Self {
match &err {
DbErr::RecordNotFound(msg) => {
let resource = msg.split_whitespace().next().unwrap_or("Resource");
Self::NotFound {
resource: resource.to_string(),
id: None,
}
}
_ => Self::Database {
message: "A database error occurred".to_string(),
internal: err,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_not_found_with_id() {
let err = ApiError::not_found("User", Some("123".to_string()));
assert_eq!(err.status_code(), StatusCode::NOT_FOUND);
assert_eq!(err.user_message(), "User with ID '123' not found");
}
#[test]
fn test_not_found_without_id() {
let err = ApiError::not_found("User", None);
assert_eq!(err.status_code(), StatusCode::NOT_FOUND);
assert_eq!(err.user_message(), "User not found");
}
#[test]
fn test_bad_request() {
let err = ApiError::bad_request("Invalid email format");
assert_eq!(err.status_code(), StatusCode::BAD_REQUEST);
assert_eq!(err.user_message(), "Invalid email format");
}
#[test]
fn test_unauthorized() {
let err = ApiError::unauthorized("Invalid credentials");
assert_eq!(err.status_code(), StatusCode::UNAUTHORIZED);
assert_eq!(err.user_message(), "Invalid credentials");
}
#[test]
fn test_forbidden() {
let err = ApiError::forbidden("Insufficient permissions");
assert_eq!(err.status_code(), StatusCode::FORBIDDEN);
assert_eq!(err.user_message(), "Insufficient permissions");
}
#[test]
fn test_conflict() {
let err = ApiError::conflict("Email already exists");
assert_eq!(err.status_code(), StatusCode::CONFLICT);
assert_eq!(err.user_message(), "Email already exists");
}
#[test]
fn test_validation_failed_single_error() {
let err = ApiError::validation_failed(vec!["Email is required".to_string()]);
assert_eq!(err.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
assert_eq!(err.user_message(), "Email is required");
}
#[test]
fn test_validation_failed_multiple_errors() {
let err = ApiError::validation_failed(vec![
"Email is required".to_string(),
"Password too short".to_string(),
]);
assert_eq!(err.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
assert_eq!(
err.user_message(),
"Validation failed: Email is required, Password too short"
);
}
#[test]
fn test_database_error() {
let db_err = DbErr::Type("Type mismatch error".to_string());
let err = ApiError::database(db_err);
assert_eq!(err.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(err.user_message(), "A database error occurred");
}
#[test]
fn test_internal_error_with_details() {
let err = ApiError::internal(
"Processing failed",
Some("Null pointer exception".to_string()),
);
assert_eq!(err.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(err.user_message(), "Processing failed");
}
#[test]
fn test_internal_error_without_details() {
let err = ApiError::internal("Processing failed", None);
assert_eq!(err.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(err.user_message(), "Processing failed");
}
#[test]
fn test_custom_error() {
let err = ApiError::custom(
StatusCode::TOO_MANY_REQUESTS,
"Rate limit exceeded",
Some("User hit 100 req/min".to_string()),
);
assert_eq!(err.status_code(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(err.user_message(), "Rate limit exceeded");
}
#[test]
fn test_dberr_record_not_found_conversion() {
let db_err = DbErr::RecordNotFound("User not found".to_string());
let api_err: ApiError = db_err.into();
assert_eq!(api_err.status_code(), StatusCode::NOT_FOUND);
assert!(api_err.user_message().contains("not found"));
}
#[test]
fn test_dberr_custom_becomes_internal() {
let db_err = DbErr::Custom("Something went wrong".to_string());
let api_err: ApiError = db_err.into();
assert_eq!(api_err.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(api_err.user_message(), "A database error occurred");
}
#[test]
fn test_dberr_type_error() {
let db_err = DbErr::Type("Type conversion failed".to_string());
let api_err: ApiError = db_err.into();
assert_eq!(api_err.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(api_err.user_message(), "A database error occurred");
}
#[test]
fn test_dberr_json_error() {
let db_err = DbErr::Json("JSON parsing failed".to_string());
let api_err: ApiError = db_err.into();
assert_eq!(api_err.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(api_err.user_message(), "A database error occurred");
}
#[test]
fn test_dberr_record_not_found_becomes_404() {
let db_err = DbErr::RecordNotFound("Blog post not found".to_string());
let api_err: ApiError = db_err.into();
assert_eq!(api_err.status_code(), StatusCode::NOT_FOUND);
assert!(api_err.user_message().contains("not found"));
}
#[test]
fn test_all_other_dberr_become_500() {
let test_cases = vec![
DbErr::Custom("Any custom error".to_string()),
DbErr::Type("Type error".to_string()),
DbErr::Json("JSON error".to_string()),
];
for db_err in test_cases {
let api_err: ApiError = db_err.into();
assert_eq!(api_err.status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(api_err.user_message(), "A database error occurred");
}
}
#[test]
fn test_display_trait() {
let err = ApiError::bad_request("Test error");
assert_eq!(format!("{}", err), "Test error");
}
#[test]
fn test_error_trait() {
let err = ApiError::bad_request("Test error");
let _: &dyn std::error::Error = &err; }
#[test]
fn test_all_status_codes() {
let test_cases = vec![
(ApiError::not_found("Test", None), StatusCode::NOT_FOUND),
(ApiError::bad_request("Test"), StatusCode::BAD_REQUEST),
(ApiError::unauthorized("Test"), StatusCode::UNAUTHORIZED),
(ApiError::forbidden("Test"), StatusCode::FORBIDDEN),
(ApiError::conflict("Test"), StatusCode::CONFLICT),
(
ApiError::validation_failed(vec!["Test".to_string()]),
StatusCode::UNPROCESSABLE_ENTITY,
),
(
ApiError::database(DbErr::Conn(sea_orm::RuntimeErr::Internal(
"Test".to_string(),
))),
StatusCode::INTERNAL_SERVER_ERROR,
),
(
ApiError::internal("Test", None),
StatusCode::INTERNAL_SERVER_ERROR,
),
(
ApiError::custom(StatusCode::IM_A_TEAPOT, "Test", None),
StatusCode::IM_A_TEAPOT,
),
];
for (err, expected_status) in test_cases {
assert_eq!(err.status_code(), expected_status);
}
}
}