Skip to main content

gateway_runtime/
handlers.rs

1//! # Custom Handlers
2//!
3//! This module provides specific implementations of handlers mirroring the `grpc-gateway` (Go) ecosystem.
4//! These handlers offer extended functionality beyond the `defaults` module, such as
5//! advanced error formatting and context extraction.
6//!
7//! ## Components
8//! -   `custom_http_error`: Returns detailed JSON error responses matching Go's default error format.
9//! -   `http_response_modifier`: Modifies the HTTP response status code based on gRPC metadata (`x-http-code`).
10//! -   Header matchers: specialized filtering for Auth/Refresh tokens.
11
12use crate::alloc::string::{String, ToString};
13use crate::defaults::default_error_handler;
14use crate::errors::GatewayError;
15use crate::{GatewayRequest, GatewayResponse};
16use http::StatusCode;
17use http_body_util::BodyExt;
18
19/// A custom HTTP error handler that returns JSON error responses.
20///
21/// This handler maps gRPC status codes to HTTP status codes and returns a JSON body with:
22/// - `message`: The error message.
23/// - `status_code`: The HTTP status code.
24/// - `title`: A title for the error (default "Error").
25///
26/// It falls back to `default_error_handler` for 2xx codes or codes outside the valid HTTP range [200, 505].
27pub fn custom_http_error(req: &GatewayRequest, err: GatewayError) -> GatewayResponse {
28    let status = match &err {
29        GatewayError::Upstream(s) => crate::errors::map_code_to_status(s.code()),
30        GatewayError::Http(_) => StatusCode::INTERNAL_SERVER_ERROR,
31        GatewayError::Custom(s, _) => *s,
32        GatewayError::MethodNotAllowed => StatusCode::METHOD_NOT_ALLOWED,
33        GatewayError::NotFound => StatusCode::NOT_FOUND,
34        GatewayError::Encoding(_) => StatusCode::BAD_REQUEST,
35    };
36
37    let code = status.as_u16();
38
39    // Delegate to default handler for success codes or invalid HTTP ranges
40    if code <= 300 || code > 505 {
41        return default_error_handler(req, err);
42    }
43
44    // JSON Error Response Body
45    #[derive(serde::Serialize)]
46    struct ErrorMessage {
47        message: String,
48        status_code: u16,
49        title: String,
50    }
51
52    let msg = ErrorMessage {
53        message: err.to_string(),
54        status_code: code,
55        title: "Error".to_string(),
56    };
57
58    let body_bytes = serde_json::to_vec(&msg).unwrap_or_default();
59    let body = http_body_util::BodyExt::boxed_unsync(
60        http_body_util::Full::new(crate::bytes::Bytes::from(body_bytes))
61            .map_err(|_| unreachable!()),
62    );
63
64    http::Response::builder()
65        .status(status)
66        .header("Content-Type", "application/json")
67        .body(body)
68        .unwrap()
69}
70
71/// Modifies the HTTP response based on metadata headers.
72///
73/// Specifically, it looks for an `x-http-code` header (which may have been mapped from
74/// gRPC metadata by the upstream service) and uses it to override the HTTP status code.
75/// It then removes the header to prevent leaking implementation details.
76pub fn http_response_modifier(_req: &GatewayRequest, resp: &mut GatewayResponse) {
77    if let Some(val) = resp.headers().get("x-http-code") {
78        if let Ok(s) = val.to_str() {
79            if let Ok(code) = s.parse::<u16>() {
80                if let Ok(status) = StatusCode::from_u16(code) {
81                    *resp.status_mut() = status;
82                }
83            }
84        }
85        // Cleanup headers
86        resp.headers_mut().remove("x-http-code");
87        resp.headers_mut()
88            .remove("grpc-metadata-x-http-status-code");
89    }
90}
91
92/// Incoming header matcher that filters Auth headers.
93///
94/// - `Authorization`: Dropped (returns `None`).
95/// - `Refresh`: Renamed to `x-refresh-token`.
96/// - Other: Forwarded as lowercased.
97pub fn incoming_header_matcher(key: &str) -> Option<String> {
98    let key_lower = key.to_lowercase();
99    match key_lower.as_str() {
100        "authorization" => None,
101        "refresh" => Some("x-refresh-token".to_string()),
102        _ => Some(key_lower),
103    }
104}
105
106/// Outgoing header matcher (Pass-through).
107///
108/// Forwards all headers in lowercase.
109pub fn outgoing_header_matcher(key: &str) -> Option<String> {
110    Some(key.to_lowercase())
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::alloc::vec::Vec;
117
118    #[test]
119    fn test_custom_http_error_json() {
120        let req = http::Request::builder().body(Vec::new()).unwrap();
121        let err = GatewayError::NotFound;
122        let resp = custom_http_error(&req, err);
123
124        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
125        assert_eq!(
126            resp.headers().get("content-type").unwrap(),
127            "application/json"
128        );
129    }
130
131    #[test]
132    fn test_custom_http_error_fallback() {
133        let req = http::Request::builder().body(Vec::new()).unwrap();
134        let err = GatewayError::Custom(
135            StatusCode::INTERNAL_SERVER_ERROR,
136            "internal error".to_string(),
137        );
138        let resp = custom_http_error(&req, err);
139        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
140    }
141
142    #[test]
143    fn test_http_response_modifier() {
144        let req = http::Request::builder().body(Vec::new()).unwrap();
145        let mut resp = http::Response::builder()
146            .status(StatusCode::OK)
147            .header("x-http-code", "400")
148            .body(BodyExt::boxed_unsync(
149                http_body_util::Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
150            ))
151            .unwrap();
152
153        http_response_modifier(&req, &mut resp);
154        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
155        assert!(resp.headers().get("x-http-code").is_none());
156    }
157
158    #[test]
159    fn test_incoming_header_matcher() {
160        assert_eq!(incoming_header_matcher("Authorization"), None);
161        assert_eq!(
162            incoming_header_matcher("Refresh"),
163            Some("x-refresh-token".to_string())
164        );
165        assert_eq!(incoming_header_matcher("Other"), Some("other".to_string()));
166    }
167}