1use axum::response::{IntoResponse, Response};
4use http::StatusCode;
5use std::fmt;
6
7pub type Result<T> = std::result::Result<T, Error>;
9
10pub struct Error {
39 status: StatusCode,
40 message: String,
41 source: Option<Box<dyn std::error::Error + Send + Sync>>,
42 error_code: Option<&'static str>,
43 locale_key: Option<&'static str>,
44 details: Option<serde_json::Value>,
45 lagged: bool,
46}
47
48impl Error {
49 pub fn new(status: StatusCode, message: impl Into<String>) -> Self {
51 Self {
52 status,
53 message: message.into(),
54 source: None,
55 error_code: None,
56 locale_key: None,
57 details: None,
58 lagged: false,
59 }
60 }
61
62 pub fn with_source(
66 status: StatusCode,
67 message: impl Into<String>,
68 source: impl std::error::Error + Send + Sync + 'static,
69 ) -> Self {
70 Self {
71 status,
72 message: message.into(),
73 source: Some(Box::new(source)),
74 error_code: None,
75 locale_key: None,
76 details: None,
77 lagged: false,
78 }
79 }
80
81 pub fn localized(status: StatusCode, key: &'static str) -> Self {
113 Self {
114 status,
115 message: key.to_string(),
116 source: None,
117 error_code: None,
118 locale_key: Some(key),
119 details: None,
120 lagged: false,
121 }
122 }
123
124 pub fn status(&self) -> StatusCode {
126 self.status
127 }
128
129 pub fn message(&self) -> &str {
131 &self.message
132 }
133
134 pub fn details(&self) -> Option<&serde_json::Value> {
136 self.details.as_ref()
137 }
138
139 pub fn with_details(mut self, details: serde_json::Value) -> Self {
141 self.details = Some(details);
142 self
143 }
144
145 pub fn chain(mut self, source: impl std::error::Error + Send + Sync + 'static) -> Self {
147 self.source = Some(Box::new(source));
148 self
149 }
150
151 pub fn with_code(mut self, code: &'static str) -> Self {
160 self.error_code = Some(code);
161 self
162 }
163
164 pub fn error_code(&self) -> Option<&'static str> {
166 self.error_code
167 }
168
169 pub fn with_locale_key(mut self, key: &'static str) -> Self {
183 self.locale_key = Some(key);
184 self
185 }
186
187 pub fn locale_key(&self) -> Option<&'static str> {
190 self.locale_key
191 }
192
193 pub fn source_as<T: std::error::Error + 'static>(&self) -> Option<&T> {
197 self.source.as_ref()?.downcast_ref::<T>()
198 }
199
200 pub fn bad_request(msg: impl Into<String>) -> Self {
202 Self::new(StatusCode::BAD_REQUEST, msg)
203 }
204
205 pub fn unauthorized(msg: impl Into<String>) -> Self {
207 Self::new(StatusCode::UNAUTHORIZED, msg)
208 }
209
210 pub fn forbidden(msg: impl Into<String>) -> Self {
212 Self::new(StatusCode::FORBIDDEN, msg)
213 }
214
215 pub fn not_found(msg: impl Into<String>) -> Self {
217 Self::new(StatusCode::NOT_FOUND, msg)
218 }
219
220 pub fn conflict(msg: impl Into<String>) -> Self {
222 Self::new(StatusCode::CONFLICT, msg)
223 }
224
225 pub fn payload_too_large(msg: impl Into<String>) -> Self {
227 Self::new(StatusCode::PAYLOAD_TOO_LARGE, msg)
228 }
229
230 pub fn unprocessable_entity(msg: impl Into<String>) -> Self {
232 Self::new(StatusCode::UNPROCESSABLE_ENTITY, msg)
233 }
234
235 pub fn too_many_requests(msg: impl Into<String>) -> Self {
237 Self::new(StatusCode::TOO_MANY_REQUESTS, msg)
238 }
239
240 pub fn internal(msg: impl Into<String>) -> Self {
242 Self::new(StatusCode::INTERNAL_SERVER_ERROR, msg)
243 }
244
245 pub fn bad_gateway(msg: impl Into<String>) -> Self {
247 Self::new(StatusCode::BAD_GATEWAY, msg)
248 }
249
250 pub fn gateway_timeout(msg: impl Into<String>) -> Self {
252 Self::new(StatusCode::GATEWAY_TIMEOUT, msg)
253 }
254
255 pub fn lagged(skipped: u64) -> Self {
260 Self {
261 status: StatusCode::INTERNAL_SERVER_ERROR,
262 message: format!("SSE subscriber lagged, skipped {skipped} messages"),
263 source: None,
264 error_code: None,
265 locale_key: None,
266 details: None,
267 lagged: true,
268 }
269 }
270
271 pub fn is_lagged(&self) -> bool {
273 self.lagged
274 }
275}
276
277impl Clone for Error {
282 fn clone(&self) -> Self {
283 Self {
284 status: self.status,
285 message: self.message.clone(),
286 source: None, error_code: self.error_code,
288 locale_key: self.locale_key,
289 details: self.details.clone(),
290 lagged: self.lagged,
291 }
292 }
293}
294
295impl fmt::Display for Error {
296 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297 write!(f, "{}", self.message)
298 }
299}
300
301impl fmt::Debug for Error {
302 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 f.debug_struct("Error")
304 .field("status", &self.status)
305 .field("message", &self.message)
306 .field("source", &self.source)
307 .field("error_code", &self.error_code)
308 .field("locale_key", &self.locale_key)
309 .field("details", &self.details)
310 .field("lagged", &self.lagged)
311 .finish()
312 }
313}
314
315impl std::error::Error for Error {
316 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
317 self.source
318 .as_ref()
319 .map(|e| e.as_ref() as &(dyn std::error::Error + 'static))
320 }
321}
322
323pub(crate) fn render_error_body(
330 status: StatusCode,
331 message: &str,
332 details: Option<&serde_json::Value>,
333) -> serde_json::Value {
334 let mut body = serde_json::json!({
335 "error": {
336 "status": status.as_u16(),
337 "message": message,
338 }
339 });
340 if let Some(d) = details {
341 body["error"]["details"] = d.clone();
342 }
343 body
344}
345
346impl IntoResponse for Error {
359 fn into_response(self) -> Response {
360 let status = self.status;
361 let message = self.message.clone();
362 let details = self.details.clone();
363
364 let body = render_error_body(status, &message, details.as_ref());
365
366 let ext_error = Error {
368 status,
369 message,
370 source: None, error_code: self.error_code,
372 locale_key: self.locale_key,
373 details,
374 lagged: self.lagged,
375 };
376
377 let mut response = (status, axum::Json(body)).into_response();
378 response.extensions_mut().insert(ext_error);
379 response
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn lagged_error_has_internal_status() {
389 let err = Error::lagged(5);
390 assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
391 assert!(err.message().contains('5'));
392 }
393
394 #[test]
395 fn is_lagged_returns_true_for_lagged_error() {
396 let err = Error::lagged(10);
397 assert!(err.is_lagged());
398 }
399
400 #[test]
401 fn is_lagged_returns_false_for_other_errors() {
402 let err = Error::internal("something else");
403 assert!(!err.is_lagged());
404 }
405
406 #[test]
407 fn payload_too_large_error_has_413_status() {
408 let err = Error::payload_too_large("file too big");
409 assert_eq!(err.status(), StatusCode::PAYLOAD_TOO_LARGE);
410 assert_eq!(err.message(), "file too big");
411 }
412
413 #[test]
414 fn chain_sets_source() {
415 use std::error::Error as _;
416 use std::io;
417 let err = super::Error::internal("failed").chain(io::Error::other("disk"));
418 assert!(err.source().is_some());
419 }
420
421 #[test]
422 fn source_as_downcasts_correctly() {
423 use std::io;
424 let io_err = io::Error::new(io::ErrorKind::NotFound, "missing");
425 let err = Error::internal("failed").chain(io_err);
426 let downcasted = err.source_as::<io::Error>();
427 assert!(downcasted.is_some());
428 assert_eq!(downcasted.unwrap().kind(), io::ErrorKind::NotFound);
429 }
430
431 #[test]
432 fn source_as_returns_none_for_wrong_type() {
433 use std::io;
434 let err = Error::internal("failed").chain(io::Error::other("x"));
435 let downcasted = err.source_as::<std::num::ParseIntError>();
436 assert!(downcasted.is_none());
437 }
438
439 #[test]
440 fn source_as_returns_none_when_no_source() {
441 let err = Error::internal("no source");
442 let downcasted = err.source_as::<std::io::Error>();
443 assert!(downcasted.is_none());
444 }
445
446 #[test]
447 fn with_code_sets_error_code() {
448 let err = Error::unauthorized("denied").with_code("jwt:expired");
449 assert_eq!(err.error_code(), Some("jwt:expired"));
450 }
451
452 #[test]
453 fn error_code_is_none_by_default() {
454 let err = Error::internal("plain");
455 assert!(err.error_code().is_none());
456 }
457
458 #[test]
459 fn error_code_survives_clone() {
460 let err = Error::unauthorized("denied").with_code("jwt:expired");
461 let cloned = err.clone();
462 assert_eq!(cloned.error_code(), Some("jwt:expired"));
463 }
464
465 #[test]
466 fn error_code_survives_into_response() {
467 use axum::response::IntoResponse;
468 let err = Error::unauthorized("denied").with_code("jwt:expired");
469 let response = err.into_response();
470 let ext_err = response.extensions().get::<Error>().unwrap();
471 assert_eq!(ext_err.error_code(), Some("jwt:expired"));
472 }
473
474 #[test]
475 fn bad_gateway_error_has_502_status() {
476 let err = Error::bad_gateway("upstream failed");
477 assert_eq!(err.status(), StatusCode::BAD_GATEWAY);
478 assert_eq!(err.message(), "upstream failed");
479 }
480
481 #[test]
482 fn gateway_timeout_error_has_504_status() {
483 let err = Error::gateway_timeout("timed out");
484 assert_eq!(err.status(), StatusCode::GATEWAY_TIMEOUT);
485 assert_eq!(err.message(), "timed out");
486 }
487
488 #[test]
489 fn localized_sets_key_and_falls_back_to_key_as_message() {
490 let err = Error::localized(StatusCode::NOT_FOUND, "errors.user.not_found");
491 assert_eq!(err.status(), StatusCode::NOT_FOUND);
492 assert_eq!(err.locale_key(), Some("errors.user.not_found"));
493 assert_eq!(err.message(), "errors.user.not_found");
496 assert!(err.error_code().is_none());
497 assert!(err.details().is_none());
498 }
499
500 #[test]
501 fn with_locale_key_tags_existing_error() {
502 let err = Error::bad_request("boom").with_locale_key("errors.validation.generic");
503 assert_eq!(err.message(), "boom");
505 assert_eq!(err.locale_key(), Some("errors.validation.generic"));
506 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
507 }
508
509 #[test]
510 fn clone_preserves_locale_key() {
511 let err = Error::localized(StatusCode::CONFLICT, "errors.email.in_use");
512 let cloned = err.clone();
513 assert_eq!(cloned.locale_key(), Some("errors.email.in_use"));
514 assert_eq!(cloned.status(), StatusCode::CONFLICT);
515 assert_eq!(cloned.message(), "errors.email.in_use");
516 }
517
518 #[test]
519 fn response_extensions_clone_preserves_locale_key() {
520 use axum::response::IntoResponse;
521 let err = Error::localized(StatusCode::UNAUTHORIZED, "errors.auth.expired");
522 let response = err.into_response();
523 let ext_err = response.extensions().get::<Error>().unwrap();
524 assert_eq!(ext_err.locale_key(), Some("errors.auth.expired"));
525 assert_eq!(ext_err.status(), StatusCode::UNAUTHORIZED);
526 }
527
528 #[test]
529 fn render_error_body_without_details() {
530 let body = render_error_body(StatusCode::NOT_FOUND, "user not found", None);
531 assert_eq!(
532 body,
533 serde_json::json!({
534 "error": {
535 "status": 404,
536 "message": "user not found",
537 }
538 })
539 );
540 }
541
542 #[test]
543 fn render_error_body_with_details() {
544 let details = serde_json::json!({"field": "email"});
545 let body = render_error_body(StatusCode::UNPROCESSABLE_ENTITY, "invalid", Some(&details));
546 assert_eq!(
547 body,
548 serde_json::json!({
549 "error": {
550 "status": 422,
551 "message": "invalid",
552 "details": {"field": "email"},
553 }
554 })
555 );
556 }
557}