rama_http/layer/forwarded/
set_forwarded.rs

1use crate::Request;
2use crate::headers::HeaderMapExt;
3use crate::headers::forwarded::{
4    ForwardHeader, Via, XForwardedFor, XForwardedHost, XForwardedProto,
5};
6use rama_core::error::BoxError;
7use rama_core::{Context, Layer, Service};
8use rama_http_headers::forwarded::Forwarded;
9use rama_net::address::Domain;
10use rama_net::forwarded::{ForwardedElement, NodeId};
11use rama_net::http::RequestContext;
12use rama_net::stream::SocketInfo;
13use std::fmt;
14use std::marker::PhantomData;
15
16/// Layer to write [`Forwarded`] information for this proxy,
17/// added to the end of the chain of forwarded information already known.
18///
19/// This layer can set any header as long as you have a [`ForwardHeader`] implementation
20/// for the header you want to set. You can pass it as the type to the layer when creating
21/// the layer using [`SetForwardedHeaderLayer::new`].
22///
23/// The following headers are supported out of the box with each their own constructor:
24///
25/// - [`SetForwardedHeaderLayer::forwarded`]: the standard [`Forwarded`] header [`RFC 7239`](https://tools.ietf.org/html/rfc7239);
26/// - [`SetForwardedHeaderLayer::via`]: the canonical [`Via`] header (non-standard);
27/// - [`SetForwardedHeaderLayer::x_forwarded_for`]: the canonical [`X-Forwarded-For`][`XForwardedFor`] header (non-standard);
28/// - [`SetForwardedHeaderLayer::x_forwarded_host`]: the canonical [`X-Forwarded-Host`][`XForwardedHost`] header (non-standard);
29/// - [`SetForwardedHeaderLayer::x_forwarded_proto`]: the canonical [`X-Forwarded-Proto`][`XForwardedProto`] header (non-standard).
30///
31/// The "by" property is set to `rama` by default. Use [`SetForwardedHeaderLayer::forward_by`] to overwrite this,
32/// typically with the actual [`IPv4`]/[`IPv6`] address of your proxy.
33///
34/// [`IPv4`]: std::net::Ipv4Addr
35/// [`IPv6`]: std::net::Ipv6Addr
36///
37/// Rama also has the following headers already implemented for you to use:
38///
39/// > [`X-Real-Ip`], [`X-Client-Ip`], [`Client-Ip`], [`CF-Connecting-Ip`] and [`True-Client-Ip`].
40///
41/// There are no [`SetForwardedHeaderLayer`] constructors for these headers,
42/// but you can use the [`SetForwardedHeaderLayer::new`] constructor and pass the header type as a type parameter,
43/// alone or in a tuple with other headers.
44///
45/// [`X-Real-Ip`]: crate::headers::XRealIp
46/// [`X-Client-Ip`]: crate::headers::XClientIp
47/// [`Client-Ip`]: crate::headers::ClientIp
48/// [`CF-Connecting-Ip`]: crate::headers::CFConnectingIp
49/// [`True-Client-Ip`]: crate::headers::TrueClientIp
50///
51/// ## Example
52///
53/// This example shows how you could expose the real Client IP using the [`X-Real-IP`][`crate::headers::XRealIp`] header.
54///
55/// ```rust
56/// use rama_net::stream::SocketInfo;
57/// use rama_http::Request;
58/// use rama_core::service::service_fn;
59/// use rama_http::{headers::forwarded::XRealIp, layer::forwarded::SetForwardedHeaderLayer};
60/// use rama_core::{Context, Service, Layer};
61/// use std::convert::Infallible;
62///
63/// # type Body = ();
64/// # type State = ();
65///
66/// # #[tokio::main]
67/// # async fn main() {
68/// async fn svc(_ctx: Context<State>, request: Request<Body>) -> Result<(), Infallible> {
69///     // ...
70///     # assert_eq!(
71///     #     request.headers().get("X-Real-Ip").unwrap(),
72///     #     "42.37.100.50:62345",
73///     # );
74///     # Ok(())
75/// }
76///
77/// let service = SetForwardedHeaderLayer::<XRealIp>::new()
78///     .into_layer(service_fn(svc));
79///
80/// # let req = Request::builder().uri("example.com").body(()).unwrap();
81/// # let mut ctx = Context::default();
82/// # ctx.insert(SocketInfo::new(None, "42.37.100.50:62345".parse().unwrap()));
83/// service.serve(ctx, req).await.unwrap();
84/// # }
85/// ```
86pub struct SetForwardedHeaderLayer<T = Forwarded> {
87    by_node: NodeId,
88    _headers: PhantomData<fn() -> T>,
89}
90
91impl<T: fmt::Debug> fmt::Debug for SetForwardedHeaderLayer<T> {
92    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93        f.debug_struct("SetForwardedHeaderLayer")
94            .field("by_node", &self.by_node)
95            .field(
96                "_headers",
97                &format_args!("{}", std::any::type_name::<fn() -> T>()),
98            )
99            .finish()
100    }
101}
102
103impl<T: Clone> Clone for SetForwardedHeaderLayer<T> {
104    fn clone(&self) -> Self {
105        Self {
106            by_node: self.by_node.clone(),
107            _headers: PhantomData,
108        }
109    }
110}
111
112impl<T> SetForwardedHeaderLayer<T> {
113    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
114    ///
115    /// Default of `None` will be set to `rama` otherwise.
116    pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
117        self.by_node = node_id.into();
118        self
119    }
120
121    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
122    ///
123    /// Default of `None` will be set to `rama` otherwise.
124    pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
125        self.by_node = node_id.into();
126        self
127    }
128}
129
130impl<T> SetForwardedHeaderLayer<T> {
131    /// Create a new `SetForwardedHeaderLayer` for the specified headers `T`.
132    pub fn new() -> Self {
133        Self {
134            by_node: Domain::from_static("rama").into(),
135            _headers: PhantomData,
136        }
137    }
138}
139
140impl Default for SetForwardedHeaderLayer {
141    fn default() -> Self {
142        Self::forwarded()
143    }
144}
145
146impl SetForwardedHeaderLayer {
147    #[inline]
148    /// Create a new `SetForwardedHeaderLayer` for the standard [`Forwarded`] header.
149    pub fn forwarded() -> Self {
150        Self::new()
151    }
152}
153
154impl SetForwardedHeaderLayer<Via> {
155    #[inline]
156    /// Create a new `SetForwardedHeaderLayer` for the canonical [`Via`] header.
157    pub fn via() -> Self {
158        Self::new()
159    }
160}
161
162impl SetForwardedHeaderLayer<XForwardedFor> {
163    #[inline]
164    /// Create a new `SetForwardedHeaderLayer` for the canonical [`X-Forwarded-For`] header.
165    pub fn x_forwarded_for() -> Self {
166        Self::new()
167    }
168}
169
170impl SetForwardedHeaderLayer<XForwardedHost> {
171    #[inline]
172    /// Create a new `SetForwardedHeaderLayer` for the canonical [`X-Forwarded-Host`] header.
173    pub fn x_forwarded_host() -> Self {
174        Self::new()
175    }
176}
177
178impl SetForwardedHeaderLayer<XForwardedProto> {
179    #[inline]
180    /// Create a new `SetForwardedHeaderLayer` for the canonical [`X-Forwarded-Proto`] header.
181    pub fn x_forwarded_proto() -> Self {
182        Self::new()
183    }
184}
185
186impl<H, S> Layer<S> for SetForwardedHeaderLayer<H> {
187    type Service = SetForwardedHeaderService<S, H>;
188
189    fn layer(&self, inner: S) -> Self::Service {
190        Self::Service {
191            inner,
192            by_node: self.by_node.clone(),
193            _headers: PhantomData,
194        }
195    }
196
197    fn into_layer(self, inner: S) -> Self::Service {
198        Self::Service {
199            inner,
200            by_node: self.by_node,
201            _headers: PhantomData,
202        }
203    }
204}
205
206/// Middleware [`Service`] to write [`Forwarded`] information for this proxy,
207/// added to the end of the chain of forwarded information already known.
208///
209/// See [`SetForwardedHeaderLayer`] for more information.
210pub struct SetForwardedHeaderService<S, T = Forwarded> {
211    inner: S,
212    by_node: NodeId,
213    _headers: PhantomData<fn() -> T>,
214}
215
216impl<S: fmt::Debug, T> fmt::Debug for SetForwardedHeaderService<S, T> {
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        f.debug_struct("SetForwardedHeaderService")
219            .field("inner", &self.inner)
220            .field("by_node", &self.by_node)
221            .field(
222                "_headers",
223                &format_args!("{}", std::any::type_name::<fn() -> T>()),
224            )
225            .finish()
226    }
227}
228
229impl<S: Clone, T> Clone for SetForwardedHeaderService<S, T> {
230    fn clone(&self) -> Self {
231        SetForwardedHeaderService {
232            inner: self.inner.clone(),
233            by_node: self.by_node.clone(),
234            _headers: PhantomData,
235        }
236    }
237}
238
239impl<S, T> SetForwardedHeaderService<S, T> {
240    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
241    ///
242    /// Default of `None` will be set to `rama` otherwise.
243    pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
244        self.by_node = node_id.into();
245        self
246    }
247
248    /// Set the given [`NodeId`] as the "by" property, identifying this proxy.
249    ///
250    /// Default of `None` will be set to `rama` otherwise.
251    pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
252        self.by_node = node_id.into();
253        self
254    }
255}
256
257impl<S, T> SetForwardedHeaderService<S, T> {
258    /// Create a new `SetForwardedHeaderService` for the specified headers `T`.
259    pub fn new(inner: S) -> Self {
260        Self {
261            inner,
262            by_node: Domain::from_static("rama").into(),
263            _headers: PhantomData,
264        }
265    }
266}
267
268impl<S> SetForwardedHeaderService<S> {
269    #[inline]
270    /// Create a new `SetForwardedHeaderService` for the standard [`Forwarded`] header.
271    pub fn forwarded(inner: S) -> Self {
272        Self::new(inner)
273    }
274}
275
276impl<S> SetForwardedHeaderService<S, Via> {
277    #[inline]
278    /// Create a new `SetForwardedHeaderService` for the canonical [`Via`] header.
279    pub fn via(inner: S) -> Self {
280        Self::new(inner)
281    }
282}
283
284impl<S> SetForwardedHeaderService<S, XForwardedFor> {
285    #[inline]
286    /// Create a new `SetForwardedHeaderService` for the canonical [`X-Forwarded-For`] header.
287    pub fn x_forwarded_for(inner: S) -> Self {
288        Self::new(inner)
289    }
290}
291
292impl<S> SetForwardedHeaderService<S, XForwardedHost> {
293    #[inline]
294    /// Create a new `SetForwardedHeaderService` for the canonical [`X-Forwarded-Host`] header.
295    pub fn x_forwarded_host(inner: S) -> Self {
296        Self::new(inner)
297    }
298}
299
300impl<S> SetForwardedHeaderService<S, XForwardedProto> {
301    #[inline]
302    /// Create a new `SetForwardedHeaderService` for the canonical [`X-Forwarded-Proto`] header.
303    pub fn x_forwarded_proto(inner: S) -> Self {
304        Self::new(inner)
305    }
306}
307
308impl<S, H, State, Body> Service<State, Request<Body>> for SetForwardedHeaderService<S, H>
309where
310    S: Service<State, Request<Body>, Error: Into<BoxError>>,
311    H: ForwardHeader + Send + Sync + 'static,
312    Body: Send + 'static,
313    State: Clone + Send + Sync + 'static,
314{
315    type Response = S::Response;
316    type Error = BoxError;
317
318    async fn serve(
319        &self,
320        mut ctx: Context<State>,
321        mut req: Request<Body>,
322    ) -> Result<Self::Response, Self::Error> {
323        let forwarded: Option<rama_net::forwarded::Forwarded> = ctx.get().cloned();
324
325        let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
326
327        if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
328            forwarded_element.set_forwarded_for(peer_addr);
329        }
330        let request_ctx: &mut RequestContext =
331            ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
332
333        forwarded_element.set_forwarded_host(request_ctx.authority.clone());
334
335        if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
336            forwarded_element.set_forwarded_proto(forwarded_proto);
337        }
338
339        let forwarded = match forwarded {
340            None => Some(rama_net::forwarded::Forwarded::new(forwarded_element)),
341            Some(mut forwarded) => {
342                forwarded.append(forwarded_element);
343                Some(forwarded)
344            }
345        };
346
347        if let Some(forwarded) = forwarded {
348            if let Some(header) = H::try_from_forwarded(forwarded.iter()) {
349                req.headers_mut().typed_insert(header);
350            }
351        }
352
353        self.inner.serve(ctx, req).await.map_err(Into::into)
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use crate::{
361        Response, StatusCode,
362        headers::forwarded::{TrueClientIp, XRealIp},
363        service::web::response::IntoResponse,
364    };
365    use rama_core::{Layer, error::OpaqueError, service::service_fn};
366    use std::{convert::Infallible, net::IpAddr};
367
368    fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
369
370    async fn dummy_service_fn() -> Result<Response, OpaqueError> {
371        Ok(StatusCode::OK.into_response())
372    }
373
374    #[test]
375    fn test_set_forwarded_service_is_service() {
376        assert_is_service(SetForwardedHeaderService::forwarded(service_fn(
377            dummy_service_fn,
378        )));
379        assert_is_service(SetForwardedHeaderService::via(service_fn(dummy_service_fn)));
380        assert_is_service(SetForwardedHeaderService::x_forwarded_for(service_fn(
381            dummy_service_fn,
382        )));
383        assert_is_service(SetForwardedHeaderService::x_forwarded_proto(service_fn(
384            dummy_service_fn,
385        )));
386        assert_is_service(SetForwardedHeaderService::x_forwarded_host(service_fn(
387            dummy_service_fn,
388        )));
389        assert_is_service(SetForwardedHeaderService::<_, TrueClientIp>::new(
390            service_fn(dummy_service_fn),
391        ));
392        assert_is_service(SetForwardedHeaderLayer::via().into_layer(service_fn(dummy_service_fn)));
393        assert_is_service(
394            SetForwardedHeaderLayer::<XRealIp>::new().into_layer(service_fn(dummy_service_fn)),
395        );
396    }
397
398    #[tokio::test]
399    async fn test_set_forwarded_service_forwarded() {
400        async fn svc(request: Request<()>) -> Result<(), Infallible> {
401            assert_eq!(
402                request.headers().get("Forwarded").unwrap(),
403                "by=rama;host=\"example.com:80\";proto=http"
404            );
405            Ok(())
406        }
407
408        let service = SetForwardedHeaderService::forwarded(service_fn(svc));
409        let req = Request::builder().uri("example.com").body(()).unwrap();
410        service.serve(Context::default(), req).await.unwrap();
411    }
412
413    #[tokio::test]
414    async fn test_set_forwarded_service_forwarded_with_chain() {
415        async fn svc(request: Request<()>) -> Result<(), Infallible> {
416            assert_eq!(
417                request.headers().get("Forwarded").unwrap(),
418                "for=12.23.34.45,by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
419            );
420            Ok(())
421        }
422
423        let service = SetForwardedHeaderService::forwarded(service_fn(svc));
424        let req = Request::builder()
425            .uri("https://www.example.com")
426            .body(())
427            .unwrap();
428        let mut ctx = Context::default();
429        ctx.insert(rama_net::forwarded::Forwarded::new(
430            ForwardedElement::forwarded_for(IpAddr::from([12, 23, 34, 45])),
431        ));
432        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
433        service.serve(ctx, req).await.unwrap();
434    }
435
436    #[tokio::test]
437    async fn test_set_forwarded_service_x_forwarded_for_with_chain() {
438        async fn svc(request: Request<()>) -> Result<(), Infallible> {
439            assert_eq!(
440                request.headers().get("X-Forwarded-For").unwrap(),
441                "12.23.34.45, 127.0.0.1",
442            );
443            Ok(())
444        }
445
446        let service = SetForwardedHeaderService::x_forwarded_for(service_fn(svc));
447        let req = Request::builder()
448            .uri("https://www.example.com")
449            .body(())
450            .unwrap();
451        let mut ctx = Context::default();
452        ctx.insert(rama_net::forwarded::Forwarded::new(
453            ForwardedElement::forwarded_for(IpAddr::from([12, 23, 34, 45])),
454        ));
455        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
456        service.serve(ctx, req).await.unwrap();
457    }
458
459    #[tokio::test]
460    async fn test_set_forwarded_service_forwarded_fully_defined() {
461        async fn svc(request: Request<()>) -> Result<(), Infallible> {
462            assert_eq!(
463                request.headers().get("Forwarded").unwrap(),
464                "by=12.23.34.45;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
465            );
466            Ok(())
467        }
468
469        let service = SetForwardedHeaderService::forwarded(service_fn(svc))
470            .forward_by(IpAddr::from([12, 23, 34, 45]));
471        let req = Request::builder()
472            .uri("https://www.example.com")
473            .body(())
474            .unwrap();
475        let mut ctx = Context::default();
476        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
477        service.serve(ctx, req).await.unwrap();
478    }
479
480    #[tokio::test]
481    async fn test_set_forwarded_service_forwarded_fully_defined_with_chain() {
482        async fn svc(request: Request<()>) -> Result<(), Infallible> {
483            assert_eq!(
484                request.headers().get("Forwarded").unwrap(),
485                "by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
486            );
487            Ok(())
488        }
489
490        let service = SetForwardedHeaderService::forwarded(service_fn(svc));
491        let req = Request::builder()
492            .uri("https://www.example.com")
493            .body(())
494            .unwrap();
495        let mut ctx = Context::default();
496        ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
497        service.serve(ctx, req).await.unwrap();
498    }
499}