use std::net::IpAddr;
use std::sync::Arc;
use axum::body::Body;
use axum::extract::{ConnectInfo, Request};
use axum::middleware::Next;
use axum::Router;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RealIp(pub IpAddr);
#[derive(Clone, Debug)]
pub enum HeaderStrategy {
ForwardedRfc7239,
XForwardedFor,
XRealIp,
CfConnectingIp,
Auto,
}
#[derive(Clone, Debug)]
pub struct RealIpLayer {
pub strategy: HeaderStrategy,
}
impl Default for RealIpLayer {
fn default() -> Self {
Self {
strategy: HeaderStrategy::Auto,
}
}
}
impl RealIpLayer {
#[must_use]
pub fn new(strategy: HeaderStrategy) -> Self {
Self { strategy }
}
}
pub trait RealIpRouterExt {
#[must_use]
fn real_ip(self, layer: RealIpLayer) -> Self;
}
impl<S: Clone + Send + Sync + 'static> RealIpRouterExt for Router<S> {
fn real_ip(self, layer: RealIpLayer) -> Self {
let cfg = Arc::new(layer);
self.layer(axum::middleware::from_fn(
move |mut req: Request<Body>, next: Next| {
let cfg = cfg.clone();
async move {
if let Some(ip) = extract(&req, &cfg.strategy) {
req.extensions_mut().insert(RealIp(ip));
}
next.run(req).await
}
},
))
}
}
fn extract(req: &Request<Body>, strategy: &HeaderStrategy) -> Option<IpAddr> {
let h = req.headers();
match strategy {
HeaderStrategy::ForwardedRfc7239 => {
parse_forwarded_rfc7239(h.get("forwarded").and_then(|v| v.to_str().ok())?)
}
HeaderStrategy::XForwardedFor => {
parse_x_forwarded_for(h.get("x-forwarded-for").and_then(|v| v.to_str().ok())?)
}
HeaderStrategy::XRealIp => h
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse().ok()),
HeaderStrategy::CfConnectingIp => h
.get("cf-connecting-ip")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse().ok()),
HeaderStrategy::Auto => extract(req, &HeaderStrategy::CfConnectingIp)
.or_else(|| extract(req, &HeaderStrategy::ForwardedRfc7239))
.or_else(|| extract(req, &HeaderStrategy::XForwardedFor))
.or_else(|| extract(req, &HeaderStrategy::XRealIp))
.or_else(|| {
req.extensions()
.get::<ConnectInfo<std::net::SocketAddr>>()
.map(|ci| ci.ip())
}),
}
}
fn parse_forwarded_rfc7239(s: &str) -> Option<IpAddr> {
let first = s.split(',').next()?.trim();
for kv in first.split(';') {
let kv = kv.trim();
let (k, v) = kv.split_once('=')?;
if k.eq_ignore_ascii_case("for") {
let v = v.trim().trim_matches('"');
return parse_ip_with_optional_port(v);
}
}
None
}
fn parse_x_forwarded_for(s: &str) -> Option<IpAddr> {
let first = s.split(',').next()?.trim();
parse_ip_with_optional_port(first)
}
fn parse_ip_with_optional_port(s: &str) -> Option<IpAddr> {
let s = s.trim();
if let Some(rest) = s.strip_prefix('[') {
let close = rest.find(']')?;
return rest[..close].parse().ok();
}
if s.matches(':').count() == 1 {
if let Some((ip, _port)) = s.rsplit_once(':') {
return ip.parse().ok();
}
}
s.parse().ok()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::Request;
fn req_with_header(name: &'static str, value: &str) -> Request<Body> {
Request::builder()
.header(name, value)
.body(Body::empty())
.unwrap()
}
#[test]
fn x_forwarded_for_picks_leftmost() {
let r = req_with_header("x-forwarded-for", "1.2.3.4, 10.0.0.1, 172.16.0.5");
let ip = extract(&r, &HeaderStrategy::XForwardedFor).unwrap();
assert_eq!(ip.to_string(), "1.2.3.4");
}
#[test]
fn x_forwarded_for_strips_ipv4_port() {
let r = req_with_header("x-forwarded-for", "203.0.113.7:51234, 10.0.0.1");
let ip = extract(&r, &HeaderStrategy::XForwardedFor).unwrap();
assert_eq!(ip.to_string(), "203.0.113.7");
}
#[test]
fn x_forwarded_for_handles_ipv6_brackets() {
let r = req_with_header("x-forwarded-for", "[2001:db8::1]:443, 10.0.0.1");
let ip = extract(&r, &HeaderStrategy::XForwardedFor).unwrap();
assert_eq!(ip.to_string(), "2001:db8::1");
}
#[test]
fn x_forwarded_for_bare_ipv6() {
let r = req_with_header("x-forwarded-for", "2001:db8::1");
let ip = extract(&r, &HeaderStrategy::XForwardedFor).unwrap();
assert_eq!(ip.to_string(), "2001:db8::1");
}
#[test]
fn x_real_ip_strategy() {
let r = req_with_header("x-real-ip", "198.51.100.42");
let ip = extract(&r, &HeaderStrategy::XRealIp).unwrap();
assert_eq!(ip.to_string(), "198.51.100.42");
}
#[test]
fn cf_connecting_ip_strategy() {
let r = req_with_header("cf-connecting-ip", "2606:4700::1");
let ip = extract(&r, &HeaderStrategy::CfConnectingIp).unwrap();
assert_eq!(ip.to_string(), "2606:4700::1");
}
#[test]
fn rfc7239_for_token_parses() {
let r = req_with_header("forwarded", "for=192.0.2.43, for=198.51.100.17");
let ip = extract(&r, &HeaderStrategy::ForwardedRfc7239).unwrap();
assert_eq!(ip.to_string(), "192.0.2.43");
}
#[test]
fn rfc7239_with_quoted_ipv6_port() {
let r = req_with_header("forwarded", r#"for="[2001:db8:cafe::17]:4711""#);
let ip = extract(&r, &HeaderStrategy::ForwardedRfc7239).unwrap();
assert_eq!(ip.to_string(), "2001:db8:cafe::17");
}
#[test]
fn rfc7239_ignores_other_keys() {
let r = req_with_header("forwarded", "by=10.0.0.1;for=203.0.113.7;proto=https");
let ip = extract(&r, &HeaderStrategy::ForwardedRfc7239).unwrap();
assert_eq!(ip.to_string(), "203.0.113.7");
}
#[test]
fn auto_picks_cloudflare_first_when_present() {
let r = Request::builder()
.header("x-forwarded-for", "1.1.1.1")
.header("cf-connecting-ip", "9.9.9.9")
.body(Body::empty())
.unwrap();
let ip = extract(&r, &HeaderStrategy::Auto).unwrap();
assert_eq!(ip.to_string(), "9.9.9.9");
}
#[test]
fn auto_falls_through_to_xff_when_no_cf_or_forwarded() {
let r = req_with_header("x-forwarded-for", "1.1.1.1");
let ip = extract(&r, &HeaderStrategy::Auto).unwrap();
assert_eq!(ip.to_string(), "1.1.1.1");
}
#[test]
fn no_headers_returns_none_when_no_connect_info() {
let r = Request::builder().body(Body::empty()).unwrap();
assert!(extract(&r, &HeaderStrategy::Auto).is_none());
}
#[test]
fn malformed_header_returns_none() {
let r = req_with_header("x-real-ip", "not-an-ip");
assert!(extract(&r, &HeaderStrategy::XRealIp).is_none());
}
#[tokio::test]
async fn middleware_inserts_realip_into_extensions() {
use axum::routing::get;
use axum::Extension;
use tower::ServiceExt;
async fn handler(Extension(RealIp(ip)): Extension<RealIp>) -> String {
ip.to_string()
}
let app = Router::new()
.route("/", get(handler))
.real_ip(RealIpLayer::default());
let resp = app
.oneshot(
Request::builder()
.uri("/")
.header("x-forwarded-for", "192.0.2.1, 10.0.0.1")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
let bytes = axum::body::to_bytes(resp.into_body(), 1 << 16)
.await
.unwrap();
assert_eq!(std::str::from_utf8(&bytes).unwrap(), "192.0.2.1");
}
}