use std::net::IpAddr;
use http::header::HOST;
use ipnet::IpNet;
use tracing::warn;
use crate::error::Result;
use crate::extract::{peer_addr_from_extensions, RequestScheme};
use crate::middleware::{DuplicatePolicy, Middleware, Next, Request};
use crate::response::Response;
use crate::router::BoxFuture;
const FORWARDED_HOST: &str = "x-forwarded-host";
const FORWARDED_PROTO: &str = "x-forwarded-proto";
pub struct ProxyHeaders {
trusted_ips: Vec<IpAddr>,
trusted_cidrs: Vec<IpNet>,
}
impl ProxyHeaders {
pub fn new() -> Self {
Self {
trusted_ips: Vec::new(),
trusted_cidrs: Vec::new(),
}
}
pub fn trust_proxy(mut self, addr: IpAddr) -> Self {
self.trusted_ips.push(addr);
self
}
pub fn trust_cidr(mut self, network: IpNet) -> Self {
self.trusted_cidrs.push(network);
self
}
pub fn trust_loopback(self) -> Self {
self.trust_proxy(IpAddr::from([127, 0, 0, 1]))
.trust_proxy(IpAddr::from(std::net::Ipv6Addr::LOCALHOST))
}
fn is_trusted(&self, request: &Request) -> bool {
let Some(peer) = peer_addr_from_extensions(request.extensions()) else {
return false;
};
self.trusted_ips.iter().any(|addr| *addr == peer.ip())
|| self
.trusted_cidrs
.iter()
.any(|network| network.contains(&peer.ip()))
}
fn forwarded_value<'a>(request: &'a Request, name: &'static str) -> Option<&'a str> {
request
.headers()
.get(name)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.split(',').next())
.map(str::trim)
.filter(|value| !value.is_empty())
}
}
impl Default for ProxyHeaders {
fn default() -> Self {
Self::new()
}
}
impl Middleware for ProxyHeaders {
fn handle(&self, mut request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
if !self.is_trusted(&request) {
return next.run(request);
}
if let Some(forwarded_host) = Self::forwarded_value(&request, FORWARDED_HOST) {
if let Ok(value) = http::HeaderValue::from_str(forwarded_host) {
request.headers_mut().insert(HOST, value);
}
}
if let Some(forwarded_proto) = Self::forwarded_value(&request, FORWARDED_PROTO) {
let scheme = if forwarded_proto.eq_ignore_ascii_case("https") {
Some(RequestScheme::Https)
} else if forwarded_proto.eq_ignore_ascii_case("http") {
Some(RequestScheme::Http)
} else {
None
};
if let Some(scheme) = scheme {
request.extensions_mut().insert(scheme);
} else {
warn!("tork: ignoring unsupported X-Forwarded-Proto value");
}
}
next.run(request)
}
fn name(&self) -> &'static str {
"ProxyHeaders"
}
fn duplicate_policy(&self) -> DuplicatePolicy {
DuplicatePolicy::Reject
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builtin_metadata_is_stable() {
let middleware = ProxyHeaders::new();
assert_eq!(middleware.name(), "ProxyHeaders");
assert_eq!(middleware.duplicate_policy(), DuplicatePolicy::Reject);
}
#[test]
fn default_impl_uses_new() {
let middleware: ProxyHeaders = Default::default();
assert!(middleware.trusted_ips.is_empty());
assert!(middleware.trusted_cidrs.is_empty());
}
#[test]
fn trust_builders_register_expected_networks() {
let middleware = ProxyHeaders::new()
.trust_proxy(IpAddr::from([10, 0, 0, 1]))
.trust_cidr("10.0.0.0/24".parse().unwrap())
.trust_loopback();
assert!(middleware
.trusted_ips
.contains(&IpAddr::from([10, 0, 0, 1])));
assert!(middleware
.trusted_cidrs
.contains(&"10.0.0.0/24".parse().unwrap()));
assert!(middleware
.trusted_ips
.contains(&IpAddr::from([127, 0, 0, 1])));
}
}