use axum::extract::{ConnectInfo, FromRequestParts};
use http::{request::Parts, Request, Response};
use ipnetwork::IpNetwork;
use std::{
env,
future::Future,
net::{IpAddr, SocketAddr},
pin::Pin,
str::FromStr,
sync::Arc,
task::{Context, Poll},
};
use tower::{Layer, Service};
use tracing::{debug, warn};
#[derive(Clone, Debug)]
pub struct TrustedProxyConfig {
trusted_networks: Arc<Vec<IpNetwork>>,
}
impl TrustedProxyConfig {
pub fn new(networks: Vec<IpNetwork>) -> Self {
Self {
trusted_networks: Arc::new(networks),
}
}
pub fn from_env(env_key: &str) -> Result<Self, String> {
let val =
env::var(env_key).map_err(|_| format!("Environment variable {} not found", env_key))?;
Self::parse_str(&val)
}
pub fn parse_str(input: &str) -> Result<Self, String> {
let mut networks = Vec::new();
for part in input.split(';') {
let part = part.trim();
if part.is_empty() {
continue;
}
match part.parse::<IpNetwork>() {
Ok(net) => networks.push(net),
Err(_) => match part.parse::<IpAddr>() {
Ok(ip) => networks.push(IpNetwork::from(ip)),
Err(_) => return Err(format!("Invalid IP or CIDR: {}", part)),
},
}
}
debug!("Loaded {} trusted proxy networks", networks.len());
Ok(Self::new(networks))
}
pub fn is_trusted(&self, ip: &IpAddr) -> bool {
self.trusted_networks.iter().any(|net| net.contains(*ip))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RealIp(pub IpAddr);
#[derive(Clone)]
pub struct RealIpLayer {
config: TrustedProxyConfig,
}
impl RealIpLayer {
pub fn new(config: TrustedProxyConfig) -> Self {
Self { config }
}
}
impl<S> Layer<S> for RealIpLayer {
type Service = RealIpService<S>;
fn layer(&self, inner: S) -> Self::Service {
RealIpService {
inner,
config: self.config.clone(),
}
}
}
#[derive(Clone)]
pub struct RealIpService<S> {
inner: S,
config: TrustedProxyConfig,
}
impl<S, B> Service<Request<B>> for RealIpService<S>
where
S: Service<Request<B>, Response = Response<B>> + Send + Clone + 'static,
S::Future: Send + 'static,
B: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
let remote_addr = req
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip());
let config = self.config.clone();
let headers = req.headers().clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let mut resolved_ip = remote_addr.unwrap_or_else(|| {
IpAddr::from([0, 0, 0, 0])
});
if let Some(peer_ip) = remote_addr {
if config.is_trusted(&peer_ip)
&& let Some(xff_val) = headers.get("x-forwarded-for")
&& let Ok(xff_str) = xff_val.to_str()
{
let ips: Vec<&str> = xff_str.split(',').map(|s| s.trim()).collect();
for ip_str in ips.iter().rev() {
if let Ok(ip) = IpAddr::from_str(ip_str) {
if !config.is_trusted(&ip) {
resolved_ip = ip;
break;
}
} else {
warn!("Skipping invalid IP in X-Forwarded-For: {}", ip_str);
}
}
}
}
req.extensions_mut().insert(RealIp(resolved_ip));
inner.call(req).await
})
}
}
impl<S> FromRequestParts<S> for RealIp
where
S: Send + Sync,
{
type Rejection = (http::StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts.extensions.get::<RealIp>().cloned().ok_or((
http::StatusCode::INTERNAL_SERVER_ERROR,
"RealIp middleware is not configured correctly. Missing RealIp extension.",
))
}
}