use std::{
fmt,
iter::Iterator,
net::{IpAddr, SocketAddr},
str::FromStr,
sync::OnceLock,
task::{Context, Poll},
};
use axum::{
body::Body,
extract::{ConnectInfo, FromRequestParts, Request},
http::{header::HeaderMap, request::Parts},
response::Response,
Router as AXRouter,
};
use futures_util::future::BoxFuture;
use ipnetwork::IpNetwork;
use serde::{Deserialize, Serialize};
use tower::{Layer, Service};
use tracing::error;
use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Error, Result};
static LOCAL_TRUSTED_PROXIES: OnceLock<Vec<IpNetwork>> = OnceLock::new();
fn get_local_trusted_proxies() -> &'static Vec<IpNetwork> {
LOCAL_TRUSTED_PROXIES.get_or_init(|| {
[
"127.0.0.0/8", "::1", "fc00::/7", "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16",
]
.iter()
.map(|ip| IpNetwork::from_str(ip).unwrap())
.collect()
})
}
const X_FORWARDED_FOR: &str = "X-Forwarded-For";
#[derive(Default, Serialize, Deserialize, Debug, Clone)]
pub struct RemoteIpMiddleware {
#[serde(default)]
pub enable: bool,
pub trusted_proxies: Option<Vec<String>>,
}
impl MiddlewareLayer for RemoteIpMiddleware {
fn name(&self) -> &'static str {
"remote_ip"
}
fn is_enabled(&self) -> bool {
self.enable
&& (self.trusted_proxies.is_none()
|| self.trusted_proxies.as_ref().is_some_and(|t| !t.is_empty()))
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
Ok(app.layer(RemoteIPLayer::new(self)?))
}
}
fn maybe_get_forwarded(
headers: &HeaderMap,
trusted_proxies: Option<&Vec<IpNetwork>>,
) -> Option<IpAddr> {
let xffs = headers
.get_all(X_FORWARDED_FOR)
.iter()
.map(|hdr| hdr.to_str())
.filter_map(Result::ok)
.collect::<Vec<_>>();
if xffs.is_empty() {
return None;
}
let forwarded = xffs.join(",");
forwarded
.split(',')
.map(str::trim)
.map(str::parse)
.filter_map(Result::ok)
.filter(|ip| {
let proxies = trusted_proxies.unwrap_or_else(|| get_local_trusted_proxies());
!proxies
.iter()
.any(|trusted_proxy| trusted_proxy.contains(*ip))
})
.next_back()
}
#[derive(Copy, Clone, Debug)]
pub enum RemoteIP {
Forwarded(IpAddr),
Socket(IpAddr),
None,
}
impl<S> FromRequestParts<S> for RemoteIP
where
S: Send + Sync,
{
type Rejection = ();
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let ip = parts.extensions.get::<Self>();
Ok(*ip.unwrap_or(&Self::None))
}
}
impl fmt::Display for RemoteIP {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Forwarded(ip) => write!(f, "remote: {ip}"),
Self::Socket(ip) => write!(f, "socket: {ip}"),
Self::None => write!(f, "--"),
}
}
}
#[derive(Clone, Debug)]
struct RemoteIPLayer {
trusted_proxies: Option<Vec<IpNetwork>>,
}
impl RemoteIPLayer {
pub fn new(config: &RemoteIpMiddleware) -> Result<Self> {
Ok(Self {
trusted_proxies: config
.trusted_proxies
.as_ref()
.map(|proxies| {
proxies
.iter()
.map(|proxy| {
IpNetwork::from_str(proxy).map_err(|err| {
Error::Message(format!(
"remote ip middleare cannot parse trusted proxy \
configuration: `{proxy}`, reason: `{err}`",
))
})
})
.collect::<Result<Vec<_>>>()
})
.transpose()?,
})
}
}
impl<S> Layer<S> for RemoteIPLayer {
type Service = RemoteIPMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
RemoteIPMiddleware {
inner,
layer: self.clone(),
}
}
}
#[derive(Clone, Debug)]
#[must_use]
pub struct RemoteIPMiddleware<S> {
inner: S,
layer: RemoteIPLayer,
}
impl<S> Service<Request<Body>> for RemoteIPMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let layer = self.layer.clone();
let xff_ip = maybe_get_forwarded(req.headers(), layer.trusted_proxies.as_ref());
let remote_ip = xff_ip.map_or_else(
|| {
let ip = req
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map_or_else(
|| {
error!(
"remote ip middleware cannot get socket IP (not set in axum \
extensions): setting IP to `127.0.0.1`"
);
RemoteIP::None
},
|info| RemoteIP::Socket(info.ip()),
);
ip
},
RemoteIP::Forwarded,
);
req.extensions_mut().insert(remote_ip);
Box::pin(self.inner.call(req))
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use axum::http::{HeaderMap, HeaderName, HeaderValue};
use insta::assert_debug_snapshot;
use ipnetwork::IpNetwork;
use super::maybe_get_forwarded;
fn xff(val: &str) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-forwarded-for"),
HeaderValue::from_str(val).unwrap(),
);
headers
}
#[test]
pub fn test_parsing() {
let res = maybe_get_forwarded(&xff(""), None);
assert_debug_snapshot!(res);
let res = maybe_get_forwarded(&xff("foobar"), None);
assert_debug_snapshot!(res);
let res = maybe_get_forwarded(&xff("192.1.1.1"), None);
assert_debug_snapshot!(res);
let res = maybe_get_forwarded(&xff("51.50.51.50,10.0.0.1,192.168.1.1"), None);
assert_debug_snapshot!(res);
let res = maybe_get_forwarded(&xff("19.84.19.84,192.168.0.1"), None);
assert_debug_snapshot!(res);
let res = maybe_get_forwarded(&xff("b51.50.51.50b,/10.0.0.1-,192.168.1.1"), None);
assert_debug_snapshot!(res);
let res = maybe_get_forwarded(
&xff("51.50.51.50,192.1.1.1"),
Some(&vec![IpNetwork::from_str("192.1.1.1/8").unwrap()]),
);
assert_debug_snapshot!(res);
let res = maybe_get_forwarded(
&xff("51.50.51.50,192.168.1.1"),
Some(&vec![IpNetwork::from_str("192.1.1.1/16").unwrap()]),
);
assert_debug_snapshot!(res);
}
}