Skip to main content

modo/error/
core.rs

1//! Core [`Error`] type and [`Result`] alias.
2
3use axum::response::{IntoResponse, Response};
4use http::StatusCode;
5use std::fmt;
6
7/// A type alias for `std::result::Result<T, Error>`.
8pub type Result<T> = std::result::Result<T, Error>;
9
10/// The primary error type for the modo framework.
11///
12/// `Error` carries:
13/// - an HTTP [`StatusCode`] that will be used as the response status,
14/// - a human-readable `message` string,
15/// - an optional structured `details` payload (arbitrary JSON),
16/// - an optional boxed `source` error for causal chaining,
17/// - an optional static `error_code` string that survives the response pipeline,
18/// - a `lagged` flag used by the SSE broadcaster to signal that a subscriber dropped messages.
19///
20/// # Conversion to HTTP response
21///
22/// Calling `into_response()` produces a JSON body:
23///
24/// ```json
25/// { "error": { "status": 404, "message": "user not found" } }
26/// ```
27///
28/// If [`with_details`](Error::with_details) was called, a `"details"` field is included.
29/// A copy of the error (without `source`) is also stored in response extensions so middleware
30/// can inspect it after the fact.
31///
32/// # Clone behaviour
33///
34/// Cloning an `Error` drops the `source` field because `Box<dyn Error>` is not `Clone`.
35/// The `error_code`, `details`, and all other fields are preserved.
36pub struct Error {
37    status: StatusCode,
38    message: String,
39    source: Option<Box<dyn std::error::Error + Send + Sync>>,
40    error_code: Option<&'static str>,
41    details: Option<serde_json::Value>,
42    lagged: bool,
43}
44
45impl Error {
46    /// Create a new error with the given HTTP status code and message.
47    pub fn new(status: StatusCode, message: impl Into<String>) -> Self {
48        Self {
49            status,
50            message: message.into(),
51            source: None,
52            error_code: None,
53            details: None,
54            lagged: false,
55        }
56    }
57
58    /// Create a new error with a status code, message, and a boxed source error.
59    ///
60    /// Use [`chain`](Error::chain) instead when constructing errors with the builder pattern.
61    pub fn with_source(
62        status: StatusCode,
63        message: impl Into<String>,
64        source: impl std::error::Error + Send + Sync + 'static,
65    ) -> Self {
66        Self {
67            status,
68            message: message.into(),
69            source: Some(Box::new(source)),
70            error_code: None,
71            details: None,
72            lagged: false,
73        }
74    }
75
76    /// Returns the HTTP status code of this error.
77    pub fn status(&self) -> StatusCode {
78        self.status
79    }
80
81    /// Returns the human-readable error message.
82    pub fn message(&self) -> &str {
83        &self.message
84    }
85
86    /// Returns the optional structured details payload.
87    pub fn details(&self) -> Option<&serde_json::Value> {
88        self.details.as_ref()
89    }
90
91    /// Attach a structured JSON details payload (builder-style).
92    pub fn with_details(mut self, details: serde_json::Value) -> Self {
93        self.details = Some(details);
94        self
95    }
96
97    /// Attach a source error (builder-style).
98    pub fn chain(mut self, source: impl std::error::Error + Send + Sync + 'static) -> Self {
99        self.source = Some(Box::new(source));
100        self
101    }
102
103    /// Attach a static error code to preserve error identity through the response pipeline.
104    ///
105    /// The error code is included in the copy stored in response extensions and can be retrieved
106    /// after `into_response()` via [`Error::error_code`].
107    pub fn with_code(mut self, code: &'static str) -> Self {
108        self.error_code = Some(code);
109        self
110    }
111
112    /// Returns the error code, if one was set.
113    pub fn error_code(&self) -> Option<&str> {
114        self.error_code
115    }
116
117    /// Downcast the source error to a concrete type.
118    ///
119    /// Returns `None` if no source is set or if the source is not of type `T`.
120    pub fn source_as<T: std::error::Error + 'static>(&self) -> Option<&T> {
121        self.source.as_ref()?.downcast_ref::<T>()
122    }
123
124    /// Create a `400 Bad Request` error.
125    pub fn bad_request(msg: impl Into<String>) -> Self {
126        Self::new(StatusCode::BAD_REQUEST, msg)
127    }
128
129    /// Create a `401 Unauthorized` error.
130    pub fn unauthorized(msg: impl Into<String>) -> Self {
131        Self::new(StatusCode::UNAUTHORIZED, msg)
132    }
133
134    /// Create a `403 Forbidden` error.
135    pub fn forbidden(msg: impl Into<String>) -> Self {
136        Self::new(StatusCode::FORBIDDEN, msg)
137    }
138
139    /// Create a `404 Not Found` error.
140    pub fn not_found(msg: impl Into<String>) -> Self {
141        Self::new(StatusCode::NOT_FOUND, msg)
142    }
143
144    /// Create a `409 Conflict` error.
145    pub fn conflict(msg: impl Into<String>) -> Self {
146        Self::new(StatusCode::CONFLICT, msg)
147    }
148
149    /// Create a `413 Payload Too Large` error.
150    pub fn payload_too_large(msg: impl Into<String>) -> Self {
151        Self::new(StatusCode::PAYLOAD_TOO_LARGE, msg)
152    }
153
154    /// Create a `422 Unprocessable Entity` error.
155    pub fn unprocessable_entity(msg: impl Into<String>) -> Self {
156        Self::new(StatusCode::UNPROCESSABLE_ENTITY, msg)
157    }
158
159    /// Create a `429 Too Many Requests` error.
160    pub fn too_many_requests(msg: impl Into<String>) -> Self {
161        Self::new(StatusCode::TOO_MANY_REQUESTS, msg)
162    }
163
164    /// Create a `500 Internal Server Error`.
165    pub fn internal(msg: impl Into<String>) -> Self {
166        Self::new(StatusCode::INTERNAL_SERVER_ERROR, msg)
167    }
168
169    /// Create a `502 Bad Gateway` error.
170    pub fn bad_gateway(msg: impl Into<String>) -> Self {
171        Self::new(StatusCode::BAD_GATEWAY, msg)
172    }
173
174    /// Create a `504 Gateway Timeout` error.
175    pub fn gateway_timeout(msg: impl Into<String>) -> Self {
176        Self::new(StatusCode::GATEWAY_TIMEOUT, msg)
177    }
178
179    /// Create an error indicating a broadcast subscriber lagged behind.
180    ///
181    /// The resulting error has a `500 Internal Server Error` status and [`is_lagged`](Error::is_lagged)
182    /// returns `true`. `skipped` is the number of messages that were dropped.
183    pub fn lagged(skipped: u64) -> Self {
184        Self {
185            status: StatusCode::INTERNAL_SERVER_ERROR,
186            message: format!("SSE subscriber lagged, skipped {skipped} messages"),
187            source: None,
188            error_code: None,
189            details: None,
190            lagged: true,
191        }
192    }
193
194    /// Returns `true` if this error represents a broadcast lag.
195    pub fn is_lagged(&self) -> bool {
196        self.lagged
197    }
198}
199
200/// Clones the error, dropping the `source` field (which is not `Clone`).
201///
202/// All other fields — `status`, `message`, `error_code`, `details`, and `lagged` — are preserved.
203impl Clone for Error {
204    fn clone(&self) -> Self {
205        Self {
206            status: self.status,
207            message: self.message.clone(),
208            source: None, // source (Box<dyn Error>) can't be cloned
209            error_code: self.error_code,
210            details: self.details.clone(),
211            lagged: self.lagged,
212        }
213    }
214}
215
216impl fmt::Display for Error {
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        write!(f, "{}", self.message)
219    }
220}
221
222impl fmt::Debug for Error {
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        f.debug_struct("Error")
225            .field("status", &self.status)
226            .field("message", &self.message)
227            .field("source", &self.source)
228            .field("error_code", &self.error_code)
229            .field("details", &self.details)
230            .field("lagged", &self.lagged)
231            .finish()
232    }
233}
234
235impl std::error::Error for Error {
236    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
237        self.source
238            .as_ref()
239            .map(|e| e.as_ref() as &(dyn std::error::Error + 'static))
240    }
241}
242
243/// Converts `Error` into an axum [`Response`].
244///
245/// Produces a JSON body of the form:
246///
247/// ```json
248/// { "error": { "status": 422, "message": "validation failed" } }
249/// ```
250///
251/// If [`with_details`](Error::with_details) was called, a `"details"` key is added under `"error"`.
252///
253/// A copy of the error (without the `source` field) is stored in response extensions under
254/// the type `Error` so that downstream middleware can inspect it.
255impl IntoResponse for Error {
256    fn into_response(self) -> Response {
257        let status = self.status;
258        let message = self.message.clone();
259        let details = self.details.clone();
260
261        let mut body = serde_json::json!({
262            "error": {
263                "status": status.as_u16(),
264                "message": &message
265            }
266        });
267        if let Some(ref d) = details {
268            body["error"]["details"] = d.clone();
269        }
270
271        // Store a copy in extensions so error_handler middleware can read it
272        let ext_error = Error {
273            status,
274            message,
275            source: None, // source can't be cloned
276            error_code: self.error_code,
277            details,
278            lagged: self.lagged,
279        };
280
281        let mut response = (status, axum::Json(body)).into_response();
282        response.extensions_mut().insert(ext_error);
283        response
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn lagged_error_has_internal_status() {
293        let err = Error::lagged(5);
294        assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
295        assert!(err.message().contains('5'));
296    }
297
298    #[test]
299    fn is_lagged_returns_true_for_lagged_error() {
300        let err = Error::lagged(10);
301        assert!(err.is_lagged());
302    }
303
304    #[test]
305    fn is_lagged_returns_false_for_other_errors() {
306        let err = Error::internal("something else");
307        assert!(!err.is_lagged());
308    }
309
310    #[test]
311    fn payload_too_large_error_has_413_status() {
312        let err = Error::payload_too_large("file too big");
313        assert_eq!(err.status(), StatusCode::PAYLOAD_TOO_LARGE);
314        assert_eq!(err.message(), "file too big");
315    }
316
317    #[test]
318    fn chain_sets_source() {
319        use std::error::Error as _;
320        use std::io;
321        let err = super::Error::internal("failed").chain(io::Error::other("disk"));
322        assert!(err.source().is_some());
323    }
324
325    #[test]
326    fn source_as_downcasts_correctly() {
327        use std::io;
328        let io_err = io::Error::new(io::ErrorKind::NotFound, "missing");
329        let err = Error::internal("failed").chain(io_err);
330        let downcasted = err.source_as::<io::Error>();
331        assert!(downcasted.is_some());
332        assert_eq!(downcasted.unwrap().kind(), io::ErrorKind::NotFound);
333    }
334
335    #[test]
336    fn source_as_returns_none_for_wrong_type() {
337        use std::io;
338        let err = Error::internal("failed").chain(io::Error::other("x"));
339        let downcasted = err.source_as::<std::num::ParseIntError>();
340        assert!(downcasted.is_none());
341    }
342
343    #[test]
344    fn source_as_returns_none_when_no_source() {
345        let err = Error::internal("no source");
346        let downcasted = err.source_as::<std::io::Error>();
347        assert!(downcasted.is_none());
348    }
349
350    #[test]
351    fn with_code_sets_error_code() {
352        let err = Error::unauthorized("denied").with_code("jwt:expired");
353        assert_eq!(err.error_code(), Some("jwt:expired"));
354    }
355
356    #[test]
357    fn error_code_is_none_by_default() {
358        let err = Error::internal("plain");
359        assert!(err.error_code().is_none());
360    }
361
362    #[test]
363    fn error_code_survives_clone() {
364        let err = Error::unauthorized("denied").with_code("jwt:expired");
365        let cloned = err.clone();
366        assert_eq!(cloned.error_code(), Some("jwt:expired"));
367    }
368
369    #[test]
370    fn error_code_survives_into_response() {
371        use axum::response::IntoResponse;
372        let err = Error::unauthorized("denied").with_code("jwt:expired");
373        let response = err.into_response();
374        let ext_err = response.extensions().get::<Error>().unwrap();
375        assert_eq!(ext_err.error_code(), Some("jwt:expired"));
376    }
377
378    #[test]
379    fn bad_gateway_error_has_502_status() {
380        let err = Error::bad_gateway("upstream failed");
381        assert_eq!(err.status(), StatusCode::BAD_GATEWAY);
382        assert_eq!(err.message(), "upstream failed");
383    }
384
385    #[test]
386    fn gateway_timeout_error_has_504_status() {
387        let err = Error::gateway_timeout("timed out");
388        assert_eq!(err.status(), StatusCode::GATEWAY_TIMEOUT);
389        assert_eq!(err.message(), "timed out");
390    }
391}