rama_http/layer/forwarded/
get_forwarded_multi.rs

1use crate::Request;
2use crate::headers::forwarded::ForwardHeader;
3use rama_core::{Context, Layer, Service};
4use rama_http_headers::HeaderMapExt;
5use rama_net::forwarded::Forwarded;
6use rama_net::forwarded::ForwardedElement;
7use rama_utils::macros::all_the_tuples_no_last_special_case;
8use std::fmt;
9use std::marker::PhantomData;
10
11/// Layer to extract [`Forwarded`] information from the specified `T` headers.
12///
13/// Use [`GetForwardedHeaderLayer`] if you only need a single a header.
14///
15/// [`GetForwardedHeaderLayer`]: super::GetForwardedHeaderLayer
16///
17/// This layer can be used to extract the [`Forwarded`] information from any specified header `T`,
18/// as long as the header implements the [`ForwardHeader`] trait. Multiple headers can be specified
19/// as a tuple, and the layer will extract information from them all, and combine the information.
20///
21/// Please take into consideration the following when combining headers:
22///
23/// - The last header in the tuple will take precedence over the previous headers,
24///   if the same information is present in multiple headers.
25/// - Headers that can contain multiple elements, (e.g. X-Forwarded-For, Via)
26///   will combine their elements in the order as specified. That does however mean that in
27///   case one header has less elements then the other, that the combination down the line
28///   will not be accurate.
29///
30/// Rama also has the following headers already implemented for you to use:
31///
32/// > [`X-Real-Ip`], [`X-Client-Ip`], [`Client-Ip`], [`Cf-Connecting-Ip`] and [`True-Client-Ip`].
33///
34/// There are no [`GetForwardedHeadersLayer`] constructors for these headers,
35/// but you can use the [`GetForwardedHeadersLayer::new`] constructor and pass the header type as a type parameter in a tuple with other headers.
36///
37/// [`X-Real-Ip`]: crate::headers::XRealIp
38/// [`X-Client-Ip`]: crate::headers::XClientIp
39/// [`Client-Ip`]: crate::headers::ClientIp
40/// [`CF-Connecting-Ip`]: crate::headers::CFConnectingIp
41/// [`True-Client-Ip`]: crate::headers::TrueClientIp
42pub struct GetForwardedHeadersLayer<T = Forwarded> {
43    _headers: PhantomData<fn() -> T>,
44}
45
46impl<T: fmt::Debug> fmt::Debug for GetForwardedHeadersLayer<T> {
47    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48        f.debug_struct("GetForwardedHeadersLayer")
49            .field(
50                "_headers",
51                &format_args!("{}", std::any::type_name::<fn() -> T>()),
52            )
53            .finish()
54    }
55}
56
57impl<T: Clone> Clone for GetForwardedHeadersLayer<T> {
58    fn clone(&self) -> Self {
59        Self {
60            _headers: PhantomData,
61        }
62    }
63}
64
65impl<T> Default for GetForwardedHeadersLayer<T> {
66    #[inline]
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl<T> GetForwardedHeadersLayer<T> {
73    /// Create a new `GetForwardedHeadersLayer` for the specified headers `T`.
74    pub const fn new() -> Self {
75        Self {
76            _headers: PhantomData,
77        }
78    }
79}
80
81impl<H, S> Layer<S> for GetForwardedHeadersLayer<H> {
82    type Service = GetForwardedHeadersService<S, H>;
83
84    fn layer(&self, inner: S) -> Self::Service {
85        Self::Service {
86            inner,
87            _headers: PhantomData,
88        }
89    }
90}
91
92/// Middleware service to extract [`Forwarded`] information from the specified `T` headers.
93///
94/// See [`GetForwardedHeadersLayer`] for more information.
95pub struct GetForwardedHeadersService<S, T = Forwarded> {
96    inner: S,
97    _headers: PhantomData<fn() -> T>,
98}
99
100impl<S: fmt::Debug, T> fmt::Debug for GetForwardedHeadersService<S, T> {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        f.debug_struct("GetForwardedHeadersService")
103            .field("inner", &self.inner)
104            .field("_headers", &format_args!("{}", std::any::type_name::<T>()))
105            .finish()
106    }
107}
108
109impl<S: Clone, T> Clone for GetForwardedHeadersService<S, T> {
110    fn clone(&self) -> Self {
111        GetForwardedHeadersService {
112            inner: self.inner.clone(),
113            _headers: PhantomData,
114        }
115    }
116}
117
118impl<S, T> GetForwardedHeadersService<S, T> {
119    /// Create a new `GetForwardedHeadersService` for the specified headers `T`.
120    pub const fn new(inner: S) -> Self {
121        Self {
122            inner,
123            _headers: PhantomData,
124        }
125    }
126}
127
128macro_rules! get_forwarded_service_for_tuple {
129    ( $($ty:ident),* $(,)? ) => {
130        #[allow(non_snake_case)]
131        impl<$($ty,)* S, State, Body> Service<State, Request<Body>> for GetForwardedHeadersService<S, ($($ty,)*)>
132        where
133            $( $ty: ForwardHeader + Send + Sync + 'static, )*
134            S: Service<State, Request<Body>>,
135            Body: Send + 'static,
136            State: Clone + Send + Sync + 'static,
137        {
138            type Response = S::Response;
139            type Error = S::Error;
140
141            fn serve(
142                &self,
143                mut ctx: Context<State>,
144                req: Request<Body>,
145            ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
146                let mut forwarded_elements: Vec<ForwardedElement> = Vec::with_capacity(1);
147
148                $(
149                    if let Some($ty) = req.headers().typed_get::<$ty>() {
150                        let mut iter = $ty.into_iter();
151                        for element in forwarded_elements.iter_mut() {
152                            let other = iter.next();
153                            match other {
154                                Some(other) => {
155                                    element.merge(other);
156                                }
157                                None => break,
158                            }
159                        }
160                        for other in iter {
161                            forwarded_elements.push(other);
162                        }
163                    }
164                )*
165
166                if !forwarded_elements.is_empty() {
167                    match ctx.get_mut::<Forwarded>() {
168                        Some(ref mut f) => {
169                            f.extend(forwarded_elements);
170                        }
171                        None => {
172                            let mut it = forwarded_elements.into_iter();
173                            let mut forwarded = Forwarded::new(it.next().unwrap());
174                            forwarded.extend(it);
175                            ctx.insert(forwarded);
176                        }
177                    }
178                }
179
180                self.inner.serve(ctx, req)
181            }
182        }
183    }
184}
185
186all_the_tuples_no_last_special_case!(get_forwarded_service_for_tuple);
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::{
192        Response, StatusCode,
193        headers::forwarded::{ClientIp, TrueClientIp, XClientIp},
194        service::web::response::IntoResponse,
195    };
196    use rama_core::{Layer, error::OpaqueError, service::service_fn};
197    use rama_net::forwarded::ForwardedProtocol;
198    use std::{convert::Infallible, net::IpAddr};
199
200    fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
201
202    async fn dummy_service_fn() -> Result<Response, OpaqueError> {
203        Ok(StatusCode::OK.into_response())
204    }
205
206    #[test]
207    fn test_get_forwarded_service_is_service() {
208        assert_is_service(GetForwardedHeadersService::<_, (TrueClientIp,)>::new(
209            service_fn(dummy_service_fn),
210        ));
211        assert_is_service(
212            GetForwardedHeadersService::<_, (TrueClientIp, XClientIp)>::new(service_fn(
213                dummy_service_fn,
214            )),
215        );
216        assert_is_service(
217            GetForwardedHeadersLayer::<(ClientIp, TrueClientIp)>::new()
218                .into_layer(service_fn(dummy_service_fn)),
219        );
220    }
221
222    #[tokio::test]
223    async fn test_get_forwarded_headers() {
224        let service = GetForwardedHeadersLayer::<(rama_http_headers::forwarded::Forwarded,)>::new()
225            .into_layer(service_fn(async |ctx: Context<()>, _| {
226                let forwarded = ctx.get::<Forwarded>().unwrap();
227                assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
228                assert_eq!(forwarded.client_proto(), Some(ForwardedProtocol::HTTP));
229                Ok::<_, Infallible>(())
230            }));
231
232        let req = Request::builder()
233            .header("Forwarded", "for=\"12.23.34.45:5000\";proto=http")
234            .body(())
235            .unwrap();
236
237        service.serve(Context::default(), req).await.unwrap();
238    }
239}