Skip to main content

netray_common/
error.rs

1use axum::http::StatusCode;
2use axum::response::{IntoResponse, Response};
3use serde::Serialize;
4
5/// JSON body returned for all error responses.
6///
7/// Wire format: `{"error": {"code": "...", "message": "..."}}`
8#[derive(Debug, Serialize)]
9pub struct ErrorResponse {
10    pub error: ErrorInfo,
11}
12
13/// Error detail contained in an error response.
14#[derive(Debug, Serialize)]
15pub struct ErrorInfo {
16    /// Machine-readable error code (e.g. `INVALID_DOMAIN`).
17    pub code: &'static str,
18    /// Human-readable error message.
19    pub message: String,
20}
21
22/// Trait for application-specific error types that can be rendered as
23/// structured JSON error responses.
24///
25/// Each project defines its own error enum and implements this trait.
26/// The shared `IntoResponse` implementation (via [`into_error_response`])
27/// handles JSON serialization, status codes, and the `Retry-After` header
28/// for rate-limited responses.
29pub trait ApiError: std::fmt::Display {
30    /// HTTP status code for this error variant.
31    fn status_code(&self) -> StatusCode;
32
33    /// Machine-readable error code string (e.g. `"INVALID_DOMAIN"`).
34    fn error_code(&self) -> &'static str;
35
36    /// If this is a rate-limited error, return the retry-after duration in seconds.
37    fn retry_after_secs(&self) -> Option<u64> {
38        None
39    }
40}
41
42/// Convert any [`ApiError`] into an axum [`Response`].
43///
44/// Produces a JSON body of the form:
45/// ```json
46/// {"error": {"code": "ERROR_CODE", "message": "human-readable message"}}
47/// ```
48///
49/// For rate-limited responses (when `retry_after_secs()` returns `Some`),
50/// includes the `Retry-After` header per RFC 6585.
51pub fn into_error_response(err: &impl ApiError) -> Response {
52    let status = err.status_code();
53
54    if status.is_server_error() {
55        tracing::error!(error = %err, "internal server error");
56    } else if status.is_client_error() {
57        tracing::warn!(error = %err, "client error");
58    }
59
60    let body = ErrorResponse {
61        error: ErrorInfo {
62            code: err.error_code(),
63            message: err.to_string(),
64        },
65    };
66
67    let mut response = (status, axum::Json(body)).into_response();
68
69    if let Some(secs) = err.retry_after_secs() {
70        response.headers_mut().insert(
71            axum::http::header::RETRY_AFTER,
72            axum::http::HeaderValue::from(secs),
73        );
74    }
75
76    response
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use axum::body::to_bytes;
83
84    #[derive(Debug)]
85    enum TestError {
86        BadInput(String),
87        RateLimited { retry_after: u64 },
88        Internal(String),
89    }
90
91    impl std::fmt::Display for TestError {
92        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93            match self {
94                Self::BadInput(msg) => write!(f, "bad input: {msg}"),
95                Self::RateLimited { .. } => write!(f, "rate limited"),
96                Self::Internal(msg) => write!(f, "internal error: {msg}"),
97            }
98        }
99    }
100
101    impl ApiError for TestError {
102        fn status_code(&self) -> StatusCode {
103            match self {
104                Self::BadInput(_) => StatusCode::BAD_REQUEST,
105                Self::RateLimited { .. } => StatusCode::TOO_MANY_REQUESTS,
106                Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
107            }
108        }
109
110        fn error_code(&self) -> &'static str {
111            match self {
112                Self::BadInput(_) => "BAD_INPUT",
113                Self::RateLimited { .. } => "RATE_LIMITED",
114                Self::Internal(_) => "INTERNAL_ERROR",
115            }
116        }
117
118        fn retry_after_secs(&self) -> Option<u64> {
119            match self {
120                Self::RateLimited { retry_after } => Some(*retry_after),
121                _ => None,
122            }
123        }
124    }
125
126    async fn body_json(err: TestError) -> serde_json::Value {
127        let response = into_error_response(&err);
128        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
129        serde_json::from_slice(&bytes).unwrap()
130    }
131
132    async fn into_parts(
133        err: TestError,
134    ) -> (StatusCode, axum::http::HeaderMap, serde_json::Value) {
135        let response = into_error_response(&err);
136        let status = response.status();
137        let headers = response.headers().clone();
138        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
139        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
140        (status, headers, body)
141    }
142
143    #[tokio::test]
144    async fn bad_input_is_400() {
145        let response = into_error_response(&TestError::BadInput("oops".into()));
146        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
147    }
148
149    #[tokio::test]
150    async fn rate_limited_is_429() {
151        let response = into_error_response(&TestError::RateLimited { retry_after: 5 });
152        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
153    }
154
155    #[tokio::test]
156    async fn internal_is_500() {
157        let response = into_error_response(&TestError::Internal("boom".into()));
158        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
159    }
160
161    #[tokio::test]
162    async fn body_has_error_code_and_message() {
163        let body = body_json(TestError::BadInput("test".into())).await;
164        assert_eq!(body["error"]["code"], "BAD_INPUT");
165        assert!(body["error"]["message"].as_str().unwrap().contains("test"));
166        assert_eq!(body.as_object().unwrap().len(), 1);
167    }
168
169    #[tokio::test]
170    async fn rate_limited_includes_retry_after_header() {
171        let (status, headers, _) =
172            into_parts(TestError::RateLimited { retry_after: 42 }).await;
173        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
174        let retry_after = headers
175            .get(axum::http::header::RETRY_AFTER)
176            .expect("Retry-After header must be present");
177        let value: u64 = retry_after.to_str().unwrap().parse().unwrap();
178        assert_eq!(value, 42);
179    }
180
181    #[tokio::test]
182    async fn non_rate_limited_has_no_retry_after() {
183        let (_, headers, _) = into_parts(TestError::BadInput("x".into())).await;
184        assert!(headers.get(axum::http::header::RETRY_AFTER).is_none());
185    }
186
187    #[tokio::test]
188    async fn error_response_has_json_content_type() {
189        let response = into_error_response(&TestError::BadInput("test".into()));
190        let ct = response
191            .headers()
192            .get(axum::http::header::CONTENT_TYPE)
193            .expect("Content-Type header must be present")
194            .to_str()
195            .unwrap();
196        assert!(
197            ct.contains("application/json"),
198            "expected application/json, got {ct}"
199        );
200    }
201}