Skip to main content

gateway_runtime/layers/
headers.rs

1//! # Header Processing Layer
2//!
3//! This layer intercepts incoming requests and outgoing responses to filter or transform
4//! HTTP headers using configured [HeaderMatcher] functions.
5//!
6//! This allows for renaming headers (e.g., `Authorization` -> `x-auth-token`) or stripping
7//! sensitive/unwanted headers before they reach the application logic or the client.
8
9use crate::alloc::boxed::Box;
10use crate::alloc::string::String;
11use crate::alloc::sync::Arc;
12use crate::{GatewayRequest, GatewayResponse};
13use core::task::{Context, Poll};
14use std::future::Future;
15use std::pin::Pin;
16use tower::Service;
17
18/// A handler for matching and transforming headers.
19///
20/// It takes a header name (as a string slice) and returns an `Option<String>`.
21/// *   `Some(new_name)`: Renames the header to `new_name` (or keeps it if identical).
22/// *   `None`: Removes the header.
23pub type HeaderMatcher = Arc<dyn Fn(&str) -> Option<String> + Send + Sync>;
24
25/// A Tower middleware that applies header matching logic.
26#[derive(Clone)]
27pub struct HeaderLayer<S> {
28    inner: S,
29    incoming_matcher: Option<HeaderMatcher>,
30    outgoing_matcher: Option<HeaderMatcher>,
31}
32
33impl<S> HeaderLayer<S> {
34    /// Creates a new `HeaderLayer`.
35    ///
36    /// # Parameters
37    /// *   `inner`: The inner service.
38    /// *   `incoming`: Optional matcher for request headers.
39    /// *   `outgoing`: Optional matcher for response headers.
40    pub fn new(inner: S, incoming: Option<HeaderMatcher>, outgoing: Option<HeaderMatcher>) -> Self {
41        Self {
42            inner,
43            incoming_matcher: incoming,
44            outgoing_matcher: outgoing,
45        }
46    }
47}
48
49impl<S> Service<GatewayRequest> for HeaderLayer<S>
50where
51    S: Service<GatewayRequest, Response = GatewayResponse>,
52    S::Future: Send + 'static,
53{
54    type Response = GatewayResponse;
55    type Error = S::Error;
56    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
57
58    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
59        self.inner.poll_ready(cx)
60    }
61
62    fn call(&mut self, mut req: GatewayRequest) -> Self::Future {
63        // Process Incoming Headers
64        if let Some(matcher) = &self.incoming_matcher {
65            let mut new_headers = http::HeaderMap::new();
66            for (key, value) in req.headers() {
67                if let Some(new_key) = matcher(key.as_str()) {
68                    if let Ok(k) = http::header::HeaderName::from_bytes(new_key.as_bytes()) {
69                        new_headers.insert(k, value.clone());
70                    }
71                }
72            }
73            *req.headers_mut() = new_headers;
74        }
75
76        let outgoing_matcher = self.outgoing_matcher.clone();
77        let fut = self.inner.call(req);
78
79        Box::pin(async move {
80            let mut resp = fut.await?;
81
82            // Process Outgoing Headers
83            if let Some(matcher) = outgoing_matcher {
84                let mut new_headers = http::HeaderMap::new();
85                for (key, value) in resp.headers() {
86                    if let Some(new_key) = matcher(key.as_str()) {
87                        if let Ok(k) = http::header::HeaderName::from_bytes(new_key.as_bytes()) {
88                            new_headers.insert(k, value.clone());
89                        }
90                    }
91                }
92                *resp.headers_mut() = new_headers;
93            }
94
95            Ok(resp)
96        })
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::GatewayError;
104    use http_body_util::BodyExt;
105    use http_body_util::Full;
106
107    #[tokio::test]
108    async fn test_header_layer_filters_incoming() {
109        let matcher: HeaderMatcher = Arc::new(|key| {
110            if key == "x-allowed" {
111                Some(key.to_string())
112            } else {
113                None
114            }
115        });
116
117        let service = tower::service_fn(|req: GatewayRequest| async move {
118            assert!(req.headers().contains_key("x-allowed"));
119            assert!(!req.headers().contains_key("x-forbidden"));
120            Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
121                Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
122            )))
123        });
124
125        let mut layer = HeaderLayer::new(service, Some(matcher), None);
126        let req = http::Request::builder()
127            .header("x-allowed", "true")
128            .header("x-forbidden", "true")
129            .body(crate::alloc::vec::Vec::new())
130            .unwrap();
131
132        layer.call(req).await.unwrap();
133    }
134
135    #[tokio::test]
136    async fn test_header_layer_renames_incoming() {
137        let matcher: HeaderMatcher = Arc::new(|key| {
138            if key == "old-name" {
139                Some("new-name".to_string())
140            } else {
141                Some(key.to_string())
142            }
143        });
144
145        let service = tower::service_fn(|req: GatewayRequest| async move {
146            assert!(req.headers().contains_key("new-name"));
147            assert!(!req.headers().contains_key("old-name"));
148            Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
149                Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
150            )))
151        });
152
153        let mut layer = HeaderLayer::new(service, Some(matcher), None);
154        let req = http::Request::builder()
155            .header("old-name", "value")
156            .body(crate::alloc::vec::Vec::new())
157            .unwrap();
158
159        layer.call(req).await.unwrap();
160    }
161
162    #[tokio::test]
163    async fn test_header_layer_transforms_outgoing() {
164        let matcher: HeaderMatcher = Arc::new(|key| {
165            if key == "x-secret" {
166                None
167            } else if key == "x-internal" {
168                Some("x-public".to_string())
169            } else {
170                Some(key.to_string())
171            }
172        });
173
174        let service = tower::service_fn(|_req: GatewayRequest| async {
175            let mut resp = http::Response::new(BodyExt::boxed_unsync(
176                Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
177            ));
178            resp.headers_mut()
179                .insert("x-secret", "shhh".parse().unwrap());
180            resp.headers_mut()
181                .insert("x-internal", "val".parse().unwrap());
182            resp.headers_mut()
183                .insert("content-type", "application/json".parse().unwrap());
184            Ok::<GatewayResponse, GatewayError>(resp)
185        });
186
187        let mut layer = HeaderLayer::new(service, None, Some(matcher));
188        let req = http::Request::builder()
189            .body(crate::alloc::vec::Vec::new())
190            .unwrap();
191
192        let resp = layer.call(req).await.unwrap();
193        assert!(!resp.headers().contains_key("x-secret"));
194        assert!(!resp.headers().contains_key("x-internal"));
195        assert!(resp.headers().contains_key("x-public"));
196        assert!(resp.headers().contains_key("content-type"));
197    }
198}