axum_client_ip/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::{
3    error::Error,
4    fmt,
5    marker::Sync,
6    net::{IpAddr, SocketAddr},
7    str::FromStr,
8};
9
10use axum::{
11    extract::{ConnectInfo, Extension, FromRequestParts},
12    http::{HeaderMap, HeaderName, StatusCode, request::Parts},
13    response::{IntoResponse, Response},
14};
15use serde::{Deserialize, Serialize};
16
17/// An internal helper trait to extract an IP from headers
18trait IpExtractor {
19    const HEADER_NAME: HeaderName;
20
21    /// Extracts IP from decoded header value. Default implementation assumes
22    /// the header value is just a valid IP.
23    fn ip_from_header_value(header_value: &str) -> Result<IpAddr, Rejection> {
24        header_value
25            .trim()
26            .parse()
27            .map_err(|_| Rejection::MalformedHeaderValue {
28                header_name: Self::HEADER_NAME,
29                header_value: header_value.to_owned(),
30            })
31    }
32
33    /// Extracts an IP from headers.
34    fn ip_from_headers(headers: &HeaderMap) -> Result<IpAddr, Rejection> {
35        let header_value = Self::last_header_value(headers)?;
36        Self::ip_from_header_value(header_value)
37    }
38
39    /// Returns a decoded value of the last occurring header. Can also be used
40    /// for a header occurring only once.
41    fn last_header_value(headers: &HeaderMap) -> Result<&str, Rejection> {
42        headers
43            .get_all(Self::HEADER_NAME)
44            .into_iter()
45            .last()
46            .ok_or_else(|| Rejection::AbsentHeader {
47                header_name: Self::HEADER_NAME,
48            })?
49            .to_str()
50            .map_err(|_| Rejection::NonAsciiHeaderValue {
51                header_name: Self::HEADER_NAME,
52            })
53    }
54}
55
56/// Implements default [`IpExtractor`]
57macro_rules! impl_default_ip_extractor {
58    ($type:ty, $header:literal) => {
59        impl IpExtractor for $type {
60            const HEADER_NAME: HeaderName = HeaderName::from_static($header);
61        }
62
63        impl<S> FromRequestParts<S> for $type
64        where
65            S: Sync,
66        {
67            type Rejection = Rejection;
68
69            async fn from_request_parts(
70                parts: &mut Parts,
71                _state: &S,
72            ) -> Result<Self, Self::Rejection> {
73                Self::ip_from_headers(&parts.headers).map(Self)
74            }
75        }
76    };
77}
78
79/// Extracts an IP from `CF-Connecting-IP` (Cloudflare) header
80#[derive(Debug, Clone, Copy)]
81pub struct CfConnectingIp(pub IpAddr);
82
83impl_default_ip_extractor!(CfConnectingIp, "cf-connecting-ip");
84
85/// Extracts an IP from `CloudFront-Viewer-Address` (AWS CloudFront) header
86#[derive(Debug, Clone, Copy)]
87pub struct CloudFrontViewerAddress(pub IpAddr);
88
89impl IpExtractor for CloudFrontViewerAddress {
90    const HEADER_NAME: HeaderName = HeaderName::from_static("cloudfront-viewer-address");
91
92    fn ip_from_header_value(header_value: &str) -> Result<IpAddr, Rejection> {
93        // Spec: https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/adding-cloudfront-headers.html#cloudfront-headers-viewer-location
94        // Note: Both IPv4 and IPv6 addresses (in the specified format) do not contain
95        //       non-ascii characters, so no need to handle percent-encoding.
96        //
97        // CloudFront does not use `[::]:12345` style notation for IPv6 (unfortunately),
98        // otherwise parsing via `SocketAddr` would be possible.
99        header_value
100            .rsplit_once(':')
101            .map(|(ip, _port)| ip)
102            .ok_or_else(|| Rejection::MalformedHeaderValue {
103                header_name: Self::HEADER_NAME,
104                header_value: header_value.to_owned(),
105            })?
106            .parse::<IpAddr>()
107            .map_err(|_| Rejection::MalformedHeaderValue {
108                header_name: Self::HEADER_NAME,
109                header_value: header_value.to_owned(),
110            })
111    }
112}
113
114impl<S> FromRequestParts<S> for CloudFrontViewerAddress
115where
116    S: Sync,
117{
118    type Rejection = Rejection;
119
120    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
121        Self::ip_from_headers(&parts.headers).map(Self)
122    }
123}
124
125/// Extracts an IP from `Fly-Client-IP` (Fly.io) header
126///
127/// When [`FlyClientIp`] extractor is run for health check path,
128/// provide required `Fly-Client-IP` header through
129/// [`services.http_checks.headers`](https://fly.io/docs/reference/configuration/#services-http_checks)
130/// or [`http_service.checks.headers`](https://fly.io/docs/reference/configuration/#services-http_checks)
131#[derive(Debug, Clone, Copy)]
132pub struct FlyClientIp(pub IpAddr);
133
134impl_default_ip_extractor!(FlyClientIp, "fly-client-ip");
135
136/// Extracts the rightmost IP from `Forwarded` header
137#[derive(Debug, Clone, Copy)]
138pub struct RightmostForwarded(pub IpAddr);
139
140impl IpExtractor for RightmostForwarded {
141    const HEADER_NAME: HeaderName = HeaderName::from_static("forwarded");
142
143    fn ip_from_header_value(header_value: &str) -> Result<IpAddr, Rejection> {
144        use forwarded_header_value::{ForwardedHeaderValue, Identifier};
145
146        let stanza = ForwardedHeaderValue::from_forwarded(header_value)
147            .map_err(|_| Rejection::MalformedHeaderValue {
148                header_name: Self::HEADER_NAME,
149                header_value: header_value.to_owned(),
150            })?
151            .into_iter()
152            .last()
153            .ok_or_else(|| Rejection::MalformedHeaderValue {
154                header_name: Self::HEADER_NAME,
155                header_value: header_value.to_owned(),
156            })?;
157
158        let forwarded_for = stanza
159            .forwarded_for
160            .ok_or_else(|| Rejection::ForwardedNoFor {
161                header_value: header_value.to_owned(),
162            })?;
163
164        match forwarded_for {
165            Identifier::SocketAddr(a) => Ok(a.ip()),
166            Identifier::IpAddr(ip) => Ok(ip),
167            Identifier::String(_) => Err(Rejection::ForwardedObfuscated {
168                header_value: header_value.to_owned(),
169            }),
170            Identifier::Unknown => Err(Rejection::ForwardedUnknown {
171                header_value: header_value.to_owned(),
172            }),
173        }
174    }
175}
176
177impl<S> FromRequestParts<S> for RightmostForwarded
178where
179    S: Sync,
180{
181    type Rejection = Rejection;
182
183    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
184        Self::ip_from_headers(&parts.headers).map(Self)
185    }
186}
187
188/// Extracts the rightmost IP from `X-Forwarded-For` header
189#[derive(Debug, Clone, Copy)]
190pub struct RightmostXForwardedFor(pub IpAddr);
191
192impl IpExtractor for RightmostXForwardedFor {
193    const HEADER_NAME: HeaderName = HeaderName::from_static("x-forwarded-for");
194
195    fn ip_from_header_value(header_value: &str) -> Result<IpAddr, Rejection> {
196        header_value
197            .split(',')
198            .last()
199            .ok_or_else(|| Rejection::MalformedHeaderValue {
200                header_name: Self::HEADER_NAME,
201                header_value: header_value.to_owned(),
202            })?
203            .trim()
204            .parse::<IpAddr>()
205            .map_err(|_| Rejection::MalformedHeaderValue {
206                header_name: Self::HEADER_NAME,
207                header_value: header_value.to_owned(),
208            })
209    }
210}
211
212impl<S> FromRequestParts<S> for RightmostXForwardedFor
213where
214    S: Sync,
215{
216    type Rejection = Rejection;
217
218    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
219        Self::ip_from_headers(&parts.headers).map(Self)
220    }
221}
222
223/// Extracts an IP from `True-Client-IP` (Akamai, Cloudflare) header
224#[derive(Debug, Clone, Copy)]
225pub struct TrueClientIp(pub IpAddr);
226
227impl_default_ip_extractor!(TrueClientIp, "true-client-ip");
228
229/// Extracts an IP from `X-Real-Ip` (Nginx) header
230#[derive(Debug, Clone, Copy)]
231pub struct XRealIp(pub IpAddr);
232
233impl_default_ip_extractor!(XRealIp, "x-real-ip");
234
235/// Client IP extractor with configurable source
236///
237/// The configuration would include knowing the header the last proxy (the
238/// one you own or the one your cloud server provides) is using to store
239/// user connection IP. Then you'd need to pass a corresponding
240/// [`ClientIpSource`] variant into the [`axum::routing::Router::layer`] as
241/// an extension. Look at the [example][].
242///
243/// [example]: https://github.com/imbolc/axum-client-ip/blob/main/examples/integration.rs
244#[derive(Debug, Clone, Copy)]
245pub struct ClientIp(pub IpAddr);
246
247/// [`ClientIp`] source configuration
248#[derive(Clone, Debug, Deserialize, Serialize)]
249pub enum ClientIpSource {
250    /// IP from the `CF-Connecting-IP` header
251    CfConnectingIp,
252    /// IP from the `CloudFront-Viewer-Address` header
253    CloudFrontViewerAddress,
254    /// IP from the [`axum::extract::ConnectInfo`]
255    ConnectInfo,
256    /// IP from the `Fly-Client-IP` header
257    FlyClientIp,
258    /// Rightmost IP from the `Forwarded` header
259    RightmostForwarded,
260    /// Rightmost IP from the `X-Forwarded-For` header
261    RightmostXForwardedFor,
262    /// IP from the `True-Client-IP` header
263    TrueClientIp,
264    /// IP from the `X-Real-Ip` header
265    XRealIp,
266}
267
268impl ClientIpSource {
269    /// Wraps [`ClientIpSource`] into the [`axum::extract::Extension`]
270    /// for passing to [`axum::routing::Router::layer`]
271    pub const fn into_extension(self) -> Extension<Self> {
272        Extension(self)
273    }
274}
275
276/// Invalid [`ClientIpSource`]
277#[derive(Debug)]
278pub struct ParseClientIpSourceError(String);
279
280impl fmt::Display for ParseClientIpSourceError {
281    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282        write!(f, "Invalid ClientIpSource value {}", self.0)
283    }
284}
285
286impl Error for ParseClientIpSourceError {}
287
288impl FromStr for ClientIpSource {
289    type Err = ParseClientIpSourceError;
290
291    fn from_str(s: &str) -> Result<Self, Self::Err> {
292        Ok(match s {
293            "RightmostForwarded" => Self::RightmostForwarded,
294            "RightmostXForwardedFor" => Self::RightmostXForwardedFor,
295            "XRealIp" => Self::XRealIp,
296            "FlyClientIp" => Self::FlyClientIp,
297            "TrueClientIp" => Self::TrueClientIp,
298            "CfConnectingIp" => Self::CfConnectingIp,
299            "ConnectInfo" => Self::ConnectInfo,
300            "CloudFrontViewerAddress" => Self::CloudFrontViewerAddress,
301            _ => return Err(ParseClientIpSourceError(s.to_string())),
302        })
303    }
304}
305
306impl<S> FromRequestParts<S> for ClientIp
307where
308    S: Sync,
309{
310    type Rejection = Rejection;
311
312    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
313        let Some(ip_source) = parts.extensions.get() else {
314            return Err(Rejection::NoClientIpSource);
315        };
316
317        match ip_source {
318            ClientIpSource::CfConnectingIp => CfConnectingIp::ip_from_headers(&parts.headers),
319            ClientIpSource::CloudFrontViewerAddress => {
320                CloudFrontViewerAddress::ip_from_headers(&parts.headers)
321            }
322            ClientIpSource::ConnectInfo => parts
323                .extensions
324                .get::<ConnectInfo<SocketAddr>>()
325                .map(|ConnectInfo(addr)| addr.ip())
326                .ok_or_else(|| Rejection::NoConnectInfo),
327            ClientIpSource::FlyClientIp => FlyClientIp::ip_from_headers(&parts.headers),
328            ClientIpSource::RightmostForwarded => {
329                RightmostForwarded::ip_from_headers(&parts.headers)
330            }
331            ClientIpSource::RightmostXForwardedFor => {
332                RightmostXForwardedFor::ip_from_headers(&parts.headers)
333            }
334            ClientIpSource::TrueClientIp => TrueClientIp::ip_from_headers(&parts.headers),
335            ClientIpSource::XRealIp => XRealIp::ip_from_headers(&parts.headers),
336        }
337        .map(Self)
338    }
339}
340
341/// Rejection type for IP extractors
342#[derive(Debug, PartialEq)]
343pub enum Rejection {
344    /// No [`axum::extract::ConnectInfo`] in extensions
345    NoConnectInfo,
346    /// No [`ClientIpSource`] in extensions
347    NoClientIpSource,
348    /// The IP-related header is missing
349    AbsentHeader {
350        /// Header name
351        header_name: HeaderName,
352    },
353    /// Header value contains not only visible ASCII characters
354    NonAsciiHeaderValue {
355        /// Header name
356        header_name: HeaderName,
357    },
358    /// Header value has an unexpected format
359    MalformedHeaderValue {
360        /// Header name
361        header_name: HeaderName,
362        /// Header value
363        header_value: String,
364    },
365    /// Forwarded header doesn't contain `for` directive
366    ForwardedNoFor {
367        /// Header value
368        header_value: String,
369    },
370    /// RFC 7239 allows to [obfuscate IPs](https://www.rfc-editor.org/rfc/rfc7239.html#section-6.3)
371    ForwardedObfuscated {
372        /// Header value
373        header_value: String,
374    },
375    /// RFC 7239 allows [unknown identifiers](https://www.rfc-editor.org/rfc/rfc7239.html#section-6.2)
376    ForwardedUnknown {
377        /// Header value
378        header_value: String,
379    },
380}
381
382impl fmt::Display for Rejection {
383    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384        match self {
385            Rejection::NoConnectInfo => {
386                write!(f, "Add `axum::extract::ConnectInfo` to request extensions")
387            }
388            Rejection::NoClientIpSource => write!(
389                f,
390                "Add `axum_client_ip::ClientIpSource` to request extensions"
391            ),
392            Rejection::AbsentHeader { header_name } => {
393                write!(f, "Missing required header: {header_name}")
394            }
395            Rejection::NonAsciiHeaderValue { header_name } => write!(
396                f,
397                "Header value contains non-ASCII characters: {header_name}",
398            ),
399            Rejection::MalformedHeaderValue {
400                header_name,
401                header_value,
402            } => write!(
403                f,
404                "Malformed header value for `{header_name}`: {header_value}",
405            ),
406            Rejection::ForwardedNoFor { header_value } => write!(
407                f,
408                "`Forwarded` header missing `for` directive: {header_value}",
409            ),
410            Rejection::ForwardedObfuscated { header_value } => write!(
411                f,
412                "`Forwarded` header contains obfuscated IP: {header_value}",
413            ),
414            Rejection::ForwardedUnknown { header_value } => write!(
415                f,
416                "`Forwarded` header contains unknown identifier: {header_value}",
417            ),
418        }
419    }
420}
421
422impl std::error::Error for Rejection {}
423
424impl IntoResponse for Rejection {
425    fn into_response(self) -> Response {
426        let request_issue = (StatusCode::BAD_REQUEST, "400 Bad Request");
427        let proxy_issue = (
428            StatusCode::INTERNAL_SERVER_ERROR,
429            "500 Proxy Server Misconfiguration",
430        );
431        let axum_issue = (
432            StatusCode::INTERNAL_SERVER_ERROR,
433            "500 Axum Misconfiguration",
434        );
435
436        let (code, title) = match self {
437            Self::NoConnectInfo => axum_issue,
438            Self::NoClientIpSource => axum_issue,
439            Self::AbsentHeader { .. } => proxy_issue,
440            Self::NonAsciiHeaderValue { .. } => proxy_issue,
441            Self::MalformedHeaderValue { .. } => proxy_issue,
442            Self::ForwardedNoFor { .. } => proxy_issue,
443            Self::ForwardedObfuscated { .. } => proxy_issue,
444            Self::ForwardedUnknown { .. } => request_issue,
445        };
446
447        let footer = "(the request is rejected by axum-client-ip)";
448        let text = format!("{title}\n\n{self}\n\n{footer}");
449        (code, text).into_response()
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use std::net::IpAddr;
456
457    use axum::{
458        Router,
459        body::Body,
460        http::{HeaderMap, HeaderName, Request, StatusCode},
461        routing::get,
462    };
463    use http_body_util::BodyExt;
464    use tower::ServiceExt;
465
466    use super::{
467        CfConnectingIp, FlyClientIp, RightmostForwarded, RightmostXForwardedFor, TrueClientIp,
468        XRealIp,
469    };
470    use crate::{CloudFrontViewerAddress, IpExtractor, Rejection};
471
472    const VALID_IPV4: &str = "1.2.3.4";
473    const VALID_IPV6: &str = "1:23:4567:89ab:c:d:e:f";
474
475    async fn body_to_string(body: Body) -> String {
476        let bytes = body.collect().await.unwrap().to_bytes();
477        String::from_utf8_lossy(&bytes).into()
478    }
479
480    fn headers<'a>(items: impl IntoIterator<Item = (&'a str, &'a str)>) -> HeaderMap {
481        HeaderMap::from_iter(
482            items
483                .into_iter()
484                .map(|(name, value)| (name.parse().unwrap(), value.parse().unwrap())),
485        )
486    }
487
488    #[tokio::test]
489    async fn cf_connecting_ip() {
490        let header = "cf-connecting-ip";
491
492        assert_eq!(
493            CfConnectingIp::ip_from_headers(&headers([])).unwrap_err(),
494            Rejection::AbsentHeader {
495                header_name: HeaderName::from_static(header)
496            }
497        );
498        assert_eq!(
499            CfConnectingIp::ip_from_headers(&headers([(header, "ы")])).unwrap_err(),
500            Rejection::NonAsciiHeaderValue {
501                header_name: HeaderName::from_static(header)
502            }
503        );
504        assert_eq!(
505            CfConnectingIp::ip_from_headers(&headers([(header, "foo")])).unwrap_err(),
506            Rejection::MalformedHeaderValue {
507                header_name: HeaderName::from_static(header),
508                header_value: "foo".into(),
509            }
510        );
511
512        assert_eq!(
513            CfConnectingIp::ip_from_headers(&headers([(header, VALID_IPV4)])).unwrap(),
514            VALID_IPV4.parse::<IpAddr>().unwrap()
515        );
516        assert_eq!(
517            CfConnectingIp::ip_from_headers(&headers([(header, VALID_IPV6)])).unwrap(),
518            VALID_IPV6.parse::<IpAddr>().unwrap()
519        );
520
521        fn app() -> Router {
522            Router::new().route(
523                "/",
524                get(|ip: CfConnectingIp| async move { ip.0.to_string() }),
525            )
526        }
527
528        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
529        let resp = app().oneshot(req).await.unwrap();
530        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
531
532        let req = Request::builder()
533            .uri("/")
534            .header(header, VALID_IPV4)
535            .body(Body::empty())
536            .unwrap();
537        let resp = app().oneshot(req).await.unwrap();
538        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
539
540        let req = Request::builder()
541            .uri("/")
542            .header(header, VALID_IPV6)
543            .body(Body::empty())
544            .unwrap();
545        let resp = app().oneshot(req).await.unwrap();
546        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
547    }
548
549    #[tokio::test]
550    async fn cloudfront_viewer_address() {
551        let header = "cloudfront-viewer-address";
552
553        assert_eq!(
554            CloudFrontViewerAddress::ip_from_headers(&headers([])).unwrap_err(),
555            Rejection::AbsentHeader {
556                header_name: HeaderName::from_static(header)
557            }
558        );
559        assert_eq!(
560            CloudFrontViewerAddress::ip_from_headers(&headers([(header, "ы")])).unwrap_err(),
561            Rejection::NonAsciiHeaderValue {
562                header_name: HeaderName::from_static(header)
563            }
564        );
565        assert_eq!(
566            CloudFrontViewerAddress::ip_from_headers(&headers([(header, VALID_IPV4)])).unwrap_err(),
567            Rejection::MalformedHeaderValue {
568                header_name: HeaderName::from_static(header),
569                header_value: VALID_IPV4.into(),
570            }
571        );
572        assert_eq!(
573            CloudFrontViewerAddress::ip_from_headers(&headers([(header, "foo:8000")])).unwrap_err(),
574            Rejection::MalformedHeaderValue {
575                header_name: HeaderName::from_static(header),
576                header_value: "foo:8000".into(),
577            }
578        );
579
580        let valid_header_value_v4 = format!("{VALID_IPV4}:8000");
581        let valid_header_value_v6 = format!("{VALID_IPV6}:8000");
582        assert_eq!(
583            CloudFrontViewerAddress::ip_from_headers(&headers([(
584                header,
585                valid_header_value_v4.as_ref()
586            )]))
587            .unwrap(),
588            VALID_IPV4.parse::<IpAddr>().unwrap()
589        );
590        assert_eq!(
591            CloudFrontViewerAddress::ip_from_headers(&headers([(
592                header,
593                valid_header_value_v6.as_ref()
594            )]))
595            .unwrap(),
596            VALID_IPV6.parse::<IpAddr>().unwrap()
597        );
598
599        fn app() -> Router {
600            Router::new().route(
601                "/",
602                get(|ip: CloudFrontViewerAddress| async move { ip.0.to_string() }),
603            )
604        }
605
606        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
607        let resp = app().oneshot(req).await.unwrap();
608        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
609
610        let req = Request::builder()
611            .uri("/")
612            .header(header, &valid_header_value_v4)
613            .body(Body::empty())
614            .unwrap();
615        let resp = app().oneshot(req).await.unwrap();
616        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
617
618        let req = Request::builder()
619            .uri("/")
620            .header(header, &valid_header_value_v6)
621            .body(Body::empty())
622            .unwrap();
623        let resp = app().oneshot(req).await.unwrap();
624        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
625    }
626
627    #[tokio::test]
628    async fn fly_client_ip() {
629        let header = "fly-client-ip";
630
631        assert_eq!(
632            FlyClientIp::ip_from_headers(&headers([])).unwrap_err(),
633            Rejection::AbsentHeader {
634                header_name: HeaderName::from_static(header)
635            }
636        );
637        assert_eq!(
638            FlyClientIp::ip_from_headers(&headers([(header, "ы")])).unwrap_err(),
639            Rejection::NonAsciiHeaderValue {
640                header_name: HeaderName::from_static(header)
641            }
642        );
643        assert_eq!(
644            FlyClientIp::ip_from_headers(&headers([(header, "foo")])).unwrap_err(),
645            Rejection::MalformedHeaderValue {
646                header_name: HeaderName::from_static(header),
647                header_value: "foo".into(),
648            }
649        );
650
651        assert_eq!(
652            FlyClientIp::ip_from_headers(&headers([(header, VALID_IPV4)])).unwrap(),
653            VALID_IPV4.parse::<IpAddr>().unwrap()
654        );
655        assert_eq!(
656            FlyClientIp::ip_from_headers(&headers([(header, VALID_IPV6)])).unwrap(),
657            VALID_IPV6.parse::<IpAddr>().unwrap()
658        );
659
660        fn app() -> Router {
661            Router::new().route("/", get(|ip: FlyClientIp| async move { ip.0.to_string() }))
662        }
663
664        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
665        let resp = app().oneshot(req).await.unwrap();
666        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
667
668        let req = Request::builder()
669            .uri("/")
670            .header(header, VALID_IPV4)
671            .body(Body::empty())
672            .unwrap();
673        let resp = app().oneshot(req).await.unwrap();
674        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
675
676        let req = Request::builder()
677            .uri("/")
678            .header(header, VALID_IPV6)
679            .body(Body::empty())
680            .unwrap();
681        let resp = app().oneshot(req).await.unwrap();
682        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
683    }
684
685    #[tokio::test]
686    async fn rightmost_forwarded() {
687        let header = "forwarded";
688
689        assert_eq!(
690            RightmostForwarded::ip_from_headers(&headers([])).unwrap_err(),
691            Rejection::AbsentHeader {
692                header_name: HeaderName::from_static(header)
693            }
694        );
695        assert_eq!(
696            RightmostForwarded::ip_from_headers(&headers([(header, "ы")])).unwrap_err(),
697            Rejection::NonAsciiHeaderValue {
698                header_name: HeaderName::from_static(header)
699            }
700        );
701        assert_eq!(
702            RightmostForwarded::ip_from_headers(&headers([(header, "foo")])).unwrap_err(),
703            Rejection::MalformedHeaderValue {
704                header_name: HeaderName::from_static(header),
705                header_value: "foo".into(),
706            }
707        );
708        assert_eq!(
709            RightmostForwarded::ip_from_headers(&headers([
710                (header, format!("for={VALID_IPV4}").as_ref()),
711                (header, "proto=http"),
712            ]))
713            .unwrap_err(),
714            Rejection::ForwardedNoFor {
715                header_value: "proto=http".into(),
716            }
717        );
718        assert_eq!(
719            RightmostForwarded::ip_from_headers(&headers([(header, "for=unknown")])).unwrap_err(),
720            Rejection::ForwardedUnknown {
721                header_value: "for=unknown".into(),
722            }
723        );
724        assert_eq!(
725            RightmostForwarded::ip_from_headers(&headers([(header, "for=_foo")])).unwrap_err(),
726            Rejection::ForwardedObfuscated {
727                header_value: "for=_foo".into(),
728            }
729        );
730
731        assert_eq!(
732            RightmostForwarded::ip_from_headers(&headers([
733                (header, "proto=http"),
734                (header, format!("for={VALID_IPV4};proto=http").as_ref()),
735            ]))
736            .unwrap(),
737            VALID_IPV4.parse::<IpAddr>().unwrap()
738        );
739        assert_eq!(
740            RightmostForwarded::ip_from_headers(&headers([(
741                header,
742                format!("for={VALID_IPV4}:8000").as_ref()
743            ),]))
744            .unwrap(),
745            VALID_IPV4.parse::<IpAddr>().unwrap()
746        );
747
748        assert_eq!(
749            RightmostForwarded::ip_from_headers(&headers([(
750                header,
751                format!("for={VALID_IPV6}").as_ref()
752            ),]))
753            .unwrap(),
754            VALID_IPV6.parse::<IpAddr>().unwrap()
755        );
756        assert_eq!(
757            RightmostForwarded::ip_from_headers(&headers([(
758                header,
759                format!("for=[{VALID_IPV6}]:8000").as_ref()
760            ),]))
761            .unwrap(),
762            VALID_IPV6.parse::<IpAddr>().unwrap()
763        );
764
765        fn app() -> Router {
766            Router::new().route(
767                "/",
768                get(|ip: RightmostForwarded| async move { ip.0.to_string() }),
769            )
770        }
771
772        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
773        let resp = app().oneshot(req).await.unwrap();
774        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
775
776        let req = Request::builder()
777            .uri("/")
778            .header(header, format!("for=[{VALID_IPV6}]:8000"))
779            .body(Body::empty())
780            .unwrap();
781        let resp = app().oneshot(req).await.unwrap();
782        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
783
784        let req = Request::builder()
785            .uri("/")
786            .header("Forwarded", r#"for="_mdn""#)
787            .header("Forwarded", r#"For="[2001:db8:cafe::17]:4711""#)
788            .header("Forwarded", r#"for=192.0.2.60;proto=http;by=203.0.113.43"#)
789            .body(Body::empty())
790            .unwrap();
791        let resp = app().oneshot(req).await.unwrap();
792        assert_eq!(body_to_string(resp.into_body()).await, "192.0.2.60");
793    }
794
795    #[tokio::test]
796    async fn rightmost_x_forwarded_for() {
797        let header = "x-forwarded-for";
798
799        assert_eq!(
800            RightmostXForwardedFor::ip_from_headers(&headers([])).unwrap_err(),
801            Rejection::AbsentHeader {
802                header_name: HeaderName::from_static(header)
803            }
804        );
805        assert_eq!(
806            RightmostXForwardedFor::ip_from_headers(&headers([(header, "ы")])).unwrap_err(),
807            Rejection::NonAsciiHeaderValue {
808                header_name: HeaderName::from_static(header)
809            }
810        );
811        assert_eq!(
812            RightmostXForwardedFor::ip_from_headers(&headers([(header, "1.2.3.4,foo")]))
813                .unwrap_err(),
814            Rejection::MalformedHeaderValue {
815                header_name: HeaderName::from_static(header),
816                header_value: "1.2.3.4,foo".into(),
817            }
818        );
819
820        assert_eq!(
821            RightmostXForwardedFor::ip_from_headers(&headers([(
822                header,
823                format!("foo,{VALID_IPV4}").as_ref()
824            )]))
825            .unwrap(),
826            VALID_IPV4.parse::<IpAddr>().unwrap()
827        );
828        assert_eq!(
829            RightmostXForwardedFor::ip_from_headers(&headers([(header, VALID_IPV6)])).unwrap(),
830            VALID_IPV6.parse::<IpAddr>().unwrap()
831        );
832
833        fn app() -> Router {
834            Router::new().route(
835                "/",
836                get(|ip: RightmostXForwardedFor| async move { ip.0.to_string() }),
837            )
838        }
839
840        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
841        let resp = app().oneshot(req).await.unwrap();
842        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
843
844        let req = Request::builder()
845            .uri("/")
846            .header(
847                "X-Forwarded-For",
848                "1.1.1.1, foo, 2001:db8:85a3:8d3:1319:8a2e:370:7348",
849            )
850            .header("X-Forwarded-For", "bar")
851            .header("X-Forwarded-For", format!("2.2.2.2, {VALID_IPV4}"))
852            .body(Body::empty())
853            .unwrap();
854        let resp = app().oneshot(req).await.unwrap();
855        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
856    }
857
858    #[tokio::test]
859    async fn true_client_ip() {
860        let header = "true-client-ip";
861
862        assert_eq!(
863            TrueClientIp::ip_from_headers(&headers([])).unwrap_err(),
864            Rejection::AbsentHeader {
865                header_name: HeaderName::from_static(header)
866            }
867        );
868        assert_eq!(
869            TrueClientIp::ip_from_headers(&headers([(header, "ы")])).unwrap_err(),
870            Rejection::NonAsciiHeaderValue {
871                header_name: HeaderName::from_static(header)
872            }
873        );
874        assert_eq!(
875            TrueClientIp::ip_from_headers(&headers([(header, "foo")])).unwrap_err(),
876            Rejection::MalformedHeaderValue {
877                header_name: HeaderName::from_static(header),
878                header_value: "foo".into(),
879            }
880        );
881
882        assert_eq!(
883            TrueClientIp::ip_from_headers(&headers([(header, VALID_IPV4)])).unwrap(),
884            VALID_IPV4.parse::<IpAddr>().unwrap()
885        );
886        assert_eq!(
887            TrueClientIp::ip_from_headers(&headers([(header, VALID_IPV6)])).unwrap(),
888            VALID_IPV6.parse::<IpAddr>().unwrap()
889        );
890
891        fn app() -> Router {
892            Router::new().route("/", get(|ip: TrueClientIp| async move { ip.0.to_string() }))
893        }
894
895        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
896        let resp = app().oneshot(req).await.unwrap();
897        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
898
899        let req = Request::builder()
900            .uri("/")
901            .header(header, VALID_IPV4)
902            .body(Body::empty())
903            .unwrap();
904        let resp = app().oneshot(req).await.unwrap();
905        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
906
907        let req = Request::builder()
908            .uri("/")
909            .header(header, VALID_IPV6)
910            .body(Body::empty())
911            .unwrap();
912        let resp = app().oneshot(req).await.unwrap();
913        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
914    }
915
916    #[tokio::test]
917    async fn x_real_ip() {
918        let header = "x-real-ip";
919
920        assert_eq!(
921            XRealIp::ip_from_headers(&headers([])).unwrap_err(),
922            Rejection::AbsentHeader {
923                header_name: HeaderName::from_static(header)
924            }
925        );
926        assert_eq!(
927            XRealIp::ip_from_headers(&headers([(header, "ы")])).unwrap_err(),
928            Rejection::NonAsciiHeaderValue {
929                header_name: HeaderName::from_static(header)
930            }
931        );
932        assert_eq!(
933            XRealIp::ip_from_headers(&headers([(header, "foo")])).unwrap_err(),
934            Rejection::MalformedHeaderValue {
935                header_name: HeaderName::from_static(header),
936                header_value: "foo".into(),
937            }
938        );
939
940        assert_eq!(
941            XRealIp::ip_from_headers(&headers([(header, VALID_IPV4)])).unwrap(),
942            VALID_IPV4.parse::<IpAddr>().unwrap()
943        );
944        assert_eq!(
945            XRealIp::ip_from_headers(&headers([(header, VALID_IPV6)])).unwrap(),
946            VALID_IPV6.parse::<IpAddr>().unwrap()
947        );
948
949        fn app() -> Router {
950            Router::new().route("/", get(|ip: XRealIp| async move { ip.0.to_string() }))
951        }
952
953        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
954        let resp = app().oneshot(req).await.unwrap();
955        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
956
957        let req = Request::builder()
958            .uri("/")
959            .header(header, VALID_IPV4)
960            .body(Body::empty())
961            .unwrap();
962        let resp = app().oneshot(req).await.unwrap();
963        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
964
965        let req = Request::builder()
966            .uri("/")
967            .header(header, VALID_IPV6)
968            .body(Body::empty())
969            .unwrap();
970        let resp = app().oneshot(req).await.unwrap();
971        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
972    }
973}