1#![doc = include_str!("../README.md")]
2use std::net::IpAddr;
3
4pub use error::Error;
5use http::{HeaderMap, HeaderName};
6
7type Result<T> = std::result::Result<T, Error>;
8
9pub fn cf_connecting_ip(header_map: &HeaderMap) -> Result<IpAddr> {
11 ip_from_single_header(header_map, &HeaderName::from_static("cf-connecting-ip"))
12}
13
14pub fn cloudfront_viewer_address(header_map: &HeaderMap) -> Result<IpAddr> {
16 const HEADER_NAME: HeaderName = HeaderName::from_static("cloudfront-viewer-address");
17
18 fn ip_from_header_value(header_value: &str) -> Result<IpAddr> {
19 header_value
26 .rsplit_once(':')
27 .map(|(ip, _port)| ip)
28 .ok_or_else(|| Error::MalformedHeaderValue {
29 header_name: HEADER_NAME,
30 header_value: header_value.to_owned(),
31 })?
32 .trim()
33 .parse::<IpAddr>()
34 .map_err(|_| Error::MalformedHeaderValue {
35 header_name: HEADER_NAME,
36 header_value: header_value.to_owned(),
37 })
38 }
39
40 let header_value = AsciiHeaderValue::of_last_header(header_map, &HEADER_NAME)?;
41 ip_from_header_value(header_value.0)
42}
43
44pub fn fly_client_ip(header_map: &HeaderMap) -> Result<IpAddr> {
50 ip_from_single_header(header_map, &HeaderName::from_static("fly-client-ip"))
51}
52
53#[cfg(feature = "forwarded-header")]
54pub fn rightmost_forwarded(header_map: &HeaderMap) -> Result<IpAddr> {
56 const HEADER_NAME: HeaderName = HeaderName::from_static("forwarded");
57
58 fn ip_from_header_value(header_value: &str) -> Result<IpAddr> {
59 use forwarded_header_value::{ForwardedHeaderValue, Identifier};
60
61 let stanza = ForwardedHeaderValue::from_forwarded(header_value)
62 .map_err(|_| Error::MalformedHeaderValue {
63 header_name: HEADER_NAME,
64 header_value: header_value.to_owned(),
65 })?
66 .into_iter()
67 .last()
68 .ok_or_else(|| Error::MalformedHeaderValue {
69 header_name: HEADER_NAME,
70 header_value: header_value.to_owned(),
71 })?;
72
73 let forwarded_for = stanza.forwarded_for.ok_or_else(|| Error::ForwardedNoFor {
74 header_value: header_value.to_owned(),
75 })?;
76
77 match forwarded_for {
78 Identifier::SocketAddr(a) => Ok(a.ip()),
79 Identifier::IpAddr(ip) => Ok(ip),
80 Identifier::String(_) => Err(Error::ForwardedObfuscated {
81 header_value: header_value.to_owned(),
82 }),
83 Identifier::Unknown => Err(Error::ForwardedUnknown {
84 header_value: header_value.to_owned(),
85 }),
86 }
87 }
88
89 let header_value = AsciiHeaderValue::of_last_header(header_map, &HEADER_NAME)?;
90 ip_from_header_value(header_value.0)
91}
92
93pub fn rightmost_x_forwarded_for(header_map: &HeaderMap) -> Result<IpAddr> {
96 const HEADER_NAME: HeaderName = HeaderName::from_static("x-forwarded-for");
97
98 fn ip_from_header_value(header_value: &str) -> Result<IpAddr> {
99 header_value
100 .split(',')
101 .next_back()
102 .ok_or_else(|| Error::MalformedHeaderValue {
103 header_name: HEADER_NAME,
104 header_value: header_value.to_owned(),
105 })?
106 .trim()
107 .parse::<IpAddr>()
108 .map_err(|_| Error::MalformedHeaderValue {
109 header_name: HEADER_NAME,
110 header_value: header_value.to_owned(),
111 })
112 }
113
114 let header_value = AsciiHeaderValue::of_last_header(header_map, &HEADER_NAME)?;
115 ip_from_header_value(header_value.0)
116}
117
118pub fn true_client_ip(header_map: &HeaderMap) -> Result<IpAddr> {
120 ip_from_single_header(header_map, &HeaderName::from_static("true-client-ip"))
121}
122
123pub fn x_real_ip(header_map: &HeaderMap) -> Result<IpAddr> {
125 ip_from_single_header(header_map, &HeaderName::from_static("x-real-ip"))
126}
127
128#[derive(Debug)]
130struct AsciiHeaderValue<'a>(&'a str);
131
132impl<'a> AsciiHeaderValue<'a> {
133 fn of_single_header(header_map: &'a HeaderMap, header_name: &HeaderName) -> Result<Self> {
137 let mut iter = header_map.get_all(header_name).into_iter();
138
139 let Some(header_value) = iter.next() else {
140 return Err(Error::AbsentHeader {
141 header_name: header_name.to_owned(),
142 });
143 };
144
145 if iter.next().is_some() {
146 return Err(Error::SingleHeaderRequired {
147 header_name: header_name.to_owned(),
148 });
149 }
150
151 header_value
152 .to_str()
153 .map_err(|_| Error::NonAsciiHeaderValue {
154 header_name: header_name.to_owned(),
155 })
156 .map(Self)
157 }
158
159 fn of_last_header(header_map: &'a HeaderMap, header_name: &HeaderName) -> Result<Self> {
161 header_map
162 .get_all(header_name)
163 .into_iter()
164 .next_back()
165 .ok_or_else(|| Error::AbsentHeader {
166 header_name: header_name.to_owned(),
167 })?
168 .to_str()
169 .map_err(|_| Error::NonAsciiHeaderValue {
170 header_name: header_name.to_owned(),
171 })
172 .map(Self)
173 }
174
175 fn parse_ip(&self, header_name: &HeaderName) -> Result<IpAddr> {
177 self.0
178 .trim()
179 .parse()
180 .map_err(|_| Error::MalformedHeaderValue {
181 header_name: header_name.to_owned(),
182 header_value: self.0.to_owned(),
183 })
184 }
185}
186
187fn ip_from_single_header(header_map: &HeaderMap, header_name: &HeaderName) -> Result<IpAddr> {
190 AsciiHeaderValue::of_single_header(header_map, header_name)?.parse_ip(header_name)
191}
192
193mod error {
194 use std::fmt;
195
196 use http::HeaderName;
197
198 #[derive(Debug, PartialEq)]
200 pub enum Error {
201 AbsentHeader {
203 header_name: HeaderName,
205 },
206 NonAsciiHeaderValue {
208 header_name: HeaderName,
210 },
211 MalformedHeaderValue {
213 header_name: HeaderName,
215 header_value: String,
217 },
218 SingleHeaderRequired {
225 header_name: HeaderName,
227 },
228 #[cfg(feature = "forwarded-header")]
229 ForwardedNoFor {
231 header_value: String,
233 },
234 #[cfg(feature = "forwarded-header")]
235 ForwardedObfuscated {
237 header_value: String,
239 },
240 #[cfg(feature = "forwarded-header")]
241 ForwardedUnknown {
243 header_value: String,
245 },
246 }
247
248 impl fmt::Display for Error {
249 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250 match self {
251 Self::AbsentHeader { header_name } => {
252 write!(f, "Missing required header: {header_name}")
253 }
254 Self::NonAsciiHeaderValue { header_name } => write!(
255 f,
256 "Header value contains non-ASCII characters: {header_name}",
257 ),
258 Self::MalformedHeaderValue {
259 header_name,
260 header_value,
261 } => write!(
262 f,
263 "Malformed header value for `{header_name}`: {header_value}",
264 ),
265 Self::SingleHeaderRequired { header_name } => write!(
266 f,
267 "Multiple occurrences of the header aren't allowed: {header_name}"
268 ),
269 #[cfg(feature = "forwarded-header")]
270 Self::ForwardedNoFor { header_value } => write!(
271 f,
272 "`Forwarded` header missing `for` directive: {header_value}",
273 ),
274 #[cfg(feature = "forwarded-header")]
275 Self::ForwardedObfuscated { header_value } => write!(
276 f,
277 "`Forwarded` header contains obfuscated IP: {header_value}",
278 ),
279 #[cfg(feature = "forwarded-header")]
280 Self::ForwardedUnknown { header_value } => write!(
281 f,
282 "`Forwarded` header contains unknown identifier: {header_value}",
283 ),
284 }
285 }
286 }
287
288 impl std::error::Error for Error {}
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 const VALID_IPV4: &str = "1.2.3.4";
296 const VALID_IPV6: &str = "1:23:4567:89ab:c:d:e:f";
297
298 fn headers<'a>(items: impl IntoIterator<Item = (&'a str, &'a str)>) -> HeaderMap {
299 HeaderMap::from_iter(
300 items
301 .into_iter()
302 .map(|(name, value)| (name.parse().unwrap(), value.parse().unwrap())),
303 )
304 }
305
306 #[test]
307 fn test_ascii_header_value_of_last_header() {
308 let header_name_str = "my-header";
309 let header_name = HeaderName::from_static(header_name_str);
310
311 assert_eq!(
312 AsciiHeaderValue::of_last_header(&headers([]), &header_name).unwrap_err(),
313 Error::AbsentHeader {
314 header_name: header_name.clone()
315 }
316 );
317
318 assert_eq!(
319 AsciiHeaderValue::of_last_header(&headers([(header_name_str, "ы")]), &header_name)
320 .unwrap_err(),
321 Error::NonAsciiHeaderValue {
322 header_name: header_name.clone()
323 }
324 );
325
326 assert_eq!(
327 AsciiHeaderValue::of_last_header(&headers([(header_name_str, "foo")]), &header_name)
328 .unwrap()
329 .0,
330 "foo",
331 "single valid header"
332 );
333
334 assert_eq!(
335 AsciiHeaderValue::of_last_header(
336 &headers([(header_name_str, "foo"), (header_name_str, "bar")]),
337 &header_name
338 )
339 .unwrap()
340 .0,
341 "bar",
342 "multiple valid headers"
343 );
344 }
345
346 #[test]
347 fn test_ascii_header_value_of_single_header() {
348 let header_name_str = "my-header";
349 let header_name = HeaderName::from_static(header_name_str);
350
351 assert_eq!(
352 AsciiHeaderValue::of_single_header(&headers([]), &header_name).unwrap_err(),
353 Error::AbsentHeader {
354 header_name: header_name.clone()
355 }
356 );
357
358 assert_eq!(
359 AsciiHeaderValue::of_single_header(&headers([(header_name_str, "ы")]), &header_name)
360 .unwrap_err(),
361 Error::NonAsciiHeaderValue {
362 header_name: header_name.clone()
363 }
364 );
365
366 assert_eq!(
367 AsciiHeaderValue::of_single_header(
368 &headers([(header_name_str, "foo"), (header_name_str, "bar")]),
369 &header_name
370 )
371 .unwrap_err(),
372 Error::SingleHeaderRequired {
373 header_name: header_name.clone()
374 }
375 );
376
377 assert_eq!(
378 AsciiHeaderValue::of_single_header(&headers([(header_name_str, "foo")]), &header_name)
379 .unwrap()
380 .0,
381 "foo"
382 );
383 }
384
385 #[test]
386 fn test_cf_connecting_ip() {
387 let header = "cf-connecting-ip";
388
389 assert_eq!(
390 cf_connecting_ip(&headers([])).unwrap_err(),
391 Error::AbsentHeader {
392 header_name: HeaderName::from_static(header)
393 }
394 );
395 assert_eq!(
396 cf_connecting_ip(&headers([(header, "ы")])).unwrap_err(),
397 Error::NonAsciiHeaderValue {
398 header_name: HeaderName::from_static(header)
399 }
400 );
401 assert_eq!(
402 cf_connecting_ip(&headers([(header, "foo")])).unwrap_err(),
403 Error::MalformedHeaderValue {
404 header_name: HeaderName::from_static(header),
405 header_value: "foo".into(),
406 }
407 );
408
409 assert_eq!(
410 cf_connecting_ip(&headers([(header, VALID_IPV4)])).unwrap(),
411 VALID_IPV4.parse::<IpAddr>().unwrap()
412 );
413 assert_eq!(
414 cf_connecting_ip(&headers([(header, VALID_IPV6)])).unwrap(),
415 VALID_IPV6.parse::<IpAddr>().unwrap()
416 );
417 }
418
419 #[test]
420 fn test_cloudfront_viewer_address() {
421 let header = "cloudfront-viewer-address";
422
423 assert_eq!(
424 cloudfront_viewer_address(&headers([])).unwrap_err(),
425 Error::AbsentHeader {
426 header_name: HeaderName::from_static(header)
427 }
428 );
429 assert_eq!(
430 cloudfront_viewer_address(&headers([(header, "ы")])).unwrap_err(),
431 Error::NonAsciiHeaderValue {
432 header_name: HeaderName::from_static(header)
433 }
434 );
435 assert_eq!(
436 cloudfront_viewer_address(&headers([(header, VALID_IPV4)])).unwrap_err(),
437 Error::MalformedHeaderValue {
438 header_name: HeaderName::from_static(header),
439 header_value: VALID_IPV4.into(),
440 }
441 );
442 assert_eq!(
443 cloudfront_viewer_address(&headers([(header, "foo:8000")])).unwrap_err(),
444 Error::MalformedHeaderValue {
445 header_name: HeaderName::from_static(header),
446 header_value: "foo:8000".into(),
447 }
448 );
449
450 let valid_header_value_v4 = format!("{VALID_IPV4}:8000");
451 let valid_header_value_v6 = format!("{VALID_IPV6}:8000");
452 assert_eq!(
453 cloudfront_viewer_address(&headers([(header, valid_header_value_v4.as_ref())]))
454 .unwrap(),
455 VALID_IPV4.parse::<IpAddr>().unwrap()
456 );
457 assert_eq!(
458 cloudfront_viewer_address(&headers([(header, valid_header_value_v6.as_ref())]))
459 .unwrap(),
460 VALID_IPV6.parse::<IpAddr>().unwrap()
461 );
462 }
463
464 #[test]
465 fn test_fly_client_ip() {
466 let header = "fly-client-ip";
467
468 assert_eq!(
469 fly_client_ip(&headers([])).unwrap_err(),
470 Error::AbsentHeader {
471 header_name: HeaderName::from_static(header)
472 }
473 );
474 assert_eq!(
475 fly_client_ip(&headers([(header, "ы")])).unwrap_err(),
476 Error::NonAsciiHeaderValue {
477 header_name: HeaderName::from_static(header)
478 }
479 );
480 assert_eq!(
481 fly_client_ip(&headers([(header, "foo")])).unwrap_err(),
482 Error::MalformedHeaderValue {
483 header_name: HeaderName::from_static(header),
484 header_value: "foo".into(),
485 }
486 );
487
488 assert_eq!(
489 fly_client_ip(&headers([(header, VALID_IPV4)])).unwrap(),
490 VALID_IPV4.parse::<IpAddr>().unwrap()
491 );
492 assert_eq!(
493 fly_client_ip(&headers([(header, VALID_IPV6)])).unwrap(),
494 VALID_IPV6.parse::<IpAddr>().unwrap()
495 );
496 }
497
498 #[cfg(feature = "forwarded-header")]
499 #[test]
500 fn test_rightmost_forwarded() {
501 let header = "forwarded";
502
503 assert_eq!(
504 rightmost_forwarded(&headers([])).unwrap_err(),
505 Error::AbsentHeader {
506 header_name: HeaderName::from_static(header)
507 }
508 );
509 assert_eq!(
510 rightmost_forwarded(&headers([(header, "ы")])).unwrap_err(),
511 Error::NonAsciiHeaderValue {
512 header_name: HeaderName::from_static(header)
513 }
514 );
515 assert_eq!(
516 rightmost_forwarded(&headers([(header, "foo")])).unwrap_err(),
517 Error::MalformedHeaderValue {
518 header_name: HeaderName::from_static(header),
519 header_value: "foo".into(),
520 }
521 );
522 assert_eq!(
523 rightmost_forwarded(&headers([
524 (header, format!("for={VALID_IPV4}").as_ref()),
525 (header, "proto=http"),
526 ]))
527 .unwrap_err(),
528 Error::ForwardedNoFor {
529 header_value: "proto=http".into(),
530 }
531 );
532 assert_eq!(
533 rightmost_forwarded(&headers([(header, "for=unknown")])).unwrap_err(),
534 Error::ForwardedUnknown {
535 header_value: "for=unknown".into(),
536 }
537 );
538 assert_eq!(
539 rightmost_forwarded(&headers([(header, "for=_foo")])).unwrap_err(),
540 Error::ForwardedObfuscated {
541 header_value: "for=_foo".into(),
542 }
543 );
544
545 assert_eq!(
546 rightmost_forwarded(&headers([
547 (header, "proto=http"),
548 (header, format!("for={VALID_IPV4};proto=http").as_ref()),
549 ]))
550 .unwrap(),
551 VALID_IPV4.parse::<IpAddr>().unwrap()
552 );
553 assert_eq!(
554 rightmost_forwarded(&headers([(
555 header,
556 format!("for={VALID_IPV4}:8000").as_ref()
557 ),]))
558 .unwrap(),
559 VALID_IPV4.parse::<IpAddr>().unwrap()
560 );
561
562 assert_eq!(
563 rightmost_forwarded(&headers([(header, format!("for={VALID_IPV6}").as_ref()),]))
564 .unwrap(),
565 VALID_IPV6.parse::<IpAddr>().unwrap()
566 );
567 assert_eq!(
568 rightmost_forwarded(&headers([(
569 header,
570 format!("for=[{VALID_IPV6}]:8000").as_ref()
571 ),]))
572 .unwrap(),
573 VALID_IPV6.parse::<IpAddr>().unwrap()
574 );
575 }
576
577 #[test]
578 fn test_rightmost_x_forwarded_for() {
579 let header = "x-forwarded-for";
580
581 assert_eq!(
582 rightmost_x_forwarded_for(&headers([])).unwrap_err(),
583 Error::AbsentHeader {
584 header_name: HeaderName::from_static(header)
585 }
586 );
587 assert_eq!(
588 rightmost_x_forwarded_for(&headers([(header, "ы")])).unwrap_err(),
589 Error::NonAsciiHeaderValue {
590 header_name: HeaderName::from_static(header)
591 }
592 );
593 assert_eq!(
594 rightmost_x_forwarded_for(&headers([(header, "1.2.3.4,foo")])).unwrap_err(),
595 Error::MalformedHeaderValue {
596 header_name: HeaderName::from_static(header),
597 header_value: "1.2.3.4,foo".into(),
598 }
599 );
600
601 assert_eq!(
602 rightmost_x_forwarded_for(&headers([(header, format!("foo,{VALID_IPV4}").as_ref())]))
603 .unwrap(),
604 VALID_IPV4.parse::<IpAddr>().unwrap()
605 );
606 assert_eq!(
607 rightmost_x_forwarded_for(&headers([(header, VALID_IPV6)])).unwrap(),
608 VALID_IPV6.parse::<IpAddr>().unwrap()
609 );
610 }
611
612 #[test]
613 fn test_true_client_ip() {
614 let header = "true-client-ip";
615
616 assert_eq!(
617 true_client_ip(&headers([])).unwrap_err(),
618 Error::AbsentHeader {
619 header_name: HeaderName::from_static(header)
620 }
621 );
622 assert_eq!(
623 true_client_ip(&headers([(header, "ы")])).unwrap_err(),
624 Error::NonAsciiHeaderValue {
625 header_name: HeaderName::from_static(header)
626 }
627 );
628 assert_eq!(
629 true_client_ip(&headers([(header, "foo")])).unwrap_err(),
630 Error::MalformedHeaderValue {
631 header_name: HeaderName::from_static(header),
632 header_value: "foo".into(),
633 }
634 );
635
636 assert_eq!(
637 true_client_ip(&headers([(header, VALID_IPV4)])).unwrap(),
638 VALID_IPV4.parse::<IpAddr>().unwrap()
639 );
640 assert_eq!(
641 true_client_ip(&headers([(header, VALID_IPV6)])).unwrap(),
642 VALID_IPV6.parse::<IpAddr>().unwrap()
643 );
644 }
645
646 #[test]
647 fn test_x_real_ip() {
648 let header = "x-real-ip";
649
650 assert_eq!(
651 x_real_ip(&headers([])).unwrap_err(),
652 Error::AbsentHeader {
653 header_name: HeaderName::from_static(header)
654 }
655 );
656 assert_eq!(
657 x_real_ip(&headers([(header, "ы")])).unwrap_err(),
658 Error::NonAsciiHeaderValue {
659 header_name: HeaderName::from_static(header)
660 }
661 );
662 assert_eq!(
663 x_real_ip(&headers([(header, "foo")])).unwrap_err(),
664 Error::MalformedHeaderValue {
665 header_name: HeaderName::from_static(header),
666 header_value: "foo".into(),
667 }
668 );
669
670 assert_eq!(
671 x_real_ip(&headers([(header, VALID_IPV4)])).unwrap(),
672 VALID_IPV4.parse::<IpAddr>().unwrap()
673 );
674 assert_eq!(
675 x_real_ip(&headers([(header, VALID_IPV6)])).unwrap(),
676 VALID_IPV6.parse::<IpAddr>().unwrap()
677 );
678 }
679}