use crate::errors::GovernorError;
use forwarded_header_value::{ForwardedHeaderValue, Identifier};
use http::request::Request;
use http::{header::FORWARDED, HeaderMap};
use std::fmt::Debug;
use std::net::SocketAddr;
use std::{hash::Hash, net::IpAddr};
pub trait KeyExtractor: Clone {
type Key: Clone + Hash + Eq + Debug;
#[cfg(feature = "tracing")]
fn name(&self) -> &'static str;
fn extract<T>(&self, req: &Request<T>) -> Result<Self::Key, GovernorError>;
#[cfg(feature = "tracing")]
fn key_name(&self, _key: &Self::Key) -> Option<String> {
None
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct GlobalKeyExtractor;
impl KeyExtractor for GlobalKeyExtractor {
type Key = ();
#[cfg(feature = "tracing")]
fn name(&self) -> &'static str {
"global"
}
fn extract<T>(&self, _req: &Request<T>) -> Result<Self::Key, GovernorError> {
Ok(())
}
#[cfg(feature = "tracing")]
fn key_name(&self, _key: &Self::Key) -> Option<String> {
None
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PeerIpKeyExtractor;
impl KeyExtractor for PeerIpKeyExtractor {
type Key = IpAddr;
#[cfg(feature = "tracing")]
fn name(&self) -> &'static str {
"peer IP"
}
fn extract<T>(&self, req: &Request<T>) -> Result<Self::Key, GovernorError> {
maybe_connect_info(req).ok_or(GovernorError::UnableToExtractKey)
}
#[cfg(feature = "tracing")]
fn key_name(&self, key: &Self::Key) -> Option<String> {
Some(key.to_string())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SmartIpKeyExtractor;
impl KeyExtractor for SmartIpKeyExtractor {
type Key = IpAddr;
#[cfg(feature = "tracing")]
fn name(&self) -> &'static str {
"smart IP"
}
fn extract<T>(&self, req: &Request<T>) -> Result<Self::Key, GovernorError> {
let headers = req.headers();
maybe_x_forwarded_for(headers)
.or_else(|| maybe_x_real_ip(headers))
.or_else(|| maybe_forwarded(headers))
.or_else(|| maybe_connect_info(req))
.or_else(|| maybe_socket_addr(req))
.ok_or(GovernorError::UnableToExtractKey)
}
#[cfg(feature = "tracing")]
fn key_name(&self, key: &Self::Key) -> Option<String> {
Some(key.to_string())
}
}
const X_REAL_IP: &str = "x-real-ip";
const X_FORWARDED_FOR: &str = "x-forwarded-for";
fn maybe_x_forwarded_for(headers: &HeaderMap) -> Option<IpAddr> {
headers
.get(X_FORWARDED_FOR)
.and_then(|hv| hv.to_str().ok())
.and_then(|s| s.split(',').find_map(|s| s.trim().parse::<IpAddr>().ok()))
}
fn maybe_x_real_ip(headers: &HeaderMap) -> Option<IpAddr> {
headers
.get(X_REAL_IP)
.and_then(|hv| hv.to_str().ok())
.and_then(|s| s.parse::<IpAddr>().ok())
}
fn maybe_forwarded(headers: &HeaderMap) -> Option<IpAddr> {
headers.get_all(FORWARDED).iter().find_map(|hv| {
hv.to_str()
.ok()
.and_then(|s| ForwardedHeaderValue::from_forwarded(s).ok())
.and_then(|f| {
f.iter()
.filter_map(|fs| fs.forwarded_for.as_ref())
.find_map(|ff| match ff {
Identifier::SocketAddr(a) => Some(a.ip()),
Identifier::IpAddr(ip) => Some(*ip),
_ => None,
})
})
})
}
fn maybe_connect_info<T>(req: &Request<T>) -> Option<IpAddr> {
#[cfg(feature = "axum")]
{
req.extensions()
.get::<axum::extract::ConnectInfo<SocketAddr>>()
.map(|addr| addr.ip())
}
#[cfg(not(feature = "axum"))]
None
}
fn maybe_socket_addr<T>(req: &Request<T>) -> Option<IpAddr> {
req.extensions().get::<SocketAddr>().map(|addr| addr.ip())
}