gateway_runtime/layers/
headers.rs1use crate::alloc::boxed::Box;
10use crate::gateway::HeaderMatcher;
11use crate::{GatewayRequest, GatewayResponse};
12use core::task::{Context, Poll};
13use std::future::Future;
14use std::pin::Pin;
15use tower::Service;
16
17#[derive(Clone)]
19pub struct HeaderLayer<S> {
20 inner: S,
21 incoming_matcher: Option<HeaderMatcher>,
22 outgoing_matcher: Option<HeaderMatcher>,
23}
24
25impl<S> HeaderLayer<S> {
26 pub fn new(inner: S, incoming: Option<HeaderMatcher>, outgoing: Option<HeaderMatcher>) -> Self {
33 Self {
34 inner,
35 incoming_matcher: incoming,
36 outgoing_matcher: outgoing,
37 }
38 }
39}
40
41impl<S> Service<GatewayRequest> for HeaderLayer<S>
42where
43 S: Service<GatewayRequest, Response = GatewayResponse>,
44 S::Future: Send + 'static,
45{
46 type Response = GatewayResponse;
47 type Error = S::Error;
48 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
49
50 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
51 self.inner.poll_ready(cx)
52 }
53
54 fn call(&mut self, mut req: GatewayRequest) -> Self::Future {
55 if let Some(matcher) = &self.incoming_matcher {
57 let mut new_headers = http::HeaderMap::new();
58 for (key, value) in req.headers() {
59 if let Some(new_key) = matcher(key.as_str()) {
60 if let Ok(k) = http::header::HeaderName::from_bytes(new_key.as_bytes()) {
61 new_headers.insert(k, value.clone());
62 }
63 }
64 }
65 *req.headers_mut() = new_headers;
66 }
67
68 let outgoing_matcher = self.outgoing_matcher.clone();
69 let fut = self.inner.call(req);
70
71 Box::pin(async move {
72 let mut resp = fut.await?;
73
74 if let Some(matcher) = outgoing_matcher {
76 let mut new_headers = http::HeaderMap::new();
77 for (key, value) in resp.headers() {
78 if let Some(new_key) = matcher(key.as_str()) {
79 if let Ok(k) = http::header::HeaderName::from_bytes(new_key.as_bytes()) {
80 new_headers.insert(k, value.clone());
81 }
82 }
83 }
84 *resp.headers_mut() = new_headers;
85 }
86
87 Ok(resp)
88 })
89 }
90}