Skip to main content

netray_common/
error.rs

1use axum::http::StatusCode;
2use axum::response::{IntoResponse, Response};
3use serde::Serialize;
4
5/// JSON body returned for all error responses.
6///
7/// Wire format: `{"error": {"code": "...", "message": "..."}}`
8#[derive(Debug, Serialize)]
9pub struct ErrorResponse {
10    pub error: ErrorInfo,
11}
12
13/// Error detail contained in an error response.
14#[derive(Debug, Serialize)]
15pub struct ErrorInfo {
16    /// Machine-readable error code (e.g. `INVALID_DOMAIN`).
17    pub code: &'static str,
18    /// Human-readable error message.
19    pub message: String,
20}
21
22/// Trait for application-specific error types that can be rendered as
23/// structured JSON error responses.
24///
25/// Each project defines its own error enum and implements this trait.
26/// The shared `IntoResponse` implementation (via [`into_error_response`])
27/// handles JSON serialization, status codes, and the `Retry-After` header
28/// for rate-limited responses.
29pub trait ApiError: std::fmt::Display {
30    /// HTTP status code for this error variant.
31    fn status_code(&self) -> StatusCode;
32
33    /// Machine-readable error code string (e.g. `"INVALID_DOMAIN"`).
34    fn error_code(&self) -> &'static str;
35
36    /// If this is a rate-limited error, return the retry-after duration in seconds.
37    fn retry_after_secs(&self) -> Option<u64> {
38        None
39    }
40}
41
42/// Convert any [`ApiError`] into an axum [`Response`].
43///
44/// Produces a JSON body of the form:
45/// ```json
46/// {"error": {"code": "ERROR_CODE", "message": "human-readable message"}}
47/// ```
48///
49/// For rate-limited responses (when `retry_after_secs()` returns `Some`),
50/// includes the `Retry-After` header per RFC 6585.
51pub fn into_error_response(err: &(impl ApiError + std::fmt::Display)) -> Response {
52    let status = err.status_code();
53
54    if status == StatusCode::INTERNAL_SERVER_ERROR {
55        tracing::error!(error = %err, "internal server error");
56    }
57
58    let body = ErrorResponse {
59        error: ErrorInfo {
60            code: err.error_code(),
61            message: err.to_string(),
62        },
63    };
64
65    let mut response = (status, axum::Json(body)).into_response();
66
67    if let Some(secs) = err.retry_after_secs() {
68        response.headers_mut().insert(
69            axum::http::header::RETRY_AFTER,
70            axum::http::HeaderValue::from(secs),
71        );
72    }
73
74    response
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use axum::body::to_bytes;
81
82    #[derive(Debug)]
83    enum TestError {
84        BadInput(String),
85        RateLimited { retry_after: u64 },
86        Internal(String),
87    }
88
89    impl std::fmt::Display for TestError {
90        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91            match self {
92                Self::BadInput(msg) => write!(f, "bad input: {msg}"),
93                Self::RateLimited { .. } => write!(f, "rate limited"),
94                Self::Internal(msg) => write!(f, "internal error: {msg}"),
95            }
96        }
97    }
98
99    impl ApiError for TestError {
100        fn status_code(&self) -> StatusCode {
101            match self {
102                Self::BadInput(_) => StatusCode::BAD_REQUEST,
103                Self::RateLimited { .. } => StatusCode::TOO_MANY_REQUESTS,
104                Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
105            }
106        }
107
108        fn error_code(&self) -> &'static str {
109            match self {
110                Self::BadInput(_) => "BAD_INPUT",
111                Self::RateLimited { .. } => "RATE_LIMITED",
112                Self::Internal(_) => "INTERNAL_ERROR",
113            }
114        }
115
116        fn retry_after_secs(&self) -> Option<u64> {
117            match self {
118                Self::RateLimited { retry_after } => Some(*retry_after),
119                _ => None,
120            }
121        }
122    }
123
124    async fn body_json(err: TestError) -> serde_json::Value {
125        let response = into_error_response(&err);
126        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
127        serde_json::from_slice(&bytes).unwrap()
128    }
129
130    async fn into_parts(
131        err: TestError,
132    ) -> (StatusCode, axum::http::HeaderMap, serde_json::Value) {
133        let response = into_error_response(&err);
134        let status = response.status();
135        let headers = response.headers().clone();
136        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
137        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
138        (status, headers, body)
139    }
140
141    #[tokio::test]
142    async fn bad_input_is_400() {
143        let response = into_error_response(&TestError::BadInput("oops".into()));
144        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
145    }
146
147    #[tokio::test]
148    async fn rate_limited_is_429() {
149        let response = into_error_response(&TestError::RateLimited { retry_after: 5 });
150        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
151    }
152
153    #[tokio::test]
154    async fn internal_is_500() {
155        let response = into_error_response(&TestError::Internal("boom".into()));
156        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
157    }
158
159    #[tokio::test]
160    async fn body_has_error_code_and_message() {
161        let body = body_json(TestError::BadInput("test".into())).await;
162        assert_eq!(body["error"]["code"], "BAD_INPUT");
163        assert!(body["error"]["message"].as_str().unwrap().contains("test"));
164        assert_eq!(body.as_object().unwrap().len(), 1);
165    }
166
167    #[tokio::test]
168    async fn rate_limited_includes_retry_after_header() {
169        let (status, headers, _) =
170            into_parts(TestError::RateLimited { retry_after: 42 }).await;
171        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
172        let retry_after = headers
173            .get(axum::http::header::RETRY_AFTER)
174            .expect("Retry-After header must be present");
175        let value: u64 = retry_after.to_str().unwrap().parse().unwrap();
176        assert_eq!(value, 42);
177    }
178
179    #[tokio::test]
180    async fn non_rate_limited_has_no_retry_after() {
181        let (_, headers, _) = into_parts(TestError::BadInput("x".into())).await;
182        assert!(headers.get(axum::http::header::RETRY_AFTER).is_none());
183    }
184}