Skip to main content

gateway_runtime/
defaults.rs

1//! # Default Handlers
2//!
3//! This module provides standard implementations for common gateway hooks, offering sane defaults
4//! that cover most use cases.
5//!
6//! ## Components
7//! -   `default_error_handler`: Returns simple JSON error responses.
8//! -   `default_metadata_annotator`: Injects standard metadata like Request ID and Client IP.
9//! -   `default_response_modifier`: Honors the `x-http-code` header for status code overrides.
10//! -   `default_incoming_header_matcher`: Filters `Authorization`, `x-request-id` (security) and renames `Refresh`.
11//! -   `default_outgoing_header_matcher`: Lowercases all headers.
12//! -   `default_auth_verifier`: Basic API Key presence check.
13
14use crate::alloc::string::{String, ToString};
15use crate::errors::GatewayError;
16use crate::router::{AuthLocation, RouteMetadata};
17use crate::{GatewayRequest, GatewayResponse};
18use http::StatusCode;
19use http_body_util::BodyExt;
20use tonic::metadata::MetadataMap;
21
22/// A custom HTTP error handler that returns JSON error responses.
23///
24/// This handler maps gRPC status codes to HTTP status codes and returns a JSON body
25/// with the error details.
26pub fn default_error_handler(_req: &GatewayRequest, err: GatewayError) -> GatewayResponse {
27    let status = match &err {
28        GatewayError::Upstream(s) => crate::errors::map_code_to_status(s.code()),
29        GatewayError::Http(_) => StatusCode::INTERNAL_SERVER_ERROR,
30        GatewayError::Custom(s, _) => *s,
31        GatewayError::MethodNotAllowed => StatusCode::METHOD_NOT_ALLOWED,
32        GatewayError::NotFound => StatusCode::NOT_FOUND,
33        GatewayError::Encoding(_) => StatusCode::BAD_REQUEST,
34    };
35
36    #[derive(serde::Serialize)]
37    struct ErrorMessage {
38        message: String,
39        status_code: u16,
40        title: String,
41    }
42
43    let msg = ErrorMessage {
44        message: err.to_string(),
45        status_code: status.as_u16(),
46        title: "Error".to_string(),
47    };
48
49    let body_bytes = serde_json::to_vec(&msg).unwrap_or_default();
50    let body = http_body_util::BodyExt::boxed_unsync(
51        http_body_util::Full::new(crate::bytes::Bytes::from(body_bytes))
52            .map_err(|_| unreachable!()),
53    );
54
55    http::Response::builder()
56        .status(status)
57        .header("Content-Type", "application/json")
58        .body(body)
59        .unwrap()
60}
61
62/// A metadata annotator that injects request context information.
63///
64/// This annotator:
65/// 1. Generates and injects a unique request ID (`x-request-id`).
66///    - It assumes incoming `x-request-id` headers have been stripped by the `HeaderLayer`
67///      to ensure the ID is trusted and generated by the gateway.
68/// 2. Injects the gateway IP as `x-gateway-ip` (placeholder).
69pub fn default_metadata_annotator(_req: &GatewayRequest) -> MetadataMap {
70    let mut map = MetadataMap::new();
71
72    // 1. Request ID Generation
73    // We strictly generate a new ID.
74    #[cfg(feature = "std")]
75    {
76        // Simple ID generation using timestamp + random-ish suffix (or just nanoseconds for now)
77        // In a real production system, use the `uuid` crate.
78        let req_id = format!(
79            "req-{}",
80            std::time::SystemTime::now()
81                .duration_since(std::time::UNIX_EPOCH)
82                .unwrap_or_default()
83                .as_nanos()
84        );
85        if let Ok(v) = tonic::metadata::MetadataValue::try_from(req_id.as_str()) {
86            map.insert("x-request-id", v);
87        }
88    }
89
90    // 2. Gateway IP (Placeholder)
91    map.insert(
92        "x-gateway-ip",
93        tonic::metadata::MetadataValue::from_static("127.0.0.1"),
94    );
95
96    map
97}
98
99/// A response modifier that sets the HTTP status code based on the `x-http-code` header.
100pub fn default_response_modifier(_req: &GatewayRequest, resp: &mut GatewayResponse) {
101    if let Some(val) = resp.headers().get("x-http-code") {
102        if let Ok(s) = val.to_str() {
103            if let Ok(code) = s.parse::<u16>() {
104                if let Ok(status) = StatusCode::from_u16(code) {
105                    *resp.status_mut() = status;
106                }
107            }
108        }
109        resp.headers_mut().remove("x-http-code");
110        resp.headers_mut()
111            .remove("grpc-metadata-x-http-status-code");
112    }
113}
114
115/// An incoming header matcher that filters specific headers.
116///
117/// - "authorization": Dropped (Security).
118/// - "x-request-id": Dropped (Security - ensure we generate our own).
119/// - "refresh": Renamed to "x-refresh-token".
120/// - Other: Kept as is.
121pub fn default_incoming_header_matcher(key: &str) -> Option<String> {
122    let key_lower = key.to_lowercase();
123    match key_lower.as_str() {
124        "authorization" => None,
125        "x-request-id" => None, // Strip incoming request IDs
126        "refresh" => Some("x-refresh-token".to_string()),
127        _ => Some(key_lower),
128    }
129}
130
131/// An outgoing header matcher that passes through all headers lowercased.
132pub fn default_outgoing_header_matcher(key: &str) -> Option<String> {
133    Some(key.to_lowercase())
134}
135
136/// A default authentication verifier.
137///
138/// Checks if the route requires an API Key and verifies its presence in the configured location.
139/// Note: This implementation only checks for *presence*, not validity of the key value.
140///
141/// # Errors
142/// Returns `GatewayError::Upstream` with `Unauthenticated` status if the key is missing.
143pub fn default_auth_verifier(
144    req: &GatewayRequest,
145    metadata: &RouteMetadata,
146) -> Result<(), GatewayError> {
147    if let Some(auth_config) = &metadata.auth_required {
148        // Only checking ApiKey for now as per requirement
149        if auth_config.scheme == "ApiKey" {
150            let present = match auth_config.location {
151                AuthLocation::Header => req.headers().contains_key(&auth_config.name),
152                AuthLocation::Query => {
153                    if let Some(query) = req.uri().query() {
154                        // Simple substring check for query param key presence
155                        // Robust implementation would parse query string
156                        query.contains(&format!("{}=", auth_config.name))
157                    } else {
158                        false
159                    }
160                }
161                AuthLocation::Cookie => {
162                    if let Some(cookie_header) = req.headers().get(http::header::COOKIE) {
163                        if let Ok(s) = cookie_header.to_str() {
164                            s.contains(&format!("{}=", auth_config.name))
165                        } else {
166                            false
167                        }
168                    } else {
169                        false
170                    }
171                }
172            };
173
174            if !present {
175                return Err(GatewayError::Upstream(tonic::Status::unauthenticated(
176                    format!("Missing authentication: {}", auth_config.name),
177                )));
178            }
179        }
180    }
181    Ok(())
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::alloc::vec::Vec;
188    use crate::router::{AuthConfig, RouteMetadata};
189
190    #[test]
191    fn test_default_incoming_header_matcher() {
192        assert_eq!(default_incoming_header_matcher("Authorization"), None);
193        assert_eq!(default_incoming_header_matcher("x-request-id"), None);
194        assert_eq!(
195            default_incoming_header_matcher("REFRESH"),
196            Some("x-refresh-token".to_string())
197        );
198        assert_eq!(
199            default_incoming_header_matcher("X-Custom"),
200            Some("x-custom".to_string())
201        );
202    }
203
204    #[test]
205    fn test_default_outgoing_header_matcher() {
206        assert_eq!(
207            default_outgoing_header_matcher("Authorization"),
208            Some("authorization".to_string())
209        );
210    }
211
212    #[test]
213    fn test_default_metadata_annotator() {
214        let req = http::Request::builder().body(Vec::new()).unwrap();
215        let md = default_metadata_annotator(&req);
216        // Check injected
217        assert!(md.get("x-request-id").is_some());
218        assert_eq!(
219            md.get("x-gateway-ip").unwrap().to_str().unwrap(),
220            "127.0.0.1"
221        );
222    }
223
224    #[test]
225    fn test_default_response_modifier() {
226        let req = http::Request::builder().body(Vec::new()).unwrap();
227        let mut resp = http::Response::builder()
228            .header("x-http-code", "418")
229            .body(BodyExt::boxed_unsync(
230                http_body_util::Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
231            ))
232            .unwrap();
233
234        default_response_modifier(&req, &mut resp);
235        assert_eq!(resp.status(), StatusCode::IM_A_TEAPOT);
236        assert!(resp.headers().get("x-http-code").is_none());
237    }
238
239    #[test]
240    fn test_default_response_modifier_invalid() {
241        let req = http::Request::builder().body(Vec::new()).unwrap();
242        let mut resp = http::Response::builder()
243            .status(StatusCode::OK)
244            .header("x-http-code", "invalid")
245            .body(BodyExt::boxed_unsync(
246                http_body_util::Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
247            ))
248            .unwrap();
249
250        default_response_modifier(&req, &mut resp);
251        assert_eq!(resp.status(), StatusCode::OK);
252        // Should remove header anyway? Implementation removes it.
253        assert!(resp.headers().get("x-http-code").is_none());
254    }
255
256    #[test]
257    fn test_default_error_handler_not_found() {
258        let req = http::Request::builder().body(Vec::new()).unwrap();
259        let err = GatewayError::NotFound;
260        let resp = default_error_handler(&req, err);
261        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
262    }
263
264    #[test]
265    fn test_default_error_handler_method_not_allowed() {
266        let req = http::Request::builder().body(Vec::new()).unwrap();
267        let err = GatewayError::MethodNotAllowed;
268        let resp = default_error_handler(&req, err);
269        assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
270    }
271
272    #[test]
273    fn test_default_error_handler_upstream() {
274        let req = http::Request::builder().body(Vec::new()).unwrap();
275        let err = GatewayError::Upstream(tonic::Status::invalid_argument("invalid"));
276        let resp = default_error_handler(&req, err);
277        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
278    }
279
280    #[test]
281    fn test_default_error_handler_json_body() {
282        let req = http::Request::builder().body(Vec::new()).unwrap();
283        let err = GatewayError::NotFound;
284        let resp = default_error_handler(&req, err);
285        assert_eq!(
286            resp.headers().get("content-type").unwrap(),
287            "application/json"
288        );
289    }
290
291    #[test]
292    fn test_default_auth_verifier_header() {
293        let req = http::Request::builder()
294            .header("X-API-KEY", "secret")
295            .body(Vec::new())
296            .unwrap();
297        let meta = RouteMetadata {
298            auth_required: Some(AuthConfig {
299                scheme: "ApiKey".to_string(),
300                location: AuthLocation::Header,
301                name: "X-API-KEY".to_string(),
302            }),
303        };
304        assert!(default_auth_verifier(&req, &meta).is_ok());
305    }
306
307    #[test]
308    fn test_default_auth_verifier_missing() {
309        let req = http::Request::builder().body(Vec::new()).unwrap();
310        let meta = RouteMetadata {
311            auth_required: Some(AuthConfig {
312                scheme: "ApiKey".to_string(),
313                location: AuthLocation::Header,
314                name: "X-API-KEY".to_string(),
315            }),
316        };
317        assert!(default_auth_verifier(&req, &meta).is_err());
318    }
319
320    #[test]
321    fn test_default_auth_verifier_none_required() {
322        let req = http::Request::builder().body(Vec::new()).unwrap();
323        let meta = RouteMetadata::default();
324        assert!(default_auth_verifier(&req, &meta).is_ok());
325    }
326}