rama_http/layer/forwarded/
set_forwarded.rs

1use crate::Request;
2use crate::headers::{
3    ForwardHeader, HeaderMapExt, Via, XForwardedFor, XForwardedHost, XForwardedProto,
4};
5use rama_core::error::BoxError;
6use rama_core::{Context, Layer, Service};
7use rama_net::address::Domain;
8use rama_net::forwarded::{Forwarded, ForwardedElement, NodeId};
9use rama_net::http::RequestContext;
10use rama_net::stream::SocketInfo;
11use rama_utils::macros::all_the_tuples_no_last_special_case;
12use std::fmt;
13use std::marker::PhantomData;
14
15/// Layer to write [`Forwarded`] information for this proxy,
16/// added to the end of the chain of forwarded information already known.
17///
18/// This layer can set any header as long as you have a [`ForwardHeader`] implementation
19/// for the header you want to set. You can pass it as the type to the layer when creating
20/// the layer using [`SetForwardedHeadersLayer::new`]. Multiple headers (in order) can also be set
21/// by specifying multiple types as a tuple.
22///
23/// The following headers are supported out of the box with each their own constructor:
24///
25/// - [`SetForwardedHeadersLayer::forwarded`]: the standard [`Forwarded`] header [`RFC 7239`](https://tools.ietf.org/html/rfc7239);
26/// - [`SetForwardedHeadersLayer::via`]: the canonical [`Via`] header (non-standard);
27/// - [`SetForwardedHeadersLayer::x_forwarded_for`]: the canonical [`X-Forwarded-For`][`XForwardedFor`] header (non-standard);
28/// - [`SetForwardedHeadersLayer::x_forwarded_host`]: the canonical [`X-Forwarded-Host`][`XForwardedHost`] header (non-standard);
29/// - [`SetForwardedHeadersLayer::x_forwarded_proto`]: the canonical [`X-Forwarded-Proto`][`XForwardedProto`] header (non-standard).
30///
31/// The "by" property is set to `rama` by default. Use [`SetForwardedHeadersLayer::forward_by`] to overwrite this,
32/// typically with the actual [`IPv4`]/[`IPv6`] address of your proxy.
33///
34/// [`IPv4`]: std::net::Ipv4Addr
35/// [`IPv6`]: std::net::Ipv6Addr
36///
37/// Rama also has the following headers already implemented for you to use:
38///
39/// > [`X-Real-Ip`], [`X-Client-Ip`], [`Client-Ip`], [`CF-Connecting-Ip`] and [`True-Client-Ip`].
40///
41/// There are no [`SetForwardedHeadersLayer`] constructors for these headers,
42/// but you can use the [`SetForwardedHeadersLayer::new`] constructor and pass the header type as a type parameter,
43/// alone or in a tuple with other headers.
44///
45/// [`X-Real-Ip`]: crate::headers::XRealIp
46/// [`X-Client-Ip`]: crate::headers::XClientIp
47/// [`Client-Ip`]: crate::headers::ClientIp
48/// [`CF-Connecting-Ip`]: crate::headers::CFConnectingIp
49/// [`True-Client-Ip`]: crate::headers::TrueClientIp
50///
51/// ## Example
52///
53/// This example shows how you could expose the real Client IP using the [`X-Real-IP`][`crate::headers::XRealIp`] header.
54///
55/// ```rust
56/// use rama_net::stream::SocketInfo;
57/// use rama_http::Request;
58/// use rama_core::service::service_fn;
59/// use rama_http::{headers::XRealIp, layer::forwarded::SetForwardedHeadersLayer};
60/// use rama_core::{Context, Service, Layer};
61/// use std::convert::Infallible;
62///
63/// # type Body = ();
64/// # type State = ();
65///
66/// # #[tokio::main]
67/// # async fn main() {
68/// async fn svc(_ctx: Context<State>, request: Request<Body>) -> Result<(), Infallible> {
69///     // ...
70///     # assert_eq!(
71///     #     request.headers().get("X-Real-Ip").unwrap(),
72///     #     "42.37.100.50:62345",
73///     # );
74///     # Ok(())
75/// }
76///
77/// let service = SetForwardedHeadersLayer::<XRealIp>::new()
78///     .into_layer(service_fn(svc));
79///
80/// # let req = Request::builder().uri("example.com").body(()).unwrap();
81/// # let mut ctx = Context::default();
82/// # ctx.insert(SocketInfo::new(None, "42.37.100.50:62345".parse().unwrap()));
83/// service.serve(ctx, req).await.unwrap();
84/// # }
85/// ```
86pub struct SetForwardedHeadersLayer<T = Forwarded> {
87    by_node: NodeId,
88    _headers: PhantomData<fn() -> T>,
89}
90
91impl<T: fmt::Debug> fmt::Debug for SetForwardedHeadersLayer<T> {
92    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93        f.debug_struct("SetForwardedHeadersLayer")
94            .field("by_node", &self.by_node)
95            .field(
96                "_headers",
97                &format_args!("{}", std::any::type_name::<fn() -> T>()),
98            )
99            .finish()
100    }
101}
102
103impl<T: Clone> Clone for SetForwardedHeadersLayer<T> {
104    fn clone(&self) -> Self {
105        Self {
106            by_node: self.by_node.clone(),
107            _headers: PhantomData,
108        }
109    }
110}
111
112impl<T> SetForwardedHeadersLayer<T> {
113    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
114    ///
115    /// Default of `None` will be set to `rama` otherwise.
116    pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
117        self.by_node = node_id.into();
118        self
119    }
120
121    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
122    ///
123    /// Default of `None` will be set to `rama` otherwise.
124    pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
125        self.by_node = node_id.into();
126        self
127    }
128}
129
130impl<T> SetForwardedHeadersLayer<T> {
131    /// Create a new `SetForwardedHeadersLayer` for the specified headers `T`.
132    pub fn new() -> Self {
133        Self {
134            by_node: Domain::from_static("rama").into(),
135            _headers: PhantomData,
136        }
137    }
138}
139
140impl Default for SetForwardedHeadersLayer {
141    fn default() -> Self {
142        Self::forwarded()
143    }
144}
145
146impl SetForwardedHeadersLayer {
147    #[inline]
148    /// Create a new `SetForwardedHeadersLayer` for the standard [`Forwarded`] header.
149    pub fn forwarded() -> Self {
150        Self::new()
151    }
152}
153
154impl SetForwardedHeadersLayer<Via> {
155    #[inline]
156    /// Create a new `SetForwardedHeadersLayer` for the canonical [`Via`] header.
157    pub fn via() -> Self {
158        Self::new()
159    }
160}
161
162impl SetForwardedHeadersLayer<XForwardedFor> {
163    #[inline]
164    /// Create a new `SetForwardedHeadersLayer` for the canonical [`X-Forwarded-For`] header.
165    pub fn x_forwarded_for() -> Self {
166        Self::new()
167    }
168}
169
170impl SetForwardedHeadersLayer<XForwardedHost> {
171    #[inline]
172    /// Create a new `SetForwardedHeadersLayer` for the canonical [`X-Forwarded-Host`] header.
173    pub fn x_forwarded_host() -> Self {
174        Self::new()
175    }
176}
177
178impl SetForwardedHeadersLayer<XForwardedProto> {
179    #[inline]
180    /// Create a new `SetForwardedHeadersLayer` for the canonical [`X-Forwarded-Proto`] header.
181    pub fn x_forwarded_proto() -> Self {
182        Self::new()
183    }
184}
185
186impl<H, S> Layer<S> for SetForwardedHeadersLayer<H> {
187    type Service = SetForwardedHeadersService<S, H>;
188
189    fn layer(&self, inner: S) -> Self::Service {
190        Self::Service {
191            inner,
192            by_node: self.by_node.clone(),
193            _headers: PhantomData,
194        }
195    }
196
197    fn into_layer(self, inner: S) -> Self::Service {
198        Self::Service {
199            inner,
200            by_node: self.by_node,
201            _headers: PhantomData,
202        }
203    }
204}
205
206/// Middleware [`Service`] to write [`Forwarded`] information for this proxy,
207/// added to the end of the chain of forwarded information already known.
208///
209/// See [`SetForwardedHeadersLayer`] for more information.
210pub struct SetForwardedHeadersService<S, T = Forwarded> {
211    inner: S,
212    by_node: NodeId,
213    _headers: PhantomData<fn() -> T>,
214}
215
216impl<S: fmt::Debug, T> fmt::Debug for SetForwardedHeadersService<S, T> {
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        f.debug_struct("SetForwardedHeadersService")
219            .field("inner", &self.inner)
220            .field("by_node", &self.by_node)
221            .field(
222                "_headers",
223                &format_args!("{}", std::any::type_name::<fn() -> T>()),
224            )
225            .finish()
226    }
227}
228
229impl<S: Clone, T> Clone for SetForwardedHeadersService<S, T> {
230    fn clone(&self) -> Self {
231        SetForwardedHeadersService {
232            inner: self.inner.clone(),
233            by_node: self.by_node.clone(),
234            _headers: PhantomData,
235        }
236    }
237}
238
239impl<S, T> SetForwardedHeadersService<S, T> {
240    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
241    ///
242    /// Default of `None` will be set to `rama` otherwise.
243    pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
244        self.by_node = node_id.into();
245        self
246    }
247
248    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
249    ///
250    /// Default of `None` will be set to `rama` otherwise.
251    pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
252        self.by_node = node_id.into();
253        self
254    }
255}
256
257impl<S, T> SetForwardedHeadersService<S, T> {
258    /// Create a new `SetForwardedHeadersService` for the specified headers `T`.
259    pub fn new(inner: S) -> Self {
260        Self {
261            inner,
262            by_node: Domain::from_static("rama").into(),
263            _headers: PhantomData,
264        }
265    }
266}
267
268impl<S> SetForwardedHeadersService<S> {
269    #[inline]
270    /// Create a new `SetForwardedHeadersService` for the standard [`Forwarded`] header.
271    pub fn forwarded(inner: S) -> Self {
272        Self::new(inner)
273    }
274}
275
276impl<S> SetForwardedHeadersService<S, Via> {
277    #[inline]
278    /// Create a new `SetForwardedHeadersService` for the canonical [`Via`] header.
279    pub fn via(inner: S) -> Self {
280        Self::new(inner)
281    }
282}
283
284impl<S> SetForwardedHeadersService<S, XForwardedFor> {
285    #[inline]
286    /// Create a new `SetForwardedHeadersService` for the canonical [`X-Forwarded-For`] header.
287    pub fn x_forwarded_for(inner: S) -> Self {
288        Self::new(inner)
289    }
290}
291
292impl<S> SetForwardedHeadersService<S, XForwardedHost> {
293    #[inline]
294    /// Create a new `SetForwardedHeadersService` for the canonical [`X-Forwarded-Host`] header.
295    pub fn x_forwarded_host(inner: S) -> Self {
296        Self::new(inner)
297    }
298}
299
300impl<S> SetForwardedHeadersService<S, XForwardedProto> {
301    #[inline]
302    /// Create a new `SetForwardedHeadersService` for the canonical [`X-Forwarded-Proto`] header.
303    pub fn x_forwarded_proto(inner: S) -> Self {
304        Self::new(inner)
305    }
306}
307
308impl<S, H, State, Body> Service<State, Request<Body>> for SetForwardedHeadersService<S, H>
309where
310    S: Service<State, Request<Body>, Error: Into<BoxError>>,
311    H: ForwardHeader + Send + Sync + 'static,
312    Body: Send + 'static,
313    State: Clone + Send + Sync + 'static,
314{
315    type Response = S::Response;
316    type Error = BoxError;
317
318    async fn serve(
319        &self,
320        mut ctx: Context<State>,
321        mut req: Request<Body>,
322    ) -> Result<Self::Response, Self::Error> {
323        let forwarded: Option<Forwarded> = ctx.get().cloned();
324
325        let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
326
327        if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
328            forwarded_element.set_forwarded_for(peer_addr);
329        }
330        let request_ctx: &mut RequestContext =
331            ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
332
333        forwarded_element.set_forwarded_host(request_ctx.authority.clone());
334
335        if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
336            forwarded_element.set_forwarded_proto(forwarded_proto);
337        }
338
339        let forwarded = match forwarded {
340            None => Some(Forwarded::new(forwarded_element)),
341            Some(mut forwarded) => {
342                forwarded.append(forwarded_element);
343                Some(forwarded)
344            }
345        };
346
347        if let Some(forwarded) = forwarded {
348            if let Some(header) = H::try_from_forwarded(forwarded.iter()) {
349                req.headers_mut().typed_insert(header);
350            }
351        }
352
353        self.inner.serve(ctx, req).await.map_err(Into::into)
354    }
355}
356
357macro_rules! set_forwarded_service_for_tuple {
358    ( $($ty:ident),* $(,)? ) => {
359        #[allow(non_snake_case)]
360        impl<S, $($ty),* , State, Body> Service<State, Request<Body>> for SetForwardedHeadersService<S, ($($ty,)*)>
361        where
362            $( $ty: ForwardHeader + Send + Sync + 'static, )*
363            S: Service<State, Request<Body>, Error: Into<BoxError>>,
364            Body: Send + 'static,
365            State: Clone + Send + Sync + 'static,
366        {
367            type Response = S::Response;
368            type Error = BoxError;
369
370            async fn serve(
371                &self,
372                mut ctx: Context<State>,
373                mut req: Request<Body>,
374            ) -> Result<Self::Response, Self::Error> {
375                let forwarded: Option<Forwarded> = ctx.get().cloned();
376
377                let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
378
379                if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
380                    forwarded_element.set_forwarded_for(peer_addr);
381                }
382
383                let request_ctx: &mut RequestContext =
384                    ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
385
386                forwarded_element.set_forwarded_host(request_ctx.authority.clone());
387
388                if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
389                    forwarded_element.set_forwarded_proto(forwarded_proto);
390                }
391
392                let forwarded = match forwarded {
393                    None => Some(Forwarded::new(forwarded_element)),
394                    Some(mut forwarded) => {
395                        forwarded.append(forwarded_element);
396                        Some(forwarded)
397                    }
398                };
399
400                if let Some(forwarded) = forwarded {
401                    $(
402                        if let Some(header) = $ty::try_from_forwarded(forwarded.iter()) {
403                            req.headers_mut().typed_insert(header);
404                        }
405                    )*
406                }
407
408                self.inner.serve(ctx, req).await.map_err(Into::into)
409            }
410        }
411    };
412}
413all_the_tuples_no_last_special_case!(set_forwarded_service_for_tuple);
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use crate::{
419        IntoResponse, Response, StatusCode,
420        headers::{TrueClientIp, XClientIp, XRealIp},
421    };
422    use rama_core::{Layer, error::OpaqueError, service::service_fn};
423    use std::{convert::Infallible, net::IpAddr};
424
425    fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
426
427    async fn dummy_service_fn() -> Result<Response, OpaqueError> {
428        Ok(StatusCode::OK.into_response())
429    }
430
431    #[test]
432    fn test_set_forwarded_service_is_service() {
433        assert_is_service(SetForwardedHeadersService::forwarded(service_fn(
434            dummy_service_fn,
435        )));
436        assert_is_service(SetForwardedHeadersService::via(service_fn(
437            dummy_service_fn,
438        )));
439        assert_is_service(SetForwardedHeadersService::x_forwarded_for(service_fn(
440            dummy_service_fn,
441        )));
442        assert_is_service(SetForwardedHeadersService::x_forwarded_proto(service_fn(
443            dummy_service_fn,
444        )));
445        assert_is_service(SetForwardedHeadersService::x_forwarded_host(service_fn(
446            dummy_service_fn,
447        )));
448        assert_is_service(SetForwardedHeadersService::<_, TrueClientIp>::new(
449            service_fn(dummy_service_fn),
450        ));
451        assert_is_service(SetForwardedHeadersService::<_, (TrueClientIp,)>::new(
452            service_fn(dummy_service_fn),
453        ));
454        assert_is_service(
455            SetForwardedHeadersService::<_, (TrueClientIp, XClientIp)>::new(service_fn(
456                dummy_service_fn,
457            )),
458        );
459        assert_is_service(SetForwardedHeadersLayer::via().into_layer(service_fn(dummy_service_fn)));
460        assert_is_service(
461            SetForwardedHeadersLayer::<XRealIp>::new().into_layer(service_fn(dummy_service_fn)),
462        );
463        assert_is_service(
464            SetForwardedHeadersLayer::<(XRealIp, XForwardedProto)>::new()
465                .into_layer(service_fn(dummy_service_fn)),
466        );
467    }
468
469    #[tokio::test]
470    async fn test_set_forwarded_service_forwarded() {
471        async fn svc(request: Request<()>) -> Result<(), Infallible> {
472            assert_eq!(
473                request.headers().get("Forwarded").unwrap(),
474                "by=rama;host=\"example.com:80\";proto=http"
475            );
476            Ok(())
477        }
478
479        let service = SetForwardedHeadersService::forwarded(service_fn(svc));
480        let req = Request::builder().uri("example.com").body(()).unwrap();
481        service.serve(Context::default(), req).await.unwrap();
482    }
483
484    #[tokio::test]
485    async fn test_set_forwarded_service_forwarded_with_chain() {
486        async fn svc(request: Request<()>) -> Result<(), Infallible> {
487            assert_eq!(
488                request.headers().get("Forwarded").unwrap(),
489                "for=12.23.34.45,by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
490            );
491            Ok(())
492        }
493
494        let service = SetForwardedHeadersService::forwarded(service_fn(svc));
495        let req = Request::builder()
496            .uri("https://www.example.com")
497            .body(())
498            .unwrap();
499        let mut ctx = Context::default();
500        ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
501            IpAddr::from([12, 23, 34, 45]),
502        )));
503        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
504        service.serve(ctx, req).await.unwrap();
505    }
506
507    #[tokio::test]
508    async fn test_set_forwarded_service_x_forwarded_for_with_chain() {
509        async fn svc(request: Request<()>) -> Result<(), Infallible> {
510            assert_eq!(
511                request.headers().get("X-Forwarded-For").unwrap(),
512                "12.23.34.45, 127.0.0.1",
513            );
514            Ok(())
515        }
516
517        let service = SetForwardedHeadersService::x_forwarded_for(service_fn(svc));
518        let req = Request::builder()
519            .uri("https://www.example.com")
520            .body(())
521            .unwrap();
522        let mut ctx = Context::default();
523        ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
524            IpAddr::from([12, 23, 34, 45]),
525        )));
526        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
527        service.serve(ctx, req).await.unwrap();
528    }
529
530    #[tokio::test]
531    async fn test_set_forwarded_service_forwarded_fully_defined() {
532        async fn svc(request: Request<()>) -> Result<(), Infallible> {
533            assert_eq!(
534                request.headers().get("Forwarded").unwrap(),
535                "by=12.23.34.45;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
536            );
537            Ok(())
538        }
539
540        let service = SetForwardedHeadersService::forwarded(service_fn(svc))
541            .forward_by(IpAddr::from([12, 23, 34, 45]));
542        let req = Request::builder()
543            .uri("https://www.example.com")
544            .body(())
545            .unwrap();
546        let mut ctx = Context::default();
547        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
548        service.serve(ctx, req).await.unwrap();
549    }
550
551    #[tokio::test]
552    async fn test_set_forwarded_service_forwarded_fully_defined_with_chain() {
553        async fn svc(request: Request<()>) -> Result<(), Infallible> {
554            assert_eq!(
555                request.headers().get("Forwarded").unwrap(),
556                "by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
557            );
558            Ok(())
559        }
560
561        let service = SetForwardedHeadersService::forwarded(service_fn(svc));
562        let req = Request::builder()
563            .uri("https://www.example.com")
564            .body(())
565            .unwrap();
566        let mut ctx = Context::default();
567        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
568        service.serve(ctx, req).await.unwrap();
569    }
570}