1use std::net::{IpAddr, SocketAddr};
2use std::str::FromStr;
3
4use axum::http::HeaderMap;
5use ip_network::IpNetwork;
6
7#[derive(Debug)]
21pub struct IpExtractor {
22 trusted_proxies: Vec<IpNetwork>,
23}
24
25impl IpExtractor {
26 pub fn new(trusted_proxy_strs: &[String]) -> Result<Self, String> {
32 let mut proxies = Vec::with_capacity(trusted_proxy_strs.len());
33
34 for s in trusted_proxy_strs {
35 if let Ok(net) = s.parse::<IpNetwork>() {
37 proxies.push(net);
38 } else if let Ok(ip) = IpAddr::from_str(s) {
39 proxies.push(IpNetwork::from(ip));
40 } else {
41 tracing::warn!(entry = %s, "trusted_proxies entry is not a valid IP or CIDR range -- skipped");
42 }
43 }
44
45 Ok(Self {
46 trusted_proxies: proxies,
47 })
48 }
49
50 pub fn is_empty(&self) -> bool {
52 self.trusted_proxies.is_empty()
53 }
54
55 pub fn extract(&self, headers: &HeaderMap, peer_addr: SocketAddr) -> IpAddr {
65 if self.trusted_proxies.is_empty() {
66 return peer_addr.ip();
67 }
68
69 if !self.is_trusted(peer_addr.ip()) {
70 return peer_addr.ip();
71 }
72
73 self.extract_cf_connecting_ip(headers)
74 .or_else(|| self.extract_x_real_ip(headers))
75 .or_else(|| self.extract_x_forwarded_for(headers))
76 .unwrap_or_else(|| peer_addr.ip())
77 }
78
79 fn is_trusted(&self, ip: IpAddr) -> bool {
80 self.trusted_proxies.iter().any(|net| net.contains(ip))
81 }
82
83 fn extract_cf_connecting_ip(&self, headers: &HeaderMap) -> Option<IpAddr> {
84 headers
85 .get("cf-connecting-ip")
86 .and_then(|v| v.to_str().ok())
87 .and_then(|s| IpAddr::from_str(s.trim()).ok())
88 }
89
90 fn extract_x_real_ip(&self, headers: &HeaderMap) -> Option<IpAddr> {
91 headers
92 .get("x-real-ip")
93 .and_then(|v| v.to_str().ok())
94 .and_then(|s| IpAddr::from_str(s.trim()).ok())
95 }
96
97 fn extract_x_forwarded_for(&self, headers: &HeaderMap) -> Option<IpAddr> {
100 let value = headers.get("x-forwarded-for")?.to_str().ok()?;
101 value
102 .rsplit(',')
103 .filter_map(|s| IpAddr::from_str(s.trim()).ok())
104 .find(|ip| !self.is_trusted(*ip))
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use axum::http::HeaderValue;
112
113 fn peer(addr: &str) -> SocketAddr {
114 addr.parse().unwrap()
115 }
116
117 fn extractor(proxies: &[&str]) -> IpExtractor {
118 IpExtractor::new(&proxies.iter().map(|s| s.to_string()).collect::<Vec<_>>()).unwrap()
119 }
120
121 #[test]
122 fn no_proxies_returns_peer_ip() {
123 let ext = extractor(&[]);
124 let headers = HeaderMap::new();
125 assert_eq!(
126 ext.extract(&headers, peer("1.2.3.4:12345")),
127 "1.2.3.4".parse::<IpAddr>().unwrap()
128 );
129 }
130
131 #[test]
132 fn no_proxies_ignores_all_headers() {
133 let ext = extractor(&[]);
134 let mut headers = HeaderMap::new();
135 headers.insert("cf-connecting-ip", HeaderValue::from_static("5.6.7.8"));
136 headers.insert("x-real-ip", HeaderValue::from_static("9.10.11.12"));
137 headers.insert("x-forwarded-for", HeaderValue::from_static("13.14.15.16"));
138
139 assert_eq!(
140 ext.extract(&headers, peer("1.2.3.4:12345")),
141 "1.2.3.4".parse::<IpAddr>().unwrap()
142 );
143 }
144
145 #[test]
146 fn untrusted_peer_returns_peer_ip() {
147 let ext = extractor(&["10.0.0.1"]);
148 let mut headers = HeaderMap::new();
149 headers.insert("cf-connecting-ip", HeaderValue::from_static("5.6.7.8"));
150
151 assert_eq!(
152 ext.extract(&headers, peer("1.2.3.4:12345")),
153 "1.2.3.4".parse::<IpAddr>().unwrap()
154 );
155 }
156
157 #[test]
158 fn trusted_peer_uses_cf_connecting_ip() {
159 let ext = extractor(&["10.0.0.1"]);
160 let mut headers = HeaderMap::new();
161 headers.insert("cf-connecting-ip", HeaderValue::from_static("203.0.114.50"));
162
163 assert_eq!(
164 ext.extract(&headers, peer("10.0.0.1:443")),
165 "203.0.114.50".parse::<IpAddr>().unwrap()
166 );
167 }
168
169 #[test]
170 fn cf_connecting_ip_with_whitespace() {
171 let ext = extractor(&["10.0.0.1"]);
172 let mut headers = HeaderMap::new();
173 headers.insert(
174 "cf-connecting-ip",
175 HeaderValue::from_static(" 203.0.114.50 "),
176 );
177
178 assert_eq!(
179 ext.extract(&headers, peer("10.0.0.1:443")),
180 "203.0.114.50".parse::<IpAddr>().unwrap()
181 );
182 }
183
184 #[test]
185 fn cf_connecting_ip_invalid_falls_through() {
186 let ext = extractor(&["10.0.0.1"]);
187 let mut headers = HeaderMap::new();
188 headers.insert("cf-connecting-ip", HeaderValue::from_static("not-an-ip"));
189 headers.insert("x-real-ip", HeaderValue::from_static("5.6.7.8"));
190
191 assert_eq!(
192 ext.extract(&headers, peer("10.0.0.1:443")),
193 "5.6.7.8".parse::<IpAddr>().unwrap()
194 );
195 }
196
197 #[test]
198 fn trusted_peer_uses_x_real_ip() {
199 let ext = extractor(&["10.0.0.1"]);
200 let mut headers = HeaderMap::new();
201 headers.insert("x-real-ip", HeaderValue::from_static("5.6.7.8"));
202
203 assert_eq!(
204 ext.extract(&headers, peer("10.0.0.1:443")),
205 "5.6.7.8".parse::<IpAddr>().unwrap()
206 );
207 }
208
209 #[test]
210 fn cf_connecting_ip_takes_priority_over_x_real_ip() {
211 let ext = extractor(&["10.0.0.1"]);
212 let mut headers = HeaderMap::new();
213 headers.insert("cf-connecting-ip", HeaderValue::from_static("1.1.1.1"));
214 headers.insert("x-real-ip", HeaderValue::from_static("2.2.2.2"));
215
216 assert_eq!(
217 ext.extract(&headers, peer("10.0.0.1:443")),
218 "1.1.1.1".parse::<IpAddr>().unwrap()
219 );
220 }
221
222 #[test]
223 fn x_forwarded_for_single_ip() {
224 let ext = extractor(&["10.0.0.1"]);
225 let mut headers = HeaderMap::new();
226 headers.insert("x-forwarded-for", HeaderValue::from_static("203.0.114.50"));
227
228 assert_eq!(
229 ext.extract(&headers, peer("10.0.0.1:443")),
230 "203.0.114.50".parse::<IpAddr>().unwrap()
231 );
232 }
233
234 #[test]
235 fn x_forwarded_for_rightmost_untrusted() {
236 let ext = extractor(&["10.0.0.1", "10.0.0.2"]);
237 let mut headers = HeaderMap::new();
238 headers.insert(
239 "x-forwarded-for",
240 HeaderValue::from_static("99.99.99.99, 5.6.7.8, 10.0.0.2"),
241 );
242
243 assert_eq!(
244 ext.extract(&headers, peer("10.0.0.1:443")),
245 "5.6.7.8".parse::<IpAddr>().unwrap()
246 );
247 }
248
249 #[test]
250 fn x_forwarded_for_all_trusted_returns_peer() {
251 let ext = extractor(&["10.0.0.1", "10.0.0.2", "10.0.0.3"]);
252 let mut headers = HeaderMap::new();
253 headers.insert(
254 "x-forwarded-for",
255 HeaderValue::from_static("10.0.0.3, 10.0.0.2"),
256 );
257
258 assert_eq!(
259 ext.extract(&headers, peer("10.0.0.1:443")),
260 "10.0.0.1".parse::<IpAddr>().unwrap()
261 );
262 }
263
264 #[test]
265 fn x_forwarded_for_with_whitespace() {
266 let ext = extractor(&["10.0.0.1"]);
267 let mut headers = HeaderMap::new();
268 headers.insert(
269 "x-forwarded-for",
270 HeaderValue::from_static(" 5.6.7.8 , 10.0.0.1 "),
271 );
272
273 assert_eq!(
274 ext.extract(&headers, peer("10.0.0.1:443")),
275 "5.6.7.8".parse::<IpAddr>().unwrap()
276 );
277 }
278
279 #[test]
280 fn x_forwarded_for_with_invalid_entries() {
281 let ext = extractor(&["10.0.0.1"]);
282 let mut headers = HeaderMap::new();
283 headers.insert(
284 "x-forwarded-for",
285 HeaderValue::from_static("5.6.7.8, garbage, not-ip"),
286 );
287
288 assert_eq!(
289 ext.extract(&headers, peer("10.0.0.1:443")),
290 "5.6.7.8".parse::<IpAddr>().unwrap()
291 );
292 }
293
294 #[test]
295 fn no_headers_returns_peer() {
296 let ext = extractor(&["10.0.0.1"]);
297 let headers = HeaderMap::new();
298
299 assert_eq!(
300 ext.extract(&headers, peer("10.0.0.1:443")),
301 "10.0.0.1".parse::<IpAddr>().unwrap()
302 );
303 }
304
305 #[test]
306 fn ipv6_peer_and_header() {
307 let ext = extractor(&["::1"]);
308 let mut headers = HeaderMap::new();
309 headers.insert(
310 "x-real-ip",
311 HeaderValue::from_static("2001:4860:4860::8888"),
312 );
313
314 assert_eq!(
315 ext.extract(&headers, peer("[::1]:443")),
316 "2001:4860:4860::8888".parse::<IpAddr>().unwrap()
317 );
318 }
319
320 #[test]
321 fn ipv6_in_x_forwarded_for() {
322 let ext = extractor(&["::1"]);
323 let mut headers = HeaderMap::new();
324 headers.insert(
325 "x-forwarded-for",
326 HeaderValue::from_static("2606:4700::1, ::1"),
327 );
328
329 assert_eq!(
330 ext.extract(&headers, peer("[::1]:443")),
331 "2606:4700::1".parse::<IpAddr>().unwrap()
332 );
333 }
334
335 #[test]
336 fn invalid_proxy_strings_are_skipped() {
337 let ext = IpExtractor::new(&[
338 "10.0.0.1".to_string(),
339 "not-an-ip".to_string(),
340 "".to_string(),
341 "10.0.0.2".to_string(),
342 ])
343 .unwrap();
344 assert_eq!(ext.trusted_proxies.len(), 2);
345 }
346
347 #[test]
348 fn cidr_trusted_proxy_matches_subnet() {
349 let ext = extractor(&["10.0.0.0/8"]);
350 let mut headers = HeaderMap::new();
351 headers.insert("x-real-ip", HeaderValue::from_static("1.2.3.4"));
352
353 assert_eq!(
354 ext.extract(&headers, peer("10.0.0.5:443")),
355 "1.2.3.4".parse::<IpAddr>().unwrap()
356 );
357 }
358
359 #[test]
360 fn cidr_xff_skips_trusted_ranges() {
361 let ext = extractor(&["10.0.0.0/8", "172.16.0.0/12"]);
362 let mut headers = HeaderMap::new();
363 headers.insert(
364 "x-forwarded-for",
365 HeaderValue::from_static("8.8.8.8, 10.0.0.1, 172.16.0.1"),
366 );
367
368 assert_eq!(
369 ext.extract(&headers, peer("172.16.0.1:443")),
370 "8.8.8.8".parse::<IpAddr>().unwrap()
371 );
372 }
373
374 #[test]
375 fn cidr_mixed_exact_and_range() {
376 let ext = extractor(&["10.0.0.0/8", "192.168.1.1"]);
377 let mut headers = HeaderMap::new();
378 headers.insert("x-real-ip", HeaderValue::from_static("5.6.7.8"));
379
380 assert_eq!(
382 ext.extract(&headers, peer("192.168.1.1:443")),
383 "5.6.7.8".parse::<IpAddr>().unwrap()
384 );
385 assert_eq!(
387 ext.extract(&headers, peer("10.99.99.99:443")),
388 "5.6.7.8".parse::<IpAddr>().unwrap()
389 );
390 }
391
392 #[test]
393 fn bare_ip_auto_promotes_to_host_network() {
394 let ext = extractor(&["10.0.0.1"]);
396 let mut headers = HeaderMap::new();
397 headers.insert("x-real-ip", HeaderValue::from_static("1.2.3.4"));
398
399 assert_eq!(
401 ext.extract(&headers, peer("10.0.0.1:443")),
402 "1.2.3.4".parse::<IpAddr>().unwrap()
403 );
404 assert_eq!(
406 ext.extract(&headers, peer("10.0.0.2:443")),
407 "10.0.0.2".parse::<IpAddr>().unwrap()
408 );
409 }
410}