Skip to main content

omnia_sdk/
error.rs

1//! Errors
2
3// use axum::response::{IntoResponse, Response};
4use http::StatusCode;
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7
8/// Result type used across the crate.
9pub type Result<T> = anyhow::Result<T, Error>;
10
11/// Domain level error type returned by the adapter.
12#[derive(Error, Debug, Clone, Serialize, Deserialize)]
13pub enum Error {
14    // --- Client errors ---
15    /// Request payload is invalid or missing required fields.
16    #[error("code: {code}, description: {description}")]
17    BadRequest {
18        /// The error code.
19        code: String,
20        /// The error description.
21        description: String,
22    },
23
24    /// Resource or data not found.
25    #[error("code: {code}, description: {description}")]
26    NotFound {
27        /// The error code.
28        code: String,
29        /// The error description.
30        description: String,
31    },
32
33    // --- Server errors ---
34    /// A non recoverable internal error occurred.
35    #[error("code: {code}, description: {description}")]
36    ServerError {
37        /// The error code.
38        code: String,
39        /// The error description.
40        description: String,
41    },
42
43    /// An upstream dependency failed while fulfilling the request.
44    #[error("code: {code}, description: {description}")]
45    BadGateway {
46        /// The error code.
47        code: String,
48        /// The error description.
49        description: String,
50    },
51}
52
53impl Error {
54    /// Returns the HTTP status code associated with the variant.
55    #[must_use]
56    pub const fn status(&self) -> StatusCode {
57        match self {
58            Self::BadRequest { .. } => StatusCode::BAD_REQUEST,
59            Self::NotFound { .. } => StatusCode::NOT_FOUND,
60            Self::ServerError { .. } => StatusCode::INTERNAL_SERVER_ERROR,
61            Self::BadGateway { .. } => StatusCode::BAD_GATEWAY,
62        }
63    }
64
65    /// Returns the error code for the variant.
66    #[must_use]
67    pub fn code(&self) -> String {
68        match self {
69            Self::BadRequest { code, .. }
70            | Self::NotFound { code, .. }
71            | Self::ServerError { code, .. }
72            | Self::BadGateway { code, .. } => code.clone(),
73        }
74    }
75
76    /// Returns the error description.
77    #[must_use]
78    pub fn description(&self) -> String {
79        match self {
80            Self::BadRequest { description, .. }
81            | Self::NotFound { description, .. }
82            | Self::ServerError { description, .. }
83            | Self::BadGateway { description, .. } => description.clone(),
84        }
85    }
86}
87
88impl From<anyhow::Error> for Error {
89    fn from(err: anyhow::Error) -> Self {
90        let chain = err.chain().map(ToString::to_string).collect::<Vec<_>>().join(": ");
91
92        // if type is Error, return it with the newly added context
93        if let Some(inner) = err.downcast_ref::<Self>() {
94            tracing::debug!("Error: {err}, caused by: {inner}");
95
96            return match inner {
97                Self::BadRequest { code, .. } => Self::BadRequest {
98                    code: code.clone(),
99                    description: chain,
100                },
101                Self::NotFound { code, .. } => Self::NotFound {
102                    code: code.clone(),
103                    description: chain,
104                },
105                Self::ServerError { code, .. } => Self::ServerError {
106                    code: code.clone(),
107                    description: chain,
108                },
109                Self::BadGateway { code, .. } => Self::BadGateway {
110                    code: code.clone(),
111                    description: chain,
112                },
113            };
114        }
115
116        // otherwise, return an Internal error
117        Self::ServerError {
118            code: "server_error".to_string(),
119            description: chain,
120        }
121    }
122}
123
124impl From<serde_json::Error> for Error {
125    fn from(err: serde_json::Error) -> Self {
126        Self::BadRequest {
127            code: "serde_json".to_string(),
128            description: err.to_string(),
129        }
130    }
131}
132
133/// Create a new `BadRequest` error.
134#[macro_export]
135macro_rules! bad_request {
136    ($fmt:expr, $($arg:tt)*) => {
137        $crate::Error::BadRequest { code: "bad_request".to_string(), description: format!($fmt, $($arg)*) }
138    };
139    ($desc:expr $(,)?) => {
140        $crate::Error::BadRequest { code: "bad_request".to_string(), description: format!($desc) }
141    };
142}
143
144/// Create a new `ServerError` error.
145#[macro_export]
146macro_rules! server_error {
147    ($fmt:expr, $($arg:tt)*) => {
148        $crate::Error::ServerError { code: "server_error".to_string(), description: format!($fmt, $($arg)*) }
149    };
150     ($err:expr $(,)?) => {
151        $crate::Error::ServerError { code: "server_error".to_string(), description: format!($err) }
152    };
153}
154
155/// Create a new `BadGateway` error.
156#[macro_export]
157macro_rules! bad_gateway {
158    ($fmt:expr, $($arg:tt)*) => {
159        $crate::Error::BadGateway { code: "bad_gateway".to_string(), description: format!($fmt, $($arg)*) }
160    };
161     ($err:expr $(,)?) => {
162        $crate::Error::BadGateway { code: "bad_gateway".to_string(), description: format!($err) }
163    };
164}
165
166#[cfg(test)]
167mod tests {
168    use anyhow::{Context, Result, anyhow};
169    use serde_json::Value;
170    use tracing_subscriber::layer::SubscriberExt;
171    use tracing_subscriber::util::SubscriberInitExt;
172    use tracing_subscriber::{EnvFilter, Registry, fmt};
173
174    use super::Error;
175
176    #[test]
177    fn error_display() {
178        let err = bad_request!("invalid input");
179        assert_eq!(format!("{err}",), "code: bad_request, description: invalid input");
180    }
181
182    #[test]
183    fn with_context() {
184        Registry::default().with(EnvFilter::new("debug")).with(fmt::layer()).init();
185
186        let context_error = || -> Result<(), Error> {
187            Err(bad_request!("invalid input"))
188                .context("doing something")
189                .context("more context")?;
190            Ok(())
191        };
192
193        let result = context_error();
194        assert_eq!(
195            result.unwrap_err().to_string(),
196            bad_request!(
197                "more context: doing something: code: bad_request, description: invalid input"
198            )
199            .to_string()
200        );
201    }
202
203    // Test that error details are returned as json.
204    #[test]
205    fn r9k_context() {
206        let result = Err::<(), Error>(server_error!("server error")).context("request context");
207        let err: Error = result.unwrap_err().into();
208
209        assert_eq!(
210            err.to_string(),
211            "code: server_error, description: request context: code: server_error, description: server error"
212        );
213    }
214
215    #[test]
216    fn anyhow_context() {
217        let result = Err::<(), anyhow::Error>(anyhow!("one-off error")).context("error context");
218        let err: Error = result.unwrap_err().into();
219
220        assert_eq!(
221            err.to_string(),
222            "code: server_error, description: error context: one-off error"
223        );
224    }
225
226    #[test]
227    fn serde_context() {
228        let result: Result<Value, anyhow::Error> =
229            serde_json::from_str(r#"{"foo": "bar""#).context("error context");
230        let err: Error = result.unwrap_err().into();
231
232        assert_eq!(
233            err.to_string(),
234            "code: server_error, description: error context: EOF while parsing an object at line 1 column 13"
235        );
236    }
237}