use crate::{Headers, KnownHeaderName, Method, headers::hpack::FieldSection};
use std::borrow::Cow;
pub(super) fn authority_matches_host(authority: &str, host: &str, scheme: Option<&str>) -> bool {
let (authority_host, authority_port) = split_host_port(authority);
let (host_host, host_port) = split_host_port(host);
if !authority_host.eq_ignore_ascii_case(host_host) {
return false;
}
let default_port = match scheme {
Some("http" | "ws") => Some("80"),
Some("https" | "wss") => Some("443"),
_ => None,
};
authority_port.or(default_port) == host_port.or(default_port)
}
fn split_host_port(authority: &str) -> (&str, Option<&str>) {
let authority = authority.split_once('@').map_or(authority, |(_, h)| h);
match authority.strip_prefix('[') {
Some(rest) => match rest.split_once(']') {
Some((host, after)) => (host, after.strip_prefix(':')),
None => (authority, None),
},
None => match authority.split_once(':') {
Some((host, port)) => (host, Some(port)),
None => (authority, None),
},
}
}
#[cfg(test)]
mod tests {
use super::authority_matches_host;
#[test]
fn authority_host_equivalence() {
assert!(authority_matches_host(
"example.com:8080",
"example.com:8080",
Some("http")
));
assert!(authority_matches_host(
"example.com:443",
"example.com",
Some("https")
));
assert!(authority_matches_host(
"example.com",
"example.com:443",
Some("https")
));
assert!(authority_matches_host(
"example.com:80",
"example.com",
Some("http")
));
assert!(authority_matches_host(
"Example.COM",
"example.com",
Some("http")
));
assert!(authority_matches_host(
"user@example.com",
"example.com",
Some("http")
));
assert!(authority_matches_host(
"[::1]:8080",
"[::1]:8080",
Some("http")
));
assert!(authority_matches_host("[::1]", "[::1]:443", Some("https")));
assert!(!authority_matches_host(
"example.com",
"example.com:8080",
Some("http")
));
assert!(!authority_matches_host(
"example.com:8080",
"example.com:8081",
Some("http")
));
assert!(!authority_matches_host(
"example.com",
"evil.example.com",
Some("http")
));
assert!(!authority_matches_host(
"example.com:443",
"example.com",
None
));
}
}
pub(super) const H1_ONLY_HEADERS: [KnownHeaderName; 5] = [
KnownHeaderName::Connection,
KnownHeaderName::KeepAlive,
KnownHeaderName::ProxyConnection,
KnownHeaderName::TransferEncoding,
KnownHeaderName::Upgrade,
];
pub(crate) struct ValidatedRequest {
pub method: Method,
pub path: Cow<'static, str>,
pub authority: Option<Cow<'static, str>>,
pub scheme: Option<Cow<'static, str>>,
pub protocol: Option<Cow<'static, str>>,
pub request_headers: Headers,
}
impl ValidatedRequest {
pub(super) fn new(mut field_section: FieldSection<'static>) -> Option<ValidatedRequest> {
let pseudo_headers = field_section.pseudo_headers_mut();
if pseudo_headers.status().is_some() {
return None;
}
let method = pseudo_headers.take_method();
let path = pseudo_headers.take_path();
let authority = pseudo_headers.take_authority();
let scheme = pseudo_headers.take_scheme();
let protocol = pseudo_headers.take_protocol();
let request_headers = field_section.into_headers().into_owned();
if let Some(host) = request_headers.get_str(KnownHeaderName::Host)
&& let Some(authority) = &authority
&& !authority_matches_host(authority, host, scheme.as_deref())
{
return None;
}
if H1_ONLY_HEADERS
.into_iter()
.any(|name| request_headers.has_header(name))
{
return None;
}
if crate::util::validate_content_length(
request_headers.get_values(KnownHeaderName::ContentLength),
)
.is_err()
{
return None;
}
let method = method?;
if method != Method::Connect && scheme.is_none() {
return None;
}
let path = match (method, path) {
(_, Some(path)) if !path.is_empty() => path,
(Method::Connect, _) => Cow::Borrowed("/"),
_ => return None,
};
if method == Method::Connect && authority.is_none() {
return None;
}
if method != Method::Connect
&& matches!(scheme.as_deref(), Some("http" | "https"))
&& authority.is_none()
&& request_headers.get_str(KnownHeaderName::Host).is_none()
{
return None;
}
match request_headers.get_str(KnownHeaderName::Te) {
None | Some("trailers") => {}
_ => return None,
}
Some(ValidatedRequest {
method,
path,
authority,
scheme,
protocol,
request_headers,
})
}
}