Skip to main content

netray_common/
error.rs

1use axum::http::StatusCode;
2use axum::response::{IntoResponse, Response};
3use serde::Serialize;
4
5/// JSON body returned for all error responses.
6///
7/// Wire format: `{"error": {"code": "...", "message": "..."}}`
8#[derive(Debug, Serialize)]
9#[cfg_attr(feature = "schema", derive(utoipa::ToSchema))]
10pub struct ErrorResponse {
11    pub error: ErrorInfo,
12}
13
14/// Error detail contained in an error response.
15#[derive(Debug, Serialize)]
16#[cfg_attr(feature = "schema", derive(utoipa::ToSchema))]
17pub struct ErrorInfo {
18    /// Machine-readable error code (e.g. `INVALID_DOMAIN`).
19    pub code: &'static str,
20    /// Human-readable error message.
21    pub message: String,
22}
23
24/// Trait for application-specific error types that can be rendered as
25/// structured JSON error responses.
26///
27/// Each project defines its own error enum and implements this trait.
28/// The shared `IntoResponse` implementation (via [`into_error_response`])
29/// handles JSON serialization, status codes, and the `Retry-After` header
30/// for rate-limited responses.
31pub trait ApiError: std::fmt::Display {
32    /// HTTP status code for this error variant.
33    fn status_code(&self) -> StatusCode;
34
35    /// Machine-readable error code string (e.g. `"INVALID_DOMAIN"`).
36    fn error_code(&self) -> &'static str;
37
38    /// If this is a rate-limited error, return the retry-after duration in seconds.
39    fn retry_after_secs(&self) -> Option<u64> {
40        None
41    }
42}
43
44/// Convert any [`ApiError`] into an axum [`Response`].
45///
46/// Produces a JSON body of the form:
47/// ```json
48/// {"error": {"code": "ERROR_CODE", "message": "human-readable message"}}
49/// ```
50///
51/// For rate-limited responses (when `retry_after_secs()` returns `Some`),
52/// includes the `Retry-After` header per RFC 6585.
53pub fn into_error_response(err: &impl ApiError) -> Response {
54    let status = err.status_code();
55
56    if status.is_server_error() {
57        tracing::error!(error = %err, "internal server error");
58    } else if status.is_client_error() {
59        tracing::warn!(error = %err, "client error");
60    }
61
62    build_error_response(err)
63}
64
65/// Build an error response without logging (caller handles logging).
66///
67/// Use this when the caller already logs with tool-specific context.
68/// Use [`into_error_response`] for the default behaviour (logs + response).
69pub fn build_error_response(err: &impl ApiError) -> Response {
70    let status = err.status_code();
71
72    let body = ErrorResponse {
73        error: ErrorInfo {
74            code: err.error_code(),
75            message: err.to_string(),
76        },
77    };
78
79    let mut response = (status, axum::Json(body)).into_response();
80
81    if let Some(secs) = err.retry_after_secs() {
82        response.headers_mut().insert(
83            axum::http::header::RETRY_AFTER,
84            axum::http::HeaderValue::from(secs),
85        );
86    }
87
88    response
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use axum::body::to_bytes;
95
96    #[derive(Debug)]
97    enum TestError {
98        BadInput(String),
99        RateLimited { retry_after: u64 },
100        Internal(String),
101    }
102
103    impl std::fmt::Display for TestError {
104        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105            match self {
106                Self::BadInput(msg) => write!(f, "bad input: {msg}"),
107                Self::RateLimited { .. } => write!(f, "rate limited"),
108                Self::Internal(msg) => write!(f, "internal error: {msg}"),
109            }
110        }
111    }
112
113    impl ApiError for TestError {
114        fn status_code(&self) -> StatusCode {
115            match self {
116                Self::BadInput(_) => StatusCode::BAD_REQUEST,
117                Self::RateLimited { .. } => StatusCode::TOO_MANY_REQUESTS,
118                Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
119            }
120        }
121
122        fn error_code(&self) -> &'static str {
123            match self {
124                Self::BadInput(_) => "BAD_INPUT",
125                Self::RateLimited { .. } => "RATE_LIMITED",
126                Self::Internal(_) => "INTERNAL_ERROR",
127            }
128        }
129
130        fn retry_after_secs(&self) -> Option<u64> {
131            match self {
132                Self::RateLimited { retry_after } => Some(*retry_after),
133                _ => None,
134            }
135        }
136    }
137
138    async fn body_json(err: TestError) -> serde_json::Value {
139        let response = into_error_response(&err);
140        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
141        serde_json::from_slice(&bytes).unwrap()
142    }
143
144    async fn into_parts(err: TestError) -> (StatusCode, axum::http::HeaderMap, serde_json::Value) {
145        let response = into_error_response(&err);
146        let status = response.status();
147        let headers = response.headers().clone();
148        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
149        let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
150        (status, headers, body)
151    }
152
153    #[tokio::test]
154    async fn bad_input_is_400() {
155        let response = into_error_response(&TestError::BadInput("oops".into()));
156        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
157    }
158
159    #[tokio::test]
160    async fn rate_limited_is_429() {
161        let response = into_error_response(&TestError::RateLimited { retry_after: 5 });
162        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
163    }
164
165    #[tokio::test]
166    async fn internal_is_500() {
167        let response = into_error_response(&TestError::Internal("boom".into()));
168        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
169    }
170
171    #[tokio::test]
172    async fn body_has_error_code_and_message() {
173        let body = body_json(TestError::BadInput("test".into())).await;
174        assert_eq!(body["error"]["code"], "BAD_INPUT");
175        assert!(body["error"]["message"].as_str().unwrap().contains("test"));
176        assert_eq!(body.as_object().unwrap().len(), 1);
177    }
178
179    #[tokio::test]
180    async fn rate_limited_includes_retry_after_header() {
181        let (status, headers, _) = into_parts(TestError::RateLimited { retry_after: 42 }).await;
182        assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
183        let retry_after = headers
184            .get(axum::http::header::RETRY_AFTER)
185            .expect("Retry-After header must be present");
186        let value: u64 = retry_after.to_str().unwrap().parse().unwrap();
187        assert_eq!(value, 42);
188    }
189
190    #[tokio::test]
191    async fn non_rate_limited_has_no_retry_after() {
192        let (_, headers, _) = into_parts(TestError::BadInput("x".into())).await;
193        assert!(headers.get(axum::http::header::RETRY_AFTER).is_none());
194    }
195
196    #[tokio::test]
197    async fn error_response_has_json_content_type() {
198        let response = into_error_response(&TestError::BadInput("test".into()));
199        let ct = response
200            .headers()
201            .get(axum::http::header::CONTENT_TYPE)
202            .expect("Content-Type header must be present")
203            .to_str()
204            .unwrap();
205        assert!(
206            ct.contains("application/json"),
207            "expected application/json, got {ct}"
208        );
209    }
210}