rama_http/layer/forwarded/
get_forwarded.rs

1use crate::Request;
2use crate::headers::{
3    ForwardHeader, HeaderMapExt, Via, XForwardedFor, XForwardedHost, XForwardedProto,
4};
5use rama_core::{Context, Layer, Service};
6use rama_net::forwarded::Forwarded;
7use rama_net::forwarded::ForwardedElement;
8use rama_utils::macros::all_the_tuples_no_last_special_case;
9use std::fmt;
10use std::marker::PhantomData;
11
12/// Layer to extract [`Forwarded`] information from the specified `T` headers.
13///
14/// This layer can be used to extract the [`Forwarded`] information from any specified header `T`,
15/// as long as the header implements the [`ForwardHeader`] trait. Multiple headers can be specified
16/// as a tuple, and the layer will extract information from them all, and combine the information.
17///
18/// Please take into consideration the following when combining headers:
19///
20/// - The last header in the tuple will take precedence over the previous headers,
21///   if the same information is present in multiple headers.
22/// - Headers that can contain multiple elements, (e.g. X-Forwarded-For, Via)
23///   will combine their elements in the order as specified. That does however mean that in
24///   case one header has less elements then the other, that the combination down the line
25///   will not be accurate.
26///
27/// The following headers are supported by default:
28///
29/// - [`GetForwardedHeadersLayer::forwarded`]: The standard [`Forwarded`] header [`RFC 7239`](https://tools.ietf.org/html/rfc7239).
30/// - [`GetForwardedHeadersLayer::via`]: The canonical [`Via`] header [`RFC 7230`](https://tools.ietf.org/html/rfc7230#section-5.7.1).
31/// - [`GetForwardedHeadersLayer::x_forwarded_for`]: The canonical [`X-Forwarded-For`] header [`RFC 7239`](https://tools.ietf.org/html/rfc7239#section-5.2).
32/// - [`GetForwardedHeadersLayer::x_forwarded_host`]: The canonical [`X-Forwarded-Host`] header [`RFC 7239`](https://tools.ietf.org/html/rfc7239#section-5.4).
33/// - [`GetForwardedHeadersLayer::x_forwarded_proto`]: The canonical [`X-Forwarded-Proto`] header [`RFC 7239`](https://tools.ietf.org/html/rfc7239#section-5.3).
34///
35/// Rama also has the following headers already implemented for you to use:
36///
37/// > [`X-Real-Ip`], [`X-Client-Ip`], [`Client-Ip`], [`Cf-Connecting-Ip`] and [`True-Client-Ip`].
38///
39/// There are no [`GetForwardedHeadersLayer`] constructors for these headers,
40/// but you can use the [`GetForwardedHeadersLayer::new`] constructor and pass the header type as a type parameter,
41/// alone or in a tuple with other headers.
42///
43/// [`X-Real-Ip`]: crate::headers::XRealIp
44/// [`X-Client-Ip`]: crate::headers::XClientIp
45/// [`Client-Ip`]: crate::headers::ClientIp
46/// [`CF-Connecting-Ip`]: crate::headers::CFConnectingIp
47/// [`True-Client-Ip`]: crate::headers::TrueClientIp
48///
49/// ## Example
50///
51/// This example shows you can extract the client IP from the `X-Forwarded-For`
52/// header in case your application is behind a proxy which sets this header.
53///
54/// ```rust
55/// use rama_core::{
56///     service::service_fn,
57///     Context, Service, Layer,
58/// };
59/// use rama_http::{headers::Forwarded, layer::forwarded::GetForwardedHeadersLayer, Request};
60/// use std::{convert::Infallible, net::IpAddr};
61///
62/// #[tokio::main]
63/// async fn main() {
64///     let service = GetForwardedHeadersLayer::x_forwarded_for()
65///         .into_layer(service_fn(async |ctx: Context<()>, _| {
66///             let forwarded = ctx.get::<Forwarded>().unwrap();
67///             assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
68///             assert!(forwarded.client_proto().is_none());
69///
70///             // ...
71///
72///             Ok::<_, Infallible>(())
73///         }));
74///
75///     let req = Request::builder()
76///         .header("X-Forwarded-For", "12.23.34.45")
77///         .body(())
78///         .unwrap();
79///
80///     service.serve(Context::default(), req).await.unwrap();
81/// }
82/// ```
83pub struct GetForwardedHeadersLayer<T = Forwarded> {
84    _headers: PhantomData<fn() -> T>,
85}
86
87impl<T: fmt::Debug> fmt::Debug for GetForwardedHeadersLayer<T> {
88    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
89        f.debug_struct("GetForwardedHeadersLayer")
90            .field(
91                "_headers",
92                &format_args!("{}", std::any::type_name::<fn() -> T>()),
93            )
94            .finish()
95    }
96}
97
98impl<T: Clone> Clone for GetForwardedHeadersLayer<T> {
99    fn clone(&self) -> Self {
100        Self {
101            _headers: PhantomData,
102        }
103    }
104}
105
106impl Default for GetForwardedHeadersLayer {
107    fn default() -> Self {
108        Self::forwarded()
109    }
110}
111
112impl<T> GetForwardedHeadersLayer<T> {
113    /// Create a new `GetForwardedHeadersLayer` for the specified headers `T`.
114    pub const fn new() -> Self {
115        Self {
116            _headers: PhantomData,
117        }
118    }
119}
120
121impl GetForwardedHeadersLayer {
122    #[inline]
123    /// Create a new `GetForwardedHeadersLayer` for the standard [`Forwarded`] header.
124    pub fn forwarded() -> Self {
125        Self::new()
126    }
127}
128
129impl GetForwardedHeadersLayer<Via> {
130    #[inline]
131    /// Create a new `GetForwardedHeadersLayer` for the canonical [`Via`] header.
132    pub fn via() -> Self {
133        Self::new()
134    }
135}
136
137impl GetForwardedHeadersLayer<XForwardedFor> {
138    #[inline]
139    /// Create a new `GetForwardedHeadersLayer` for the canonical [`X-Forwarded-For`] header.
140    pub fn x_forwarded_for() -> Self {
141        Self::new()
142    }
143}
144
145impl GetForwardedHeadersLayer<XForwardedHost> {
146    #[inline]
147    /// Create a new `GetForwardedHeadersLayer` for the canonical [`X-Forwarded-Host`] header.
148    pub fn x_forwarded_host() -> Self {
149        Self::new()
150    }
151}
152
153impl GetForwardedHeadersLayer<XForwardedProto> {
154    #[inline]
155    /// Create a new `GetForwardedHeadersLayer` for the canonical [`X-Forwarded-Proto`] header.
156    pub fn x_forwarded_proto() -> Self {
157        Self::new()
158    }
159}
160
161impl<H, S> Layer<S> for GetForwardedHeadersLayer<H> {
162    type Service = GetForwardedHeadersService<S, H>;
163
164    fn layer(&self, inner: S) -> Self::Service {
165        Self::Service {
166            inner,
167            _headers: PhantomData,
168        }
169    }
170}
171
172/// Middleware service to extract [`Forwarded`] information from the specified `T` headers.
173///
174/// See [`GetForwardedHeadersLayer`] for more information.
175pub struct GetForwardedHeadersService<S, T = Forwarded> {
176    inner: S,
177    _headers: PhantomData<fn() -> T>,
178}
179
180impl<S: fmt::Debug, T> fmt::Debug for GetForwardedHeadersService<S, T> {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        f.debug_struct("GetForwardedHeadersService")
183            .field("inner", &self.inner)
184            .field("_headers", &format_args!("{}", std::any::type_name::<T>()))
185            .finish()
186    }
187}
188
189impl<S: Clone, T> Clone for GetForwardedHeadersService<S, T> {
190    fn clone(&self) -> Self {
191        GetForwardedHeadersService {
192            inner: self.inner.clone(),
193            _headers: PhantomData,
194        }
195    }
196}
197
198impl<S, T> GetForwardedHeadersService<S, T> {
199    /// Create a new `GetForwardedHeadersService` for the specified headers `T`.
200    pub const fn new(inner: S) -> Self {
201        Self {
202            inner,
203            _headers: PhantomData,
204        }
205    }
206}
207
208impl<S> GetForwardedHeadersService<S> {
209    #[inline]
210    /// Create a new `GetForwardedHeadersService` for the standard [`Forwarded`] header.
211    pub fn forwarded(inner: S) -> Self {
212        Self::new(inner)
213    }
214}
215
216impl<S> GetForwardedHeadersService<S, Via> {
217    #[inline]
218    /// Create a new `GetForwardedHeadersService` for the canonical [`Via`] header.
219    pub fn via(inner: S) -> Self {
220        Self::new(inner)
221    }
222}
223
224impl<S> GetForwardedHeadersService<S, XForwardedFor> {
225    #[inline]
226    /// Create a new `GetForwardedHeadersService` for the canonical [`X-Forwarded-For`] header.
227    pub fn x_forwarded_for(inner: S) -> Self {
228        Self::new(inner)
229    }
230}
231
232impl<S> GetForwardedHeadersService<S, XForwardedHost> {
233    #[inline]
234    /// Create a new `GetForwardedHeadersService` for the canonical [`X-Forwarded-Host`] header.
235    pub fn x_forwarded_host(inner: S) -> Self {
236        Self::new(inner)
237    }
238}
239
240impl<S> GetForwardedHeadersService<S, XForwardedProto> {
241    #[inline]
242    /// Create a new `GetForwardedHeadersService` for the canonical [`X-Forwarded-Proto`] header.
243    pub fn x_forwarded_proto(inner: S) -> Self {
244        Self::new(inner)
245    }
246}
247
248macro_rules! get_forwarded_service_for_tuple {
249    ( $($ty:ident),* $(,)? ) => {
250        #[allow(non_snake_case)]
251        impl<$($ty,)* S, State, Body> Service<State, Request<Body>> for GetForwardedHeadersService<S, ($($ty,)*)>
252        where
253            $( $ty: ForwardHeader + Send + Sync + 'static, )*
254            S: Service<State, Request<Body>>,
255            Body: Send + 'static,
256            State: Clone + Send + Sync + 'static,
257        {
258            type Response = S::Response;
259            type Error = S::Error;
260
261            fn serve(
262                &self,
263                mut ctx: Context<State>,
264                req: Request<Body>,
265            ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
266                let mut forwarded_elements: Vec<ForwardedElement> = Vec::with_capacity(1);
267
268                $(
269                    if let Some($ty) = req.headers().typed_get::<$ty>() {
270                        let mut iter = $ty.into_iter();
271                        for element in forwarded_elements.iter_mut() {
272                            let other = iter.next();
273                            match other {
274                                Some(other) => {
275                                    element.merge(other);
276                                }
277                                None => break,
278                            }
279                        }
280                        for other in iter {
281                            forwarded_elements.push(other);
282                        }
283                    }
284                )*
285
286                if !forwarded_elements.is_empty() {
287                    match ctx.get_mut::<Forwarded>() {
288                        Some(ref mut f) => {
289                            f.extend(forwarded_elements);
290                        }
291                        None => {
292                            let mut it = forwarded_elements.into_iter();
293                            let mut forwarded = Forwarded::new(it.next().unwrap());
294                            forwarded.extend(it);
295                            ctx.insert(forwarded);
296                        }
297                    }
298                }
299
300                self.inner.serve(ctx, req)
301            }
302        }
303    }
304}
305
306impl<H, S, State, Body> Service<State, Request<Body>> for GetForwardedHeadersService<S, H>
307where
308    H: ForwardHeader + Send + Sync + 'static,
309    S: Service<State, Request<Body>>,
310    Body: Send + 'static,
311    State: Clone + Send + Sync + 'static,
312{
313    type Response = S::Response;
314    type Error = S::Error;
315
316    fn serve(
317        &self,
318        mut ctx: Context<State>,
319        req: Request<Body>,
320    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
321        let mut forwarded_elements: Vec<ForwardedElement> = Vec::with_capacity(1);
322
323        if let Some(header) = req.headers().typed_get::<H>() {
324            forwarded_elements.extend(header);
325        }
326
327        if !forwarded_elements.is_empty() {
328            match ctx.get_mut::<Forwarded>() {
329                Some(ref mut f) => {
330                    f.extend(forwarded_elements);
331                }
332                None => {
333                    let mut it = forwarded_elements.into_iter();
334                    let mut forwarded = Forwarded::new(it.next().unwrap());
335                    forwarded.extend(it);
336                    ctx.insert(forwarded);
337                }
338            }
339        }
340
341        self.inner.serve(ctx, req)
342    }
343}
344
345all_the_tuples_no_last_special_case!(get_forwarded_service_for_tuple);
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use crate::{
351        IntoResponse, Response, StatusCode,
352        headers::{ClientIp, TrueClientIp, XClientIp, XRealIp},
353    };
354    use rama_core::{Layer, error::OpaqueError, service::service_fn};
355    use rama_net::forwarded::{ForwardedProtocol, ForwardedVersion};
356    use std::{convert::Infallible, net::IpAddr};
357
358    fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
359
360    async fn dummy_service_fn() -> Result<Response, OpaqueError> {
361        Ok(StatusCode::OK.into_response())
362    }
363
364    #[test]
365    fn test_get_forwarded_service_is_service() {
366        assert_is_service(GetForwardedHeadersService::forwarded(service_fn(
367            dummy_service_fn,
368        )));
369        assert_is_service(GetForwardedHeadersService::via(service_fn(
370            dummy_service_fn,
371        )));
372        assert_is_service(GetForwardedHeadersService::x_forwarded_for(service_fn(
373            dummy_service_fn,
374        )));
375        assert_is_service(GetForwardedHeadersService::x_forwarded_proto(service_fn(
376            dummy_service_fn,
377        )));
378        assert_is_service(GetForwardedHeadersService::x_forwarded_host(service_fn(
379            dummy_service_fn,
380        )));
381        assert_is_service(GetForwardedHeadersService::<_, TrueClientIp>::new(
382            service_fn(dummy_service_fn),
383        ));
384        assert_is_service(GetForwardedHeadersService::<_, (TrueClientIp,)>::new(
385            service_fn(dummy_service_fn),
386        ));
387        assert_is_service(
388            GetForwardedHeadersService::<_, (TrueClientIp, XClientIp)>::new(service_fn(
389                dummy_service_fn,
390            )),
391        );
392        assert_is_service(
393            GetForwardedHeadersLayer::forwarded().into_layer(service_fn(dummy_service_fn)),
394        );
395        assert_is_service(GetForwardedHeadersLayer::via().into_layer(service_fn(dummy_service_fn)));
396        assert_is_service(
397            GetForwardedHeadersLayer::<XRealIp>::new().into_layer(service_fn(dummy_service_fn)),
398        );
399        assert_is_service(
400            GetForwardedHeadersLayer::<(ClientIp, TrueClientIp)>::new()
401                .into_layer(service_fn(dummy_service_fn)),
402        );
403    }
404
405    #[tokio::test]
406    async fn test_get_forwarded_header_forwarded() {
407        let service = GetForwardedHeadersLayer::forwarded().into_layer(service_fn(
408            async |ctx: Context<()>, _| {
409                let forwarded = ctx.get::<Forwarded>().unwrap();
410                assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
411                assert_eq!(forwarded.client_proto(), Some(ForwardedProtocol::HTTP));
412                Ok::<_, Infallible>(())
413            },
414        ));
415
416        let req = Request::builder()
417            .header("Forwarded", "for=\"12.23.34.45:5000\";proto=http")
418            .body(())
419            .unwrap();
420
421        service.serve(Context::default(), req).await.unwrap();
422    }
423
424    #[tokio::test]
425    async fn test_get_forwarded_header_via() {
426        let service =
427            GetForwardedHeadersLayer::via().into_layer(service_fn(async |ctx: Context<()>, _| {
428                let forwarded = ctx.get::<Forwarded>().unwrap();
429                assert!(forwarded.client_ip().is_none());
430                assert_eq!(
431                    forwarded.iter().next().unwrap().ref_forwarded_by(),
432                    Some(&(IpAddr::from([12, 23, 34, 45]), 5000).into())
433                );
434                assert!(forwarded.client_proto().is_none());
435                assert_eq!(forwarded.client_version(), Some(ForwardedVersion::HTTP_11));
436                Ok::<_, Infallible>(())
437            }));
438
439        let req = Request::builder()
440            .header("Via", "1.1 12.23.34.45:5000")
441            .body(())
442            .unwrap();
443
444        service.serve(Context::default(), req).await.unwrap();
445    }
446
447    #[tokio::test]
448    async fn test_get_forwarded_header_x_forwarded_for() {
449        let service = GetForwardedHeadersLayer::x_forwarded_for().into_layer(service_fn(
450            async |ctx: Context<()>, _| {
451                let forwarded = ctx.get::<Forwarded>().unwrap();
452                assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
453                assert!(forwarded.client_proto().is_none());
454                Ok::<_, Infallible>(())
455            },
456        ));
457
458        let req = Request::builder()
459            .header("X-Forwarded-For", "12.23.34.45, 127.0.0.1")
460            .body(())
461            .unwrap();
462
463        service.serve(Context::default(), req).await.unwrap();
464    }
465}