rama_http/layer/forwarded/
set_forwarded.rs

1use crate::headers::{
2    ForwardHeader, HeaderMapExt, Via, XForwardedFor, XForwardedHost, XForwardedProto,
3};
4use crate::Request;
5use rama_core::error::BoxError;
6use rama_core::{Context, Layer, Service};
7use rama_net::address::Domain;
8use rama_net::forwarded::{Forwarded, ForwardedElement, NodeId};
9use rama_net::http::RequestContext;
10use rama_net::stream::SocketInfo;
11use rama_utils::macros::all_the_tuples_no_last_special_case;
12use std::fmt;
13use std::marker::PhantomData;
14
15/// Layer to write [`Forwarded`] information for this proxy,
16/// added to the end of the chain of forwarded information already known.
17///
18/// This layer can set any header as long as you have a [`ForwardHeader`] implementation
19/// for the header you want to set. You can pass it as the type to the layer when creating
20/// the layer using [`SetForwardedHeadersLayer::new`]. Multiple headers (in order) can also be set
21/// by specifying multiple types as a tuple.
22///
23/// The following headers are supported out of the box with each their own constructor:
24///
25/// - [`SetForwardedHeadersLayer::forwarded`]: the standard [`Forwarded`] header [`RFC 7239`](https://tools.ietf.org/html/rfc7239);
26/// - [`SetForwardedHeadersLayer::via`]: the canonical [`Via`] header (non-standard);
27/// - [`SetForwardedHeadersLayer::x_forwarded_for`]: the canonical [`X-Forwarded-For`][`XForwardedFor`] header (non-standard);
28/// - [`SetForwardedHeadersLayer::x_forwarded_host`]: the canonical [`X-Forwarded-Host`][`XForwardedHost`] header (non-standard);
29/// - [`SetForwardedHeadersLayer::x_forwarded_proto`]: the canonical [`X-Forwarded-Proto`][`XForwardedProto`] header (non-standard).
30///
31/// The "by" property is set to `rama` by default. Use [`SetForwardedHeadersLayer::forward_by`] to overwrite this,
32/// typically with the actual [`IPv4`]/[`IPv6`] address of your proxy.
33///
34/// [`IPv4`]: std::net::Ipv4Addr
35/// [`IPv6`]: std::net::Ipv6Addr
36///
37/// Rama also has the following headers already implemented for you to use:
38///
39/// > [`X-Real-Ip`], [`X-Client-Ip`], [`Client-Ip`], [`CF-Connecting-Ip`] and [`True-Client-Ip`].
40///
41/// There are no [`SetForwardedHeadersLayer`] constructors for these headers,
42/// but you can use the [`SetForwardedHeadersLayer::new`] constructor and pass the header type as a type parameter,
43/// alone or in a tuple with other headers.
44///
45/// [`X-Real-Ip`]: crate::headers::XRealIp
46/// [`X-Client-Ip`]: crate::headers::XClientIp
47/// [`Client-Ip`]: crate::headers::ClientIp
48/// [`CF-Connecting-Ip`]: crate::headers::CFConnectingIp
49/// [`True-Client-Ip`]: crate::headers::TrueClientIp
50///
51/// ## Example
52///
53/// This example shows how you could expose the real Client IP using the [`X-Real-IP`][`crate::headers::XRealIp`] header.
54///
55/// ```rust
56/// use rama_net::stream::SocketInfo;
57/// use rama_http::Request;
58/// use rama_core::service::service_fn;
59/// use rama_http::{headers::XRealIp, layer::forwarded::SetForwardedHeadersLayer};
60/// use rama_core::{Context, Service, Layer};
61/// use std::convert::Infallible;
62///
63/// # type Body = ();
64/// # type State = ();
65///
66/// # #[tokio::main]
67/// # async fn main() {
68/// async fn svc(_ctx: Context<State>, request: Request<Body>) -> Result<(), Infallible> {
69///     // ...
70///     # assert_eq!(
71///     #     request.headers().get("X-Real-Ip").unwrap(),
72///     #     "42.37.100.50:62345",
73///     # );
74///     # Ok(())
75/// }
76///
77/// let service = SetForwardedHeadersLayer::<XRealIp>::new()
78///     .layer(service_fn(svc));
79///
80/// # let req = Request::builder().uri("example.com").body(()).unwrap();
81/// # let mut ctx = Context::default();
82/// # ctx.insert(SocketInfo::new(None, "42.37.100.50:62345".parse().unwrap()));
83/// service.serve(ctx, req).await.unwrap();
84/// # }
85/// ```
86pub struct SetForwardedHeadersLayer<T = Forwarded> {
87    by_node: NodeId,
88    _headers: PhantomData<fn() -> T>,
89}
90
91impl<T: fmt::Debug> fmt::Debug for SetForwardedHeadersLayer<T> {
92    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93        f.debug_struct("SetForwardedHeadersLayer")
94            .field("by_node", &self.by_node)
95            .field(
96                "_headers",
97                &format_args!("{}", std::any::type_name::<fn() -> T>()),
98            )
99            .finish()
100    }
101}
102
103impl<T: Clone> Clone for SetForwardedHeadersLayer<T> {
104    fn clone(&self) -> Self {
105        Self {
106            by_node: self.by_node.clone(),
107            _headers: PhantomData,
108        }
109    }
110}
111
112impl<T> SetForwardedHeadersLayer<T> {
113    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
114    ///
115    /// Default of `None` will be set to `rama` otherwise.
116    pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
117        self.by_node = node_id.into();
118        self
119    }
120
121    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
122    ///
123    /// Default of `None` will be set to `rama` otherwise.
124    pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
125        self.by_node = node_id.into();
126        self
127    }
128}
129
130impl<T> SetForwardedHeadersLayer<T> {
131    /// Create a new `SetForwardedHeadersLayer` for the specified headers `T`.
132    pub fn new() -> Self {
133        Self {
134            by_node: Domain::from_static("rama").into(),
135            _headers: PhantomData,
136        }
137    }
138}
139
140impl Default for SetForwardedHeadersLayer {
141    fn default() -> Self {
142        Self::forwarded()
143    }
144}
145
146impl SetForwardedHeadersLayer {
147    #[inline]
148    /// Create a new `SetForwardedHeadersLayer` for the standard [`Forwarded`] header.
149    pub fn forwarded() -> Self {
150        Self::new()
151    }
152}
153
154impl SetForwardedHeadersLayer<Via> {
155    #[inline]
156    /// Create a new `SetForwardedHeadersLayer` for the canonical [`Via`] header.
157    pub fn via() -> Self {
158        Self::new()
159    }
160}
161
162impl SetForwardedHeadersLayer<XForwardedFor> {
163    #[inline]
164    /// Create a new `SetForwardedHeadersLayer` for the canonical [`X-Forwarded-For`] header.
165    pub fn x_forwarded_for() -> Self {
166        Self::new()
167    }
168}
169
170impl SetForwardedHeadersLayer<XForwardedHost> {
171    #[inline]
172    /// Create a new `SetForwardedHeadersLayer` for the canonical [`X-Forwarded-Host`] header.
173    pub fn x_forwarded_host() -> Self {
174        Self::new()
175    }
176}
177
178impl SetForwardedHeadersLayer<XForwardedProto> {
179    #[inline]
180    /// Create a new `SetForwardedHeadersLayer` for the canonical [`X-Forwarded-Proto`] header.
181    pub fn x_forwarded_proto() -> Self {
182        Self::new()
183    }
184}
185
186impl<H, S> Layer<S> for SetForwardedHeadersLayer<H> {
187    type Service = SetForwardedHeadersService<S, H>;
188
189    fn layer(&self, inner: S) -> Self::Service {
190        Self::Service {
191            inner,
192            by_node: self.by_node.clone(),
193            _headers: PhantomData,
194        }
195    }
196}
197
198/// Middleware [`Service`] to write [`Forwarded`] information for this proxy,
199/// added to the end of the chain of forwarded information already known.
200///
201/// See [`SetForwardedHeadersLayer`] for more information.
202pub struct SetForwardedHeadersService<S, T = Forwarded> {
203    inner: S,
204    by_node: NodeId,
205    _headers: PhantomData<fn() -> T>,
206}
207
208impl<S: fmt::Debug, T> fmt::Debug for SetForwardedHeadersService<S, T> {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        f.debug_struct("SetForwardedHeadersService")
211            .field("inner", &self.inner)
212            .field("by_node", &self.by_node)
213            .field(
214                "_headers",
215                &format_args!("{}", std::any::type_name::<fn() -> T>()),
216            )
217            .finish()
218    }
219}
220
221impl<S: Clone, T> Clone for SetForwardedHeadersService<S, T> {
222    fn clone(&self) -> Self {
223        SetForwardedHeadersService {
224            inner: self.inner.clone(),
225            by_node: self.by_node.clone(),
226            _headers: PhantomData,
227        }
228    }
229}
230
231impl<S, T> SetForwardedHeadersService<S, T> {
232    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
233    ///
234    /// Default of `None` will be set to `rama` otherwise.
235    pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
236        self.by_node = node_id.into();
237        self
238    }
239
240    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
241    ///
242    /// Default of `None` will be set to `rama` otherwise.
243    pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
244        self.by_node = node_id.into();
245        self
246    }
247}
248
249impl<S, T> SetForwardedHeadersService<S, T> {
250    /// Create a new `SetForwardedHeadersService` for the specified headers `T`.
251    pub fn new(inner: S) -> Self {
252        Self {
253            inner,
254            by_node: Domain::from_static("rama").into(),
255            _headers: PhantomData,
256        }
257    }
258}
259
260impl<S> SetForwardedHeadersService<S> {
261    #[inline]
262    /// Create a new `SetForwardedHeadersService` for the standard [`Forwarded`] header.
263    pub fn forwarded(inner: S) -> Self {
264        Self::new(inner)
265    }
266}
267
268impl<S> SetForwardedHeadersService<S, Via> {
269    #[inline]
270    /// Create a new `SetForwardedHeadersService` for the canonical [`Via`] header.
271    pub fn via(inner: S) -> Self {
272        Self::new(inner)
273    }
274}
275
276impl<S> SetForwardedHeadersService<S, XForwardedFor> {
277    #[inline]
278    /// Create a new `SetForwardedHeadersService` for the canonical [`X-Forwarded-For`] header.
279    pub fn x_forwarded_for(inner: S) -> Self {
280        Self::new(inner)
281    }
282}
283
284impl<S> SetForwardedHeadersService<S, XForwardedHost> {
285    #[inline]
286    /// Create a new `SetForwardedHeadersService` for the canonical [`X-Forwarded-Host`] header.
287    pub fn x_forwarded_host(inner: S) -> Self {
288        Self::new(inner)
289    }
290}
291
292impl<S> SetForwardedHeadersService<S, XForwardedProto> {
293    #[inline]
294    /// Create a new `SetForwardedHeadersService` for the canonical [`X-Forwarded-Proto`] header.
295    pub fn x_forwarded_proto(inner: S) -> Self {
296        Self::new(inner)
297    }
298}
299
300impl<S, H, State, Body> Service<State, Request<Body>> for SetForwardedHeadersService<S, H>
301where
302    S: Service<State, Request<Body>, Error: Into<BoxError>>,
303    H: ForwardHeader + Send + Sync + 'static,
304    Body: Send + 'static,
305    State: Clone + Send + Sync + 'static,
306{
307    type Response = S::Response;
308    type Error = BoxError;
309
310    async fn serve(
311        &self,
312        mut ctx: Context<State>,
313        mut req: Request<Body>,
314    ) -> Result<Self::Response, Self::Error> {
315        let forwarded: Option<Forwarded> = ctx.get().cloned();
316
317        let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
318
319        if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
320            forwarded_element.set_forwarded_for(peer_addr);
321        }
322        let request_ctx: &mut RequestContext =
323            ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
324
325        forwarded_element.set_forwarded_host(request_ctx.authority.clone());
326
327        if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
328            forwarded_element.set_forwarded_proto(forwarded_proto);
329        }
330
331        let forwarded = match forwarded {
332            None => Some(Forwarded::new(forwarded_element)),
333            Some(mut forwarded) => {
334                forwarded.append(forwarded_element);
335                Some(forwarded)
336            }
337        };
338
339        if let Some(forwarded) = forwarded {
340            if let Some(header) = H::try_from_forwarded(forwarded.iter()) {
341                req.headers_mut().typed_insert(header);
342            }
343        }
344
345        self.inner.serve(ctx, req).await.map_err(Into::into)
346    }
347}
348
349macro_rules! set_forwarded_service_for_tuple {
350    ( $($ty:ident),* $(,)? ) => {
351        #[allow(non_snake_case)]
352        impl<S, $($ty),* , State, Body> Service<State, Request<Body>> for SetForwardedHeadersService<S, ($($ty,)*)>
353        where
354            $( $ty: ForwardHeader + Send + Sync + 'static, )*
355            S: Service<State, Request<Body>, Error: Into<BoxError>>,
356            Body: Send + 'static,
357            State: Clone + Send + Sync + 'static,
358        {
359            type Response = S::Response;
360            type Error = BoxError;
361
362            async fn serve(
363                &self,
364                mut ctx: Context<State>,
365                mut req: Request<Body>,
366            ) -> Result<Self::Response, Self::Error> {
367                let forwarded: Option<Forwarded> = ctx.get().cloned();
368
369                let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
370
371                if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
372                    forwarded_element.set_forwarded_for(peer_addr);
373                }
374
375                let request_ctx: &mut RequestContext =
376                    ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
377
378                forwarded_element.set_forwarded_host(request_ctx.authority.clone());
379
380                if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
381                    forwarded_element.set_forwarded_proto(forwarded_proto);
382                }
383
384                let forwarded = match forwarded {
385                    None => Some(Forwarded::new(forwarded_element)),
386                    Some(mut forwarded) => {
387                        forwarded.append(forwarded_element);
388                        Some(forwarded)
389                    }
390                };
391
392                if let Some(forwarded) = forwarded {
393                    $(
394                        if let Some(header) = $ty::try_from_forwarded(forwarded.iter()) {
395                            req.headers_mut().typed_insert(header);
396                        }
397                    )*
398                }
399
400                self.inner.serve(ctx, req).await.map_err(Into::into)
401            }
402        }
403    };
404}
405all_the_tuples_no_last_special_case!(set_forwarded_service_for_tuple);
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use crate::{
411        headers::{TrueClientIp, XClientIp, XRealIp},
412        IntoResponse, Response, StatusCode,
413    };
414    use rama_core::{error::OpaqueError, service::service_fn, Layer};
415    use std::{convert::Infallible, net::IpAddr};
416
417    fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
418
419    async fn dummy_service_fn() -> Result<Response, OpaqueError> {
420        Ok(StatusCode::OK.into_response())
421    }
422
423    #[test]
424    fn test_set_forwarded_service_is_service() {
425        assert_is_service(SetForwardedHeadersService::forwarded(service_fn(
426            dummy_service_fn,
427        )));
428        assert_is_service(SetForwardedHeadersService::via(service_fn(
429            dummy_service_fn,
430        )));
431        assert_is_service(SetForwardedHeadersService::x_forwarded_for(service_fn(
432            dummy_service_fn,
433        )));
434        assert_is_service(SetForwardedHeadersService::x_forwarded_proto(service_fn(
435            dummy_service_fn,
436        )));
437        assert_is_service(SetForwardedHeadersService::x_forwarded_host(service_fn(
438            dummy_service_fn,
439        )));
440        assert_is_service(SetForwardedHeadersService::<_, TrueClientIp>::new(
441            service_fn(dummy_service_fn),
442        ));
443        assert_is_service(SetForwardedHeadersService::<_, (TrueClientIp,)>::new(
444            service_fn(dummy_service_fn),
445        ));
446        assert_is_service(
447            SetForwardedHeadersService::<_, (TrueClientIp, XClientIp)>::new(service_fn(
448                dummy_service_fn,
449            )),
450        );
451        assert_is_service(SetForwardedHeadersLayer::via().layer(service_fn(dummy_service_fn)));
452        assert_is_service(
453            SetForwardedHeadersLayer::<XRealIp>::new().layer(service_fn(dummy_service_fn)),
454        );
455        assert_is_service(
456            SetForwardedHeadersLayer::<(XRealIp, XForwardedProto)>::new()
457                .layer(service_fn(dummy_service_fn)),
458        );
459    }
460
461    #[tokio::test]
462    async fn test_set_forwarded_service_forwarded() {
463        async fn svc(request: Request<()>) -> Result<(), Infallible> {
464            assert_eq!(
465                request.headers().get("Forwarded").unwrap(),
466                "by=rama;host=\"example.com:80\";proto=http"
467            );
468            Ok(())
469        }
470
471        let service = SetForwardedHeadersService::forwarded(service_fn(svc));
472        let req = Request::builder().uri("example.com").body(()).unwrap();
473        service.serve(Context::default(), req).await.unwrap();
474    }
475
476    #[tokio::test]
477    async fn test_set_forwarded_service_forwarded_with_chain() {
478        async fn svc(request: Request<()>) -> Result<(), Infallible> {
479            assert_eq!(
480                request.headers().get("Forwarded").unwrap(),
481                "for=12.23.34.45,by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
482            );
483            Ok(())
484        }
485
486        let service = SetForwardedHeadersService::forwarded(service_fn(svc));
487        let req = Request::builder()
488            .uri("https://www.example.com")
489            .body(())
490            .unwrap();
491        let mut ctx = Context::default();
492        ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
493            IpAddr::from([12, 23, 34, 45]),
494        )));
495        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
496        service.serve(ctx, req).await.unwrap();
497    }
498
499    #[tokio::test]
500    async fn test_set_forwarded_service_x_forwarded_for_with_chain() {
501        async fn svc(request: Request<()>) -> Result<(), Infallible> {
502            assert_eq!(
503                request.headers().get("X-Forwarded-For").unwrap(),
504                "12.23.34.45, 127.0.0.1",
505            );
506            Ok(())
507        }
508
509        let service = SetForwardedHeadersService::x_forwarded_for(service_fn(svc));
510        let req = Request::builder()
511            .uri("https://www.example.com")
512            .body(())
513            .unwrap();
514        let mut ctx = Context::default();
515        ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
516            IpAddr::from([12, 23, 34, 45]),
517        )));
518        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
519        service.serve(ctx, req).await.unwrap();
520    }
521
522    #[tokio::test]
523    async fn test_set_forwarded_service_forwarded_fully_defined() {
524        async fn svc(request: Request<()>) -> Result<(), Infallible> {
525            assert_eq!(
526                request.headers().get("Forwarded").unwrap(),
527                "by=12.23.34.45;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
528            );
529            Ok(())
530        }
531
532        let service = SetForwardedHeadersService::forwarded(service_fn(svc))
533            .forward_by(IpAddr::from([12, 23, 34, 45]));
534        let req = Request::builder()
535            .uri("https://www.example.com")
536            .body(())
537            .unwrap();
538        let mut ctx = Context::default();
539        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
540        service.serve(ctx, req).await.unwrap();
541    }
542
543    #[tokio::test]
544    async fn test_set_forwarded_service_forwarded_fully_defined_with_chain() {
545        async fn svc(request: Request<()>) -> Result<(), Infallible> {
546            assert_eq!(
547                request.headers().get("Forwarded").unwrap(),
548                "by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
549            );
550            Ok(())
551        }
552
553        let service = SetForwardedHeadersService::forwarded(service_fn(svc));
554        let req = Request::builder()
555            .uri("https://www.example.com")
556            .body(())
557            .unwrap();
558        let mut ctx = Context::default();
559        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
560        service.serve(ctx, req).await.unwrap();
561    }
562}