use std::{
fmt::{self, Display},
str::FromStr,
sync::OnceLock,
};
use rfc7239::parse as parse_forwarded;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Debug, Clone, Default)]
pub enum ExternalBaseUrl {
#[default]
Auto,
Fixed(String),
}
impl FromStr for ExternalBaseUrl {
type Err = std::io::Error;
fn from_str(value: &str) -> Result<Self, Self::Err> {
if value.trim().eq_ignore_ascii_case("auto") {
Ok(Self::Auto)
} else {
Ok(Self::Fixed(value.trim_end_matches('/').to_string()))
}
}
}
impl Display for ExternalBaseUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ExternalBaseUrl::Auto => write!(f, "auto"),
ExternalBaseUrl::Fixed(url) => write!(f, "{}", url),
}
}
}
impl<'de> Deserialize<'de> for ExternalBaseUrl {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
impl Serialize for ExternalBaseUrl {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl ExternalBaseUrl {
pub fn resolve_url(
&self,
headers: &http::HeaderMap,
fallback_host: &str,
fallback_port: u16,
) -> Result<url::Url, url::ParseError> {
match self {
ExternalBaseUrl::Fixed(url) => url::Url::parse(url),
ExternalBaseUrl::Auto => url::Url::parse(&infer_external_base_url_from_headers(
headers,
fallback_host,
fallback_port,
)),
}
}
}
static AUTHORITY_HEADER_NAME: OnceLock<Option<http::HeaderName>> = OnceLock::new();
fn authority_header_name() -> Option<&'static http::HeaderName> {
AUTHORITY_HEADER_NAME
.get_or_init(|| http::HeaderName::from_bytes(b":authority").ok())
.as_ref()
}
pub fn resolve_external_base_url(
config: &ExternalBaseUrl,
headers: &http::HeaderMap,
fallback_host: &str,
fallback_port: u16,
) -> String {
match config {
ExternalBaseUrl::Fixed(url) => url.clone(),
ExternalBaseUrl::Auto => {
infer_external_base_url_from_headers(headers, fallback_host, fallback_port)
}
}
}
fn infer_external_base_url_from_headers(
headers: &http::HeaderMap,
fallback_host: &str,
fallback_port: u16,
) -> String {
let sources: [(Option<String>, Option<String>); 3] = [
try_forwarded(headers),
try_x_forwarded(headers),
try_host_header(headers),
];
let host_from_headers = sources.iter().find_map(|(h, _)| h.clone());
let host = host_from_headers
.clone()
.unwrap_or_else(|| format_fallback_host(fallback_host, fallback_port));
let protocol = sources
.iter()
.find_map(|(_, p)| p.clone())
.or_else(|| {
host_from_headers
.as_ref()
.map(|h| infer_protocol_from_host(h).to_string())
})
.unwrap_or_else(|| "http".to_string());
format!("{}://{}", protocol, host)
}
fn try_forwarded(headers: &http::HeaderMap) -> (Option<String>, Option<String>) {
let value = match headers
.get(http::header::FORWARDED)
.and_then(|v| v.to_str().ok())
{
Some(v) => v,
None => return (None, None),
};
let mut nodes = parse_forwarded(value);
let node = match nodes.next().and_then(|r| r.ok()) {
Some(n) => n,
None => return (None, None),
};
let host = node.host.map(|s| s.trim_matches('"').to_string());
let protocol = node.protocol.map(|s| s.trim_matches('"').to_string());
(host, protocol)
}
fn try_x_forwarded(headers: &http::HeaderMap) -> (Option<String>, Option<String>) {
let host = headers
.get("x-forwarded-host")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
let protocol = headers
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
(host, protocol)
}
fn try_host_header(headers: &http::HeaderMap) -> (Option<String>, Option<String>) {
let host = headers
.get(http::header::HOST)
.or_else(|| authority_header_name().and_then(|name| headers.get(name)))
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
(host, None)
}
fn infer_protocol_from_host(host: &str) -> &'static str {
if is_loopback_host(host) {
"http"
} else {
"https"
}
}
fn format_fallback_host(host: &str, port: u16) -> String {
if is_default_port("http", port) {
host.to_string()
} else {
format!("{}:{}", host, port)
}
}
fn is_default_port(proto: &str, port: u16) -> bool {
matches!((proto, port), ("http", 80) | ("https", 443))
}
fn is_loopback_host(host: &str) -> bool {
let hostname = host.split(':').next().unwrap_or(host);
matches!(hostname, "localhost" | "127.0.0.1" | "::1" | "[::1]")
}
#[cfg(test)]
mod tests {
use http::HeaderMap;
use super::*;
fn make_fallback() -> (&'static str, u16) {
("0.0.0.0", 7021)
}
#[test]
fn fixed_config_ignores_headers() {
let config = ExternalBaseUrl::Fixed("https://fixed.example.com".to_string());
let headers = HeaderMap::new();
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://fixed.example.com"
);
}
#[test]
fn auto_with_forwarded_header() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert(
"forwarded",
"for=192.0.2.60;proto=https;host=example.com"
.parse()
.unwrap(),
);
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://example.com"
);
}
#[test]
fn auto_with_forwarded_header_custom_port() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert(
"forwarded",
"proto=https;host=example.com:8443".parse().unwrap(),
);
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://example.com:8443"
);
}
#[test]
fn auto_with_forwarded_header_no_proto() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert("forwarded", "host=example.com".parse().unwrap());
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://example.com"
);
}
#[test]
fn auto_with_x_forwarded_headers() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-host", "proxy.example.com".parse().unwrap());
headers.insert("x-forwarded-proto", "https".parse().unwrap());
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://proxy.example.com"
);
}
#[test]
fn auto_with_x_forwarded_host_only() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-host", "proxy.example.com".parse().unwrap());
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://proxy.example.com"
);
}
#[test]
fn auto_with_host_header() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert(http::header::HOST, "myhost.example.com".parse().unwrap());
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://myhost.example.com"
);
}
#[test]
fn auto_with_localhost_host_header() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert(http::header::HOST, "localhost:3000".parse().unwrap());
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"http://localhost:3000"
);
}
#[test]
fn auto_fallback_to_bind_address() {
let config = ExternalBaseUrl::Auto;
let headers = HeaderMap::new();
assert_eq!(
resolve_external_base_url(&config, &headers, "0.0.0.0", 7021),
"http://0.0.0.0:7021"
);
}
#[test]
fn auto_fallback_default_port() {
let config = ExternalBaseUrl::Auto;
let headers = HeaderMap::new();
assert_eq!(
resolve_external_base_url(&config, &headers, "0.0.0.0", 80),
"http://0.0.0.0"
);
}
#[test]
fn forwarded_takes_priority_over_x_forwarded() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert(
"forwarded",
"proto=https;host=rfc.example.com".parse().unwrap(),
);
headers.insert(
"x-forwarded-host",
"nonstandard.example.com".parse().unwrap(),
);
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://rfc.example.com"
);
}
#[test]
fn x_forwarded_takes_priority_over_host() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-host", "proxy.example.com".parse().unwrap());
headers.insert("x-forwarded-proto", "https".parse().unwrap());
headers.insert(http::header::HOST, "internal.example.com".parse().unwrap());
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://proxy.example.com"
);
}
#[test]
fn forwarded_with_quoted_values() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert(
"forwarded",
"for=\"192.0.2.60\";proto=https;host=\"quoted.example.com\""
.parse()
.unwrap(),
);
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://quoted.example.com"
);
}
#[test]
fn forwarded_chain_uses_first_entry() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert(
"forwarded",
"proto=https;host=first.example.com, proto=http;host=second.example.com"
.parse()
.unwrap(),
);
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://first.example.com"
);
}
#[test]
fn authority_used_when_host_absent_if_supported() {
let name = match authority_header_name() {
Some(n) => n.clone(),
None => return, };
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert(name, "h2.example.com".parse().unwrap());
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://h2.example.com"
);
}
#[test]
fn host_takes_priority_over_authority() {
let config = ExternalBaseUrl::Auto;
let mut headers = HeaderMap::new();
headers.insert(http::header::HOST, "host.example.com".parse().unwrap());
if let Some(name) = authority_header_name() {
headers.insert(name.clone(), "authority.example.com".parse().unwrap());
}
let (host, port) = make_fallback();
assert_eq!(
resolve_external_base_url(&config, &headers, host, port),
"https://host.example.com"
);
}
}