1use axum::{
4 http::StatusCode,
5 response::{IntoResponse, Response},
6 Json,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt;
10
11#[derive(Debug)]
13pub enum ApiError {
14 BadRequest(String),
16
17 Unauthorized(String),
19
20 Forbidden(String),
22
23 NotFound(String),
25
26 Conflict(String),
28
29 UnprocessableEntity(String),
31
32 TooManyRequests(String),
34
35 InternalServerError(String),
37
38 ServiceUnavailable(String),
40
41 ValidationError(Vec<ValidationError>),
43
44 DatabaseError(String),
46
47 DomainError(String),
49}
50
51impl fmt::Display for ApiError {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match self {
54 Self::BadRequest(msg) => write!(f, "Bad request: {}", msg),
55 Self::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
56 Self::Forbidden(msg) => write!(f, "Forbidden: {}", msg),
57 Self::NotFound(msg) => write!(f, "Not found: {}", msg),
58 Self::Conflict(msg) => write!(f, "Conflict: {}", msg),
59 Self::UnprocessableEntity(msg) => write!(f, "Unprocessable entity: {}", msg),
60 Self::TooManyRequests(msg) => write!(f, "Too many requests: {}", msg),
61 Self::InternalServerError(msg) => write!(f, "Internal server error: {}", msg),
62 Self::ServiceUnavailable(msg) => write!(f, "Service unavailable: {}", msg),
63 Self::ValidationError(errors) => write!(f, "Validation error: {:?}", errors),
64 Self::DatabaseError(msg) => write!(f, "Database error: {}", msg),
65 Self::DomainError(msg) => write!(f, "Domain error: {}", msg),
66 }
67 }
68}
69
70impl std::error::Error for ApiError {}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ValidationError {
75 pub field: String,
76 pub message: String,
77 pub code: Option<String>,
78}
79
80impl ValidationError {
81 pub fn new(field: impl Into<String>, message: impl Into<String>) -> Self {
82 Self {
83 field: field.into(),
84 message: message.into(),
85 code: None,
86 }
87 }
88
89 pub fn with_code(mut self, code: impl Into<String>) -> Self {
90 self.code = Some(code.into());
91 self
92 }
93}
94
95#[derive(Debug, Serialize, Deserialize)]
97pub struct ErrorResponse {
98 pub error: ErrorDetail,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 pub request_id: Option<String>,
101}
102
103#[derive(Debug, Serialize, Deserialize)]
104pub struct ErrorDetail {
105 pub code: String,
106 pub message: String,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 pub details: Option<serde_json::Value>,
109}
110
111impl ApiError {
112 pub fn status_code(&self) -> StatusCode {
114 match self {
115 Self::BadRequest(_) => StatusCode::BAD_REQUEST,
116 Self::Unauthorized(_) => StatusCode::UNAUTHORIZED,
117 Self::Forbidden(_) => StatusCode::FORBIDDEN,
118 Self::NotFound(_) => StatusCode::NOT_FOUND,
119 Self::Conflict(_) => StatusCode::CONFLICT,
120 Self::UnprocessableEntity(_) | Self::ValidationError(_) => {
121 StatusCode::UNPROCESSABLE_ENTITY
122 }
123 Self::TooManyRequests(_) => StatusCode::TOO_MANY_REQUESTS,
124 Self::InternalServerError(_) | Self::DatabaseError(_) | Self::DomainError(_) => {
125 StatusCode::INTERNAL_SERVER_ERROR
126 }
127 Self::ServiceUnavailable(_) => StatusCode::SERVICE_UNAVAILABLE,
128 }
129 }
130
131 pub fn error_code(&self) -> String {
133 match self {
134 Self::BadRequest(_) => "BAD_REQUEST",
135 Self::Unauthorized(_) => "UNAUTHORIZED",
136 Self::Forbidden(_) => "FORBIDDEN",
137 Self::NotFound(_) => "NOT_FOUND",
138 Self::Conflict(_) => "CONFLICT",
139 Self::UnprocessableEntity(_) => "UNPROCESSABLE_ENTITY",
140 Self::ValidationError(_) => "VALIDATION_ERROR",
141 Self::TooManyRequests(_) => "TOO_MANY_REQUESTS",
142 Self::InternalServerError(_) => "INTERNAL_SERVER_ERROR",
143 Self::ServiceUnavailable(_) => "SERVICE_UNAVAILABLE",
144 Self::DatabaseError(_) => "DATABASE_ERROR",
145 Self::DomainError(_) => "DOMAIN_ERROR",
146 }
147 .to_string()
148 }
149
150 pub fn to_response(&self) -> ErrorResponse {
152 let details = match self {
153 Self::ValidationError(errors) => Some(serde_json::to_value(errors).unwrap()),
154 _ => None,
155 };
156
157 ErrorResponse {
158 error: ErrorDetail {
159 code: self.error_code(),
160 message: self.to_string(),
161 details,
162 },
163 request_id: None,
164 }
165 }
166}
167
168impl IntoResponse for ApiError {
169 fn into_response(self) -> Response {
170 let status = self.status_code();
171 let body = Json(self.to_response());
172 (status, body).into_response()
173 }
174}
175
176impl From<llm_cost_ops::CostOpsError> for ApiError {
178 fn from(err: llm_cost_ops::CostOpsError) -> Self {
179 Self::DomainError(err.to_string())
180 }
181}
182
183impl From<sqlx::Error> for ApiError {
184 fn from(err: sqlx::Error) -> Self {
185 match err {
186 sqlx::Error::RowNotFound => Self::NotFound("Resource not found".to_string()),
187 _ => Self::DatabaseError(err.to_string()),
188 }
189 }
190}
191
192impl From<serde_json::Error> for ApiError {
193 fn from(err: serde_json::Error) -> Self {
194 Self::BadRequest(format!("Invalid JSON: {}", err))
195 }
196}
197
198impl From<validator::ValidationErrors> for ApiError {
199 fn from(errors: validator::ValidationErrors) -> Self {
200 let validation_errors: Vec<ValidationError> = errors
201 .field_errors()
202 .iter()
203 .flat_map(|(field, errors)| {
204 errors.iter().map(move |error| {
205 ValidationError::new(
206 field.to_string(),
207 error.message.clone().unwrap_or_default().to_string(),
208 )
209 .with_code(error.code.to_string())
210 })
211 })
212 .collect();
213
214 Self::ValidationError(validation_errors)
215 }
216}
217
218pub type ApiResult<T> = Result<T, ApiError>;
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_api_error_status_codes() {
226 assert_eq!(
227 ApiError::BadRequest("test".to_string()).status_code(),
228 StatusCode::BAD_REQUEST
229 );
230 assert_eq!(
231 ApiError::Unauthorized("test".to_string()).status_code(),
232 StatusCode::UNAUTHORIZED
233 );
234 assert_eq!(
235 ApiError::NotFound("test".to_string()).status_code(),
236 StatusCode::NOT_FOUND
237 );
238 assert_eq!(
239 ApiError::InternalServerError("test".to_string()).status_code(),
240 StatusCode::INTERNAL_SERVER_ERROR
241 );
242 }
243
244 #[test]
245 fn test_validation_error_creation() {
246 let error = ValidationError::new("email", "Invalid email format")
247 .with_code("INVALID_EMAIL");
248
249 assert_eq!(error.field, "email");
250 assert_eq!(error.message, "Invalid email format");
251 assert_eq!(error.code, Some("INVALID_EMAIL".to_string()));
252 }
253
254 #[test]
255 fn test_error_response_serialization() {
256 let api_error = ApiError::BadRequest("Invalid request".to_string());
257 let response = api_error.to_response();
258
259 assert_eq!(response.error.code, "BAD_REQUEST");
260 assert!(response.error.message.contains("Invalid request"));
261 }
262}