use axum::{
body::Body,
http::{HeaderValue, Method, Request, Response, StatusCode, header::HeaderName},
};
use std::{
collections::HashSet,
future::Future,
pin::Pin,
str::FromStr,
task::{Context, Poll},
};
static HOP_BY_HOP_HEADERS: &[&str] = &[
"connection",
"keep-alive",
"proxy-connection",
"transfer-encoding",
"te",
"trailer",
"upgrade",
];
use tower::{Layer, Service};
#[derive(Debug, Clone)]
#[allow(dead_code)] struct ViaEntry {
protocol: String, pseudonym: String, port: Option<String>, comment: Option<String>, }
impl ViaEntry {
fn parse(entry: &str) -> Option<Self> {
let mut parts = entry.split_whitespace();
let protocol = parts.next()?.to_string();
let pseudonym_part = parts.next()?;
let (pseudonym, port) = if let Some(colon_idx) = pseudonym_part.find(':') {
let (name, port) = pseudonym_part.split_at(colon_idx);
(name.to_string(), Some(port[1..].to_string()))
} else {
(pseudonym_part.to_string(), None)
};
let comment = entry
.find('(')
.and_then(|start| entry.rfind(')').map(|end| entry[start..=end].to_string()));
Some(ViaEntry {
protocol,
pseudonym,
port,
comment,
})
}
}
impl std::fmt::Display for ViaEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} {}", self.protocol, self.pseudonym)?;
if let Some(port) = &self.port {
write!(f, ":{port}")?;
}
if let Some(comment) = &self.comment {
write!(f, " {comment}")?;
}
Ok(())
}
}
#[allow(dead_code)] fn parse_via_header(header: &str) -> Vec<ViaEntry> {
header
.split(',')
.filter_map(|entry| ViaEntry::parse(entry.trim()))
.collect()
}
#[derive(Clone, Debug)]
pub struct Rfc9110Config {
pub server_names: Option<HashSet<String>>,
pub pseudonym: Option<String>,
pub combine_via: bool,
pub preserve_websocket_headers: bool,
}
impl Default for Rfc9110Config {
fn default() -> Self {
Self {
server_names: None,
pseudonym: None,
combine_via: true,
preserve_websocket_headers: true,
}
}
}
#[derive(Clone)]
pub struct Rfc9110Layer {
config: Rfc9110Config,
}
impl Default for Rfc9110Layer {
fn default() -> Self {
Self::new()
}
}
impl Rfc9110Layer {
pub fn new() -> Self {
Self {
config: Rfc9110Config::default(),
}
}
pub fn with_config(config: Rfc9110Config) -> Self {
Self { config }
}
}
impl<S> Layer<S> for Rfc9110Layer {
type Service = Rfc9110<S>;
fn layer(&self, inner: S) -> Self::Service {
Rfc9110 {
inner,
config: self.config.clone(),
}
}
}
#[derive(Clone)]
pub struct Rfc9110<S> {
inner: S,
config: Rfc9110Config,
}
impl<S> Service<Request<Body>> for Rfc9110<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut request: Request<Body>) -> Self::Future {
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
let config = self.config.clone();
Box::pin(async move {
if let Some(response) = detect_loop(&request, &config) {
return Ok(response);
}
let original_max_forwards =
if request.method() != Method::TRACE && request.method() != Method::OPTIONS {
request.headers().get(http::header::MAX_FORWARDS).cloned()
} else {
None
};
if let Some(response) = process_max_forwards(&mut request) {
return Ok(response);
}
let max_forwards = request.headers().get(http::header::MAX_FORWARDS).cloned();
let is_websocket =
config.preserve_websocket_headers && is_websocket_upgrade_request(&request);
process_connection_header(&mut request, is_websocket);
let via_header = add_via_header(&mut request, &config);
let mut response = inner.call(request).await?;
let is_websocket_response =
is_websocket && response.status() == StatusCode::SWITCHING_PROTOCOLS;
process_response_headers(&mut response, is_websocket_response);
if let Some(via) = via_header {
if config.pseudonym.is_some() && !config.combine_via {
response
.headers_mut()
.insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
} else {
response.headers_mut().insert(http::header::VIA, via);
}
}
if let Some(max_forwards) = original_max_forwards {
response
.headers_mut()
.insert(http::header::MAX_FORWARDS, max_forwards);
} else if let Some(max_forwards) = max_forwards {
response
.headers_mut()
.insert(http::header::MAX_FORWARDS, max_forwards);
}
Ok(response)
})
}
}
fn detect_loop(request: &Request<Body>, config: &Rfc9110Config) -> Option<Response<Body>> {
if let Some(server_names) = &config.server_names
&& let Some(host) = request.uri().host()
&& server_names.contains(host)
{
let mut response = Response::new(Body::empty());
*response.status_mut() = StatusCode::LOOP_DETECTED;
return Some(response);
}
if let Some(via) = request.headers().get(http::header::VIA)
&& let Ok(via_str) = via.to_str()
{
let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
let via_entries: Vec<&str> = via_str.split(',').map(str::trim).collect();
for entry in via_entries {
let parts: Vec<&str> = entry.split_whitespace().collect();
if parts.len() >= 2 && parts[1] == pseudonym {
let mut response = Response::new(Body::empty());
*response.status_mut() = StatusCode::LOOP_DETECTED;
return Some(response);
}
}
}
None
}
fn process_max_forwards(request: &mut Request<Body>) -> Option<Response<Body>> {
let method = request.method();
if let Some(max_forwards) = request.headers().get(http::header::MAX_FORWARDS) {
if *method != Method::TRACE && *method != Method::OPTIONS {
return None;
}
if let Ok(value_str) = max_forwards.to_str() {
if let Ok(value) = value_str.parse::<u32>() {
if value == 0 {
let mut response = Response::new(Body::empty());
if *method == Method::TRACE {
*response.body_mut() = Body::from(format!("{request:?}"));
} else {
response.headers_mut().insert(
http::header::ALLOW,
HeaderValue::from_static("GET, HEAD, OPTIONS, TRACE"),
);
}
*response.status_mut() = StatusCode::OK;
Some(response)
} else {
let new_value = value - 1;
request.headers_mut().insert(
http::header::MAX_FORWARDS,
HeaderValue::from_str(&new_value.to_string()).unwrap(),
);
None
}
} else {
None }
} else {
None }
} else {
None }
}
static WEBSOCKET_HEADERS: &[&str] = &["connection", "upgrade"];
fn process_connection_header(request: &mut Request<Body>, preserve_websocket: bool) {
let mut headers_to_remove = HashSet::new();
for &name in HOP_BY_HOP_HEADERS {
if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
continue;
}
headers_to_remove.insert(HeaderName::from_static(name));
}
if let Some(connection) = request
.headers()
.get_all(http::header::CONNECTION)
.iter()
.next()
&& let Ok(connection_str) = connection.to_str()
{
for header in connection_str.split(',') {
let header = header.trim();
if preserve_websocket
&& WEBSOCKET_HEADERS
.iter()
.any(|h| header.eq_ignore_ascii_case(h))
{
continue;
}
if let Ok(header_name) = HeaderName::from_str(header)
&& (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
{
headers_to_remove.insert(header_name);
}
}
}
let headers_to_remove = headers_to_remove; let headers_to_remove: Vec<_> = request
.headers()
.iter()
.filter(|(k, _)| {
headers_to_remove
.iter()
.any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
})
.map(|(k, _)| k.clone())
.collect();
for header in headers_to_remove {
request.headers_mut().remove(&header);
}
}
fn add_via_header(request: &mut Request<Body>, config: &Rfc9110Config) -> Option<HeaderValue> {
let protocol_version = match request.version() {
http::Version::HTTP_09 => "0.9",
http::Version::HTTP_10 => "1.0",
http::Version::HTTP_11 => "1.1",
http::Version::HTTP_2 => "2.0",
http::Version::HTTP_3 => "3.0",
_ => "1.1", };
let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
if config.pseudonym.is_some() && !config.combine_via {
let via = HeaderValue::from_static("1.1 firewall");
request.headers_mut().insert(http::header::VIA, via.clone());
return Some(via);
}
let mut via_values = Vec::new();
if let Some(existing_via) = request.headers().get(http::header::VIA)
&& let Ok(existing_via_str) = existing_via.to_str()
{
if config.combine_via && config.pseudonym.is_some() {
let entries: Vec<_> = existing_via_str.split(',').map(|s| s.trim()).collect();
let all_same_protocol = entries.iter().all(|s| s.starts_with(protocol_version));
if all_same_protocol {
let via = HeaderValue::from_str(&format!(
"{} {}",
protocol_version,
config.pseudonym.as_ref().unwrap()
))
.ok()?;
request.headers_mut().insert(http::header::VIA, via.clone());
return Some(via);
}
}
via_values.extend(existing_via_str.split(',').map(|s| s.trim().to_string()));
}
let new_value = format!("{protocol_version} {pseudonym}");
via_values.push(new_value);
let combined_via = via_values.join(", ");
let via = HeaderValue::from_str(&combined_via).ok()?;
request.headers_mut().insert(http::header::VIA, via.clone());
Some(via)
}
fn process_response_headers(response: &mut Response<Body>, preserve_websocket: bool) {
let mut headers_to_remove = HashSet::new();
for &name in HOP_BY_HOP_HEADERS {
if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
continue;
}
headers_to_remove.insert(HeaderName::from_static(name));
}
if let Some(connection) = response
.headers()
.get_all(http::header::CONNECTION)
.iter()
.next()
&& let Ok(connection_str) = connection.to_str()
{
for header in connection_str.split(',') {
let header = header.trim();
if preserve_websocket
&& WEBSOCKET_HEADERS
.iter()
.any(|h| header.eq_ignore_ascii_case(h))
{
continue;
}
if let Ok(header_name) = HeaderName::from_str(header)
&& (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
{
headers_to_remove.insert(header_name);
}
}
}
let headers_to_remove = headers_to_remove; let headers_to_remove: Vec<_> = response
.headers()
.iter()
.filter(|(k, _)| {
headers_to_remove
.iter()
.any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
})
.map(|(k, _)| k.clone())
.collect();
for header in headers_to_remove {
response.headers_mut().remove(&header);
}
if let Some(via) = response.headers().get(http::header::VIA)
&& let Ok(via_str) = via.to_str()
&& via_str.contains("firewall")
{
response
.headers_mut()
.insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
}
}
fn is_hop_by_hop_header(name: &HeaderName) -> bool {
HOP_BY_HOP_HEADERS
.iter()
.any(|h| name.as_str().eq_ignore_ascii_case(h))
|| name.as_str().eq_ignore_ascii_case("via")
}
fn is_end_to_end_header(name: &HeaderName) -> bool {
matches!(
name.as_str(),
"cache-control"
| "authorization"
| "content-length"
| "content-type"
| "content-encoding"
| "accept"
| "accept-encoding"
| "accept-language"
| "range"
| "cookie"
| "set-cookie"
| "etag"
)
}
fn is_websocket_upgrade_request(request: &Request<Body>) -> bool {
request.headers().contains_key("sec-websocket-key")
&& request.headers().contains_key("sec-websocket-version")
}