use std::{
marker::Sync,
net::{IpAddr, SocketAddr},
};
use axum::{
extract::{ConnectInfo, FromRequestParts},
http::{request::Parts, Extensions, HeaderMap, HeaderValue, StatusCode},
};
use crate::rudimental::{
CfConnectingIp, CloudFrontViewerAddress, FlyClientIp, Forwarded, MultiIpHeader, SingleIpHeader,
TrueClientIp, XForwardedFor, XRealIp,
};
#[derive(Debug, Clone, Copy)]
pub struct InsecureClientIp(pub IpAddr);
type Rejection = (StatusCode, &'static str);
impl InsecureClientIp {
pub fn from(
headers: &HeaderMap<HeaderValue>,
extensions: &Extensions,
) -> Result<Self, Rejection> {
XForwardedFor::maybe_leftmost_ip(headers)
.or_else(|| Forwarded::maybe_leftmost_ip(headers))
.or_else(|| XRealIp::maybe_ip_from_headers(headers))
.or_else(|| FlyClientIp::maybe_ip_from_headers(headers))
.or_else(|| TrueClientIp::maybe_ip_from_headers(headers))
.or_else(|| CfConnectingIp::maybe_ip_from_headers(headers))
.or_else(|| CloudFrontViewerAddress::maybe_ip_from_headers(headers))
.or_else(|| maybe_connect_info(extensions))
.map(Self)
.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Can't extract `UnsecureClientIp`, provide `axum::extract::ConnectInfo`",
))
}
}
impl<S> FromRequestParts<S> for InsecureClientIp
where
S: Sync,
{
type Rejection = Rejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
Self::from(&parts.headers, &parts.extensions)
}
}
fn maybe_connect_info(extensions: &Extensions) -> Option<IpAddr> {
extensions
.get::<ConnectInfo<SocketAddr>>()
.map(|ConnectInfo(addr)| addr.ip())
}
#[cfg(test)]
mod tests {
use axum::{body::Body, http::Request, routing::get, Router};
use http_body_util::BodyExt;
use tower::ServiceExt;
use super::InsecureClientIp;
fn app() -> Router {
Router::new().route(
"/",
get(|InsecureClientIp(ip): InsecureClientIp| async move { ip.to_string() }),
)
}
async fn body_string(body: Body) -> String {
let bytes = body.collect().await.unwrap().to_bytes();
String::from_utf8_lossy(&bytes).into()
}
#[tokio::test]
async fn x_forwarded_for() {
let req = Request::builder()
.uri("/")
.header("X-Forwarded-For", "1.1.1.1, 2.2.2.2")
.body(Body::empty())
.unwrap();
let resp = app().oneshot(req).await.unwrap();
assert_eq!(body_string(resp.into_body()).await, "1.1.1.1");
}
#[tokio::test]
async fn x_real_ip() {
let req = Request::builder()
.uri("/")
.header("X-Real-Ip", "1.2.3.4")
.body(Body::empty())
.unwrap();
let resp = app().oneshot(req).await.unwrap();
assert_eq!(body_string(resp.into_body()).await, "1.2.3.4");
}
#[tokio::test]
async fn forwarded() {
let req = Request::builder()
.uri("/")
.header("Forwarded", "For=\"[2001:db8:cafe::17]:4711\"")
.body(Body::empty())
.unwrap();
let resp = app().oneshot(req).await.unwrap();
assert_eq!(body_string(resp.into_body()).await, "2001:db8:cafe::17");
}
#[tokio::test]
async fn malformed() {
let req = Request::builder()
.uri("/")
.header("X-Forwarded-For", "foo")
.header("X-Real-Ip", "foo")
.header("Forwarded", "foo")
.header("Forwarded", "for=1.1.1.1;proto=https;by=2.2.2.2")
.body(Body::empty())
.unwrap();
let resp = app().oneshot(req).await.unwrap();
assert_eq!(body_string(resp.into_body()).await, "1.1.1.1");
}
}