use crate::rudimental::{
CfConnectingIp, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader, TrueClientIp,
XForwardedFor, XRealIp,
};
use axum::{
async_trait,
extract::{ConnectInfo, FromRequestParts},
http::{request::Parts, Extensions, StatusCode},
};
use std::{
marker::Sync,
net::{IpAddr, SocketAddr},
};
#[derive(Debug)]
pub struct InsecureClientIp(pub IpAddr);
#[async_trait]
impl<S> FromRequestParts<S> for InsecureClientIp
where
S: Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
XForwardedFor::maybe_leftmost_ip(&parts.headers)
.or_else(|| Forwarded::maybe_leftmost_ip(&parts.headers))
.or_else(|| XRealIp::maybe_ip_from_headers(&parts.headers))
.or_else(|| FlyClientIp::maybe_ip_from_headers(&parts.headers))
.or_else(|| TrueClientIp::maybe_ip_from_headers(&parts.headers))
.or_else(|| CfConnectingIp::maybe_ip_from_headers(&parts.headers))
.or_else(|| maybe_connect_info(&parts.extensions))
.map(Self)
.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Can't extract `UnsecureClientIp`, provide `axum::extract::ConnectInfo`",
))
}
}
fn maybe_connect_info(extensions: &Extensions) -> Option<IpAddr> {
extensions
.get::<ConnectInfo<SocketAddr>>()
.map(|ConnectInfo(addr)| addr.ip())
}
#[cfg(test)]
mod tests {
use super::InsecureClientIp;
use axum::{
body::{Body, BoxBody},
http::Request,
routing::get,
Router,
};
use tower::ServiceExt;
fn app() -> Router {
Router::new().route(
"/",
get(|InsecureClientIp(ip): InsecureClientIp| async move { ip.to_string() }),
)
}
async fn body_string(body: BoxBody) -> String {
let bytes = hyper::body::to_bytes(body).await.unwrap();
String::from_utf8_lossy(&bytes).into()
}
#[tokio::test]
async fn x_forwarded_for() {
let req = Request::builder()
.uri("/")
.header("X-Forwarded-For", "1.1.1.1, 2.2.2.2")
.body(Body::empty())
.unwrap();
let res = app().oneshot(req).await.unwrap();
assert_eq!(body_string(res.into_body()).await, "1.1.1.1");
}
#[tokio::test]
async fn x_real_ip() {
let req = Request::builder()
.uri("/")
.header("X-Real-Ip", "1.2.3.4")
.body(Body::empty())
.unwrap();
let res = app().oneshot(req).await.unwrap();
assert_eq!(body_string(res.into_body()).await, "1.2.3.4");
}
#[tokio::test]
async fn forwarded() {
let req = Request::builder()
.uri("/")
.header("Forwarded", "For=\"[2001:db8:cafe::17]:4711\"")
.body(Body::empty())
.unwrap();
let res = app().oneshot(req).await.unwrap();
assert_eq!(body_string(res.into_body()).await, "2001:db8:cafe::17");
}
#[tokio::test]
async fn malformed() {
let req = Request::builder()
.uri("/")
.header("X-Forwarded-For", "foo")
.header("X-Real-Ip", "foo")
.header("Forwarded", "foo")
.header("Forwarded", "for=1.1.1.1;proto=https;by=2.2.2.2")
.body(Body::empty())
.unwrap();
let res = app().oneshot(req).await.unwrap();
assert_eq!(body_string(res.into_body()).await, "1.1.1.1");
}
}