use std::net::{AddrParseError, IpAddr, SocketAddr};
use axum::{
extract::{ConnectInfo, FromRequestParts, rejection::ExtensionRejection},
http::{self, StatusCode, request},
response::IntoResponse,
};
use snafu::{ResultExt as _, Snafu};
use crate::middleware::ClientIpConfig;
#[derive(Debug, Snafu)]
pub enum Rejection {
#[snafu(display("could not get connection info"))]
GetConnectInfo { source: ExtensionRejection },
#[snafu(display("X-Forwarded-For header has an invalid value"))]
InvalidXForwardedFor { source: http::header::ToStrError },
#[snafu(display("X-Forwarded-For header contains an invalid address"))]
InvalidAddressInXForwardedFor { source: AddrParseError },
}
impl IntoResponse for Rejection {
fn into_response(self) -> axum::response::Response {
match self {
Rejection::GetConnectInfo { source } => source.into_response(),
Rejection::InvalidXForwardedFor { ref source } => {
(StatusCode::BAD_REQUEST, format!("{self}: {source}")).into_response()
}
Rejection::InvalidAddressInXForwardedFor { ref source } => {
(StatusCode::BAD_REQUEST, format!("{self}: {source}")).into_response()
}
}
}
}
pub struct ClientAddress {
pub address: IpAddr,
}
impl<S: Send + Sync + ClientIpConfig> FromRequestParts<S> for ClientAddress {
type Rejection = Rejection;
async fn from_request_parts(
parts: &mut request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let connect_info = ConnectInfo::<SocketAddr>::from_request_parts(parts, state)
.await
.context(GetConnectInfoSnafu)?;
let mut last_addr = connect_info.0.ip();
let x_forwarded_for = parts
.headers
.get("X-Forwarded-For")
.map(|v| v.to_str())
.transpose()
.context(InvalidXForwardedForSnafu)?;
for next_addr in x_forwarded_for.into_iter().flat_map(|v| v.split(',')).rev() {
if !state.is_trusted_forwarder(last_addr) {
break;
}
last_addr = next_addr
.trim()
.parse()
.context(InvalidAddressInXForwardedForSnafu)?;
}
Ok(ClientAddress { address: last_addr })
}
}