1use http::HeaderMap;
2use std::net::{IpAddr, Ipv4Addr};
3
4pub fn extract_client_ip(
20 headers: &HeaderMap,
21 trusted_proxies: &[ipnet::IpNet],
22 connect_ip: Option<IpAddr>,
23) -> IpAddr {
24 if let Some(ip) = connect_ip
25 && !trusted_proxies.is_empty()
26 && !trusted_proxies.iter().any(|net| net.contains(&ip))
27 {
28 return ip;
29 }
30
31 if let Some(forwarded) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok())
32 && let Some(first) = forwarded.split(',').next()
33 {
34 let candidate = first.trim();
35 if let Ok(ip) = candidate.parse::<IpAddr>() {
36 return ip;
37 }
38 }
39
40 if let Some(real_ip) = headers.get("x-real-ip").and_then(|v| v.to_str().ok()) {
41 let candidate = real_ip.trim();
42 if let Ok(ip) = candidate.parse::<IpAddr>() {
43 return ip;
44 }
45 }
46
47 connect_ip.unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST))
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use http::HeaderMap;
54 use std::net::{IpAddr, Ipv4Addr};
55
56 #[test]
57 fn direct_ip_not_in_trusted_proxies() {
58 let mut headers = HeaderMap::new();
59 headers.insert("x-forwarded-for", "1.2.3.4".parse().unwrap());
60 let connect_ip: IpAddr = "203.0.113.5".parse().unwrap();
61 let trusted: Vec<ipnet::IpNet> = vec!["10.0.0.0/24".parse().unwrap()];
62 assert_eq!(
63 extract_client_ip(&headers, &trusted, Some(connect_ip)),
64 connect_ip
65 );
66 }
67
68 #[test]
69 fn trusted_proxy_uses_xff() {
70 let mut headers = HeaderMap::new();
71 headers.insert("x-forwarded-for", "8.8.8.8, 10.0.0.1".parse().unwrap());
72 let connect_ip: IpAddr = "10.0.0.1".parse().unwrap();
73 let trusted: Vec<ipnet::IpNet> = vec!["10.0.0.0/24".parse().unwrap()];
74 let expected: IpAddr = "8.8.8.8".parse().unwrap();
75 assert_eq!(
76 extract_client_ip(&headers, &trusted, Some(connect_ip)),
77 expected
78 );
79 }
80
81 #[test]
82 fn trusted_proxy_uses_x_real_ip_when_no_xff() {
83 let mut headers = HeaderMap::new();
84 headers.insert("x-real-ip", "9.8.7.6".parse().unwrap());
85 let connect_ip: IpAddr = "10.0.0.1".parse().unwrap();
86 let trusted: Vec<ipnet::IpNet> = vec!["10.0.0.0/24".parse().unwrap()];
87 let expected: IpAddr = "9.8.7.6".parse().unwrap();
88 assert_eq!(
89 extract_client_ip(&headers, &trusted, Some(connect_ip)),
90 expected
91 );
92 }
93
94 #[test]
95 fn no_trusted_proxies_uses_xff() {
96 let mut headers = HeaderMap::new();
97 headers.insert("x-forwarded-for", "1.2.3.4".parse().unwrap());
98 let expected: IpAddr = "1.2.3.4".parse().unwrap();
99 assert_eq!(extract_client_ip(&headers, &[], None), expected);
100 }
101
102 #[test]
103 fn no_trusted_proxies_uses_x_real_ip() {
104 let mut headers = HeaderMap::new();
105 headers.insert("x-real-ip", "9.8.7.6".parse().unwrap());
106 let expected: IpAddr = "9.8.7.6".parse().unwrap();
107 assert_eq!(extract_client_ip(&headers, &[], None), expected);
108 }
109
110 #[test]
111 fn xff_preferred_over_x_real_ip() {
112 let mut headers = HeaderMap::new();
113 headers.insert("x-forwarded-for", "1.2.3.4".parse().unwrap());
114 headers.insert("x-real-ip", "9.8.7.6".parse().unwrap());
115 let expected: IpAddr = "1.2.3.4".parse().unwrap();
116 assert_eq!(extract_client_ip(&headers, &[], None), expected);
117 }
118
119 #[test]
120 fn fallback_to_connect_ip() {
121 let headers = HeaderMap::new();
122 let connect_ip: IpAddr = "192.168.1.1".parse().unwrap();
123 assert_eq!(
124 extract_client_ip(&headers, &[], Some(connect_ip)),
125 connect_ip
126 );
127 }
128
129 #[test]
130 fn fallback_to_localhost() {
131 let headers = HeaderMap::new();
132 assert_eq!(
133 extract_client_ip(&headers, &[], None),
134 IpAddr::V4(Ipv4Addr::LOCALHOST),
135 );
136 }
137
138 #[test]
139 fn invalid_xff_falls_back() {
140 let mut headers = HeaderMap::new();
141 headers.insert("x-forwarded-for", "not-an-ip".parse().unwrap());
142 let connect_ip: IpAddr = "192.168.1.1".parse().unwrap();
143 assert_eq!(
144 extract_client_ip(&headers, &[], Some(connect_ip)),
145 connect_ip
146 );
147 }
148
149 #[test]
150 fn empty_trusted_proxies_with_connect_ip_trusts_xff() {
151 let mut headers = HeaderMap::new();
152 headers.insert("x-forwarded-for", "1.2.3.4".parse().unwrap());
153 let connect_ip: IpAddr = "203.0.113.5".parse().unwrap();
154 let expected: IpAddr = "1.2.3.4".parse().unwrap();
155 assert_eq!(extract_client_ip(&headers, &[], Some(connect_ip)), expected);
156 }
157}