gateway_runtime/layers/
headers.rs1use crate::alloc::boxed::Box;
10use crate::alloc::string::String;
11use crate::alloc::sync::Arc;
12use crate::{GatewayRequest, GatewayResponse};
13use core::task::{Context, Poll};
14use std::future::Future;
15use std::pin::Pin;
16use tower::Service;
17
18pub type HeaderMatcher = Arc<dyn Fn(&str) -> Option<String> + Send + Sync>;
24
25#[derive(Clone)]
27pub struct HeaderLayer<S> {
28 inner: S,
29 incoming_matcher: Option<HeaderMatcher>,
30 outgoing_matcher: Option<HeaderMatcher>,
31}
32
33impl<S> HeaderLayer<S> {
34 pub fn new(inner: S, incoming: Option<HeaderMatcher>, outgoing: Option<HeaderMatcher>) -> Self {
41 Self {
42 inner,
43 incoming_matcher: incoming,
44 outgoing_matcher: outgoing,
45 }
46 }
47}
48
49impl<S> Service<GatewayRequest> for HeaderLayer<S>
50where
51 S: Service<GatewayRequest, Response = GatewayResponse>,
52 S::Future: Send + 'static,
53{
54 type Response = GatewayResponse;
55 type Error = S::Error;
56 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
57
58 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
59 self.inner.poll_ready(cx)
60 }
61
62 fn call(&mut self, mut req: GatewayRequest) -> Self::Future {
63 if let Some(matcher) = &self.incoming_matcher {
65 let mut new_headers = http::HeaderMap::new();
66 for (key, value) in req.headers() {
67 if let Some(new_key) = matcher(key.as_str()) {
68 if let Ok(k) = http::header::HeaderName::from_bytes(new_key.as_bytes()) {
69 new_headers.insert(k, value.clone());
70 }
71 }
72 }
73 *req.headers_mut() = new_headers;
74 }
75
76 let outgoing_matcher = self.outgoing_matcher.clone();
77 let fut = self.inner.call(req);
78
79 Box::pin(async move {
80 let mut resp = fut.await?;
81
82 if let Some(matcher) = outgoing_matcher {
84 let mut new_headers = http::HeaderMap::new();
85 for (key, value) in resp.headers() {
86 if let Some(new_key) = matcher(key.as_str()) {
87 if let Ok(k) = http::header::HeaderName::from_bytes(new_key.as_bytes()) {
88 new_headers.insert(k, value.clone());
89 }
90 }
91 }
92 *resp.headers_mut() = new_headers;
93 }
94
95 Ok(resp)
96 })
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::GatewayError;
104 use http_body_util::BodyExt;
105 use http_body_util::Full;
106
107 #[tokio::test]
108 async fn test_header_layer_filters_incoming() {
109 let matcher: HeaderMatcher = Arc::new(|key| {
110 if key == "x-allowed" {
111 Some(key.to_string())
112 } else {
113 None
114 }
115 });
116
117 let service = tower::service_fn(|req: GatewayRequest| async move {
118 assert!(req.headers().contains_key("x-allowed"));
119 assert!(!req.headers().contains_key("x-forbidden"));
120 Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
121 Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
122 )))
123 });
124
125 let mut layer = HeaderLayer::new(service, Some(matcher), None);
126 let req = http::Request::builder()
127 .header("x-allowed", "true")
128 .header("x-forbidden", "true")
129 .body(crate::alloc::vec::Vec::new())
130 .unwrap();
131
132 layer.call(req).await.unwrap();
133 }
134
135 #[tokio::test]
136 async fn test_header_layer_renames_incoming() {
137 let matcher: HeaderMatcher = Arc::new(|key| {
138 if key == "old-name" {
139 Some("new-name".to_string())
140 } else {
141 Some(key.to_string())
142 }
143 });
144
145 let service = tower::service_fn(|req: GatewayRequest| async move {
146 assert!(req.headers().contains_key("new-name"));
147 assert!(!req.headers().contains_key("old-name"));
148 Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
149 Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
150 )))
151 });
152
153 let mut layer = HeaderLayer::new(service, Some(matcher), None);
154 let req = http::Request::builder()
155 .header("old-name", "value")
156 .body(crate::alloc::vec::Vec::new())
157 .unwrap();
158
159 layer.call(req).await.unwrap();
160 }
161
162 #[tokio::test]
163 async fn test_header_layer_transforms_outgoing() {
164 let matcher: HeaderMatcher = Arc::new(|key| {
165 if key == "x-secret" {
166 None
167 } else if key == "x-internal" {
168 Some("x-public".to_string())
169 } else {
170 Some(key.to_string())
171 }
172 });
173
174 let service = tower::service_fn(|_req: GatewayRequest| async {
175 let mut resp = http::Response::new(BodyExt::boxed_unsync(
176 Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
177 ));
178 resp.headers_mut()
179 .insert("x-secret", "shhh".parse().unwrap());
180 resp.headers_mut()
181 .insert("x-internal", "val".parse().unwrap());
182 resp.headers_mut()
183 .insert("content-type", "application/json".parse().unwrap());
184 Ok::<GatewayResponse, GatewayError>(resp)
185 });
186
187 let mut layer = HeaderLayer::new(service, None, Some(matcher));
188 let req = http::Request::builder()
189 .body(crate::alloc::vec::Vec::new())
190 .unwrap();
191
192 let resp = layer.call(req).await.unwrap();
193 assert!(!resp.headers().contains_key("x-secret"));
194 assert!(!resp.headers().contains_key("x-internal"));
195 assert!(resp.headers().contains_key("x-public"));
196 assert!(resp.headers().contains_key("content-type"));
197 }
198}