rama_http/layer/forwarded/
set_forwarded_multi.rs

1use crate::Request;
2use crate::headers::HeaderMapExt;
3use crate::headers::forwarded::ForwardHeader;
4use rama_core::error::BoxError;
5use rama_core::{Context, Layer, Service};
6use rama_net::address::Domain;
7use rama_net::forwarded::{Forwarded, ForwardedElement, NodeId};
8use rama_net::http::RequestContext;
9use rama_net::stream::SocketInfo;
10use rama_utils::macros::all_the_tuples_no_last_special_case;
11use std::fmt;
12use std::marker::PhantomData;
13
14/// Layer to write [`Forwarded`] information for this proxy,
15/// added to the end of the chain of forwarded information already known.
16///
17/// Use [`SetForwardedHeaderLayer`] if you only need a single a header.
18///
19/// This layer can set any headers as long as you have a [`ForwardHeader`] implementation
20/// for the headers you want to set. You can pass it as the type to the layer when creating
21/// the layer using [`SetForwardedHeadersLayer::new`], with the headers in a single tuple.
22pub struct SetForwardedHeadersLayer<T = Forwarded> {
23    by_node: NodeId,
24    _headers: PhantomData<fn() -> T>,
25}
26
27impl<T: fmt::Debug> fmt::Debug for SetForwardedHeadersLayer<T> {
28    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
29        f.debug_struct("SetForwardedHeadersLayer")
30            .field("by_node", &self.by_node)
31            .field(
32                "_headers",
33                &format_args!("{}", std::any::type_name::<fn() -> T>()),
34            )
35            .finish()
36    }
37}
38
39impl<T: Clone> Clone for SetForwardedHeadersLayer<T> {
40    fn clone(&self) -> Self {
41        Self {
42            by_node: self.by_node.clone(),
43            _headers: PhantomData,
44        }
45    }
46}
47
48impl<T> Default for SetForwardedHeadersLayer<T> {
49    #[inline]
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl<T> SetForwardedHeadersLayer<T> {
56    /// Create a new `SetForwardedHeadersLayer` for the specified headers `T`.
57    pub fn new() -> Self {
58        Self {
59            by_node: Domain::from_static("rama").into(),
60            _headers: PhantomData,
61        }
62    }
63}
64
65impl<H, S> Layer<S> for SetForwardedHeadersLayer<H> {
66    type Service = SetForwardedHeadersService<S, H>;
67
68    fn layer(&self, inner: S) -> Self::Service {
69        Self::Service {
70            inner,
71            by_node: self.by_node.clone(),
72            _headers: PhantomData,
73        }
74    }
75
76    fn into_layer(self, inner: S) -> Self::Service {
77        Self::Service {
78            inner,
79            by_node: self.by_node,
80            _headers: PhantomData,
81        }
82    }
83}
84
85/// Middleware [`Service`] to write [`Forwarded`] information for this proxy,
86/// added to the end of the chain of forwarded information already known.
87///
88/// See [`SetForwardedHeadersLayer`] for more information.
89pub struct SetForwardedHeadersService<S, T = Forwarded> {
90    inner: S,
91    by_node: NodeId,
92    _headers: PhantomData<fn() -> T>,
93}
94
95impl<S: fmt::Debug, T> fmt::Debug for SetForwardedHeadersService<S, T> {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        f.debug_struct("SetForwardedHeadersService")
98            .field("inner", &self.inner)
99            .field("by_node", &self.by_node)
100            .field(
101                "_headers",
102                &format_args!("{}", std::any::type_name::<fn() -> T>()),
103            )
104            .finish()
105    }
106}
107
108impl<S: Clone, T> Clone for SetForwardedHeadersService<S, T> {
109    fn clone(&self) -> Self {
110        SetForwardedHeadersService {
111            inner: self.inner.clone(),
112            by_node: self.by_node.clone(),
113            _headers: PhantomData,
114        }
115    }
116}
117
118impl<S, T> SetForwardedHeadersService<S, T> {
119    /// Create a new `SetForwardedHeadersService` for the specified headers `T`.
120    pub fn new(inner: S) -> Self {
121        Self {
122            inner,
123            by_node: Domain::from_static("rama").into(),
124            _headers: PhantomData,
125        }
126    }
127}
128
129macro_rules! set_forwarded_service_for_tuple {
130    ( $($ty:ident),* $(,)? ) => {
131        #[allow(non_snake_case)]
132        impl<S, $($ty),* , State, Body> Service<State, Request<Body>> for SetForwardedHeadersService<S, ($($ty,)*)>
133        where
134            $( $ty: ForwardHeader + Send + Sync + 'static, )*
135            S: Service<State, Request<Body>, Error: Into<BoxError>>,
136            Body: Send + 'static,
137            State: Clone + Send + Sync + 'static,
138        {
139            type Response = S::Response;
140            type Error = BoxError;
141
142            async fn serve(
143                &self,
144                mut ctx: Context<State>,
145                mut req: Request<Body>,
146            ) -> Result<Self::Response, Self::Error> {
147                let forwarded: Option<Forwarded> = ctx.get().cloned();
148
149                let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
150
151                if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
152                    forwarded_element.set_forwarded_for(peer_addr);
153                }
154
155                let request_ctx: &mut RequestContext =
156                    ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
157
158                forwarded_element.set_forwarded_host(request_ctx.authority.clone());
159
160                if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
161                    forwarded_element.set_forwarded_proto(forwarded_proto);
162                }
163
164                let forwarded = match forwarded {
165                    None => Some(Forwarded::new(forwarded_element)),
166                    Some(mut forwarded) => {
167                        forwarded.append(forwarded_element);
168                        Some(forwarded)
169                    }
170                };
171
172                if let Some(forwarded) = forwarded {
173                    $(
174                        if let Some(header) = $ty::try_from_forwarded(forwarded.iter()) {
175                            req.headers_mut().typed_insert(header);
176                        }
177                    )*
178                }
179
180                self.inner.serve(ctx, req).await.map_err(Into::into)
181            }
182        }
183    };
184}
185all_the_tuples_no_last_special_case!(set_forwarded_service_for_tuple);
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::{
191        Response, StatusCode,
192        headers::forwarded::{TrueClientIp, XClientIp, XRealIp},
193        service::web::response::IntoResponse,
194    };
195    use rama_core::{Layer, error::OpaqueError, service::service_fn};
196    use rama_http_headers::forwarded::XForwardedProto;
197    use std::convert::Infallible;
198
199    fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
200
201    async fn dummy_service_fn() -> Result<Response, OpaqueError> {
202        Ok(StatusCode::OK.into_response())
203    }
204
205    #[test]
206    fn test_set_forwarded_service_is_service() {
207        assert_is_service(SetForwardedHeadersService::<_, (TrueClientIp,)>::new(
208            service_fn(dummy_service_fn),
209        ));
210        assert_is_service(
211            SetForwardedHeadersService::<_, (TrueClientIp, XClientIp)>::new(service_fn(
212                dummy_service_fn,
213            )),
214        );
215        assert_is_service(
216            SetForwardedHeadersLayer::<(XRealIp, XForwardedProto)>::new()
217                .into_layer(service_fn(dummy_service_fn)),
218        );
219    }
220
221    #[tokio::test]
222    async fn test_set_forwarded_service_forwarded() {
223        async fn svc(request: Request<()>) -> Result<(), Infallible> {
224            assert_eq!(
225                request.headers().get("Forwarded").unwrap(),
226                "by=rama;host=\"example.com:80\";proto=http"
227            );
228            Ok(())
229        }
230
231        let service =
232            SetForwardedHeadersService::<_, (rama_http_headers::forwarded::Forwarded,)>::new(
233                service_fn(svc),
234            );
235        let req = Request::builder().uri("example.com").body(()).unwrap();
236        service.serve(Context::default(), req).await.unwrap();
237    }
238}