use std::net::IpAddr;
fn validate_ip_format(ip_str: &str) -> Option<IpAddr> {
ip_str.parse::<IpAddr>().ok()
}
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub trusted_proxies: Vec<IpAddr>,
pub require_trusted_proxy: bool,
}
impl ProxyConfig {
pub fn new(trusted_proxies: Vec<IpAddr>, require_trusted_proxy: bool) -> Self {
Self {
trusted_proxies,
require_trusted_proxy,
}
}
pub fn localhost_only() -> Self {
Self {
trusted_proxies: vec!["127.0.0.1".parse().expect("valid IP")],
require_trusted_proxy: true,
}
}
pub fn none() -> Self {
Self {
trusted_proxies: vec![],
require_trusted_proxy: false,
}
}
pub fn is_trusted_proxy(&self, ip: &str) -> bool {
if self.trusted_proxies.is_empty() {
return false;
}
match validate_ip_format(ip) {
Some(addr) => self.trusted_proxies.contains(&addr),
None => false, }
}
pub fn extract_client_ip(
&self,
headers: &axum::http::HeaderMap,
socket_addr: Option<std::net::SocketAddr>,
) -> Option<String> {
let direct_ip = socket_addr.map(|addr| addr.ip().to_string());
let direct_ip_str = direct_ip.as_deref().unwrap_or("");
if let Some(forwarded_for) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
if self.is_trusted_proxy(direct_ip_str) {
if let Some(ip_str) = forwarded_for.split(',').next().map(|ip| ip.trim()) {
if validate_ip_format(ip_str).is_some() {
return Some(ip_str.to_string());
}
}
}
if let Some(ip) = direct_ip {
return Some(ip);
}
}
if let Some(real_ip) = headers.get("x-real-ip").and_then(|v| v.to_str().ok()) {
if self.is_trusted_proxy(direct_ip_str) {
if validate_ip_format(real_ip).is_some() {
return Some(real_ip.to_string());
}
}
if let Some(ip) = direct_ip {
return Some(ip);
}
}
direct_ip
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_proxy_config_localhost_only() {
let config = ProxyConfig::localhost_only();
assert!(config.is_trusted_proxy("127.0.0.1"));
assert!(!config.is_trusted_proxy("192.168.1.1"));
}
#[test]
fn test_proxy_config_is_trusted_proxy_valid_ip() {
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let config = ProxyConfig::new(vec![ip], true);
assert!(config.is_trusted_proxy("10.0.0.1"));
}
#[test]
fn test_proxy_config_is_trusted_proxy_untrusted_ip() {
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let config = ProxyConfig::new(vec![ip], true);
assert!(!config.is_trusted_proxy("192.168.1.1"));
}
#[test]
fn test_proxy_config_is_trusted_proxy_invalid_ip() {
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let config = ProxyConfig::new(vec![ip], true);
assert!(!config.is_trusted_proxy("invalid_ip"));
}
#[test]
fn test_extract_client_ip_from_trusted_proxy_x_forwarded_for() {
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let config = ProxyConfig::new(vec![ip], true);
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-forwarded-for", "192.0.2.1, 10.0.0.1".parse().unwrap());
let direct_ip = "10.0.0.1".parse::<std::net::IpAddr>().ok();
let socket = direct_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
let result = config.extract_client_ip(&headers, socket);
assert_eq!(result, Some("192.0.2.1".to_string()));
}
#[test]
fn test_extract_client_ip_from_untrusted_proxy_x_forwarded_for() {
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let config = ProxyConfig::new(vec![ip], true);
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-forwarded-for", "192.0.2.1, 10.0.0.1".parse().unwrap());
let direct_ip = "192.168.1.100".parse::<std::net::IpAddr>().ok();
let socket = direct_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
let result = config.extract_client_ip(&headers, socket);
assert_eq!(result, Some("192.168.1.100".to_string()));
}
#[test]
fn test_extract_client_ip_no_headers() {
let config = ProxyConfig::localhost_only();
let headers = axum::http::HeaderMap::new();
let direct_ip = "192.168.1.100".parse::<std::net::IpAddr>().ok();
let socket = direct_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
let result = config.extract_client_ip(&headers, socket);
assert_eq!(result, Some("192.168.1.100".to_string()));
}
#[test]
fn test_extract_client_ip_empty_headers() {
let config = ProxyConfig::localhost_only();
let headers = axum::http::HeaderMap::new();
let result = config.extract_client_ip(&headers, None);
assert_eq!(result, None);
}
#[test]
fn test_extract_client_ip_spoofing_attempt() {
let trusted_ip: IpAddr = "10.0.0.1".parse().unwrap();
let config = ProxyConfig::new(vec![trusted_ip], true);
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-forwarded-for", "1.2.3.4".parse().unwrap());
let attacker_ip = "192.168.1.100".parse::<std::net::IpAddr>().ok();
let socket = attacker_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
let result = config.extract_client_ip(&headers, socket);
assert_eq!(result, Some("192.168.1.100".to_string()));
}
#[test]
fn test_extract_client_ip_invalid_format_x_forwarded_for() {
let trusted_ip: IpAddr = "10.0.0.1".parse().unwrap();
let config = ProxyConfig::new(vec![trusted_ip], true);
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-forwarded-for", "not-a-valid-ip-address, 10.0.0.1".parse().unwrap());
let trusted_source_ip = "10.0.0.1".parse::<std::net::IpAddr>().ok();
let socket = trusted_source_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
let result = config.extract_client_ip(&headers, socket);
assert_eq!(result, Some("10.0.0.1".to_string()));
}
#[test]
fn test_extract_client_ip_invalid_format_x_real_ip() {
let trusted_ip: IpAddr = "10.0.0.1".parse().unwrap();
let config = ProxyConfig::new(vec![trusted_ip], true);
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-real-ip", "256.256.256.256".parse().unwrap());
let trusted_source_ip = "10.0.0.1".parse::<std::net::IpAddr>().ok();
let socket = trusted_source_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
let result = config.extract_client_ip(&headers, socket);
assert_eq!(result, Some("10.0.0.1".to_string()));
}
#[test]
fn test_extract_client_ip_valid_ipv6() {
let trusted_ip: IpAddr = "::1".parse().unwrap();
let config = ProxyConfig::new(vec![trusted_ip], true);
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-forwarded-for", "2001:db8::1, ::1".parse().unwrap());
let trusted_source_ip = "::1".parse::<std::net::IpAddr>().ok();
let socket = trusted_source_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
let result = config.extract_client_ip(&headers, socket);
assert_eq!(result, Some("2001:db8::1".to_string()));
}
}