1use axum::response::{IntoResponse, Response};
4use http::StatusCode;
5use std::fmt;
6
7pub type Result<T> = std::result::Result<T, Error>;
9
10pub struct Error {
37 status: StatusCode,
38 message: String,
39 source: Option<Box<dyn std::error::Error + Send + Sync>>,
40 error_code: Option<&'static str>,
41 details: Option<serde_json::Value>,
42 lagged: bool,
43}
44
45impl Error {
46 pub fn new(status: StatusCode, message: impl Into<String>) -> Self {
48 Self {
49 status,
50 message: message.into(),
51 source: None,
52 error_code: None,
53 details: None,
54 lagged: false,
55 }
56 }
57
58 pub fn with_source(
62 status: StatusCode,
63 message: impl Into<String>,
64 source: impl std::error::Error + Send + Sync + 'static,
65 ) -> Self {
66 Self {
67 status,
68 message: message.into(),
69 source: Some(Box::new(source)),
70 error_code: None,
71 details: None,
72 lagged: false,
73 }
74 }
75
76 pub fn status(&self) -> StatusCode {
78 self.status
79 }
80
81 pub fn message(&self) -> &str {
83 &self.message
84 }
85
86 pub fn details(&self) -> Option<&serde_json::Value> {
88 self.details.as_ref()
89 }
90
91 pub fn with_details(mut self, details: serde_json::Value) -> Self {
93 self.details = Some(details);
94 self
95 }
96
97 pub fn chain(mut self, source: impl std::error::Error + Send + Sync + 'static) -> Self {
99 self.source = Some(Box::new(source));
100 self
101 }
102
103 pub fn with_code(mut self, code: &'static str) -> Self {
108 self.error_code = Some(code);
109 self
110 }
111
112 pub fn error_code(&self) -> Option<&str> {
114 self.error_code
115 }
116
117 pub fn source_as<T: std::error::Error + 'static>(&self) -> Option<&T> {
121 self.source.as_ref()?.downcast_ref::<T>()
122 }
123
124 pub fn bad_request(msg: impl Into<String>) -> Self {
126 Self::new(StatusCode::BAD_REQUEST, msg)
127 }
128
129 pub fn unauthorized(msg: impl Into<String>) -> Self {
131 Self::new(StatusCode::UNAUTHORIZED, msg)
132 }
133
134 pub fn forbidden(msg: impl Into<String>) -> Self {
136 Self::new(StatusCode::FORBIDDEN, msg)
137 }
138
139 pub fn not_found(msg: impl Into<String>) -> Self {
141 Self::new(StatusCode::NOT_FOUND, msg)
142 }
143
144 pub fn conflict(msg: impl Into<String>) -> Self {
146 Self::new(StatusCode::CONFLICT, msg)
147 }
148
149 pub fn payload_too_large(msg: impl Into<String>) -> Self {
151 Self::new(StatusCode::PAYLOAD_TOO_LARGE, msg)
152 }
153
154 pub fn unprocessable_entity(msg: impl Into<String>) -> Self {
156 Self::new(StatusCode::UNPROCESSABLE_ENTITY, msg)
157 }
158
159 pub fn too_many_requests(msg: impl Into<String>) -> Self {
161 Self::new(StatusCode::TOO_MANY_REQUESTS, msg)
162 }
163
164 pub fn internal(msg: impl Into<String>) -> Self {
166 Self::new(StatusCode::INTERNAL_SERVER_ERROR, msg)
167 }
168
169 pub fn bad_gateway(msg: impl Into<String>) -> Self {
171 Self::new(StatusCode::BAD_GATEWAY, msg)
172 }
173
174 pub fn gateway_timeout(msg: impl Into<String>) -> Self {
176 Self::new(StatusCode::GATEWAY_TIMEOUT, msg)
177 }
178
179 pub fn lagged(skipped: u64) -> Self {
184 Self {
185 status: StatusCode::INTERNAL_SERVER_ERROR,
186 message: format!("SSE subscriber lagged, skipped {skipped} messages"),
187 source: None,
188 error_code: None,
189 details: None,
190 lagged: true,
191 }
192 }
193
194 pub fn is_lagged(&self) -> bool {
196 self.lagged
197 }
198}
199
200impl Clone for Error {
204 fn clone(&self) -> Self {
205 Self {
206 status: self.status,
207 message: self.message.clone(),
208 source: None, error_code: self.error_code,
210 details: self.details.clone(),
211 lagged: self.lagged,
212 }
213 }
214}
215
216impl fmt::Display for Error {
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 write!(f, "{}", self.message)
219 }
220}
221
222impl fmt::Debug for Error {
223 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224 f.debug_struct("Error")
225 .field("status", &self.status)
226 .field("message", &self.message)
227 .field("source", &self.source)
228 .field("error_code", &self.error_code)
229 .field("details", &self.details)
230 .field("lagged", &self.lagged)
231 .finish()
232 }
233}
234
235impl std::error::Error for Error {
236 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
237 self.source
238 .as_ref()
239 .map(|e| e.as_ref() as &(dyn std::error::Error + 'static))
240 }
241}
242
243impl IntoResponse for Error {
256 fn into_response(self) -> Response {
257 let status = self.status;
258 let message = self.message.clone();
259 let details = self.details.clone();
260
261 let mut body = serde_json::json!({
262 "error": {
263 "status": status.as_u16(),
264 "message": &message
265 }
266 });
267 if let Some(ref d) = details {
268 body["error"]["details"] = d.clone();
269 }
270
271 let ext_error = Error {
273 status,
274 message,
275 source: None, error_code: self.error_code,
277 details,
278 lagged: self.lagged,
279 };
280
281 let mut response = (status, axum::Json(body)).into_response();
282 response.extensions_mut().insert(ext_error);
283 response
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn lagged_error_has_internal_status() {
293 let err = Error::lagged(5);
294 assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
295 assert!(err.message().contains('5'));
296 }
297
298 #[test]
299 fn is_lagged_returns_true_for_lagged_error() {
300 let err = Error::lagged(10);
301 assert!(err.is_lagged());
302 }
303
304 #[test]
305 fn is_lagged_returns_false_for_other_errors() {
306 let err = Error::internal("something else");
307 assert!(!err.is_lagged());
308 }
309
310 #[test]
311 fn payload_too_large_error_has_413_status() {
312 let err = Error::payload_too_large("file too big");
313 assert_eq!(err.status(), StatusCode::PAYLOAD_TOO_LARGE);
314 assert_eq!(err.message(), "file too big");
315 }
316
317 #[test]
318 fn chain_sets_source() {
319 use std::error::Error as _;
320 use std::io;
321 let err = super::Error::internal("failed").chain(io::Error::other("disk"));
322 assert!(err.source().is_some());
323 }
324
325 #[test]
326 fn source_as_downcasts_correctly() {
327 use std::io;
328 let io_err = io::Error::new(io::ErrorKind::NotFound, "missing");
329 let err = Error::internal("failed").chain(io_err);
330 let downcasted = err.source_as::<io::Error>();
331 assert!(downcasted.is_some());
332 assert_eq!(downcasted.unwrap().kind(), io::ErrorKind::NotFound);
333 }
334
335 #[test]
336 fn source_as_returns_none_for_wrong_type() {
337 use std::io;
338 let err = Error::internal("failed").chain(io::Error::other("x"));
339 let downcasted = err.source_as::<std::num::ParseIntError>();
340 assert!(downcasted.is_none());
341 }
342
343 #[test]
344 fn source_as_returns_none_when_no_source() {
345 let err = Error::internal("no source");
346 let downcasted = err.source_as::<std::io::Error>();
347 assert!(downcasted.is_none());
348 }
349
350 #[test]
351 fn with_code_sets_error_code() {
352 let err = Error::unauthorized("denied").with_code("jwt:expired");
353 assert_eq!(err.error_code(), Some("jwt:expired"));
354 }
355
356 #[test]
357 fn error_code_is_none_by_default() {
358 let err = Error::internal("plain");
359 assert!(err.error_code().is_none());
360 }
361
362 #[test]
363 fn error_code_survives_clone() {
364 let err = Error::unauthorized("denied").with_code("jwt:expired");
365 let cloned = err.clone();
366 assert_eq!(cloned.error_code(), Some("jwt:expired"));
367 }
368
369 #[test]
370 fn error_code_survives_into_response() {
371 use axum::response::IntoResponse;
372 let err = Error::unauthorized("denied").with_code("jwt:expired");
373 let response = err.into_response();
374 let ext_err = response.extensions().get::<Error>().unwrap();
375 assert_eq!(ext_err.error_code(), Some("jwt:expired"));
376 }
377
378 #[test]
379 fn bad_gateway_error_has_502_status() {
380 let err = Error::bad_gateway("upstream failed");
381 assert_eq!(err.status(), StatusCode::BAD_GATEWAY);
382 assert_eq!(err.message(), "upstream failed");
383 }
384
385 #[test]
386 fn gateway_timeout_error_has_504_status() {
387 let err = Error::gateway_timeout("timed out");
388 assert_eq!(err.status(), StatusCode::GATEWAY_TIMEOUT);
389 assert_eq!(err.message(), "timed out");
390 }
391}