rama_http/layer/forwarded/
get_forwarded.rs

1use crate::Request;
2use crate::headers::forwarded::{
3    ForwardHeader, Via, XForwardedFor, XForwardedHost, XForwardedProto,
4};
5use rama_core::{Context, Layer, Service};
6use rama_http_headers::HeaderMapExt;
7use rama_http_headers::forwarded::Forwarded;
8use rama_net::forwarded::ForwardedElement;
9use std::fmt;
10use std::marker::PhantomData;
11
12/// Layer to extract [`Forwarded`] information from the specified `T` headers.
13///
14/// This layer can be used to extract the [`Forwarded`] information from any specified header `T`,
15/// as long as the header implements the [`ForwardHeader`] trait.
16///
17/// The following headers are supported by default:
18///
19/// - [`GetForwardedHeaderLayer::forwarded`]: The standard [`Forwarded`] header [`RFC 7239`](https://tools.ietf.org/html/rfc7239).
20/// - [`GetForwardedHeaderLayer::via`]: The canonical [`Via`] header [`RFC 7230`](https://tools.ietf.org/html/rfc7230#section-5.7.1).
21/// - [`GetForwardedHeaderLayer::x_forwarded_for`]: The canonical [`X-Forwarded-For`] header [`RFC 7239`](https://tools.ietf.org/html/rfc7239#section-5.2).
22/// - [`GetForwardedHeaderLayer::x_forwarded_host`]: The canonical [`X-Forwarded-Host`] header [`RFC 7239`](https://tools.ietf.org/html/rfc7239#section-5.4).
23/// - [`GetForwardedHeaderLayer::x_forwarded_proto`]: The canonical [`X-Forwarded-Proto`] header [`RFC 7239`](https://tools.ietf.org/html/rfc7239#section-5.3).
24///
25/// Rama also has the following headers already implemented for you to use:
26///
27/// > [`X-Real-Ip`], [`X-Client-Ip`], [`Client-Ip`], [`Cf-Connecting-Ip`] and [`True-Client-Ip`].
28///
29/// There are no [`GetForwardedHeaderLayer`] constructors for these headers,
30/// but you can use the [`GetForwardedHeaderLayer::new`] constructor and pass the header type as a type parameter,
31/// alone or in a tuple with other headers.
32///
33/// [`X-Real-Ip`]: crate::headers::XRealIp
34/// [`X-Client-Ip`]: crate::headers::XClientIp
35/// [`Client-Ip`]: crate::headers::ClientIp
36/// [`CF-Connecting-Ip`]: crate::headers::CFConnectingIp
37/// [`True-Client-Ip`]: crate::headers::TrueClientIp
38///
39/// ## Example
40///
41/// This example shows you can extract the client IP from the `X-Forwarded-For`
42/// header in case your application is behind a proxy which sets this header.
43///
44/// ```rust
45/// use rama_core::{
46///     service::service_fn,
47///     Context, Service, Layer,
48/// };
49/// use rama_http::{headers::forwarded::Forwarded, layer::forwarded::GetForwardedHeaderLayer, Request};
50/// use std::{convert::Infallible, net::IpAddr};
51///
52/// #[tokio::main]
53/// async fn main() {
54///     let service = GetForwardedHeaderLayer::x_forwarded_for()
55///         .into_layer(service_fn(async |ctx: Context<()>, _| {
56///             let forwarded = ctx.get::<rama_net::forwarded::Forwarded>().unwrap();
57///             assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
58///             assert!(forwarded.client_proto().is_none());
59///
60///             // ...
61///
62///             Ok::<_, Infallible>(())
63///         }));
64///
65///     let req = Request::builder()
66///         .header("X-Forwarded-For", "12.23.34.45")
67///         .body(())
68///         .unwrap();
69///
70///     service.serve(Context::default(), req).await.unwrap();
71/// }
72/// ```
73pub struct GetForwardedHeaderLayer<T = rama_http_headers::forwarded::Forwarded> {
74    _headers: PhantomData<fn() -> T>,
75}
76
77impl<T: fmt::Debug> fmt::Debug for GetForwardedHeaderLayer<T> {
78    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79        f.debug_struct("GetForwardedHeaderLayer")
80            .field(
81                "_headers",
82                &format_args!("{}", std::any::type_name::<fn() -> T>()),
83            )
84            .finish()
85    }
86}
87
88impl<T: Clone> Clone for GetForwardedHeaderLayer<T> {
89    fn clone(&self) -> Self {
90        Self {
91            _headers: PhantomData,
92        }
93    }
94}
95
96impl Default for GetForwardedHeaderLayer {
97    fn default() -> Self {
98        Self::forwarded()
99    }
100}
101
102impl<T> GetForwardedHeaderLayer<T> {
103    /// Create a new `GetForwardedHeaderLayer` for the specified headers `T`.
104    pub const fn new() -> Self {
105        Self {
106            _headers: PhantomData,
107        }
108    }
109}
110
111impl GetForwardedHeaderLayer {
112    #[inline]
113    /// Create a new `GetForwardedHeaderLayer` for the standard [`Forwarded`] header.
114    pub fn forwarded() -> Self {
115        Self::new()
116    }
117}
118
119impl GetForwardedHeaderLayer<Via> {
120    #[inline]
121    /// Create a new `GetForwardedHeaderLayer` for the canonical [`Via`] header.
122    pub fn via() -> Self {
123        Self::new()
124    }
125}
126
127impl GetForwardedHeaderLayer<XForwardedFor> {
128    #[inline]
129    /// Create a new `GetForwardedHeaderLayer` for the canonical [`X-Forwarded-For`] header.
130    pub fn x_forwarded_for() -> Self {
131        Self::new()
132    }
133}
134
135impl GetForwardedHeaderLayer<XForwardedHost> {
136    #[inline]
137    /// Create a new `GetForwardedHeaderLayer` for the canonical [`X-Forwarded-Host`] header.
138    pub fn x_forwarded_host() -> Self {
139        Self::new()
140    }
141}
142
143impl GetForwardedHeaderLayer<XForwardedProto> {
144    #[inline]
145    /// Create a new `GetForwardedHeaderLayer` for the canonical [`X-Forwarded-Proto`] header.
146    pub fn x_forwarded_proto() -> Self {
147        Self::new()
148    }
149}
150
151impl<H, S> Layer<S> for GetForwardedHeaderLayer<H> {
152    type Service = GetForwardedHeaderService<S, H>;
153
154    fn layer(&self, inner: S) -> Self::Service {
155        Self::Service {
156            inner,
157            _headers: PhantomData,
158        }
159    }
160}
161
162/// Middleware service to extract [`Forwarded`] information from the specified `T` headers.
163///
164/// See [`GetForwardedHeaderLayer`] for more information.
165pub struct GetForwardedHeaderService<S, T = Forwarded> {
166    inner: S,
167    _headers: PhantomData<fn() -> T>,
168}
169
170impl<S: fmt::Debug, T> fmt::Debug for GetForwardedHeaderService<S, T> {
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        f.debug_struct("GetForwardedHeaderService")
173            .field("inner", &self.inner)
174            .field("_headers", &format_args!("{}", std::any::type_name::<T>()))
175            .finish()
176    }
177}
178
179impl<S: Clone, T> Clone for GetForwardedHeaderService<S, T> {
180    fn clone(&self) -> Self {
181        GetForwardedHeaderService {
182            inner: self.inner.clone(),
183            _headers: PhantomData,
184        }
185    }
186}
187
188impl<S, T> GetForwardedHeaderService<S, T> {
189    /// Create a new `GetForwardedHeaderService` for the specified headers `T`.
190    pub const fn new(inner: S) -> Self {
191        Self {
192            inner,
193            _headers: PhantomData,
194        }
195    }
196}
197
198impl<S> GetForwardedHeaderService<S> {
199    #[inline]
200    /// Create a new `GetForwardedHeaderService` for the standard [`Forwarded`] header.
201    pub fn forwarded(inner: S) -> Self {
202        Self::new(inner)
203    }
204}
205
206impl<S> GetForwardedHeaderService<S, Via> {
207    #[inline]
208    /// Create a new `GetForwardedHeaderService` for the canonical [`Via`] header.
209    pub fn via(inner: S) -> Self {
210        Self::new(inner)
211    }
212}
213
214impl<S> GetForwardedHeaderService<S, XForwardedFor> {
215    #[inline]
216    /// Create a new `GetForwardedHeaderService` for the canonical [`X-Forwarded-For`] header.
217    pub fn x_forwarded_for(inner: S) -> Self {
218        Self::new(inner)
219    }
220}
221
222impl<S> GetForwardedHeaderService<S, XForwardedHost> {
223    #[inline]
224    /// Create a new `GetForwardedHeaderService` for the canonical [`X-Forwarded-Host`] header.
225    pub fn x_forwarded_host(inner: S) -> Self {
226        Self::new(inner)
227    }
228}
229
230impl<S> GetForwardedHeaderService<S, XForwardedProto> {
231    #[inline]
232    /// Create a new `GetForwardedHeaderService` for the canonical [`X-Forwarded-Proto`] header.
233    pub fn x_forwarded_proto(inner: S) -> Self {
234        Self::new(inner)
235    }
236}
237
238impl<H, S, State, Body> Service<State, Request<Body>> for GetForwardedHeaderService<S, H>
239where
240    H: ForwardHeader + Send + Sync + 'static,
241    S: Service<State, Request<Body>>,
242    Body: Send + 'static,
243    State: Clone + Send + Sync + 'static,
244{
245    type Response = S::Response;
246    type Error = S::Error;
247
248    fn serve(
249        &self,
250        mut ctx: Context<State>,
251        req: Request<Body>,
252    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
253        let mut forwarded_elements: Vec<ForwardedElement> = Vec::with_capacity(1);
254
255        if let Some(header) = req.headers().typed_get::<H>() {
256            forwarded_elements.extend(header);
257        }
258
259        if !forwarded_elements.is_empty() {
260            match ctx.get_mut::<Forwarded>() {
261                Some(ref mut f) => {
262                    f.extend(forwarded_elements);
263                }
264                None => {
265                    let mut it = forwarded_elements.into_iter();
266                    let mut forwarded = rama_net::forwarded::Forwarded::new(it.next().unwrap());
267                    forwarded.extend(it);
268                    ctx.insert(forwarded);
269                }
270            }
271        }
272
273        self.inner.serve(ctx, req)
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::{Response, StatusCode, service::web::response::IntoResponse};
281    use rama_core::{Layer, error::OpaqueError, service::service_fn};
282    use rama_http_headers::forwarded::{TrueClientIp, XRealIp};
283    use rama_net::forwarded::{ForwardedProtocol, ForwardedVersion};
284    use std::{convert::Infallible, net::IpAddr};
285
286    fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
287
288    async fn dummy_service_fn() -> Result<Response, OpaqueError> {
289        Ok(StatusCode::OK.into_response())
290    }
291
292    #[test]
293    fn test_get_forwarded_service_is_service() {
294        assert_is_service(GetForwardedHeaderService::forwarded(service_fn(
295            dummy_service_fn,
296        )));
297        assert_is_service(GetForwardedHeaderService::via(service_fn(dummy_service_fn)));
298        assert_is_service(GetForwardedHeaderService::x_forwarded_for(service_fn(
299            dummy_service_fn,
300        )));
301        assert_is_service(GetForwardedHeaderService::x_forwarded_proto(service_fn(
302            dummy_service_fn,
303        )));
304        assert_is_service(GetForwardedHeaderService::x_forwarded_host(service_fn(
305            dummy_service_fn,
306        )));
307        assert_is_service(GetForwardedHeaderService::<_, TrueClientIp>::new(
308            service_fn(dummy_service_fn),
309        ));
310        assert_is_service(
311            GetForwardedHeaderLayer::forwarded().into_layer(service_fn(dummy_service_fn)),
312        );
313        assert_is_service(GetForwardedHeaderLayer::via().into_layer(service_fn(dummy_service_fn)));
314        assert_is_service(
315            GetForwardedHeaderLayer::<XRealIp>::new().into_layer(service_fn(dummy_service_fn)),
316        );
317    }
318
319    #[tokio::test]
320    async fn test_get_forwarded_header_forwarded() {
321        let service = GetForwardedHeaderLayer::forwarded().into_layer(service_fn(
322            async |ctx: Context<()>, _| {
323                let forwarded = ctx.get::<rama_net::forwarded::Forwarded>().unwrap();
324                assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
325                assert_eq!(forwarded.client_proto(), Some(ForwardedProtocol::HTTP));
326                Ok::<_, Infallible>(())
327            },
328        ));
329
330        let req = Request::builder()
331            .header("Forwarded", "for=\"12.23.34.45:5000\";proto=http")
332            .body(())
333            .unwrap();
334
335        service.serve(Context::default(), req).await.unwrap();
336    }
337
338    #[tokio::test]
339    async fn test_get_forwarded_header_via() {
340        let service =
341            GetForwardedHeaderLayer::via().into_layer(service_fn(async |ctx: Context<()>, _| {
342                let forwarded = ctx.get::<rama_net::forwarded::Forwarded>().unwrap();
343                assert!(forwarded.client_ip().is_none());
344                assert_eq!(
345                    forwarded.iter().next().unwrap().ref_forwarded_by(),
346                    Some(&(IpAddr::from([12, 23, 34, 45]), 5000).into())
347                );
348                assert!(forwarded.client_proto().is_none());
349                assert_eq!(forwarded.client_version(), Some(ForwardedVersion::HTTP_11));
350                Ok::<_, Infallible>(())
351            }));
352
353        let req = Request::builder()
354            .header("Via", "1.1 12.23.34.45:5000")
355            .body(())
356            .unwrap();
357
358        service.serve(Context::default(), req).await.unwrap();
359    }
360
361    #[tokio::test]
362    async fn test_get_forwarded_header_x_forwarded_for() {
363        let service = GetForwardedHeaderLayer::x_forwarded_for().into_layer(service_fn(
364            async |ctx: Context<()>, _| {
365                let forwarded = ctx.get::<rama_net::forwarded::Forwarded>().unwrap();
366                assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
367                assert!(forwarded.client_proto().is_none());
368                Ok::<_, Infallible>(())
369            },
370        ));
371
372        let req = Request::builder()
373            .header("X-Forwarded-For", "12.23.34.45, 127.0.0.1")
374            .body(())
375            .unwrap();
376
377        service.serve(Context::default(), req).await.unwrap();
378    }
379}