use std::{net::IpAddr, str::FromStr, sync::Arc};
use http::Request;
use ic_bn_lib_common::types::http::ConnInfo;
use crate::http::headers::X_REAL_IP;
pub mod rate_limiter;
pub mod waf;
pub fn extract_ip_from_request<B>(req: &Request<B>) -> Option<IpAddr> {
req.headers()
.get(X_REAL_IP)
.and_then(|x| x.to_str().ok())
.and_then(|x| IpAddr::from_str(x).ok())
.or_else(|| {
req.extensions()
.get::<Arc<ConnInfo>>()
.map(|x| x.remote_addr.ip())
})
}
#[cfg(test)]
mod test {
use std::net::SocketAddr;
use ic_bn_lib_common::types::http::Addr;
use super::*;
#[test]
fn test_extract_ip_from_request() {
let addr1 = IpAddr::from_str("10.0.0.1").unwrap();
let addr2 = IpAddr::from_str("192.168.0.1").unwrap();
let mut ci = ConnInfo::default();
ci.remote_addr = Addr::Tcp(SocketAddr::new(addr1, 31337));
let ci = Arc::new(ci);
let req = Request::builder()
.extension(ci.clone())
.header(X_REAL_IP, addr2.to_string())
.body("")
.unwrap();
assert_eq!(extract_ip_from_request(&req), Some(addr2));
let req = Request::builder().extension(ci).body("").unwrap();
assert_eq!(extract_ip_from_request(&req), Some(addr1));
let req = Request::builder()
.header(X_REAL_IP, addr2.to_string())
.body("")
.unwrap();
assert_eq!(extract_ip_from_request(&req), Some(addr2));
let req = Request::builder().body("").unwrap();
assert_eq!(extract_ip_from_request(&req), None);
}
}