faucet_server/client/load_balancing/
ip_extractor.rs

1use crate::error::{BadRequestReason, FaucetError, FaucetResult};
2use hyper::{http::HeaderValue, Request};
3use std::net::IpAddr;
4
5#[derive(Clone, Copy, Debug, serde::Deserialize)]
6#[serde(rename = "snake_case")]
7pub enum IpExtractor {
8    ClientAddr,
9    XForwardedFor,
10    XRealIp,
11}
12
13const MISSING_X_FORWARDED_FOR: FaucetError =
14    FaucetError::BadRequest(BadRequestReason::MissingHeader("X-Forwarded-For"));
15
16const INVALID_X_FORWARDED_FOR: FaucetError =
17    FaucetError::BadRequest(BadRequestReason::InvalidHeader("X-Forwarded-For"));
18
19fn extract_ip_from_x_forwarded_for(x_forwarded_for: &HeaderValue) -> FaucetResult<IpAddr> {
20    let x_forwarded_for = x_forwarded_for
21        .to_str()
22        .map_err(|_| MISSING_X_FORWARDED_FOR)?;
23    let ip_str = x_forwarded_for
24        .split(',')
25        .next()
26        .map(|ip| ip.trim())
27        .ok_or(INVALID_X_FORWARDED_FOR)?;
28    ip_str.parse().map_err(|_| INVALID_X_FORWARDED_FOR)
29}
30
31const MISSING_X_REAL_IP: FaucetError =
32    FaucetError::BadRequest(BadRequestReason::MissingHeader("X-Real-IP"));
33
34const INVALID_X_REAL_IP: FaucetError =
35    FaucetError::BadRequest(BadRequestReason::InvalidHeader("X-Real-IP"));
36
37fn extract_ip_from_x_real_ip(x_real_ip: &HeaderValue) -> FaucetResult<IpAddr> {
38    let x_real_ip = x_real_ip.to_str().map_err(|_| MISSING_X_REAL_IP)?;
39    x_real_ip.parse().map_err(|_| INVALID_X_REAL_IP)
40}
41
42impl IpExtractor {
43    pub fn extract<B>(self, req: &Request<B>, client_addr: Option<IpAddr>) -> FaucetResult<IpAddr> {
44        use IpExtractor::*;
45        let ip = match self {
46            ClientAddr => client_addr.expect("Unable to get client address"),
47            XForwardedFor => match req.headers().get("X-Forwarded-For") {
48                Some(header) => extract_ip_from_x_forwarded_for(header)?,
49                None => return Err(MISSING_X_FORWARDED_FOR),
50            },
51            XRealIp => match req.headers().get("X-Real-IP") {
52                Some(header) => extract_ip_from_x_real_ip(header)?,
53                None => return Err(MISSING_X_REAL_IP),
54            },
55        };
56        Ok(ip)
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[test]
65    fn extract_ip_from_x_forwarded_for_ipv4() {
66        let header_value = HeaderValue::from_static("127.0.0.1");
67        let ip = extract_ip_from_x_forwarded_for(&header_value).unwrap();
68        assert_eq!(ip, IpAddr::from([127, 0, 0, 1]));
69    }
70
71    #[test]
72    fn extract_ip_from_x_forwarded_for_ipv6() {
73        let header_value = HeaderValue::from_static("::1");
74        let ip = extract_ip_from_x_forwarded_for(&header_value).unwrap();
75        assert_eq!(ip, IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1]));
76    }
77
78    #[test]
79    fn extract_ip_from_x_forwarded_for_multiple() {
80        let header_value = HeaderValue::from_static("192.168.0.1, 127.0.0.1");
81        let ip = extract_ip_from_x_forwarded_for(&header_value).unwrap();
82        assert_eq!(ip, IpAddr::from([192, 168, 0, 1]));
83    }
84
85    #[test]
86    fn extract_x_real_ip_ipv4_from_request() {
87        let header_value = HeaderValue::from_static("127.0.0.1");
88        let request = Request::builder()
89            .header("X-Real-IP", header_value)
90            .body(())
91            .unwrap();
92        let ip = IpExtractor::XRealIp
93            .extract(&request, Some(IpAddr::from([0, 0, 0, 0])))
94            .unwrap();
95        assert_eq!(ip, IpAddr::from([127, 0, 0, 1]));
96    }
97
98    #[test]
99    fn extract_x_real_ip_ipv6_from_request() {
100        let header_value = HeaderValue::from_static("::1");
101        let request = Request::builder()
102            .header("X-Real-IP", header_value)
103            .body(())
104            .unwrap();
105        let ip = IpExtractor::XRealIp
106            .extract(&request, Some(IpAddr::from([0, 0, 0, 0])))
107            .unwrap();
108        assert_eq!(ip, IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1]));
109    }
110
111    #[test]
112    fn extract_x_forwarded_for_ipv4_from_request() {
113        let header_value = HeaderValue::from_static("127.0.0.1");
114        let request = Request::builder()
115            .header("X-Forwarded-For", header_value)
116            .body(())
117            .unwrap();
118        let ip = IpExtractor::XForwardedFor
119            .extract(&request, Some(IpAddr::from([0, 0, 0, 0])))
120            .unwrap();
121        assert_eq!(ip, IpAddr::from([127, 0, 0, 1]));
122    }
123
124    #[test]
125    fn extract_x_forwarded_for_ipv6_from_request() {
126        let header_value = HeaderValue::from_static("::1");
127        let request = Request::builder()
128            .header("X-Forwarded-For", header_value)
129            .body(())
130            .unwrap();
131        let ip = IpExtractor::XForwardedFor
132            .extract(&request, Some(IpAddr::from([0, 0, 0, 0])))
133            .unwrap();
134        assert_eq!(ip, IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1]));
135    }
136
137    #[test]
138    fn extract_x_forwarded_for_ipv4_from_request_multiple() {
139        let header_value = HeaderValue::from_static("192.168.0.1, 127.0.0.1");
140        let request = Request::builder()
141            .header("X-Forwarded-For", header_value)
142            .body(())
143            .unwrap();
144        let ip = IpExtractor::XForwardedFor
145            .extract(&request, Some(IpAddr::from([0, 0, 0, 0])))
146            .unwrap();
147        assert_eq!(ip, IpAddr::from([192, 168, 0, 1]));
148    }
149
150    #[test]
151    fn extract_client_addr_ipv4_from_request() {
152        let request = Request::builder().body(()).unwrap();
153        let ip = IpExtractor::ClientAddr
154            .extract(&request, Some(IpAddr::from([127, 0, 0, 1])))
155            .unwrap();
156        assert_eq!(ip, IpAddr::from([127, 0, 0, 1]));
157    }
158
159    #[test]
160    fn extract_client_addr_ipv6_from_request() {
161        let request = Request::builder().body(()).unwrap();
162        let ip = IpExtractor::ClientAddr
163            .extract(&request, Some(IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1])))
164            .unwrap();
165        assert_eq!(ip, IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1]));
166    }
167
168    #[test]
169    fn extract_client_addr_ipv4_with_x_forwarded_for_from_request() {
170        let header_value = HeaderValue::from_static("192.168.0.1");
171        let request = Request::builder()
172            .header("X-Forwarded-For", header_value)
173            .body(())
174            .unwrap();
175        let ip = IpExtractor::ClientAddr
176            .extract(&request, Some(IpAddr::from([127, 0, 0, 1])))
177            .unwrap();
178        assert_eq!(ip, IpAddr::from([127, 0, 0, 1]));
179    }
180}