use axum::extract::ConnectInfo;
use axum::http::Request;
use std::net::{IpAddr, SocketAddr};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TrustedProxy {
pub network: IpAddr,
pub prefix_len: u8,
}
impl TrustedProxy {
#[must_use]
pub fn parse(value: &str) -> Option<Self> {
let trimmed = value.trim();
if trimmed.is_empty() {
return None;
}
let (addr, prefix_len) = if let Some((addr, prefix)) = trimmed.split_once('/') {
let addr = addr.trim().parse::<IpAddr>().ok()?;
let prefix_len = prefix.trim().parse::<u8>().ok()?;
(addr, prefix_len)
} else {
let addr = trimmed.parse::<IpAddr>().ok()?;
let prefix_len = match addr {
IpAddr::V4(_) => 32,
IpAddr::V6(_) => 128,
};
(addr, prefix_len)
};
let max_prefix = match addr {
IpAddr::V4(_) => 32,
IpAddr::V6(_) => 128,
};
(prefix_len <= max_prefix).then_some(Self {
network: addr,
prefix_len,
})
}
#[must_use]
pub fn contains(&self, ip: IpAddr) -> bool {
if self.prefix_len == 0 {
return matches!(
(self.network, ip),
(IpAddr::V4(_), IpAddr::V4(_)) | (IpAddr::V6(_), IpAddr::V6(_))
);
}
match (self.network, ip) {
(IpAddr::V4(network), IpAddr::V4(candidate)) => {
let shift = 32_u8.saturating_sub(self.prefix_len);
(u32::from(network) >> shift) == (u32::from(candidate) >> shift)
}
(IpAddr::V6(network), IpAddr::V6(candidate)) => {
let shift = 128_u8.saturating_sub(self.prefix_len);
(u128::from(network) >> shift) == (u128::from(candidate) >> shift)
}
(IpAddr::V4(_), IpAddr::V6(_)) | (IpAddr::V6(_), IpAddr::V4(_)) => false,
}
}
}
pub fn extract_client_ip<B>(
req: &Request<B>,
trust_forwarded_headers: bool,
trusted_proxies: &[TrustedProxy],
trusted_proxies_configured: bool,
) -> Option<IpAddr> {
let peer_ip = req
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ConnectInfo(addr)| addr.ip())
.or_else(|| req.extensions().get::<SocketAddr>().map(SocketAddr::ip));
if trust_forwarded_headers {
let allowed = if trusted_proxies_configured {
peer_ip.is_some_and(|ip| is_trusted_proxy(ip, trusted_proxies))
} else {
true
};
if allowed {
let all_xff: Vec<&str> = req
.headers()
.get_all("x-forwarded-for")
.iter()
.filter_map(|v| v.to_str().ok())
.collect();
if !all_xff.is_empty() {
let joined = all_xff.join(", ");
if let Some(xff_ip) = client_ip_from_x_forwarded_for(
&joined,
peer_ip,
trusted_proxies,
trusted_proxies_configured,
) {
return Some(xff_ip);
}
}
if let Some(real_ip) = req
.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|s| !s.is_empty())
.and_then(|s| s.parse::<IpAddr>().ok())
{
return Some(real_ip);
}
}
}
peer_ip
}
fn client_ip_from_x_forwarded_for(
header: &str,
peer_ip: Option<IpAddr>,
trusted_proxies: &[TrustedProxy],
trusted_proxies_configured: bool,
) -> Option<IpAddr> {
if !trusted_proxies_configured {
let mut entries = header.rsplit(',').map(str::trim).filter(|s| !s.is_empty());
let last = entries.next()?;
if let Ok(last_ip) = last.parse::<IpAddr>() {
if peer_ip.is_some_and(|peer_ip| last_ip == peer_ip)
&& let Some(prev) = entries.next()
&& let Ok(prev_ip) = prev.parse::<IpAddr>()
{
return Some(prev_ip);
}
return Some(last_ip);
}
return None;
}
for entry in header.rsplit(',').map(str::trim).filter(|s| !s.is_empty()) {
if let Ok(ip) = entry.parse::<IpAddr>()
&& !is_trusted_proxy(ip, trusted_proxies)
{
return Some(ip);
}
}
None
}
fn is_trusted_proxy(ip: IpAddr, trusted_proxies: &[TrustedProxy]) -> bool {
trusted_proxies
.iter()
.any(|trusted_proxy| trusted_proxy.contains(ip))
}