1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3use std::{
4 error::Error,
5 fmt,
6 marker::Sync,
7 net::{IpAddr, SocketAddr},
8 str::FromStr,
9};
10
11use axum::{
12 extract::{ConnectInfo, Extension, FromRequestParts},
13 http::{StatusCode, request::Parts},
14 response::{IntoResponse, Response},
15};
16
17macro_rules! define_extractor {
19 (
20 $(#[$meta:meta])*
21 $newtype:ident,
22 $extractor:path
23 ) => {
24 $(#[$meta])*
25 #[derive(Debug, Clone, Copy)]
26 pub struct $newtype(pub std::net::IpAddr);
27
28 impl $newtype {
29 fn ip_from_headers(headers: &axum::http::HeaderMap) -> Result<std::net::IpAddr, Rejection> {
30 Ok($extractor(&headers)?)
31 }
32 }
33
34 impl<S> axum::extract::FromRequestParts<S> for $newtype
35 where
36 S: Sync,
37 {
38 type Rejection = Rejection;
39
40 async fn from_request_parts(
41 parts: &mut axum::http::request::Parts,
42 _state: &S,
43 ) -> Result<Self, Self::Rejection> {
44 Self::ip_from_headers(&parts.headers).map(Self)
45 }
46 }
47 };
48}
49
50define_extractor!(
51 CfConnectingIp,
53 client_ip::cf_connecting_ip
54);
55
56define_extractor!(
57 CloudFrontViewerAddress,
59 client_ip::cloudfront_viewer_address
60);
61
62define_extractor!(
63 FlyClientIp,
70 client_ip::fly_client_ip
71);
72
73#[cfg(feature = "forwarded-header")]
74define_extractor!(
75 RightmostForwarded,
77 client_ip::rightmost_forwarded
78);
79
80define_extractor!(
81 RightmostXForwardedFor,
83 client_ip::rightmost_x_forwarded_for
84);
85
86define_extractor!(
87 TrueClientIp,
89 client_ip::true_client_ip
90);
91
92define_extractor!(
93 XRealIp,
95 client_ip::x_real_ip
96);
97
98#[derive(Debug, Clone, Copy)]
108pub struct ClientIp(pub IpAddr);
109
110#[non_exhaustive]
112#[derive(Clone, Debug)]
113#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
114pub enum ClientIpSource {
115 CfConnectingIp,
117 CloudFrontViewerAddress,
119 ConnectInfo,
121 FlyClientIp,
123 #[cfg(feature = "forwarded-header")]
124 RightmostForwarded,
126 RightmostXForwardedFor,
128 TrueClientIp,
130 XRealIp,
132}
133
134impl ClientIpSource {
135 pub const fn into_extension(self) -> Extension<Self> {
138 Extension(self)
139 }
140}
141
142#[derive(Debug)]
144pub struct ParseClientIpSourceError(String);
145
146impl fmt::Display for ParseClientIpSourceError {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 write!(f, "Invalid ClientIpSource value {}", self.0)
149 }
150}
151
152impl Error for ParseClientIpSourceError {}
153
154impl FromStr for ClientIpSource {
155 type Err = ParseClientIpSourceError;
156
157 fn from_str(s: &str) -> Result<Self, Self::Err> {
158 Ok(match s {
159 "CfConnectingIp" => Self::CfConnectingIp,
160 "CloudFrontViewerAddress" => Self::CloudFrontViewerAddress,
161 "ConnectInfo" => Self::ConnectInfo,
162 "FlyClientIp" => Self::FlyClientIp,
163 #[cfg(feature = "forwarded-header")]
164 "RightmostForwarded" => Self::RightmostForwarded,
165 "RightmostXForwardedFor" => Self::RightmostXForwardedFor,
166 "TrueClientIp" => Self::TrueClientIp,
167 "XRealIp" => Self::XRealIp,
168 _ => return Err(ParseClientIpSourceError(s.to_string())),
169 })
170 }
171}
172
173impl<S> FromRequestParts<S> for ClientIp
174where
175 S: Sync,
176{
177 type Rejection = Rejection;
178
179 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
180 let Some(ip_source) = parts.extensions.get() else {
181 return Err(Rejection::NoClientIpSource);
182 };
183
184 match ip_source {
185 ClientIpSource::CfConnectingIp => CfConnectingIp::ip_from_headers(&parts.headers),
186 ClientIpSource::CloudFrontViewerAddress => {
187 CloudFrontViewerAddress::ip_from_headers(&parts.headers)
188 }
189 ClientIpSource::ConnectInfo => parts
190 .extensions
191 .get::<ConnectInfo<SocketAddr>>()
192 .map(|ConnectInfo(addr)| addr.ip())
193 .ok_or_else(|| Rejection::NoConnectInfo),
194 ClientIpSource::FlyClientIp => FlyClientIp::ip_from_headers(&parts.headers),
195 #[cfg(feature = "forwarded-header")]
196 ClientIpSource::RightmostForwarded => {
197 RightmostForwarded::ip_from_headers(&parts.headers)
198 }
199 ClientIpSource::RightmostXForwardedFor => {
200 RightmostXForwardedFor::ip_from_headers(&parts.headers)
201 }
202 ClientIpSource::TrueClientIp => TrueClientIp::ip_from_headers(&parts.headers),
203 ClientIpSource::XRealIp => XRealIp::ip_from_headers(&parts.headers),
204 }
205 .map(Self)
206 }
207}
208
209#[non_exhaustive]
211#[derive(Debug, PartialEq)]
212pub enum Rejection {
213 NoConnectInfo,
215 NoClientIpSource,
217 ClientIp(client_ip::Error),
219}
220
221impl From<client_ip::Error> for Rejection {
222 fn from(value: client_ip::Error) -> Self {
223 Self::ClientIp(value)
224 }
225}
226
227impl fmt::Display for Rejection {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 match self {
230 Rejection::NoConnectInfo => {
231 write!(f, "Add `axum::extract::ConnectInfo` to request extensions")
232 }
233 Rejection::NoClientIpSource => write!(
234 f,
235 "Add `axum_client_ip::ClientIpSource` to request extensions"
236 ),
237 Rejection::ClientIp(e) => write!(f, "{e}"),
238 }
239 }
240}
241
242impl std::error::Error for Rejection {}
243
244impl IntoResponse for Rejection {
245 fn into_response(self) -> Response {
246 let title = match self {
247 Self::NoConnectInfo | Self::NoClientIpSource => "500 Axum Misconfiguration",
248 Self::ClientIp { .. } => "500 Proxy Server Misconfiguration",
249 };
250 let footer = "(the request is rejected by axum-client-ip)";
251 let text = format!("{title}\n\n{self}\n\n{footer}");
252 (StatusCode::INTERNAL_SERVER_ERROR, text).into_response()
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use axum::{
259 Router,
260 body::Body,
261 http::{Request, StatusCode},
262 routing::get,
263 };
264 use http_body_util::BodyExt;
265 use tower::ServiceExt;
266
267 #[cfg(feature = "forwarded-header")]
268 use super::RightmostForwarded;
269 use super::{CfConnectingIp, FlyClientIp, RightmostXForwardedFor, TrueClientIp, XRealIp};
270 use crate::CloudFrontViewerAddress;
271
272 const VALID_IPV4: &str = "1.2.3.4";
273 const VALID_IPV6: &str = "1:23:4567:89ab:c:d:e:f";
274
275 async fn body_to_string(body: Body) -> String {
276 let bytes = body.collect().await.unwrap().to_bytes();
277 String::from_utf8_lossy(&bytes).into()
278 }
279
280 #[tokio::test]
281 async fn cf_connecting_ip() {
282 let header = "cf-connecting-ip";
283
284 fn app() -> Router {
285 Router::new().route(
286 "/",
287 get(|ip: CfConnectingIp| async move { ip.0.to_string() }),
288 )
289 }
290
291 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
292 let resp = app().oneshot(req).await.unwrap();
293 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
294
295 let req = Request::builder()
296 .uri("/")
297 .header(header, VALID_IPV4)
298 .body(Body::empty())
299 .unwrap();
300 let resp = app().oneshot(req).await.unwrap();
301 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
302
303 let req = Request::builder()
304 .uri("/")
305 .header(header, VALID_IPV6)
306 .body(Body::empty())
307 .unwrap();
308 let resp = app().oneshot(req).await.unwrap();
309 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
310 }
311
312 #[tokio::test]
313 async fn cloudfront_viewer_address() {
314 let header = "cloudfront-viewer-address";
315
316 let valid_header_value_v4 = format!("{VALID_IPV4}:8000");
317 let valid_header_value_v6 = format!("{VALID_IPV6}:8000");
318
319 fn app() -> Router {
320 Router::new().route(
321 "/",
322 get(|ip: CloudFrontViewerAddress| async move { ip.0.to_string() }),
323 )
324 }
325
326 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
327 let resp = app().oneshot(req).await.unwrap();
328 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
329
330 let req = Request::builder()
331 .uri("/")
332 .header(header, &valid_header_value_v4)
333 .body(Body::empty())
334 .unwrap();
335 let resp = app().oneshot(req).await.unwrap();
336 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
337
338 let req = Request::builder()
339 .uri("/")
340 .header(header, &valid_header_value_v6)
341 .body(Body::empty())
342 .unwrap();
343 let resp = app().oneshot(req).await.unwrap();
344 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
345 }
346
347 #[tokio::test]
348 async fn fly_client_ip() {
349 let header = "fly-client-ip";
350
351 fn app() -> Router {
352 Router::new().route("/", get(|ip: FlyClientIp| async move { ip.0.to_string() }))
353 }
354
355 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
356 let resp = app().oneshot(req).await.unwrap();
357 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
358
359 let req = Request::builder()
360 .uri("/")
361 .header(header, VALID_IPV4)
362 .body(Body::empty())
363 .unwrap();
364 let resp = app().oneshot(req).await.unwrap();
365 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
366
367 let req = Request::builder()
368 .uri("/")
369 .header(header, VALID_IPV6)
370 .body(Body::empty())
371 .unwrap();
372 let resp = app().oneshot(req).await.unwrap();
373 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
374 }
375
376 #[cfg(feature = "forwarded-header")]
377 #[tokio::test]
378 async fn rightmost_forwarded() {
379 let header = "forwarded";
380
381 fn app() -> Router {
382 Router::new().route(
383 "/",
384 get(|ip: RightmostForwarded| async move { ip.0.to_string() }),
385 )
386 }
387
388 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
389 let resp = app().oneshot(req).await.unwrap();
390 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
391
392 let req = Request::builder()
393 .uri("/")
394 .header(header, format!("for=[{VALID_IPV6}]:8000"))
395 .body(Body::empty())
396 .unwrap();
397 let resp = app().oneshot(req).await.unwrap();
398 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
399
400 let req = Request::builder()
401 .uri("/")
402 .header("Forwarded", r#"for="_mdn""#)
403 .header("Forwarded", r#"For="[2001:db8:cafe::17]:4711""#)
404 .header("Forwarded", r#"for=192.0.2.60;proto=http;by=203.0.113.43"#)
405 .body(Body::empty())
406 .unwrap();
407 let resp = app().oneshot(req).await.unwrap();
408 assert_eq!(body_to_string(resp.into_body()).await, "192.0.2.60");
409 }
410
411 #[tokio::test]
412 async fn rightmost_x_forwarded_for() {
413 fn app() -> Router {
414 Router::new().route(
415 "/",
416 get(|ip: RightmostXForwardedFor| async move { ip.0.to_string() }),
417 )
418 }
419
420 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
421 let resp = app().oneshot(req).await.unwrap();
422 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
423
424 let req = Request::builder()
425 .uri("/")
426 .header(
427 "X-Forwarded-For",
428 "1.1.1.1, foo, 2001:db8:85a3:8d3:1319:8a2e:370:7348",
429 )
430 .header("X-Forwarded-For", "bar")
431 .header("X-Forwarded-For", format!("2.2.2.2, {VALID_IPV4}"))
432 .body(Body::empty())
433 .unwrap();
434 let resp = app().oneshot(req).await.unwrap();
435 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
436 }
437
438 #[tokio::test]
439 async fn true_client_ip() {
440 let header = "true-client-ip";
441
442 fn app() -> Router {
443 Router::new().route("/", get(|ip: TrueClientIp| async move { ip.0.to_string() }))
444 }
445
446 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
447 let resp = app().oneshot(req).await.unwrap();
448 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
449
450 let req = Request::builder()
451 .uri("/")
452 .header(header, VALID_IPV4)
453 .body(Body::empty())
454 .unwrap();
455 let resp = app().oneshot(req).await.unwrap();
456 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
457
458 let req = Request::builder()
459 .uri("/")
460 .header(header, VALID_IPV6)
461 .body(Body::empty())
462 .unwrap();
463 let resp = app().oneshot(req).await.unwrap();
464 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
465 }
466
467 #[tokio::test]
468 async fn x_real_ip() {
469 let header = "x-real-ip";
470
471 fn app() -> Router {
472 Router::new().route("/", get(|ip: XRealIp| async move { ip.0.to_string() }))
473 }
474
475 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
476 let resp = app().oneshot(req).await.unwrap();
477 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
478
479 let req = Request::builder()
480 .uri("/")
481 .header(header, VALID_IPV4)
482 .body(Body::empty())
483 .unwrap();
484 let resp = app().oneshot(req).await.unwrap();
485 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV4);
486
487 let req = Request::builder()
488 .uri("/")
489 .header(header, VALID_IPV6)
490 .body(Body::empty())
491 .unwrap();
492 let resp = app().oneshot(req).await.unwrap();
493 assert_eq!(body_to_string(resp.into_body()).await, VALID_IPV6);
494 }
495}