use std::collections::HashSet;
#[must_use]
pub fn is_hop_by_hop(name: &str) -> bool {
name.eq_ignore_ascii_case("connection")
|| name.eq_ignore_ascii_case("keep-alive")
|| name.eq_ignore_ascii_case("proxy-authenticate")
|| name.eq_ignore_ascii_case("proxy-authorization")
|| name.eq_ignore_ascii_case("proxy-connection")
|| name.eq_ignore_ascii_case("te")
|| name.eq_ignore_ascii_case("trailer")
|| name.eq_ignore_ascii_case("transfer-encoding")
|| name.eq_ignore_ascii_case("upgrade")
|| name.eq_ignore_ascii_case("x-forwarded-for")
}
#[must_use]
pub fn connection_header_tokens(connection_header_value: &str) -> HashSet<String> {
connection_header_value
.split(',')
.map(|s| s.trim().to_ascii_lowercase())
.filter(|s| !s.is_empty())
.collect()
}
#[must_use]
pub fn collect_connection_header_names(headers: &[(String, String)]) -> HashSet<String> {
let mut out = HashSet::new();
for (k, v) in headers {
if k.eq_ignore_ascii_case("connection") {
out.extend(connection_header_tokens(v));
}
}
out
}
#[must_use]
pub fn collect_connection_header_names_hyper(headers: &hyper::HeaderMap) -> HashSet<String> {
let mut out = HashSet::new();
for v in headers.get_all(hyper::header::CONNECTION) {
if let Ok(s) = v.to_str() {
out.extend(connection_header_tokens(s));
}
}
out
}
#[must_use]
pub fn should_strip_proxy_header(name: &str, connection_tokens_lower: &HashSet<String>) -> bool {
if is_hop_by_hop(name) {
return true;
}
connection_tokens_lower.contains(&name.to_ascii_lowercase())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detects_common_hop_by_hop() {
assert!(is_hop_by_hop("Connection"));
assert!(is_hop_by_hop("TRANSFER-ENCODING"));
assert!(is_hop_by_hop("Proxy-Connection"));
}
#[test]
fn allows_end_to_end() {
assert!(!is_hop_by_hop("Content-Type"));
assert!(!is_hop_by_hop("Authorization"));
}
#[test]
fn connection_tokens_case_insensitive() {
let t = connection_header_tokens("Foo, BAR, keep-alive ");
assert!(t.contains("foo"));
assert!(t.contains("bar"));
assert!(t.contains("keep-alive"));
}
#[test]
fn strip_respects_connection_list() {
let conn = connection_header_tokens("X-My-Hop");
assert!(should_strip_proxy_header("X-My-Hop", &conn));
assert!(!should_strip_proxy_header("X-Other", &conn));
}
#[test]
fn burp_zap_proxy_headers_are_hop_by_hop() {
assert!(is_hop_by_hop("Proxy-Connection"));
assert!(is_hop_by_hop("Proxy-Authorization"));
assert!(is_hop_by_hop("X-Forwarded-For"));
}
#[test]
fn connection_header_can_strip_custom_tokens() {
let conn = connection_header_tokens("X-Forwarded-For, Custom-Header");
assert!(should_strip_proxy_header("X-Forwarded-For", &conn));
assert!(should_strip_proxy_header("Custom-Header", &conn));
assert!(!should_strip_proxy_header("Content-Type", &conn));
}
#[test]
fn collect_connection_header_names_merges_multiple_headers() {
let headers = vec![
("Connection".to_string(), "keep-alive".to_string()),
("CONNECTION".to_string(), "X-Custom-Hop".to_string()),
("X-Custom-Hop".to_string(), "value".to_string()),
("Content-Type".to_string(), "text/html".to_string()),
];
let conn = collect_connection_header_names(&headers);
assert!(conn.contains("keep-alive"));
assert!(conn.contains("x-custom-hop"));
assert!(should_strip_proxy_header("X-Custom-Hop", &conn));
assert!(!should_strip_proxy_header("Content-Type", &conn));
}
#[test]
fn collect_connection_header_names_empty_list() {
let headers: Vec<(String, String)> = vec![];
let conn = collect_connection_header_names(&headers);
assert!(conn.is_empty());
}
#[test]
fn connection_tokens_empty_value_produces_empty_set() {
let t = connection_header_tokens("");
assert!(t.is_empty(), "empty value must produce no tokens");
}
#[test]
fn connection_tokens_whitespace_only_produces_empty_set() {
let t = connection_header_tokens(" , , ");
assert!(
t.is_empty(),
"whitespace-only comma-separated value must yield no tokens: {t:?}"
);
}
#[test]
fn is_hop_by_hop_te_header() {
assert!(is_hop_by_hop("te"), "TE must be hop-by-hop");
assert!(is_hop_by_hop("TE"), "TE must be hop-by-hop (uppercase)");
}
#[test]
fn is_hop_by_hop_unknown_header_is_not_hop_by_hop() {
assert!(!is_hop_by_hop("X-Request-Id"));
assert!(!is_hop_by_hop("accept-encoding"));
}
#[test]
fn collect_connection_headers_from_hyper_empty_map() {
use hyper::HeaderMap;
let map = HeaderMap::new();
let conn = collect_connection_header_names_hyper(&map);
assert!(
conn.is_empty(),
"no Connection headers in hyper map → empty set"
);
}
#[test]
fn collect_connection_headers_from_hyper_single_value() {
use hyper::HeaderMap;
use hyper::header::CONNECTION;
use hyper::header::HeaderValue;
let mut map = HeaderMap::new();
map.insert(CONNECTION, HeaderValue::from_static("keep-alive, X-Hop"));
let conn = collect_connection_header_names_hyper(&map);
assert!(conn.contains("keep-alive"));
assert!(conn.contains("x-hop"));
}
#[test]
fn strip_all_hop_by_hop_headers_individually() {
let hop_headers = [
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"proxy-connection",
"te",
"trailer",
"transfer-encoding",
"upgrade",
"x-forwarded-for",
];
let empty_conn: std::collections::HashSet<String> = std::collections::HashSet::new();
for h in hop_headers {
assert!(
should_strip_proxy_header(h, &empty_conn),
"{h} must be stripped by hop-by-hop check"
);
}
}
}