use crate::error::{Error, ErrorCode};
use crate::hpack::HeaderField;
pub mod pseudo_headers {
pub const METHOD: &[u8] = b":method";
pub const SCHEME: &[u8] = b":scheme";
pub const AUTHORITY: &[u8] = b":authority";
pub const PATH: &[u8] = b":path";
pub const STATUS: &[u8] = b":status";
pub const PROTOCOL: &[u8] = b":protocol";
}
pub mod forbidden_headers {
pub const CONNECTION: &[u8] = b"connection";
pub const KEEP_ALIVE: &[u8] = b"keep-alive";
pub const PROXY_CONNECTION: &[u8] = b"proxy-connection";
pub const TRANSFER_ENCODING: &[u8] = b"transfer-encoding";
pub const UPGRADE: &[u8] = b"upgrade";
}
pub const TE_ALLOWED_VALUE: &[u8] = b"trailers";
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ValidationError {
MissingPseudoHeader(&'static str),
DuplicatePseudoHeader(&'static str),
PseudoHeaderAfterRegular,
InvalidPseudoHeader(Vec<u8>),
ForbiddenHeader(Vec<u8>),
InvalidTeHeader,
EmptyPath,
AsteriskPathOnNonOptions,
ConnectWithPathOrScheme,
ConnectInvalidAuthority,
NonConnectMissingPathOrScheme,
ExtendedConnectMissingSchemeOrPath,
ProtocolOnNonConnect,
HostAuthorityMismatch,
MissingAuthority,
InvalidHeaderName(Vec<u8>),
InvalidHeaderValue(Vec<u8>),
InvalidStatusCode(Vec<u8>),
AuthorityWithUserinfo,
InvalidProtocolValue(Vec<u8>),
InvalidMethodValue(Vec<u8>),
InvalidSchemeValue(Vec<u8>),
InvalidPathValue(Vec<u8>),
}
impl std::fmt::Display for ValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingPseudoHeader(name) => write!(f, "missing required pseudo-header: {name}"),
Self::DuplicatePseudoHeader(name) => write!(f, "duplicate pseudo-header: {name}"),
Self::PseudoHeaderAfterRegular => {
write!(f, "pseudo-header after regular header")
}
Self::InvalidPseudoHeader(name) => {
write!(
f,
"invalid pseudo-header: {}",
String::from_utf8_lossy(name)
)
}
Self::ForbiddenHeader(name) => {
write!(f, "forbidden header: {}", String::from_utf8_lossy(name))
}
Self::InvalidTeHeader => write!(f, "TE header with value other than 'trailers'"),
Self::EmptyPath => write!(f, ":path is empty"),
Self::AsteriskPathOnNonOptions => {
write!(f, ":path '*' is only allowed for OPTIONS requests")
}
Self::ConnectWithPathOrScheme => {
write!(f, "CONNECT request must not include :path or :scheme")
}
Self::ConnectInvalidAuthority => {
write!(
f,
"CONNECT :authority must be in authority-form (host:port)"
)
}
Self::NonConnectMissingPathOrScheme => {
write!(f, "non-CONNECT request must include :path and :scheme")
}
Self::ExtendedConnectMissingSchemeOrPath => {
write!(f, "Extended CONNECT request must include :scheme and :path")
}
Self::ProtocolOnNonConnect => {
write!(f, ":protocol is only allowed with CONNECT method")
}
Self::HostAuthorityMismatch => {
write!(f, "Host header differs from :authority pseudo-header")
}
Self::MissingAuthority => {
write!(
f,
"http/https request must include :authority or Host header"
)
}
Self::InvalidHeaderName(name) => {
write!(f, "invalid header name: {}", String::from_utf8_lossy(name))
}
Self::InvalidHeaderValue(value) => {
write!(
f,
"invalid header value: {}",
String::from_utf8_lossy(value)
)
}
Self::InvalidStatusCode(value) => {
write!(f, "invalid status code: {}", String::from_utf8_lossy(value))
}
Self::AuthorityWithUserinfo => {
write!(f, ":authority must not include userinfo")
}
Self::InvalidProtocolValue(value) => {
write!(
f,
"invalid :protocol value: {}",
String::from_utf8_lossy(value)
)
}
Self::InvalidMethodValue(value) => {
write!(
f,
"invalid :method value: {}",
String::from_utf8_lossy(value)
)
}
Self::InvalidSchemeValue(value) => {
write!(
f,
"invalid :scheme value: {}",
String::from_utf8_lossy(value)
)
}
Self::InvalidPathValue(value) => {
write!(f, "invalid :path value: {}", String::from_utf8_lossy(value))
}
}
}
}
impl std::error::Error for ValidationError {}
fn strip_default_port<'a>(authority: &'a [u8], scheme: Option<&[u8]>) -> &'a [u8] {
let default_port: &[u8] = match scheme {
Some(s) if s.eq_ignore_ascii_case(b"http") => b":80",
Some(s) if s.eq_ignore_ascii_case(b"https") => b":443",
_ => return authority,
};
authority.strip_suffix(default_port).unwrap_or(authority)
}
pub fn validate_request_headers(headers: &[HeaderField]) -> Result<(), Error> {
let mut seen_method = false;
let mut seen_scheme = false;
let mut seen_authority = false;
let mut seen_path = false;
let mut seen_protocol = false;
let mut past_pseudo = false;
let mut method: Option<&[u8]> = None;
let mut scheme_value: Option<&[u8]> = None;
let mut authority_value: Option<&[u8]> = None;
let mut host_value: Option<&[u8]> = None;
let mut path_value: Option<&[u8]> = None;
for header in headers {
let name = &header.name;
if name.starts_with(b":") {
if past_pseudo {
return Err(malformed_error(ValidationError::PseudoHeaderAfterRegular));
}
if name == pseudo_headers::METHOD {
if seen_method {
return Err(malformed_error(ValidationError::DuplicatePseudoHeader(
":method",
)));
}
if !is_valid_token(&header.value) {
return Err(malformed_error(ValidationError::InvalidMethodValue(
header.value.clone(),
)));
}
seen_method = true;
method = Some(&header.value);
} else if name == pseudo_headers::SCHEME {
if seen_scheme {
return Err(malformed_error(ValidationError::DuplicatePseudoHeader(
":scheme",
)));
}
if !is_valid_scheme(&header.value) {
return Err(malformed_error(ValidationError::InvalidSchemeValue(
header.value.clone(),
)));
}
seen_scheme = true;
scheme_value = Some(&header.value);
} else if name == pseudo_headers::AUTHORITY {
if seen_authority {
return Err(malformed_error(ValidationError::DuplicatePseudoHeader(
":authority",
)));
}
seen_authority = true;
authority_value = Some(&header.value);
} else if name == pseudo_headers::PATH {
if seen_path {
return Err(malformed_error(ValidationError::DuplicatePseudoHeader(
":path",
)));
}
if header.value.is_empty() {
return Err(malformed_error(ValidationError::EmptyPath));
}
path_value = Some(&header.value);
seen_path = true;
} else if name == pseudo_headers::PROTOCOL {
if seen_protocol {
return Err(malformed_error(ValidationError::DuplicatePseudoHeader(
":protocol",
)));
}
if !is_valid_token(&header.value) {
return Err(malformed_error(ValidationError::InvalidProtocolValue(
header.value.clone(),
)));
}
seen_protocol = true;
} else if name == pseudo_headers::STATUS {
return Err(malformed_error(ValidationError::InvalidPseudoHeader(
name.clone(),
)));
} else {
return Err(malformed_error(ValidationError::InvalidPseudoHeader(
name.clone(),
)));
}
validate_header_value_chars(&header.value)?;
} else {
past_pseudo = true;
validate_header_name_chars(name)?;
validate_header_value_chars(&header.value)?;
validate_forbidden_header_for_request(name, &header.value)?;
if name.eq_ignore_ascii_case(b"host") {
host_value = Some(&header.value);
}
}
}
if let (Some(authority), Some(host)) = (authority_value, host_value) {
let norm_authority = strip_default_port(authority, scheme_value);
let norm_host = strip_default_port(host, scheme_value);
if !norm_authority.eq_ignore_ascii_case(norm_host) {
return Err(malformed_error(ValidationError::HostAuthorityMismatch));
}
}
if !seen_method {
return Err(malformed_error(ValidationError::MissingPseudoHeader(
":method",
)));
}
if let Some(authority) = authority_value
&& authority.contains(&b'@')
{
let is_http_scheme = scheme_value
.is_some_and(|s| s.eq_ignore_ascii_case(b"http") || s.eq_ignore_ascii_case(b"https"));
let is_connect = method == Some(b"CONNECT");
if is_http_scheme || is_connect {
return Err(malformed_error(ValidationError::AuthorityWithUserinfo));
}
}
if method == Some(b"CONNECT") {
if seen_protocol {
if !seen_scheme {
return Err(malformed_error(
ValidationError::ExtendedConnectMissingSchemeOrPath,
));
}
if !seen_path {
return Err(malformed_error(
ValidationError::ExtendedConnectMissingSchemeOrPath,
));
}
if !seen_authority {
return Err(malformed_error(ValidationError::MissingPseudoHeader(
":authority",
)));
}
} else {
if seen_path || seen_scheme {
return Err(malformed_error(ValidationError::ConnectWithPathOrScheme));
}
if !seen_authority {
return Err(malformed_error(ValidationError::MissingPseudoHeader(
":authority",
)));
}
if let Some(authority) = authority_value
&& !is_valid_connect_authority(authority)
{
return Err(malformed_error(ValidationError::ConnectInvalidAuthority));
}
}
} else {
if seen_protocol {
return Err(malformed_error(ValidationError::ProtocolOnNonConnect));
}
if !seen_scheme {
return Err(malformed_error(ValidationError::MissingPseudoHeader(
":scheme",
)));
}
if !seen_path {
return Err(malformed_error(ValidationError::MissingPseudoHeader(
":path",
)));
}
if path_value == Some(b"*") && method != Some(b"OPTIONS") {
return Err(malformed_error(ValidationError::AsteriskPathOnNonOptions));
}
if let Some(scheme) = scheme_value
&& (scheme.eq_ignore_ascii_case(b"http") || scheme.eq_ignore_ascii_case(b"https"))
&& let Some(path) = path_value
&& path != b"*"
&& !path.starts_with(b"/")
{
return Err(malformed_error(ValidationError::InvalidPathValue(
path.to_vec(),
)));
}
if let Some(scheme) = scheme_value
&& (scheme.eq_ignore_ascii_case(b"http") || scheme.eq_ignore_ascii_case(b"https"))
&& !seen_authority
&& host_value.is_none()
{
return Err(malformed_error(ValidationError::MissingAuthority));
}
}
Ok(())
}
pub fn validate_response_headers(headers: &[HeaderField]) -> Result<(), Error> {
let mut seen_status = false;
let mut past_pseudo = false;
for header in headers {
let name = &header.name;
if name.starts_with(b":") {
if past_pseudo {
return Err(malformed_error(ValidationError::PseudoHeaderAfterRegular));
}
if name == pseudo_headers::STATUS {
if seen_status {
return Err(malformed_error(ValidationError::DuplicatePseudoHeader(
":status",
)));
}
if header.value.len() != 3
|| !header.value.iter().all(|b| b.is_ascii_digit())
|| header.value == b"101"
{
return Err(malformed_error(ValidationError::InvalidStatusCode(
header.value.clone(),
)));
}
seen_status = true;
} else {
return Err(malformed_error(ValidationError::InvalidPseudoHeader(
name.clone(),
)));
}
validate_header_value_chars(&header.value)?;
} else {
past_pseudo = true;
validate_header_name_chars(name)?;
validate_header_value_chars(&header.value)?;
validate_forbidden_header_for_response(name)?;
}
}
if !seen_status {
return Err(malformed_error(ValidationError::MissingPseudoHeader(
":status",
)));
}
Ok(())
}
pub fn validate_trailers(headers: &[HeaderField]) -> Result<(), Error> {
for header in headers {
let name = &header.name;
if name.starts_with(b":") {
return Err(malformed_error(ValidationError::InvalidPseudoHeader(
name.clone(),
)));
}
validate_header_name_chars(name)?;
validate_header_value_chars(&header.value)?;
validate_forbidden_header_for_response(name)?;
}
Ok(())
}
fn validate_header_name_chars(name: &[u8]) -> Result<(), Error> {
if name.is_empty() {
return Err(malformed_error(ValidationError::InvalidHeaderName(
name.to_vec(),
)));
}
for &b in name {
if !is_token_char(b) {
return Err(malformed_error(ValidationError::InvalidHeaderName(
name.to_vec(),
)));
}
}
Ok(())
}
const fn is_token_char(b: u8) -> bool {
matches!(b,
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' |
b'^' | b'_' | b'`' | b'|' | b'~' |
b'0'..=b'9' |
b'a'..=b'z'
)
}
const fn is_token_char_case_insensitive(b: u8) -> bool {
matches!(b,
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' |
b'^' | b'_' | b'`' | b'|' | b'~' |
b'0'..=b'9' |
b'a'..=b'z' |
b'A'..=b'Z'
)
}
fn is_valid_token(value: &[u8]) -> bool {
!value.is_empty() && value.iter().all(|&b| is_token_char_case_insensitive(b))
}
fn is_valid_scheme(value: &[u8]) -> bool {
if value.is_empty() {
return false;
}
if !value[0].is_ascii_alphabetic() {
return false;
}
value[1..]
.iter()
.all(|&b| b.is_ascii_alphanumeric() || b == b'+' || b == b'-' || b == b'.')
}
fn validate_header_value_chars(value: &[u8]) -> Result<(), Error> {
if let Some(&first) = value.first()
&& (first == 0x20 || first == 0x09)
{
return Err(malformed_error(ValidationError::InvalidHeaderValue(
value.to_vec(),
)));
}
if let Some(&last) = value.last()
&& (last == 0x20 || last == 0x09)
{
return Err(malformed_error(ValidationError::InvalidHeaderValue(
value.to_vec(),
)));
}
for &b in value {
if b == 0x00 || b == 0x0d || b == 0x0a {
return Err(malformed_error(ValidationError::InvalidHeaderValue(
value.to_vec(),
)));
}
}
Ok(())
}
fn validate_forbidden_header_for_request(name: &[u8], value: &[u8]) -> Result<(), Error> {
validate_forbidden_header_common(name)?;
if name.eq_ignore_ascii_case(b"te") && !value.eq_ignore_ascii_case(TE_ALLOWED_VALUE) {
return Err(malformed_error(ValidationError::InvalidTeHeader));
}
Ok(())
}
fn validate_forbidden_header_for_response(name: &[u8]) -> Result<(), Error> {
validate_forbidden_header_common(name)?;
if name.eq_ignore_ascii_case(b"te") {
return Err(malformed_error(ValidationError::ForbiddenHeader(
name.to_vec(),
)));
}
Ok(())
}
fn validate_forbidden_header_common(name: &[u8]) -> Result<(), Error> {
if name.eq_ignore_ascii_case(forbidden_headers::CONNECTION)
|| name.eq_ignore_ascii_case(forbidden_headers::KEEP_ALIVE)
|| name.eq_ignore_ascii_case(forbidden_headers::PROXY_CONNECTION)
|| name.eq_ignore_ascii_case(forbidden_headers::TRANSFER_ENCODING)
|| name.eq_ignore_ascii_case(forbidden_headers::UPGRADE)
{
return Err(malformed_error(ValidationError::ForbiddenHeader(
name.to_vec(),
)));
}
Ok(())
}
fn is_valid_connect_authority(authority: &[u8]) -> bool {
if authority.is_empty() {
return false;
}
if authority.starts_with(b"[") {
let Some(bracket_end) = authority.iter().position(|&b| b == b']') else {
return false;
};
let rest = &authority[bracket_end + 1..];
if !rest.starts_with(b":") || rest.len() < 2 {
return false;
}
return rest[1..].iter().all(|b| b.is_ascii_digit());
}
let Some(colon_pos) = authority.iter().rposition(|&b| b == b':') else {
return false;
};
if colon_pos == 0 {
return false;
}
let port = &authority[colon_pos + 1..];
!port.is_empty() && port.iter().all(|b| b.is_ascii_digit())
}
fn malformed_error(validation_error: ValidationError) -> Error {
Error::stream_error(ErrorCode::ProtocolError, validation_error.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_get_request() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":scheme", "https"),
HeaderField::from_str(":path", "/"),
HeaderField::from_str(":authority", "example.com"),
];
assert!(validate_request_headers(&headers).is_ok());
}
#[test]
fn test_valid_connect_request() {
let headers = vec![
HeaderField::from_str(":method", "CONNECT"),
HeaderField::from_str(":authority", "example.com:443"),
];
assert!(validate_request_headers(&headers).is_ok());
}
#[test]
fn test_missing_method() {
let headers = vec![
HeaderField::from_str(":scheme", "https"),
HeaderField::from_str(":path", "/"),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_missing_scheme() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":path", "/"),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_missing_path() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":scheme", "https"),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_duplicate_method() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":method", "POST"),
HeaderField::from_str(":scheme", "https"),
HeaderField::from_str(":path", "/"),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_pseudo_header_after_regular() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str("content-type", "text/html"),
HeaderField::from_str(":scheme", "https"),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_forbidden_connection_header() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":scheme", "https"),
HeaderField::from_str(":path", "/"),
HeaderField::from_str("connection", "close"),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_forbidden_transfer_encoding() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":scheme", "https"),
HeaderField::from_str(":path", "/"),
HeaderField::from_str("transfer-encoding", "chunked"),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_te_trailers_allowed() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":scheme", "https"),
HeaderField::from_str(":path", "/"),
HeaderField::from_str(":authority", "example.com"),
HeaderField::from_str("te", "trailers"),
];
assert!(validate_request_headers(&headers).is_ok());
}
#[test]
fn test_te_gzip_forbidden() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":scheme", "https"),
HeaderField::from_str(":path", "/"),
HeaderField::from_str("te", "gzip"),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_te_trailers_forbidden_in_response() {
let headers = vec![
HeaderField::from_str(":status", "200"),
HeaderField::from_str("te", "trailers"),
];
assert!(validate_response_headers(&headers).is_err());
}
#[test]
fn test_te_trailers_forbidden_in_trailers() {
let headers = vec![HeaderField::from_str("te", "trailers")];
assert!(validate_trailers(&headers).is_err());
}
#[test]
fn test_uppercase_header_name() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":scheme", "https"),
HeaderField::from_str(":path", "/"),
HeaderField::from_str("Content-Type", "text/html"),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_empty_path() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":scheme", "https"),
HeaderField::from_str(":path", ""),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_connect_with_path() {
let headers = vec![
HeaderField::from_str(":method", "CONNECT"),
HeaderField::from_str(":authority", "example.com:443"),
HeaderField::from_str(":path", "/"),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_valid_response() {
let headers = vec![
HeaderField::from_str(":status", "200"),
HeaderField::from_str("content-type", "text/html"),
];
assert!(validate_response_headers(&headers).is_ok());
}
#[test]
fn test_response_missing_status() {
let headers = vec![HeaderField::from_str("content-type", "text/html")];
assert!(validate_response_headers(&headers).is_err());
}
#[test]
fn test_response_with_method() {
let headers = vec![
HeaderField::from_str(":status", "200"),
HeaderField::from_str(":method", "GET"),
];
assert!(validate_response_headers(&headers).is_err());
}
#[test]
fn test_valid_trailers() {
let headers = vec![
HeaderField::from_str("x-checksum", "abc123"),
HeaderField::from_str("x-trailer", "value"),
];
assert!(validate_trailers(&headers).is_ok());
}
#[test]
fn test_host_authority_mismatch() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":scheme", "https"),
HeaderField::from_str(":path", "/"),
HeaderField::from_str(":authority", "example.com"),
HeaderField::from_str("host", "other.com"),
];
assert!(validate_request_headers(&headers).is_err());
}
#[test]
fn test_host_authority_match() {
let headers = vec![
HeaderField::from_str(":method", "GET"),
HeaderField::from_str(":scheme", "https"),
HeaderField::from_str(":path", "/"),
HeaderField::from_str(":authority", "example.com"),
HeaderField::from_str("host", "example.com"),
];
assert!(validate_request_headers(&headers).is_ok());
}
#[test]
fn test_trailers_with_pseudo_header() {
let headers = vec![
HeaderField::from_str(":status", "200"),
HeaderField::from_str("x-trailer", "value"),
];
assert!(validate_trailers(&headers).is_err());
}
}