use std::future::Future;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use axum::extract::ConnectInfo;
use axum::http::Request;
use tower::{Layer, Service};
use crate::security::config::TrustedProxiesConfig;
#[derive(Debug, Clone, Copy)]
struct TrustedProxy {
network: IpAddr,
prefix_len: u8,
}
impl TrustedProxy {
fn parse(value: &str) -> Option<Self> {
let trimmed = value.trim();
if trimmed.is_empty() {
return None;
}
let (addr, prefix_len) = if let Some((addr, prefix)) = trimmed.split_once('/') {
let addr = addr.trim().parse::<IpAddr>().ok()?;
let prefix_len = prefix.trim().parse::<u8>().ok()?;
(addr, prefix_len)
} else {
let addr = trimmed.parse::<IpAddr>().ok()?;
let prefix_len = match addr {
IpAddr::V4(_) => 32,
IpAddr::V6(_) => 128,
};
(addr, prefix_len)
};
let max_prefix = match addr {
IpAddr::V4(_) => 32,
IpAddr::V6(_) => 128,
};
(prefix_len <= max_prefix).then_some(Self {
network: addr,
prefix_len,
})
}
fn contains(&self, ip: IpAddr) -> bool {
if self.prefix_len == 0 {
return matches!(
(self.network, ip),
(IpAddr::V4(_), IpAddr::V4(_)) | (IpAddr::V6(_), IpAddr::V6(_))
);
}
match (self.network, ip) {
(IpAddr::V4(network), IpAddr::V4(candidate)) => {
let shift = 32_u8.saturating_sub(self.prefix_len);
(u32::from(network) >> shift) == (u32::from(candidate) >> shift)
}
(IpAddr::V6(network), IpAddr::V6(candidate)) => {
let shift = 128_u8.saturating_sub(self.prefix_len);
(u128::from(network) >> shift) == (u128::from(candidate) >> shift)
}
(IpAddr::V4(_), IpAddr::V6(_)) | (IpAddr::V6(_), IpAddr::V4(_)) => false,
}
}
}
#[derive(Debug, Clone)]
pub struct ProxyResolver {
ranges: Vec<TrustedProxy>,
ranges_configured: bool,
trusted_hops: Option<u32>,
trust_forwarded_headers: bool,
}
impl ProxyResolver {
#[must_use]
pub fn from_config(config: &TrustedProxiesConfig) -> Self {
let ranges_configured = !config.ranges.is_empty();
let ranges = config
.ranges
.iter()
.filter_map(|proxy| {
TrustedProxy::parse(proxy).or_else(|| {
tracing::warn!(
range = %proxy,
"ignoring invalid trusted_proxies range"
);
None
})
})
.collect();
Self {
ranges,
ranges_configured,
trusted_hops: config.trusted_hops,
trust_forwarded_headers: config.trust_forwarded_headers,
}
}
fn is_trusted_ip(&self, ip: IpAddr) -> bool {
if !self.ranges_configured {
return true;
}
self.ranges.iter().any(|r| r.contains(ip))
}
#[must_use]
pub fn loopback_only() -> Self {
Self::from_config(&TrustedProxiesConfig {
ranges: vec!["127.0.0.0/8".to_owned(), "::1/128".to_owned()],
trusted_hops: None,
trust_forwarded_headers: true,
})
}
#[must_use]
pub const fn no_trust() -> Self {
Self {
ranges: Vec::new(),
ranges_configured: false,
trusted_hops: None,
trust_forwarded_headers: false,
}
}
pub fn resolve_client_addr<B>(&self, req: &Request<B>) -> Option<IpAddr> {
let peer_ip = Self::peer_ip(req);
if !self.trust_forwarded_headers {
return peer_ip;
}
if let Some(hops) = self.trusted_hops {
if let Some(xff) = Self::x_forwarded_for(req) {
let mut entries = xff
.rsplit(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.filter_map(Self::parse_forwarded_ip);
if let Some(ip) = entries.nth(hops as usize) {
return Some(ip);
}
}
return peer_ip;
}
let peer_is_trusted = peer_ip.is_some_and(|ip| self.is_trusted_ip(ip))
|| (!self.ranges_configured && peer_ip.is_none());
if !peer_is_trusted {
return peer_ip;
}
if let Some(xff) = Self::x_forwarded_for(req) {
if self.ranges_configured {
for entry in xff.rsplit(',').map(str::trim).filter(|s| !s.is_empty()) {
let Some(ip) = Self::parse_forwarded_ip(entry) else {
continue;
};
if !self.ranges.iter().any(|r| r.contains(ip)) {
return Some(ip);
}
}
return peer_ip;
}
let mut entries = xff
.rsplit(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.filter_map(Self::parse_forwarded_ip);
if let Some(rightmost_ip) = entries.next() {
if peer_ip.is_some_and(|p| rightmost_ip == p) {
if let Some(prev_ip) = entries.next() {
return Some(prev_ip);
}
return Some(rightmost_ip);
}
return Some(rightmost_ip);
}
}
if let Some(real_ip) = req
.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|s| !s.is_empty())
.and_then(|s| s.parse::<IpAddr>().ok())
{
return Some(real_ip);
}
peer_ip
}
pub fn resolve_client_host<B>(&self, req: &Request<B>) -> Option<String> {
if !self.trust_forwarded_headers {
return req
.headers()
.get(axum::http::header::HOST)
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_owned());
}
let peer_ip = Self::peer_ip(req);
let peer_is_trusted = self.trusted_hops.map_or_else(
|| peer_ip.is_some_and(|ip| self.is_trusted_ip(ip)) || !self.ranges_configured,
|hops| {
Self::x_forwarded_for(req).is_some_and(|xff| {
xff.rsplit(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.filter_map(Self::parse_forwarded_ip)
.count()
> hops as usize
})
},
);
if peer_is_trusted
&& let Some(fwd_host_raw) = req
.headers()
.get("x-forwarded-host")
.and_then(|v| v.to_str().ok())
{
let host = fwd_host_raw
.split(',')
.next()
.unwrap_or(fwd_host_raw)
.trim();
if !host.is_empty() {
return Some(host.to_owned());
}
}
req.headers()
.get(axum::http::header::HOST)
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_owned())
}
pub fn resolve_client_scheme<B>(&self, req: &Request<B>) -> Option<String> {
if self.trust_forwarded_headers {
let peer_ip = Self::peer_ip(req);
let peer_is_trusted = self.trusted_hops.map_or_else(
|| peer_ip.is_some_and(|ip| self.is_trusted_ip(ip)) || !self.ranges_configured,
|hops| {
Self::x_forwarded_for(req).is_some_and(|xff| {
xff.rsplit(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.filter_map(Self::parse_forwarded_ip)
.count()
> hops as usize
})
},
);
if peer_is_trusted
&& let Some(proto) = req
.headers()
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
{
let outermost = proto.split(',').next().unwrap_or(proto).trim();
if !outermost.is_empty() {
return Some(outermost.to_ascii_lowercase());
}
}
}
req.uri().scheme_str().map(ToOwned::to_owned)
}
fn parse_forwarded_ip(s: &str) -> Option<IpAddr> {
s.parse::<IpAddr>()
.ok()
.or_else(|| s.parse::<SocketAddr>().map(|sa| sa.ip()).ok())
}
fn peer_ip<B>(req: &Request<B>) -> Option<IpAddr> {
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ConnectInfo(addr)| addr.ip())
}
fn x_forwarded_for<B>(req: &Request<B>) -> Option<String> {
let all: Vec<&str> = req
.headers()
.get_all("x-forwarded-for")
.iter()
.filter_map(|v| v.to_str().ok())
.collect();
if all.is_empty() {
None
} else {
Some(all.join(", "))
}
}
}
#[derive(Debug, Clone)]
pub struct ResolvedClientIdentity {
pub addr: Option<IpAddr>,
pub host: Option<String>,
pub scheme: Option<String>,
}
#[derive(Clone, Debug)]
pub struct TrustedProxiesLayer {
resolver: Arc<ProxyResolver>,
}
impl TrustedProxiesLayer {
#[must_use]
pub fn from_config(config: &TrustedProxiesConfig) -> Self {
Self {
resolver: Arc::new(ProxyResolver::from_config(config)),
}
}
}
impl<S> Layer<S> for TrustedProxiesLayer {
type Service = TrustedProxiesService<S>;
fn layer(&self, inner: S) -> Self::Service {
TrustedProxiesService {
inner,
resolver: Arc::clone(&self.resolver),
}
}
}
#[derive(Clone, Debug)]
pub struct TrustedProxiesService<S> {
inner: S,
resolver: Arc<ProxyResolver>,
}
impl<S, B> Service<Request<B>> for TrustedProxiesService<S>
where
S: Service<Request<B>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
B: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
let identity = ResolvedClientIdentity {
addr: self.resolver.resolve_client_addr(&req),
host: self.resolver.resolve_client_host(&req),
scheme: self.resolver.resolve_client_scheme(&req),
};
req.extensions_mut().insert(identity);
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(inner.call(req))
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::Request;
fn req_with_xff(xff: &str) -> Request<()> {
Request::builder()
.header("x-forwarded-for", xff)
.body(())
.unwrap()
}
fn req_with_peer_and_xff(peer: &str, xff: &str) -> Request<()> {
let addr: SocketAddr = format!("{peer}:1234").parse().unwrap();
let mut req = Request::builder()
.header("x-forwarded-for", xff)
.body(())
.unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
req
}
#[test]
fn trusted_proxy_parse_exact_ipv4() {
let p = TrustedProxy::parse("10.0.0.1").unwrap();
assert!(p.contains("10.0.0.1".parse().unwrap()));
assert!(!p.contains("10.0.0.2".parse().unwrap()));
}
#[test]
fn trusted_proxy_parse_cidr() {
let p = TrustedProxy::parse("10.0.0.0/24").unwrap();
assert!(p.contains("10.0.0.1".parse().unwrap()));
assert!(p.contains("10.0.0.254".parse().unwrap()));
assert!(!p.contains("10.0.1.0".parse().unwrap()));
}
#[test]
fn trusted_proxy_parse_invalid_returns_none() {
assert!(TrustedProxy::parse("not-an-ip").is_none());
assert!(TrustedProxy::parse("").is_none());
assert!(TrustedProxy::parse("10.0.0.0/33").is_none());
}
#[test]
fn trusted_proxy_contains_ipv6() {
let proxy = TrustedProxy::parse("2001:db8::/32").unwrap();
assert!(proxy.contains("2001:db8:1234::1".parse().unwrap()));
assert!(!proxy.contains("2001:db9::1".parse().unwrap()));
let proxy_exact = TrustedProxy::parse("2001:db8::1").unwrap();
assert!(proxy_exact.contains("2001:db8::1".parse().unwrap()));
assert!(!proxy_exact.contains("2001:db8::2".parse().unwrap()));
}
#[test]
fn trusted_hops_one_rejects_attacker_controlled_leading_value() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: Some(1),
trust_forwarded_headers: true,
});
let req = req_with_xff("1.2.3.4, 5.6.7.8, 10.0.0.1");
let addr = resolver.resolve_client_addr(&req).unwrap();
assert_eq!(addr, "5.6.7.8".parse::<IpAddr>().unwrap());
}
#[test]
fn trusted_hops_zero_uses_rightmost_entry() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: Some(0),
trust_forwarded_headers: true,
});
let req = req_with_xff("1.2.3.4, 5.6.7.8");
let addr = resolver.resolve_client_addr(&req).unwrap();
assert_eq!(addr, "5.6.7.8".parse::<IpAddr>().unwrap());
}
#[test]
fn trusted_hops_with_ranges_does_not_require_peer_in_ranges() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: vec!["203.0.113.0/24".to_owned()],
trusted_hops: Some(1),
trust_forwarded_headers: true,
});
let req = req_with_peer_and_xff("10.0.1.200", "192.0.2.1, 10.0.1.200");
let addr = resolver.resolve_client_addr(&req).unwrap();
assert_eq!(addr, "192.0.2.1".parse::<IpAddr>().unwrap());
}
#[test]
fn two_hop_cdn_alb_chain_identifies_real_client() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: vec!["203.0.113.0/24".to_owned(), "10.0.0.0/8".to_owned()],
trusted_hops: None,
trust_forwarded_headers: true,
});
let req = req_with_peer_and_xff("10.0.1.100", "192.0.2.1, 203.0.113.10, 10.0.1.100");
let addr = resolver.resolve_client_addr(&req).unwrap();
assert_eq!(
addr,
"192.0.2.1".parse::<IpAddr>().unwrap(),
"real client must be identified correctly in a two-hop CDN+ALB chain"
);
}
#[test]
fn trust_forwarded_headers_false_ignores_xff() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: vec!["10.0.0.0/8".to_owned()],
trusted_hops: None,
trust_forwarded_headers: false,
});
let req = req_with_peer_and_xff("10.0.0.1", "192.0.2.1, 203.0.113.10");
let addr = resolver.resolve_client_addr(&req).unwrap();
assert_eq!(
addr,
"10.0.0.1".parse::<IpAddr>().unwrap(),
"trust_forwarded_headers=false must ignore X-Forwarded-For"
);
}
#[test]
fn trust_forwarded_headers_false_ignores_x_forwarded_host() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: None,
trust_forwarded_headers: false,
});
let mut req = Request::builder()
.header("host", "real.example")
.header("x-forwarded-host", "attacker.example")
.body(())
.unwrap();
req.extensions_mut()
.insert(ConnectInfo("127.0.0.1:1234".parse::<SocketAddr>().unwrap()));
let host = resolver.resolve_client_host(&req).unwrap();
assert_eq!(host, "real.example");
}
#[test]
fn trust_forwarded_headers_false_ignores_x_forwarded_proto() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: None,
trust_forwarded_headers: false,
});
let req = Request::builder()
.header("x-forwarded-proto", "https")
.uri("http://example.com/")
.body(())
.unwrap();
let scheme = resolver.resolve_client_scheme(&req);
assert_eq!(scheme.as_deref(), Some("http"));
}
#[test]
fn resolve_scheme_from_forwarded_proto_leftmost() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: None,
trust_forwarded_headers: true,
});
let req = Request::builder()
.header("x-forwarded-proto", "https, http")
.body(())
.unwrap();
assert_eq!(
resolver.resolve_client_scheme(&req).as_deref(),
Some("https")
);
}
#[test]
fn resolve_host_prefers_forwarded_host_when_trusted() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: None,
trust_forwarded_headers: true,
});
let req = Request::builder()
.header("host", "internal.cluster.local")
.header("x-forwarded-host", "public.example.com")
.body(())
.unwrap();
let host = resolver.resolve_client_host(&req).unwrap();
assert_eq!(host, "public.example.com");
}
#[test]
fn resolve_host_returns_leftmost_when_chained() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: None,
trust_forwarded_headers: true,
});
let req = Request::builder()
.header(
"x-forwarded-host",
"public.example.com, internal.cluster.local",
)
.body(())
.unwrap();
let host = resolver.resolve_client_host(&req).unwrap();
assert_eq!(host, "public.example.com");
}
#[test]
fn resolve_host_falls_back_to_host_header() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: None,
trust_forwarded_headers: true,
});
let req = Request::builder()
.header("host", "app.example.com")
.body(())
.unwrap();
let host = resolver.resolve_client_host(&req).unwrap();
assert_eq!(host, "app.example.com");
}
#[test]
fn trusted_hops_with_ranges_resolves_forwarded_host() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: vec!["203.0.113.0/24".to_owned()],
trusted_hops: Some(1),
trust_forwarded_headers: true,
});
let mut req = Request::builder()
.header("host", "internal.cluster.local")
.header("x-forwarded-host", "public.example.com")
.header("x-forwarded-for", "192.0.2.1, 10.0.1.200")
.body(())
.unwrap();
let addr: SocketAddr = "10.0.1.200:1234".parse().unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
let host = resolver.resolve_client_host(&req).unwrap();
assert_eq!(host, "public.example.com");
}
#[test]
fn trusted_hops_with_no_xff_does_not_trust_forwarded_host() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: vec!["203.0.113.0/24".to_owned()],
trusted_hops: Some(1),
trust_forwarded_headers: true,
});
let mut req = Request::builder()
.header("host", "real.example.com")
.header("x-forwarded-host", "attacker.example.com")
.body(())
.unwrap();
let addr: SocketAddr = "10.0.1.200:1234".parse().unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
let host = resolver.resolve_client_host(&req).unwrap();
assert_eq!(host, "real.example.com");
}
#[test]
fn trusted_hops_with_ranges_resolves_forwarded_scheme() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: vec!["203.0.113.0/24".to_owned()],
trusted_hops: Some(1),
trust_forwarded_headers: true,
});
let mut req = Request::builder()
.header("x-forwarded-proto", "https")
.header("x-forwarded-for", "192.0.2.1, 10.0.1.200")
.body(())
.unwrap();
let addr: SocketAddr = "10.0.1.200:1234".parse().unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
assert_eq!(
resolver.resolve_client_scheme(&req).as_deref(),
Some("https")
);
}
#[test]
fn trusted_hops_with_no_xff_does_not_trust_forwarded_scheme() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: vec!["203.0.113.0/24".to_owned()],
trusted_hops: Some(1),
trust_forwarded_headers: true,
});
let req = Request::builder()
.header("x-forwarded-proto", "https")
.uri("http://example.com/")
.body(())
.unwrap();
assert_eq!(
resolver.resolve_client_scheme(&req).as_deref(),
Some("http")
);
}
#[test]
fn untrusted_peer_ignores_forwarding_headers() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: vec!["10.0.0.0/8".to_owned()],
trusted_hops: None,
trust_forwarded_headers: true,
});
let req = req_with_peer_and_xff("203.0.113.1", "192.0.2.1");
let addr = resolver.resolve_client_addr(&req).unwrap();
assert_eq!(addr, "203.0.113.1".parse::<IpAddr>().unwrap());
}
#[test]
fn no_peer_ip_with_trust_enabled_falls_back_to_xff() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: None,
trust_forwarded_headers: true,
});
let req = req_with_xff("192.0.2.1, 10.0.0.1");
let addr = resolver.resolve_client_addr(&req);
assert!(addr.is_some());
}
#[test]
fn loopback_only_trusts_loopback_xff() {
let resolver = ProxyResolver::loopback_only();
let req = req_with_peer_and_xff("127.0.0.1", "192.0.2.1");
let addr = resolver.resolve_client_addr(&req).unwrap();
assert_eq!(addr, "192.0.2.1".parse::<IpAddr>().unwrap());
}
#[test]
fn loopback_only_ignores_xff_from_non_loopback_peer() {
let resolver = ProxyResolver::loopback_only();
let req = req_with_peer_and_xff("10.0.0.1", "192.0.2.1");
let addr = resolver.resolve_client_addr(&req).unwrap();
assert_eq!(addr, "10.0.0.1".parse::<IpAddr>().unwrap());
}
#[test]
fn no_trust_always_returns_peer_ip() {
let resolver = ProxyResolver::no_trust();
let req = req_with_peer_and_xff("10.0.0.1", "192.0.2.1");
let addr = resolver.resolve_client_addr(&req).unwrap();
assert_eq!(addr, "10.0.0.1".parse::<IpAddr>().unwrap());
}
#[test]
fn xff_entry_with_ipv4_port_is_parsed() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: Some(1),
trust_forwarded_headers: true,
});
let req = req_with_xff("192.0.2.1:54321, 10.0.1.200:80");
let addr = resolver.resolve_client_addr(&req).unwrap();
assert_eq!(addr, "192.0.2.1".parse::<IpAddr>().unwrap());
}
#[test]
fn xff_entry_with_ipv6_port_is_parsed() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: Some(1),
trust_forwarded_headers: true,
});
let req = req_with_xff("2001:db8::1, [::1]:8080");
let addr = resolver.resolve_client_addr(&req).unwrap();
assert_eq!(addr, "2001:db8::1".parse::<IpAddr>().unwrap());
}
#[test]
fn trusted_hops_with_junk_xff_does_not_trust_forwarded_host() {
let resolver = ProxyResolver::from_config(&TrustedProxiesConfig {
ranges: Vec::new(),
trusted_hops: Some(1),
trust_forwarded_headers: true,
});
let mut req = Request::builder()
.header("host", "real.example.com")
.header("x-forwarded-host", "attacker.example.com")
.header("x-forwarded-for", "junk, more-junk")
.body(())
.unwrap();
let addr: SocketAddr = "10.0.1.200:1234".parse().unwrap();
req.extensions_mut().insert(ConnectInfo(addr));
let host = resolver.resolve_client_host(&req).unwrap();
assert_eq!(host, "real.example.com");
}
}