axum_client_ip/
lib.rs

1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3use std::{
4    error::Error,
5    fmt,
6    marker::Sync,
7    net::{IpAddr, SocketAddr},
8    str::FromStr,
9};
10
11use axum::{
12    extract::{ConnectInfo, Extension, FromRequestParts},
13    http::{StatusCode, request::Parts},
14    response::{IntoResponse, Response},
15};
16
17/// Defines an extractor
18macro_rules! define_extractor {
19    (
20        $(#[$meta:meta])*
21        $newtype:ident,
22        $extractor:path
23    ) => {
24        $(#[$meta])*
25        #[derive(Debug, Clone, Copy)]
26        pub struct $newtype(pub std::net::IpAddr);
27
28        impl $newtype {
29            fn ip_from_headers(headers: &axum::http::HeaderMap) -> Result<std::net::IpAddr, Rejection> {
30                Ok($extractor(&headers)?)
31            }
32        }
33
34        impl<S> axum::extract::FromRequestParts<S> for $newtype
35        where
36            S: Sync,
37        {
38            type Rejection = Rejection;
39
40            async fn from_request_parts(
41                parts: &mut axum::http::request::Parts,
42                _state: &S,
43            ) -> Result<Self, Self::Rejection> {
44                Self::ip_from_headers(&parts.headers).map(Self)
45            }
46        }
47    };
48}
49
50define_extractor!(
51    /// Extracts an IP from `CF-Connecting-IP` (Cloudflare) header
52    CfConnectingIp,
53    client_ip::cf_connecting_ip
54);
55
56define_extractor!(
57    /// Extracts an IP from `CloudFront-Viewer-Address` (AWS CloudFront) header
58    CloudFrontViewerAddress,
59    client_ip::cloudfront_viewer_address
60);
61
62define_extractor!(
63    /// Extracts an IP from `Fly-Client-IP` (Fly.io) header
64    ///
65    /// When [`FlyClientIp`] extractor is run for health check path,
66    /// provide required `Fly-Client-IP` header through
67    /// [`services.http_checks.headers`](https://fly.io/docs/reference/configuration/#services-http_checks)
68    /// or [`http_service.checks.headers`](https://fly.io/docs/reference/configuration/#services-http_checks)
69    FlyClientIp,
70    client_ip::fly_client_ip
71);
72
73#[cfg(feature = "forwarded-header")]
74define_extractor!(
75    /// Extracts the rightmost IP from `Forwarded` header
76    RightmostForwarded,
77    client_ip::rightmost_forwarded
78);
79
80define_extractor!(
81    /// Extracts the rightmost IP from `X-Forwarded-For` header
82    RightmostXForwardedFor,
83    client_ip::rightmost_x_forwarded_for
84);
85
86define_extractor!(
87    /// Extracts an IP from `True-Client-IP` (Akamai, Cloudflare) header
88    TrueClientIp,
89    client_ip::true_client_ip
90);
91
92define_extractor!(
93    /// Extracts an IP from `X-Real-Ip` (Nginx) header
94    XRealIp,
95    client_ip::x_real_ip
96);
97
98/// Client IP extractor with configurable source
99///
100/// The configuration would include knowing the header the last proxy (the
101/// one you own or the one your cloud server provides) is using to store
102/// user connection IP. Then you'd need to pass a corresponding
103/// [`ClientIpSource`] variant into the [`axum::routing::Router::layer`] as
104/// an extension. Look at the [example][].
105///
106/// [example]: https://github.com/imbolc/axum-client-ip/blob/main/examples/integration.rs
107#[derive(Debug, Clone, Copy)]
108pub struct ClientIp(pub IpAddr);
109
110/// [`ClientIp`] source configuration
111#[non_exhaustive]
112#[derive(Clone, Debug)]
113#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
114pub enum ClientIpSource {
115    /// IP from the `CF-Connecting-IP` header
116    CfConnectingIp,
117    /// IP from the `CloudFront-Viewer-Address` header
118    CloudFrontViewerAddress,
119    /// IP from the [`axum::extract::ConnectInfo`]
120    ConnectInfo,
121    /// IP from the `Fly-Client-IP` header
122    FlyClientIp,
123    #[cfg(feature = "forwarded-header")]
124    /// Rightmost IP from the `Forwarded` header
125    RightmostForwarded,
126    /// Rightmost IP from the `X-Forwarded-For` header
127    RightmostXForwardedFor,
128    /// IP from the `True-Client-IP` header
129    TrueClientIp,
130    /// IP from the `X-Real-Ip` header
131    XRealIp,
132}
133
134impl ClientIpSource {
135    /// Wraps [`ClientIpSource`] into the [`axum::extract::Extension`]
136    /// for passing to [`axum::routing::Router::layer`]
137    pub const fn into_extension(self) -> Extension<Self> {
138        Extension(self)
139    }
140}
141
142/// Invalid [`ClientIpSource`]
143#[derive(Debug)]
144pub struct ParseClientIpSourceError(String);
145
146impl fmt::Display for ParseClientIpSourceError {
147    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148        write!(f, "Invalid ClientIpSource value {}", self.0)
149    }
150}
151
152impl Error for ParseClientIpSourceError {}
153
154impl FromStr for ClientIpSource {
155    type Err = ParseClientIpSourceError;
156
157    fn from_str(s: &str) -> Result<Self, Self::Err> {
158        Ok(match s {
159            "CfConnectingIp" => Self::CfConnectingIp,
160            "CloudFrontViewerAddress" => Self::CloudFrontViewerAddress,
161            "ConnectInfo" => Self::ConnectInfo,
162            "FlyClientIp" => Self::FlyClientIp,
163            #[cfg(feature = "forwarded-header")]
164            "RightmostForwarded" => Self::RightmostForwarded,
165            "RightmostXForwardedFor" => Self::RightmostXForwardedFor,
166            "TrueClientIp" => Self::TrueClientIp,
167            "XRealIp" => Self::XRealIp,
168            _ => return Err(ParseClientIpSourceError(s.to_string())),
169        })
170    }
171}
172
173impl<S> FromRequestParts<S> for ClientIp
174where
175    S: Sync,
176{
177    type Rejection = Rejection;
178
179    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
180        let Some(ip_source) = parts.extensions.get() else {
181            return Err(Rejection::NoClientIpSource);
182        };
183
184        match ip_source {
185            ClientIpSource::CfConnectingIp => CfConnectingIp::ip_from_headers(&parts.headers),
186            ClientIpSource::CloudFrontViewerAddress => {
187                CloudFrontViewerAddress::ip_from_headers(&parts.headers)
188            }
189            ClientIpSource::ConnectInfo => parts
190                .extensions
191                .get::<ConnectInfo<SocketAddr>>()
192                .map(|ConnectInfo(addr)| addr.ip())
193                .ok_or_else(|| Rejection::NoConnectInfo),
194            ClientIpSource::FlyClientIp => FlyClientIp::ip_from_headers(&parts.headers),
195            #[cfg(feature = "forwarded-header")]
196            ClientIpSource::RightmostForwarded => {
197                RightmostForwarded::ip_from_headers(&parts.headers)
198            }
199            ClientIpSource::RightmostXForwardedFor => {
200                RightmostXForwardedFor::ip_from_headers(&parts.headers)
201            }
202            ClientIpSource::TrueClientIp => TrueClientIp::ip_from_headers(&parts.headers),
203            ClientIpSource::XRealIp => XRealIp::ip_from_headers(&parts.headers),
204        }
205        .map(Self)
206    }
207}
208
209/// Rejection type for IP extractors
210#[non_exhaustive]
211#[derive(Debug, PartialEq)]
212pub enum Rejection {
213    /// No [`axum::extract::ConnectInfo`] in extensions
214    NoConnectInfo,
215    /// No [`ClientIpSource`] in extensions
216    NoClientIpSource,
217    /// [`client_ip::Error`]
218    ClientIp(client_ip::Error),
219}
220
221impl From<client_ip::Error> for Rejection {
222    fn from(value: client_ip::Error) -> Self {
223        Self::ClientIp(value)
224    }
225}
226
227impl fmt::Display for Rejection {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        match self {
230            Rejection::NoConnectInfo => {
231                write!(f, "Add `axum::extract::ConnectInfo` to request extensions")
232            }
233            Rejection::NoClientIpSource => write!(
234                f,
235                "Add `axum_client_ip::ClientIpSource` to request extensions"
236            ),
237            Rejection::ClientIp(e) => write!(f, "{e}"),
238        }
239    }
240}
241
242impl std::error::Error for Rejection {}
243
244impl IntoResponse for Rejection {
245    fn into_response(self) -> Response {
246        let title = match self {
247            Self::NoConnectInfo | Self::NoClientIpSource => "500 Axum Misconfiguration",
248            Self::ClientIp { .. } => "500 Proxy Server Misconfiguration",
249        };
250        let footer = "(the request is rejected by axum-client-ip)";
251        let text = format!("{title}\n\n{self}\n\n{footer}");
252        (StatusCode::INTERNAL_SERVER_ERROR, text).into_response()
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use axum::{
259        Router,
260        body::Body,
261        http::{Request, StatusCode},
262        routing::get,
263    };
264    use http_body_util::BodyExt;
265    use tower::ServiceExt;
266
267    #[cfg(feature = "forwarded-header")]
268    use super::RightmostForwarded;
269    use super::{CfConnectingIp, FlyClientIp, RightmostXForwardedFor, TrueClientIp, XRealIp};
270    use crate::CloudFrontViewerAddress;
271
272    const VALID_IPV4: &str = "1.2.3.4";
273    const VALID_IPV6: &str = "1:23:4567:89ab:c:d:e:f";
274
275    async fn body_to_string(body: Body) -> String {
276        let bytes = body.collect().await.unwrap().to_bytes();
277        String::from_utf8_lossy(&bytes).into()
278    }
279
280    #[tokio::test]
281    async fn cf_connecting_ip() {
282        let header = "cf-connecting-ip";
283
284        fn app() -> Router {
285            Router::new().route(
286                "/",
287                get(|ip: CfConnectingIp| async move { ip.0.to_string() }),
288            )
289        }
290
291        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
292        let resp = app().oneshot(req).await.unwrap();
293        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
294
295        let req = Request::builder()
296            .uri("/")
297            .header(header, VALID_IPV4)
298            .body(Body::empty())
299            .unwrap();
300        let resp = app().oneshot(req).await.unwrap();
301        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
302
303        let req = Request::builder()
304            .uri("/")
305            .header(header, VALID_IPV6)
306            .body(Body::empty())
307            .unwrap();
308        let resp = app().oneshot(req).await.unwrap();
309        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
310    }
311
312    #[tokio::test]
313    async fn cloudfront_viewer_address() {
314        let header = "cloudfront-viewer-address";
315
316        let valid_header_value_v4 = format!("{VALID_IPV4}:8000");
317        let valid_header_value_v6 = format!("{VALID_IPV6}:8000");
318
319        fn app() -> Router {
320            Router::new().route(
321                "/",
322                get(|ip: CloudFrontViewerAddress| async move { ip.0.to_string() }),
323            )
324        }
325
326        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
327        let resp = app().oneshot(req).await.unwrap();
328        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
329
330        let req = Request::builder()
331            .uri("/")
332            .header(header, &valid_header_value_v4)
333            .body(Body::empty())
334            .unwrap();
335        let resp = app().oneshot(req).await.unwrap();
336        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
337
338        let req = Request::builder()
339            .uri("/")
340            .header(header, &valid_header_value_v6)
341            .body(Body::empty())
342            .unwrap();
343        let resp = app().oneshot(req).await.unwrap();
344        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
345    }
346
347    #[tokio::test]
348    async fn fly_client_ip() {
349        let header = "fly-client-ip";
350
351        fn app() -> Router {
352            Router::new().route("/", get(|ip: FlyClientIp| async move { ip.0.to_string() }))
353        }
354
355        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
356        let resp = app().oneshot(req).await.unwrap();
357        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
358
359        let req = Request::builder()
360            .uri("/")
361            .header(header, VALID_IPV4)
362            .body(Body::empty())
363            .unwrap();
364        let resp = app().oneshot(req).await.unwrap();
365        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
366
367        let req = Request::builder()
368            .uri("/")
369            .header(header, VALID_IPV6)
370            .body(Body::empty())
371            .unwrap();
372        let resp = app().oneshot(req).await.unwrap();
373        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
374    }
375
376    #[cfg(feature = "forwarded-header")]
377    #[tokio::test]
378    async fn rightmost_forwarded() {
379        let header = "forwarded";
380
381        fn app() -> Router {
382            Router::new().route(
383                "/",
384                get(|ip: RightmostForwarded| async move { ip.0.to_string() }),
385            )
386        }
387
388        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
389        let resp = app().oneshot(req).await.unwrap();
390        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
391
392        let req = Request::builder()
393            .uri("/")
394            .header(header, format!("for=[{VALID_IPV6}]:8000"))
395            .body(Body::empty())
396            .unwrap();
397        let resp = app().oneshot(req).await.unwrap();
398        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
399
400        let req = Request::builder()
401            .uri("/")
402            .header("Forwarded", r#"for="_mdn""#)
403            .header("Forwarded", r#"For="[2001:db8:cafe::17]:4711""#)
404            .header("Forwarded", r#"for=192.0.2.60;proto=http;by=203.0.113.43"#)
405            .body(Body::empty())
406            .unwrap();
407        let resp = app().oneshot(req).await.unwrap();
408        assert_eq!(body_to_string(resp.into_body()).await, "192.0.2.60");
409    }
410
411    #[tokio::test]
412    async fn rightmost_x_forwarded_for() {
413        fn app() -> Router {
414            Router::new().route(
415                "/",
416                get(|ip: RightmostXForwardedFor| async move { ip.0.to_string() }),
417            )
418        }
419
420        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
421        let resp = app().oneshot(req).await.unwrap();
422        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
423
424        let req = Request::builder()
425            .uri("/")
426            .header(
427                "X-Forwarded-For",
428                "1.1.1.1, foo, 2001:db8:85a3:8d3:1319:8a2e:370:7348",
429            )
430            .header("X-Forwarded-For", "bar")
431            .header("X-Forwarded-For", format!("2.2.2.2, {VALID_IPV4}"))
432            .body(Body::empty())
433            .unwrap();
434        let resp = app().oneshot(req).await.unwrap();
435        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
436    }
437
438    #[tokio::test]
439    async fn true_client_ip() {
440        let header = "true-client-ip";
441
442        fn app() -> Router {
443            Router::new().route("/", get(|ip: TrueClientIp| async move { ip.0.to_string() }))
444        }
445
446        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
447        let resp = app().oneshot(req).await.unwrap();
448        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
449
450        let req = Request::builder()
451            .uri("/")
452            .header(header, VALID_IPV4)
453            .body(Body::empty())
454            .unwrap();
455        let resp = app().oneshot(req).await.unwrap();
456        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
457
458        let req = Request::builder()
459            .uri("/")
460            .header(header, VALID_IPV6)
461            .body(Body::empty())
462            .unwrap();
463        let resp = app().oneshot(req).await.unwrap();
464        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
465    }
466
467    #[tokio::test]
468    async fn x_real_ip() {
469        let header = "x-real-ip";
470
471        fn app() -> Router {
472            Router::new().route("/", get(|ip: XRealIp| async move { ip.0.to_string() }))
473        }
474
475        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
476        let resp = app().oneshot(req).await.unwrap();
477        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
478
479        let req = Request::builder()
480            .uri("/")
481            .header(header, VALID_IPV4)
482            .body(Body::empty())
483            .unwrap();
484        let resp = app().oneshot(req).await.unwrap();
485        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
486
487        let req = Request::builder()
488            .uri("/")
489            .header(header, VALID_IPV6)
490            .body(Body::empty())
491            .unwrap();
492        let resp = app().oneshot(req).await.unwrap();
493        assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
494    }
495}