use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, thiserror::Error)]
pub enum ApiError {
#[error("Invalid request: {0}")]
InvalidRequest(String),
#[error("Unauthorized: {0}")]
Unauthorized(String),
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Rate limit exceeded: {0}")]
RateLimitExceeded(String),
#[error("Internal server error: {0}")]
InternalError(String),
#[error("Service unavailable: {0}")]
ServiceUnavailable(String),
#[error("Validation error: {0}")]
ValidationError(String),
#[error("Scanner error: {0}")]
ScannerError(String),
#[error("Model error: {0}")]
ModelError(String),
#[error("Configuration error: {0}")]
ConfigError(String),
}
impl ApiError {
pub fn status_code(&self) -> StatusCode {
match self {
ApiError::InvalidRequest(_) => StatusCode::BAD_REQUEST,
ApiError::Unauthorized(_) => StatusCode::UNAUTHORIZED,
ApiError::Forbidden(_) => StatusCode::FORBIDDEN,
ApiError::NotFound(_) => StatusCode::NOT_FOUND,
ApiError::RateLimitExceeded(_) => StatusCode::TOO_MANY_REQUESTS,
ApiError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR,
ApiError::ServiceUnavailable(_) => StatusCode::SERVICE_UNAVAILABLE,
ApiError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
ApiError::ScannerError(_) => StatusCode::INTERNAL_SERVER_ERROR,
ApiError::ModelError(_) => StatusCode::INTERNAL_SERVER_ERROR,
ApiError::ConfigError(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
pub fn error_code(&self) -> &'static str {
match self {
ApiError::InvalidRequest(_) => "INVALID_REQUEST",
ApiError::Unauthorized(_) => "UNAUTHORIZED",
ApiError::Forbidden(_) => "FORBIDDEN",
ApiError::NotFound(_) => "NOT_FOUND",
ApiError::RateLimitExceeded(_) => "RATE_LIMIT_EXCEEDED",
ApiError::InternalError(_) => "INTERNAL_ERROR",
ApiError::ServiceUnavailable(_) => "SERVICE_UNAVAILABLE",
ApiError::ValidationError(_) => "VALIDATION_ERROR",
ApiError::ScannerError(_) => "SCANNER_ERROR",
ApiError::ModelError(_) => "MODEL_ERROR",
ApiError::ConfigError(_) => "CONFIG_ERROR",
}
}
pub fn to_error_response(&self) -> ErrorResponse {
ErrorResponse {
code: self.error_code().to_string(),
message: self.to_string(),
details: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
pub code: String,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<serde_json::Value>,
}
impl fmt::Display for ErrorResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.code, self.message)
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let status = self.status_code();
let error_response = self.to_error_response();
let body = Json(serde_json::json!({
"success": false,
"error": error_response,
}));
(status, body).into_response()
}
}
impl From<llm_shield_core::Error> for ApiError {
fn from(err: llm_shield_core::Error) -> Self {
ApiError::ScannerError(err.to_string())
}
}
impl From<config::ConfigError> for ApiError {
fn from(err: config::ConfigError) -> Self {
ApiError::ConfigError(err.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_status_codes() {
assert_eq!(
ApiError::InvalidRequest("test".to_string()).status_code(),
StatusCode::BAD_REQUEST
);
assert_eq!(
ApiError::Unauthorized("test".to_string()).status_code(),
StatusCode::UNAUTHORIZED
);
assert_eq!(
ApiError::NotFound("test".to_string()).status_code(),
StatusCode::NOT_FOUND
);
assert_eq!(
ApiError::RateLimitExceeded("test".to_string()).status_code(),
StatusCode::TOO_MANY_REQUESTS
);
assert_eq!(
ApiError::InternalError("test".to_string()).status_code(),
StatusCode::INTERNAL_SERVER_ERROR
);
}
#[test]
fn test_error_codes() {
assert_eq!(
ApiError::InvalidRequest("test".to_string()).error_code(),
"INVALID_REQUEST"
);
assert_eq!(
ApiError::Unauthorized("test".to_string()).error_code(),
"UNAUTHORIZED"
);
assert_eq!(
ApiError::RateLimitExceeded("test".to_string()).error_code(),
"RATE_LIMIT_EXCEEDED"
);
}
#[test]
fn test_error_response_creation() {
let error = ApiError::InvalidRequest("Missing field".to_string());
let response = error.to_error_response();
assert_eq!(response.code, "INVALID_REQUEST");
assert!(response.message.contains("Missing field"));
assert!(response.details.is_none());
}
#[test]
fn test_error_response_serialization() {
let error_response = ErrorResponse {
code: "TEST_ERROR".to_string(),
message: "Test message".to_string(),
details: None,
};
let json = serde_json::to_string(&error_response).unwrap();
assert!(json.contains("TEST_ERROR"));
assert!(json.contains("Test message"));
}
#[test]
fn test_error_display() {
let error = ApiError::InvalidRequest("test error".to_string());
let display = format!("{}", error);
assert!(display.contains("Invalid request"));
assert!(display.contains("test error"));
}
}