use std::net::IpAddr as StdIpAddr;
use std::str::FromStr;
use http::StatusCode;
use http::request::Parts;
use crate::extractors::FromRequest;
use crate::extractors::FromRequestParts;
use crate::responder::Responder;
use crate::types::Request;
#[derive(Debug, Clone, PartialEq)]
#[doc(alias = "ip")]
#[doc(alias = "ipaddr")]
pub struct IpAddr(pub StdIpAddr);
#[derive(Debug)]
pub enum IpAddrError {
NoIpFound,
InvalidIpFormat(String),
HeaderParseError,
}
impl Responder for IpAddrError {
fn into_response(self) -> crate::types::Response {
match self {
IpAddrError::NoIpFound => (
StatusCode::BAD_REQUEST,
"No valid IP address found in request headers",
)
.into_response(),
IpAddrError::InvalidIpFormat(ip) => (
StatusCode::BAD_REQUEST,
format!("Invalid IP address format: {ip}"),
)
.into_response(),
IpAddrError::HeaderParseError => (
StatusCode::BAD_REQUEST,
"Failed to parse IP address from headers",
)
.into_response(),
}
}
}
impl IpAddr {
pub fn new(addr: StdIpAddr) -> Self {
Self(addr)
}
pub fn inner(&self) -> StdIpAddr {
self.0
}
pub fn is_ipv4(&self) -> bool {
self.0.is_ipv4()
}
pub fn is_ipv6(&self) -> bool {
self.0.is_ipv6()
}
pub fn is_loopback(&self) -> bool {
self.0.is_loopback()
}
pub fn is_private(&self) -> bool {
match self.0 {
StdIpAddr::V4(ipv4) => ipv4.is_private(),
StdIpAddr::V6(ipv6) => {
let segments = ipv6.segments();
(segments[0] & 0xfe00) == 0xfc00 ||
(segments[0] & 0xffc0) == 0xfe80 ||
ipv6.is_loopback()
}
}
}
fn extract_from_headers(headers: &http::HeaderMap) -> Result<Self, IpAddrError> {
let header_names = [
"x-forwarded-for",
"x-real-ip",
"x-client-ip",
"cf-connecting-ip",
"x-forwarded",
"forwarded-for",
"forwarded",
"true-client-ip",
];
for header_name in &header_names {
if let Some(header_value) = headers.get(*header_name)
&& let Ok(header_str) = header_value.to_str()
&& let Some(ip) = Self::parse_ip_from_header(header_str)
{
return Ok(Self(ip));
}
}
Err(IpAddrError::NoIpFound)
}
fn parse_ip_from_header(header_value: &str) -> Option<StdIpAddr> {
for part in header_value.split(',') {
let part = part.trim();
if part.is_empty() {
continue;
}
let ip_part = if let Some(ip_part) = part.strip_prefix("for=") {
ip_part
} else {
part
};
let ip_str = if let Some(colon_pos) = ip_part.rfind(':') {
if ip_part.starts_with('[') && ip_part.contains(']') {
if let Some(bracket_end) = ip_part.find(']') {
&ip_part[1..bracket_end]
} else {
ip_part
}
} else if ip_part.matches(':').count() == 1 {
&ip_part[..colon_pos]
} else {
ip_part
}
} else {
ip_part
};
if let Ok(ip) = StdIpAddr::from_str(ip_str) {
match ip {
StdIpAddr::V4(ipv4) if ipv4.is_loopback() || ipv4.is_private() => continue,
StdIpAddr::V6(ipv6) if ipv6.is_loopback() => continue,
_ => return Some(ip),
}
}
}
None
}
}
impl std::fmt::Display for IpAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<StdIpAddr> for IpAddr {
fn from(addr: StdIpAddr) -> Self {
Self(addr)
}
}
impl From<IpAddr> for StdIpAddr {
fn from(addr: IpAddr) -> Self {
addr.0
}
}
impl<'a> FromRequest<'a> for IpAddr {
type Error = IpAddrError;
fn from_request(
req: &'a mut Request,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
futures_util::future::ready(Self::extract_from_headers(req.headers()))
}
}
impl<'a> FromRequestParts<'a> for IpAddr {
type Error = IpAddrError;
fn from_request_parts(
parts: &'a mut Parts,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
futures_util::future::ready(Self::extract_from_headers(&parts.headers))
}
}