use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
use crate::services::{AnalysisError, AudioError, EmbeddingError, InterpretationError, VectorError};
#[derive(Debug, Error)]
pub enum ApiError {
#[error("Resource not found: {0}")]
NotFound(String),
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Unauthorized: {0}")]
Unauthorized(String),
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Conflict: {0}")]
Conflict(String),
#[error("Payload too large: {0}")]
PayloadTooLarge(String),
#[error("Unsupported media type: {0}")]
UnsupportedMediaType(String),
#[error("Rate limit exceeded")]
RateLimitExceeded,
#[error("Internal error: {0}")]
Internal(String),
#[error("Service unavailable: {0}")]
ServiceUnavailable(String),
#[error("Audio processing error: {0}")]
AudioProcessing(#[from] AudioError),
#[error("Embedding error: {0}")]
Embedding(#[from] EmbeddingError),
#[error("Vector index error: {0}")]
VectorIndex(#[from] VectorError),
#[error("Analysis error: {0}")]
Analysis(#[from] AnalysisError),
#[error("Interpretation error: {0}")]
Interpretation(#[from] InterpretationError),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
#[derive(Debug, Serialize, ToSchema)]
pub struct ErrorResponse {
#[schema(example = "not_found")]
pub error: String,
#[schema(example = "Recording with ID xyz not found")]
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
}
impl ApiError {
#[must_use]
pub fn status_code(&self) -> StatusCode {
match self {
Self::NotFound(_) => StatusCode::NOT_FOUND,
Self::BadRequest(_) => StatusCode::BAD_REQUEST,
Self::Unauthorized(_) => StatusCode::UNAUTHORIZED,
Self::Forbidden(_) => StatusCode::FORBIDDEN,
Self::Conflict(_) => StatusCode::CONFLICT,
Self::PayloadTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE,
Self::UnsupportedMediaType(_) => StatusCode::UNSUPPORTED_MEDIA_TYPE,
Self::RateLimitExceeded => StatusCode::TOO_MANY_REQUESTS,
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::ServiceUnavailable(_) => StatusCode::SERVICE_UNAVAILABLE,
Self::AudioProcessing(_) => StatusCode::UNPROCESSABLE_ENTITY,
Self::Embedding(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::VectorIndex(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::Analysis(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::Interpretation(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::Other(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
#[must_use]
pub fn error_type(&self) -> &'static str {
match self {
Self::NotFound(_) => "not_found",
Self::BadRequest(_) => "bad_request",
Self::Unauthorized(_) => "unauthorized",
Self::Forbidden(_) => "forbidden",
Self::Conflict(_) => "conflict",
Self::PayloadTooLarge(_) => "payload_too_large",
Self::UnsupportedMediaType(_) => "unsupported_media_type",
Self::RateLimitExceeded => "rate_limit_exceeded",
Self::Internal(_) => "internal_error",
Self::ServiceUnavailable(_) => "service_unavailable",
Self::AudioProcessing(_) => "audio_processing_error",
Self::Embedding(_) => "embedding_error",
Self::VectorIndex(_) => "vector_index_error",
Self::Analysis(_) => "analysis_error",
Self::Interpretation(_) => "interpretation_error",
Self::Other(_) => "internal_error",
}
}
#[must_use]
pub fn not_found<T: std::fmt::Display>(resource: &str, id: T) -> Self {
Self::NotFound(format!("{resource} with ID {id} not found"))
}
#[must_use]
pub fn validation<T: Serialize>(message: &str, details: T) -> Self {
Self::BadRequest(format!(
"{}: {}",
message,
serde_json::to_string(&details).unwrap_or_default()
))
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let status = self.status_code();
let error_type = self.error_type();
let message = self.to_string();
match &self {
Self::Internal(_)
| Self::Other(_)
| Self::Embedding(_)
| Self::VectorIndex(_)
| Self::Analysis(_)
| Self::Interpretation(_) => {
tracing::error!(error = %self, "Internal API error");
}
_ => {
tracing::debug!(error = %self, "API error response");
}
}
let body = ErrorResponse {
error: error_type.to_string(),
message,
details: None,
request_id: None,
};
(status, Json(body)).into_response()
}
}
pub type ApiResult<T> = Result<T, ApiError>;
pub trait ResultExt<T> {
fn api_context(self, context: &str) -> ApiResult<T>;
fn or_not_found(self, resource: &str, id: &str) -> ApiResult<T>;
}
impl<T, E: std::error::Error + Send + Sync + 'static> ResultExt<T> for Result<T, E> {
fn api_context(self, context: &str) -> ApiResult<T> {
self.map_err(|e| ApiError::Internal(format!("{context}: {e}")))
}
fn or_not_found(self, _resource: &str, _id: &str) -> ApiResult<T> {
self.map_err(|e| ApiError::Internal(e.to_string()))
}
}
impl<T> ResultExt<T> for Option<T> {
fn api_context(self, context: &str) -> ApiResult<T> {
self.ok_or_else(|| ApiError::Internal(format!("{context}: value was None")))
}
fn or_not_found(self, resource: &str, id: &str) -> ApiResult<T> {
self.ok_or_else(|| ApiError::not_found(resource, id))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_status_codes() {
assert_eq!(
ApiError::NotFound("test".into()).status_code(),
StatusCode::NOT_FOUND
);
assert_eq!(
ApiError::BadRequest("test".into()).status_code(),
StatusCode::BAD_REQUEST
);
assert_eq!(
ApiError::RateLimitExceeded.status_code(),
StatusCode::TOO_MANY_REQUESTS
);
}
#[test]
fn test_not_found_helper() {
let err = ApiError::not_found("Recording", "abc-123");
assert!(err.to_string().contains("Recording"));
assert!(err.to_string().contains("abc-123"));
}
#[test]
fn test_error_response_serialization() {
let response = ErrorResponse {
error: "not_found".into(),
message: "Resource not found".into(),
details: None,
request_id: Some("req-123".into()),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("not_found"));
assert!(json.contains("req-123"));
}
}