use alloc::string::String;
use alloc::vec::Vec;
use core::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum TrailerError {
Empty,
InvalidFormat,
InvalidFieldName,
ProhibitedField(String),
}
impl fmt::Display for TrailerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TrailerError::Empty => write!(f, "empty Trailer header"),
TrailerError::InvalidFormat => write!(f, "invalid Trailer header format"),
TrailerError::InvalidFieldName => write!(f, "invalid Trailer field name"),
TrailerError::ProhibitedField(name) => {
write!(f, "prohibited trailer field: {}", name)
}
}
}
}
impl core::error::Error for TrailerError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Trailer {
fields: Vec<String>,
}
impl Trailer {
pub fn parse(input: &str) -> Result<Self, TrailerError> {
let input = input.trim();
let mut fields = Vec::new();
for part in input.split(',') {
let name = part.trim();
if name.is_empty() {
continue;
}
if !is_valid_token(name) {
return Err(TrailerError::InvalidFieldName);
}
let lower_name = name.to_ascii_lowercase();
if is_prohibited_trailer_field(&lower_name) {
return Err(TrailerError::ProhibitedField(lower_name));
}
fields.push(lower_name);
}
Ok(Trailer { fields })
}
pub fn fields(&self) -> &[String] {
&self.fields
}
}
impl fmt::Display for Trailer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.fields.join(", "))
}
}
fn is_valid_token(s: &str) -> bool {
!s.is_empty() && s.bytes().all(is_token_char)
}
fn is_token_char(b: u8) -> bool {
matches!(
b,
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' |
b'0'..=b'9' | b'A'..=b'Z' | b'^' | b'_' | b'`' | b'a'..=b'z' | b'|' | b'~'
)
}
pub fn is_prohibited_trailer_field(name: &str) -> bool {
name.eq_ignore_ascii_case("transfer-encoding")
|| name.eq_ignore_ascii_case("content-length")
|| name.eq_ignore_ascii_case("host")
|| name.eq_ignore_ascii_case("if-match")
|| name.eq_ignore_ascii_case("if-none-match")
|| name.eq_ignore_ascii_case("if-modified-since")
|| name.eq_ignore_ascii_case("if-unmodified-since")
|| name.eq_ignore_ascii_case("if-range")
|| name.eq_ignore_ascii_case("range")
|| name.eq_ignore_ascii_case("expect")
|| name.eq_ignore_ascii_case("te")
|| name.eq_ignore_ascii_case("authorization")
|| name.eq_ignore_ascii_case("proxy-authorization")
|| name.eq_ignore_ascii_case("www-authenticate")
|| name.eq_ignore_ascii_case("proxy-authenticate")
|| name.eq_ignore_ascii_case("cache-control")
|| name.eq_ignore_ascii_case("vary")
|| name.eq_ignore_ascii_case("date")
|| name.eq_ignore_ascii_case("expires")
|| name.eq_ignore_ascii_case("age")
|| name.eq_ignore_ascii_case("set-cookie")
|| name.eq_ignore_ascii_case("content-encoding")
|| name.eq_ignore_ascii_case("content-type")
|| name.eq_ignore_ascii_case("content-range")
|| name.eq_ignore_ascii_case("connection")
|| name.eq_ignore_ascii_case("upgrade")
|| name.eq_ignore_ascii_case("trailer")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_fields() {
let trailer = Trailer::parse("X-Checksum, X-Test").unwrap();
assert_eq!(
trailer.fields(),
&["x-checksum".to_string(), "x-test".to_string()]
);
}
#[test]
fn parse_invalid() {
assert!(Trailer::parse("bad value").is_err());
}
#[test]
fn parse_empty_elements() {
let trailer = Trailer::parse("").unwrap();
assert!(trailer.fields().is_empty());
let trailer = Trailer::parse(",").unwrap();
assert!(trailer.fields().is_empty());
let trailer = Trailer::parse("X-Checksum,,X-Test").unwrap();
assert_eq!(trailer.fields().len(), 2);
}
#[test]
fn display() {
let trailer = Trailer::parse("X-Checksum, X-Test").unwrap();
assert_eq!(trailer.to_string(), "x-checksum, x-test");
}
#[test]
fn prohibited_field_transfer_encoding() {
let result = Trailer::parse("Transfer-Encoding");
assert!(matches!(
result,
Err(TrailerError::ProhibitedField(ref name)) if name == "transfer-encoding"
));
}
#[test]
fn prohibited_field_content_length() {
let result = Trailer::parse("Content-Length");
assert!(matches!(
result,
Err(TrailerError::ProhibitedField(ref name)) if name == "content-length"
));
}
#[test]
fn prohibited_field_host() {
let result = Trailer::parse("Host");
assert!(matches!(
result,
Err(TrailerError::ProhibitedField(ref name)) if name == "host"
));
}
#[test]
fn prohibited_field_trailer() {
let result = Trailer::parse("Trailer");
assert!(matches!(
result,
Err(TrailerError::ProhibitedField(ref name)) if name == "trailer"
));
}
#[test]
fn prohibited_field_content_encoding() {
let result = Trailer::parse("Content-Encoding");
assert!(matches!(
result,
Err(TrailerError::ProhibitedField(ref name)) if name == "content-encoding"
));
}
#[test]
fn prohibited_field_content_type() {
let result = Trailer::parse("Content-Type");
assert!(matches!(
result,
Err(TrailerError::ProhibitedField(ref name)) if name == "content-type"
));
}
#[test]
fn prohibited_field_content_range() {
let result = Trailer::parse("Content-Range");
assert!(matches!(
result,
Err(TrailerError::ProhibitedField(ref name)) if name == "content-range"
));
}
#[test]
fn prohibited_field_authorization() {
let result = Trailer::parse("Authorization");
assert!(matches!(
result,
Err(TrailerError::ProhibitedField(ref name)) if name == "authorization"
));
}
#[test]
fn prohibited_field_proxy_authorization() {
let result = Trailer::parse("Proxy-Authorization");
assert!(matches!(
result,
Err(TrailerError::ProhibitedField(ref name)) if name == "proxy-authorization"
));
}
#[test]
fn prohibited_field_www_authenticate() {
let result = Trailer::parse("WWW-Authenticate");
assert!(matches!(
result,
Err(TrailerError::ProhibitedField(ref name)) if name == "www-authenticate"
));
}
#[test]
fn prohibited_field_request_modifier() {
for name in [
"If-Match",
"If-None-Match",
"If-Modified-Since",
"If-Unmodified-Since",
"If-Range",
"Range",
"Expect",
"TE",
] {
let result = Trailer::parse(name);
assert!(
matches!(result, Err(TrailerError::ProhibitedField(_))),
"{} は禁止フィールドとして拒否されるべき",
name
);
}
}
#[test]
fn prohibited_field_response_control() {
for name in [
"Cache-Control",
"Vary",
"Date",
"Expires",
"Age",
"Set-Cookie",
] {
let result = Trailer::parse(name);
assert!(
matches!(result, Err(TrailerError::ProhibitedField(_))),
"{} は禁止フィールドとして拒否されるべき",
name
);
}
}
#[test]
fn prohibited_field_connection_management() {
for name in ["Connection", "Upgrade"] {
let result = Trailer::parse(name);
assert!(
matches!(result, Err(TrailerError::ProhibitedField(_))),
"{} は禁止フィールドとして拒否されるべき",
name
);
}
}
#[test]
fn prohibited_field_in_list() {
let result = Trailer::parse("X-Custom, Content-Length, X-Other");
assert!(matches!(
result,
Err(TrailerError::ProhibitedField(ref name)) if name == "content-length"
));
}
#[test]
fn allowed_fields() {
let trailer = Trailer::parse("X-Checksum, X-Custom, X-Trace-Id").unwrap();
assert_eq!(trailer.fields().len(), 3);
}
#[test]
fn is_prohibited_trailer_field_function() {
assert!(is_prohibited_trailer_field("Transfer-Encoding"));
assert!(is_prohibited_trailer_field("transfer-encoding"));
assert!(is_prohibited_trailer_field("CONTENT-LENGTH"));
assert!(is_prohibited_trailer_field("Expires"));
assert!(is_prohibited_trailer_field("Authorization"));
assert!(is_prohibited_trailer_field("Cache-Control"));
assert!(is_prohibited_trailer_field("Range"));
assert!(is_prohibited_trailer_field("Connection"));
assert!(!is_prohibited_trailer_field("X-Custom"));
assert!(!is_prohibited_trailer_field("X-Checksum"));
}
}