#![warn(
clippy::all,
nonstandard_style,
future_incompatible,
missing_debug_implementations
)]
#![deny(missing_docs)]
#![forbid(unsafe_code)]
use std::{
cmp::{Ordering, Reverse},
collections::BTreeMap,
fmt::{self, Display},
str::FromStr,
};
use headers_core::{Error as HeaderError, Header, HeaderName, HeaderValue};
use mediatype::{names, MediaType, MediaTypeBuf, ReadParams};
#[derive(Debug)]
pub struct Accept(Vec<MediaTypeBuf>);
impl Accept {
pub fn media_types(&self) -> impl Iterator<Item = &MediaTypeBuf> {
self.0.iter()
}
pub fn negotiate<'a, 'mt: 'a, Available>(
&self,
available: Available,
) -> Option<&'a MediaType<'mt>>
where
Available: IntoIterator<Item = &'a MediaType<'mt>>,
{
struct BestMediaType<'a, 'mt: 'a> {
quality: QValue,
parsed_priority: usize,
given_priority: usize,
media_type: &'a MediaType<'mt>,
}
available
.into_iter()
.enumerate()
.filter_map(|(given_priority, available_type)| {
if let Some(matched_range) = self
.0
.iter()
.enumerate()
.find(|(_, available_range)| MediaRange(available_range) == *available_type)
{
let quality = Self::parse_q_value(matched_range.1);
if quality.is_zero() {
return None;
}
Some(BestMediaType {
quality,
parsed_priority: matched_range.0,
given_priority,
media_type: available_type,
})
} else {
None
}
})
.max_by_key(|x| (x.quality, Reverse((x.parsed_priority, x.given_priority))))
.map(|best| best.media_type)
}
fn parse(mut s: &str) -> Result<Self, HeaderError> {
let mut media_types = Vec::new();
while !s.is_empty() {
if let Some(index) = s.find(|c: char| !is_ows(c)) {
s = &s[index..];
} else {
break;
}
let mut end = 0;
let mut quoted = false;
let mut escaped = false;
for c in s.chars() {
if escaped {
escaped = false;
} else {
match c {
'"' => quoted = !quoted,
'\\' if quoted => escaped = true,
',' if !quoted => break,
_ => (),
}
}
end += c.len_utf8();
}
match MediaTypeBuf::from_str(s[..end].trim()) {
Ok(mt) => media_types.push(mt),
Err(_) => return Err(HeaderError::invalid()),
}
s = s[end..].trim_start_matches(',');
}
media_types.sort_by_key(|x| {
let spec = Self::parse_specificity(x);
let q = Self::parse_q_value(x);
Reverse((spec, q))
});
Ok(Self(media_types))
}
fn parse_q_value(media_type: &MediaTypeBuf) -> QValue {
media_type
.get_param(names::Q)
.and_then(|v| v.as_str().parse().ok())
.unwrap_or_default()
}
fn parse_specificity(media_type: &MediaTypeBuf) -> usize {
let type_specificity = if media_type.ty() != names::_STAR {
1
} else {
0
};
let subtype_specificity = if media_type.subty() != names::_STAR {
1
} else {
0
};
let parameter_count = media_type
.params()
.filter(|&(name, _)| name != names::Q)
.count();
type_specificity + subtype_specificity + parameter_count
}
}
impl Header for Accept {
fn name() -> &'static HeaderName {
&http::header::ACCEPT
}
fn decode<'i, I>(values: &mut I) -> Result<Self, HeaderError>
where
I: Iterator<Item = &'i HeaderValue>,
{
let mut values_iter = values.map(|v| v.to_str().map_err(|_| HeaderError::invalid()));
let mut value_str = String::from(values_iter.next().ok_or(HeaderError::invalid())??);
for v in values_iter {
value_str.push(',');
value_str.push_str(v?);
}
Self::parse(&value_str)
}
fn encode<E>(&self, values: &mut E)
where
E: Extend<HeaderValue>,
{
let value = HeaderValue::from_str(&self.to_string())
.expect("Header value should only contain visible ASCII characters (32-127)");
values.extend(std::iter::once(value));
}
}
impl FromStr for Accept {
type Err = HeaderError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::parse(s).map_err(|_| HeaderError::invalid())
}
}
impl TryFrom<&HeaderValue> for Accept {
type Error = HeaderError;
fn try_from(value: &HeaderValue) -> Result<Self, Self::Error> {
let s = value.to_str().map_err(|_| HeaderError::invalid())?;
s.parse().map_err(|_| HeaderError::invalid())
}
}
impl Display for Accept {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let media_types = self
.0
.iter()
.map(|mt| mt.to_string())
.collect::<Vec<_>>()
.join(", ");
write!(f, "{media_types}")
}
}
impl<'a> FromIterator<MediaType<'a>> for Accept {
fn from_iter<T: IntoIterator<Item = MediaType<'a>>>(iter: T) -> Self {
iter.into_iter().map(MediaTypeBuf::from).collect()
}
}
impl FromIterator<MediaTypeBuf> for Accept {
fn from_iter<T: IntoIterator<Item = MediaTypeBuf>>(iter: T) -> Self {
Self(iter.into_iter().collect())
}
}
const fn is_ows(c: char) -> bool {
c == ' ' || c == '\t'
}
struct MediaRange<'a>(&'a MediaTypeBuf);
impl PartialEq<MediaType<'_>> for MediaRange<'_> {
fn eq(&self, other: &MediaType<'_>) -> bool {
let (type_match, subtype_match, suffix_match) = (
self.0.ty() == other.ty,
self.0.subty() == other.subty,
self.0.suffix() == other.suffix,
);
let wildcard_type = self.0.ty() == names::_STAR;
let wildcard_subtype = self.0.subty() == names::_STAR && type_match;
let exact_match =
type_match && subtype_match && suffix_match && self.0.params().count() == 0;
let params_match = type_match && subtype_match && suffix_match && {
let self_params = self
.0
.params()
.filter(|&(name, _)| name != names::Q)
.collect::<BTreeMap<_, _>>();
let other_params = other
.params()
.filter(|&(name, _)| name != names::Q)
.collect::<BTreeMap<_, _>>();
self_params == other_params
};
wildcard_type || wildcard_subtype || exact_match || params_match
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct QValue(
u16,
);
impl Default for QValue {
fn default() -> Self {
QValue(1000)
}
}
impl QValue {
pub fn is_zero(&self) -> bool {
self.0 == 0
}
}
impl FromStr for QValue {
type Err = HeaderError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
fn parse_fractional(digits: &[u8]) -> Result<u16, HeaderError> {
digits
.iter()
.try_fold(0u16, |acc, &c| {
if c.is_ascii_digit() {
Some(acc * 10 + (c - b'0') as u16)
} else {
None
}
})
.map(|num| match digits.len() {
1 => num * 100,
2 => num * 10,
_ => num,
})
.ok_or_else(HeaderError::invalid)
}
match s.as_bytes() {
b"0" => Ok(QValue(0)),
b"1" => Ok(QValue(1000)),
[b'1', b'.', zeros @ ..] if zeros.len() <= 3 && zeros.iter().all(|d| *d == b'0') => {
Ok(QValue(1000))
}
[b'0', b'.', fractional @ ..] if fractional.len() <= 3 => {
parse_fractional(fractional).map(QValue)
}
_ => Err(HeaderError::invalid()),
}
}
}
impl Ord for QValue {
fn cmp(&self, other: &Self) -> Ordering {
self.0.cmp(&other.0)
}
}
impl PartialOrd for QValue {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reordering() {
let accept = Accept::from_str("audio/*; q=0.2, audio/basic").unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("audio/basic").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("audio/*; q=0.2").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn reordering_elaborate() {
let accept =
Accept::from_str("text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c").unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/html").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/x-c").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/x-dvi; q=0.8").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain; q=0.5").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn preserve_ordering() {
let accept = Accept::from_str("x/y, a/b").unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("x/y").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("a/b").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn params() {
let accept =
Accept::from_str("text/html, application/xhtml+xml, application/xml;q=0.9, */*;q=0.8")
.unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/html").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("application/xhtml+xml").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("application/xml;q=0.9").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("*/*;q=0.8").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn quoted_params() {
let accept = Accept::from_str(
"text/html; message=\"Hello, world!\", application/xhtml+xml; message=\"Hello, \
world?\"",
)
.unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/html; message=\"Hello, world!\"").unwrap())
);
assert_eq!(
media_types.next(),
Some(
&MediaTypeBuf::from_str("application/xhtml+xml; message=\"Hello, world?\"")
.unwrap()
)
);
assert_eq!(media_types.next(), None);
}
#[test]
fn more_specifics() {
let accept = Accept::from_str("text/*, text/plain, text/plain;format=flowed, */*").unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain;format=flowed").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/*").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("*/*").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn variable_quality_more_specifics() {
let accept = Accept::from_str(
"text/*;q=0.3, text/plain;q=0.7, text/csv;q=0, text/plain;format=flowed, \
text/plain;format=fixed;q=0.4, */*;q=0.5",
)
.unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain;format=flowed").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain;format=fixed;q=0.4").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain;q=0.7").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/csv;q=0").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/*;q=0.3").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("*/*;q=0.5").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn negotiate() {
let accept = Accept::from_str(
"text/html, application/xhtml+xml, application/xml;q=0.9, text/*;q=0.7, text/csv;q=0",
)
.unwrap();
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/html").unwrap(),
MediaType::parse("application/json").unwrap()
])
.unwrap(),
&MediaType::parse("text/html").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("application/xhtml+xml").unwrap(),
MediaType::parse("text/html").unwrap()
])
.unwrap(),
&MediaType::parse("text/html").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/plain").unwrap(),
MediaType::parse("image/gif").unwrap()
])
.unwrap(),
&MediaType::parse("text/plain").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("image/gif").unwrap(),
MediaType::parse("text/plain").unwrap(),
MediaType::parse("text/troff").unwrap(),
])
.unwrap(),
&MediaType::parse("text/plain").unwrap()
);
assert_eq!(
accept.negotiate(&vec![
MediaType::parse("image/gif").unwrap(),
MediaType::parse("image/png").unwrap()
]),
None
);
assert_eq!(
accept.negotiate(&vec![
MediaType::parse("image/gif").unwrap(),
MediaType::parse("text/csv").unwrap()
]),
None
);
}
#[test]
fn negotiate_with_full_wildcard() {
let accept =
Accept::from_str("text/html, text/*;q=0.7, */*;q=0.1, text/csv;q=0.0").unwrap();
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/html").unwrap(),
MediaType::parse("application/json").unwrap()
])
.unwrap(),
&MediaType::parse("text/html").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/plain").unwrap(),
MediaType::parse("image/gif").unwrap()
])
.unwrap(),
&MediaType::parse("text/plain").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/javascript").unwrap(),
MediaType::parse("text/plain").unwrap()
])
.unwrap(),
&MediaType::parse("text/javascript").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("image/gif").unwrap(),
MediaType::parse("image/png").unwrap()
])
.unwrap(),
&MediaType::parse("image/gif").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/csv").unwrap(),
MediaType::parse("text/javascript").unwrap()
])
.unwrap(),
&MediaType::parse("text/javascript").unwrap()
);
}
#[test]
fn negotiate_diabolically() {
let accept = Accept::from_str(
"text/*;q=0.3, text/csv;q=0.2, text/plain;q=0.7, text/plain;format=rot13;q=0.7, \
text/plain;format=flowed, text/plain;format=fixed;q=0.4, */*;q=0.5",
)
.unwrap();
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/html").unwrap(),
MediaType::parse("text/plain").unwrap()
])
.unwrap(),
&MediaType::parse("text/plain").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/plain").unwrap(),
MediaType::parse("text/plain;format=rot13").unwrap(),
])
.unwrap(),
&MediaType::parse("text/plain;format=rot13").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/plain").unwrap(),
MediaType::parse("text/plain;format=fixed").unwrap()
])
.unwrap(),
&MediaType::parse("text/plain").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/html").unwrap(),
MediaType::parse("image/gif").unwrap()
])
.unwrap(),
&MediaType::parse("image/gif").unwrap()
);
}
#[test]
fn try_from_header_value() {
let header_value = &HeaderValue::from_static("audio/*; q=0.2, audio/basic");
let accept: Accept = header_value.try_into().unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("audio/basic").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("audio/*; q=0.2").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn decode() {
let mut empty_iter = [].iter();
assert!(
Accept::decode(&mut empty_iter).is_err(),
"providing no headers results in an error"
);
let header_value_1 = HeaderValue::from_static("audio/*; q=0.2");
let header_value_2 = HeaderValue::from_static("audio/basic");
let header_value_combined = HeaderValue::from_static("audio/*; q=0.2, audio/basic");
let combined_accept_try_into: Accept = (&header_value_combined).try_into().unwrap();
let combined_accept_decode =
Accept::decode(&mut [&header_value_combined].into_iter()).unwrap();
let mut combined_iter_decode = combined_accept_decode.media_types();
let mut combined_iter_try_into = combined_accept_try_into.media_types();
for (m1, m2) in core::iter::zip(&mut combined_iter_decode, &mut combined_iter_try_into) {
assert_eq!(m1, m2, "same media type through `decode` and `try_into`");
}
assert_eq!(combined_iter_decode.next(), None);
assert_eq!(combined_iter_try_into.next(), None);
let separate_accept_decode =
Accept::decode(&mut [&header_value_1, &header_value_2].into_iter()).unwrap();
let mut separate_iter_decode = separate_accept_decode.media_types();
let mut separate_iter_try_into = combined_accept_try_into.media_types();
for (m1, m2) in core::iter::zip(&mut separate_iter_decode, &mut separate_iter_try_into) {
assert_eq!(m1, m2, "same media type through `decode` and `try_into`");
}
assert_eq!(separate_iter_decode.next(), None);
assert_eq!(separate_iter_try_into.next(), None);
}
#[test]
fn mixed_lifetime_from_iter() {
#[allow(unused)]
fn best<'a>(available: &'a [MediaType<'static>]) -> Option<&'a MediaType<'static>> {
let accept = Accept::from_str("*/*").unwrap();
accept.negotiate(available.iter())
}
}
#[test]
fn from_iterator() {
let accept = Accept::from_iter([
MediaType::parse("text/html").unwrap(),
MediaType::parse("image/gif").unwrap(),
]);
assert_eq!(
accept.media_types().collect::<Vec<_>>(),
vec![
MediaType::parse("text/html").unwrap(),
MediaType::parse("image/gif").unwrap(),
]
);
let accept = Accept::from_iter([
MediaTypeBuf::from_str("text/html").unwrap(),
MediaTypeBuf::from_str("image/gif").unwrap(),
]);
assert_eq!(
accept.media_types().collect::<Vec<_>>(),
vec![
MediaType::parse("text/html").unwrap(),
MediaType::parse("image/gif").unwrap(),
]
);
}
#[test]
fn test_qvalue_parsing_one() {
assert_eq!(QValue(1000), "1".parse().unwrap());
assert_eq!(QValue(1000), "1.".parse().unwrap());
assert_eq!(QValue(1000), "1.0".parse().unwrap());
assert_eq!(QValue(1000), "1.00".parse().unwrap());
assert_eq!(QValue(1000), "1.000".parse().unwrap());
}
#[test]
fn test_qvalue_parsing_partial() {
assert_eq!(QValue(0), "0".parse().unwrap());
assert_eq!(QValue(0), "0.".parse().unwrap());
assert_eq!(QValue(0), "0.0".parse().unwrap());
assert_eq!(QValue(0), "0.00".parse().unwrap());
assert_eq!(QValue(0), "0.000".parse().unwrap());
assert_eq!(QValue(100), "0.1".parse().unwrap());
assert_eq!(QValue(120), "0.12".parse().unwrap());
assert_eq!(QValue(123), "0.123".parse().unwrap());
assert_eq!(QValue(23), "0.023".parse().unwrap());
assert_eq!(QValue(3), "0.003".parse().unwrap());
}
#[test]
fn qvalue_parsing_invalid() {
assert!("0.0000".parse::<QValue>().is_err());
assert!("0.1.".parse::<QValue>().is_err());
assert!("0.12.".parse::<QValue>().is_err());
assert!("0.123.".parse::<QValue>().is_err());
assert!("0.1234".parse::<QValue>().is_err());
assert!("1.123".parse::<QValue>().is_err());
assert!("1.1234".parse::<QValue>().is_err());
assert!("1.12345".parse::<QValue>().is_err());
assert!("2.0".parse::<QValue>().is_err());
assert!("-0.0".parse::<QValue>().is_err());
assert!("1.0000".parse::<QValue>().is_err());
}
#[test]
fn qvalue_ordering() {
assert!(QValue(1000) > QValue(0));
assert!(QValue(1000) > QValue(100));
assert!(QValue(100) > QValue(0));
assert!(QValue(120) > QValue(100));
assert!(QValue(123) > QValue(120));
assert!(QValue(23) < QValue(100));
assert!(QValue(3) < QValue(23));
}
#[test]
fn qvalue_default() {
let q: QValue = Default::default();
assert_eq!(q, QValue(1000));
}
#[test]
fn qvalue_is_zero() {
assert!("0.".parse::<QValue>().unwrap().is_zero());
}
}