use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
str::FromStr,
};
use http::{HeaderMap, HeaderName};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use motore::{Service, layer::Layer};
use volo::{context::Context, net::Address};
use crate::{context::ServerContext, request::Request};
#[derive(Clone, Debug, Default)]
pub struct ClientIpLayer {
config: ClientIpConfig,
}
impl ClientIpLayer {
pub fn new() -> Self {
Default::default()
}
pub fn with_config(self, config: ClientIpConfig) -> Self {
Self { config }
}
}
impl<S> Layer<S> for ClientIpLayer
where
S: Send + Sync + 'static,
{
type Service = ClientIpService<S>;
fn layer(self, inner: S) -> Self::Service {
ClientIpService {
service: inner,
config: self.config,
}
}
}
#[derive(Clone, Debug)]
pub struct ClientIpConfig {
remote_ip_headers: Vec<HeaderName>,
trusted_cidrs: Vec<IpNet>,
}
impl Default for ClientIpConfig {
fn default() -> Self {
Self {
remote_ip_headers: vec![
HeaderName::from_static("x-real-ip"),
HeaderName::from_static("x-forwarded-for"),
],
trusted_cidrs: vec![
IpNet::V4(Ipv4Net::new_assert(Ipv4Addr::new(0, 0, 0, 0), 0)),
IpNet::V6(Ipv6Net::new_assert(
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
0,
)),
],
}
}
}
impl ClientIpConfig {
pub fn new() -> Self {
Default::default()
}
pub fn with_remote_ip_headers<I>(
self,
headers: I,
) -> Result<Self, http::header::InvalidHeaderName>
where
I: IntoIterator,
I::Item: AsRef<str>,
{
let headers = headers.into_iter().collect::<Vec<_>>();
let mut remote_ip_headers = Vec::with_capacity(headers.len());
for header_str in headers {
let header_value = HeaderName::from_str(header_str.as_ref())?;
remote_ip_headers.push(header_value);
}
Ok(Self {
remote_ip_headers,
trusted_cidrs: self.trusted_cidrs,
})
}
pub fn with_trusted_cidrs<H>(self, cidrs: H) -> Self
where
H: IntoIterator<Item = IpNet>,
{
Self {
remote_ip_headers: self.remote_ip_headers,
trusted_cidrs: cidrs.into_iter().collect(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ClientIp(pub Option<IpAddr>);
#[derive(Clone, Debug)]
pub struct ClientIpService<S> {
service: S,
config: ClientIpConfig,
}
impl<S> ClientIpService<S> {
fn get_client_ip(&self, cx: &ServerContext, headers: &HeaderMap) -> ClientIp {
let remote_ip = match &cx.rpc_info().caller().address {
Some(Address::Ip(socket_addr)) => Some(socket_addr.ip()),
#[cfg(target_family = "unix")]
Some(Address::Unix(_)) => None,
#[allow(unreachable_patterns)]
Some(_) => unimplemented!("unsupported type of address"),
None => return ClientIp(None),
};
if let Some(remote_ip) = &remote_ip {
if !self
.config
.trusted_cidrs
.iter()
.any(|cidr| cidr.contains(remote_ip))
{
return ClientIp(None);
}
}
for remote_ip_header in self.config.remote_ip_headers.iter() {
let Some(remote_ips) = headers.get(remote_ip_header).and_then(|v| v.to_str().ok())
else {
continue;
};
for remote_ip in remote_ips.split(',').map(str::trim) {
if let Ok(remote_ip_addr) = IpAddr::from_str(remote_ip) {
if self
.config
.trusted_cidrs
.iter()
.any(|cidr| cidr.contains(&remote_ip_addr))
{
return ClientIp(Some(remote_ip_addr));
}
}
}
}
ClientIp(remote_ip)
}
}
impl<S, B> Service<ServerContext, Request<B>> for ClientIpService<S>
where
S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
B: Send,
{
type Response = S::Response;
type Error = S::Error;
async fn call(
&self,
cx: &mut ServerContext,
req: Request<B>,
) -> Result<Self::Response, Self::Error> {
let client_ip = self.get_client_ip(cx, req.headers());
cx.extensions_mut().insert(client_ip);
self.service.call(cx, req).await
}
}
#[cfg(test)]
mod client_ip_tests {
use std::{net::SocketAddr, str::FromStr};
use http::{HeaderValue, Method};
use motore::{Service, layer::Layer};
use volo::net::Address;
use crate::{
body::BodyConversion,
context::ServerContext,
server::{
route::{Route, get},
utils::client_ip::{ClientIp, ClientIpConfig, ClientIpLayer},
},
utils::test_helpers::simple_req,
};
#[tokio::test]
async fn test_client_ip() {
async fn handler(ClientIp(client_ip): ClientIp) -> String {
client_ip.unwrap().to_string()
}
let route: Route<&str> = Route::new(get(handler));
let service = ClientIpLayer::new()
.with_config(
ClientIpConfig::default().with_trusted_cidrs(vec!["10.0.0.0/8".parse().unwrap()]),
)
.layer(route);
let mut cx = ServerContext::new(Address::from(
SocketAddr::from_str("10.0.0.1:8080").unwrap(),
));
let req = simple_req(Method::GET, "/", "");
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!("10.0.0.1", resp.into_string().await.unwrap());
let mut req = simple_req(Method::GET, "/", "");
req.headers_mut()
.insert("X-Real-IP", HeaderValue::from_static("10.0.0.2"));
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!("10.0.0.2", resp.into_string().await.unwrap());
let mut req = simple_req(Method::GET, "/", "");
req.headers_mut()
.insert("X-Forwarded-For", HeaderValue::from_static("10.0.1.0"));
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!("10.0.1.0", resp.into_string().await.unwrap());
let mut req = simple_req(Method::GET, "/", "");
req.headers_mut()
.insert("X-Real-IP", HeaderValue::from_static("11.0.0.1"));
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!("10.0.0.1", resp.into_string().await.unwrap());
}
}