covert_types/
error.rs

1use std::fmt::Display;
2
3use http::header::CONTENT_TYPE;
4use serde::Serialize;
5use serde_with::{serde_as, DisplayFromStr};
6use thiserror::Error;
7
8pub use http::StatusCode;
9use tracing_error::SpanTrace;
10
11use crate::state::StorageState;
12
13/// A shares errod type used to produce public error and add additional context
14/// for internal diagnostics. A public error will be produced by using the inner
15/// error [`Display`] implementation and `status_code` field. The internal error
16/// report will be created used the [`Debug`] implementation and `span_trace` field.
17#[serde_as]
18#[derive(Error, Debug, Serialize)]
19pub struct ApiError {
20    // Only the Display format of the source error will be returned to the client.
21    #[serde_as(as = "DisplayFromStr")]
22    #[source]
23    pub error: anyhow::Error,
24    #[serde(skip)]
25    pub status_code: StatusCode,
26    // TODO: make it non-optional
27    #[serde(skip)]
28    pub span_trace: Option<SpanTrace>,
29}
30
31impl Display for ApiError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        let report = self.report();
34        // Using Debug impl here in Display impl because ApiError
35        // doesn't need the Display impl
36        write!(f, "{report:?}")
37    }
38}
39
40#[derive(Debug)]
41pub struct Report {
42    pub cause: String,
43    // TODO: make it non-optional
44    pub span_trace: Option<SpanTrace>,
45}
46
47impl ApiError {
48    #[must_use]
49    pub fn bad_request() -> Self {
50        Self {
51            error: anyhow::Error::msg("Bad request"),
52            status_code: StatusCode::BAD_REQUEST,
53            span_trace: Some(SpanTrace::capture()),
54        }
55    }
56
57    #[must_use]
58    pub fn internal_error() -> Self {
59        Self {
60            error: anyhow::Error::msg("Internal error"),
61            status_code: StatusCode::INTERNAL_SERVER_ERROR,
62            span_trace: Some(SpanTrace::capture()),
63        }
64    }
65
66    #[must_use]
67    pub fn timeout() -> Self {
68        Self {
69            error: anyhow::Error::msg("Request timed out"),
70            status_code: StatusCode::REQUEST_TIMEOUT,
71            span_trace: Some(SpanTrace::capture()),
72        }
73    }
74
75    #[must_use]
76    pub fn invalid_state(current_state: StorageState) -> Self {
77        Self {
78            error: anyhow::Error::msg(format!(
79                "This operation is not allowed when the current state is `{current_state}`"
80            )),
81            status_code: StatusCode::FORBIDDEN,
82            span_trace: Some(SpanTrace::capture()),
83        }
84    }
85
86    #[must_use]
87    pub fn unauthorized() -> Self {
88        Self {
89            error: anyhow::Error::msg("User is not authorized to perform this operation"),
90            status_code: StatusCode::UNAUTHORIZED,
91            span_trace: Some(SpanTrace::capture()),
92        }
93    }
94
95    #[must_use]
96    pub fn not_found() -> Self {
97        Self {
98            error: anyhow::Error::msg("Not found"),
99            status_code: StatusCode::NOT_FOUND,
100            span_trace: Some(SpanTrace::capture()),
101        }
102    }
103
104    #[must_use]
105    pub fn report(&self) -> Report {
106        Report {
107            cause: format!("{:?}", self.error.root_cause()),
108            span_trace: self.span_trace.clone(),
109        }
110    }
111}
112
113impl From<ApiError> for hyper::Response<hyper::Body> {
114    fn from(err: ApiError) -> Self {
115        match serde_json::to_vec(&err) {
116            Ok(err_body) => hyper::Response::builder()
117                .header(CONTENT_TYPE, "application/json")
118                .status(err.status_code)
119                .body(err_body.into())
120                .expect("a valid response"),
121            Err(_) => hyper::Response::builder()
122                .header(CONTENT_TYPE, "application/json")
123                .status(StatusCode::INTERNAL_SERVER_ERROR)
124                .body("Internal error. Unable to return the error response.".into())
125                .expect("a valid response"),
126        }
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[derive(Debug)]
135    pub struct DummyError {
136        pub debug_field: String,
137        pub display_field: String,
138    }
139
140    impl std::error::Error for DummyError {}
141
142    impl Display for DummyError {
143        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144            write!(f, "{}", &self.display_field)
145        }
146    }
147
148    #[test]
149    fn serialize_api_error() {
150        let err = DummyError {
151            debug_field: "debug error".into(),
152            display_field: "display error".into(),
153        };
154        let api_err = ApiError {
155            error: err.into(),
156            status_code: StatusCode::INTERNAL_SERVER_ERROR,
157            span_trace: None,
158        };
159
160        // Check serialized error response
161        let api_err_serialized = serde_json::to_string(&api_err).unwrap();
162        assert_eq!(api_err_serialized, r#"{"error":"display error"}"#);
163
164        // The error report should use the Debug impl of the root cause
165        let err_report = format!("{:?}", api_err.report());
166        assert_eq!(
167            err_report,
168            r#"Report { cause: "DummyError { debug_field: \"debug error\", display_field: \"display error\" }", span_trace: None }"#
169        );
170    }
171}