Skip to main content

ic_bn_lib/http/middleware/
mod.rs

1use std::{net::IpAddr, str::FromStr, sync::Arc};
2
3use http::Request;
4use ic_bn_lib_common::types::http::ConnInfo;
5
6use crate::http::headers::X_REAL_IP;
7
8pub mod rate_limiter;
9pub mod waf;
10
11/// Extracts IP address from `x-real-ip` header or `ConnInfo` extension
12pub fn extract_ip_from_request<B>(req: &Request<B>) -> Option<IpAddr> {
13    // Try to extract from the header first
14    req.headers()
15        .get(X_REAL_IP)
16        .and_then(|x| x.to_str().ok())
17        .and_then(|x| IpAddr::from_str(x).ok())
18        .or_else(|| {
19            // Then, if that failed, from the ConnInfo extension
20            req.extensions()
21                .get::<Arc<ConnInfo>>()
22                .map(|x| x.remote_addr.ip())
23        })
24}
25
26#[cfg(test)]
27mod test {
28    use std::net::SocketAddr;
29
30    use ic_bn_lib_common::types::http::Addr;
31
32    use super::*;
33
34    #[test]
35    fn test_extract_ip_from_request() {
36        let addr1 = IpAddr::from_str("10.0.0.1").unwrap();
37        let addr2 = IpAddr::from_str("192.168.0.1").unwrap();
38
39        let mut ci = ConnInfo::default();
40        ci.remote_addr = Addr::Tcp(SocketAddr::new(addr1, 31337));
41        let ci = Arc::new(ci);
42
43        // Header takes precedence
44        let req = Request::builder()
45            .extension(ci.clone())
46            .header(X_REAL_IP, addr2.to_string())
47            .body("")
48            .unwrap();
49        assert_eq!(extract_ip_from_request(&req), Some(addr2));
50
51        // Only ConnInfo
52        let req = Request::builder().extension(ci).body("").unwrap();
53        assert_eq!(extract_ip_from_request(&req), Some(addr1));
54
55        // Only header
56        let req = Request::builder()
57            .header(X_REAL_IP, addr2.to_string())
58            .body("")
59            .unwrap();
60        assert_eq!(extract_ip_from_request(&req), Some(addr2));
61
62        // Neither
63        let req = Request::builder().body("").unwrap();
64        assert_eq!(extract_ip_from_request(&req), None);
65    }
66}