#![warn(
clippy::all,
nonstandard_style,
future_incompatible,
missing_debug_implementations
)]
#![deny(missing_docs)]
#![forbid(unsafe_code)]
use std::{
cmp::Ordering,
collections::BTreeMap,
fmt::{self, Display},
str::FromStr,
};
use headers_core::{Error as HeaderError, Header, HeaderName, HeaderValue};
use mediatype::{names, MediaType, MediaTypeBuf, ReadParams};
#[derive(Debug)]
pub struct Accept(Vec<MediaTypeBuf>);
impl Accept {
pub fn media_types(&self) -> impl Iterator<Item = &MediaTypeBuf> {
self.0.iter()
}
pub fn negotiate<'a, Available>(&self, available: Available) -> Option<&MediaType<'a>>
where
Available: IntoIterator<Item = &'a MediaType<'a>>,
{
let mut best = BestMediaType::default();
available
.into_iter()
.enumerate()
.map(|(given_priority, available_type)| {
if let Some(matched_range) = self
.0
.iter()
.enumerate()
.find(|(_, available_range)| MediaRange(available_range) == *available_type)
{
let quality = Self::parse_q_value(matched_range.1);
BestMediaType {
quality,
parsed_priority: matched_range.0,
given_priority,
ty: Some(available_type),
}
} else {
BestMediaType::default()
}
})
.for_each(|new_best| {
if new_best.quality > best.quality
|| new_best.quality == best.quality
&& (new_best.parsed_priority, new_best.given_priority)
< (best.parsed_priority, best.given_priority)
{
best = new_best
}
});
best.ty
}
fn parse(mut s: &str) -> Result<Self, HeaderError> {
let mut media_types = Vec::new();
while !s.is_empty() {
if let Some(index) = s.find(|c: char| !is_ows(c)) {
s = &s[index..];
} else {
break;
}
let mut end = 0;
let mut quoted = false;
let mut escaped = false;
for c in s.chars() {
if escaped {
escaped = false;
} else {
match c {
'"' => quoted = !quoted,
'\\' if quoted => escaped = true,
',' if !quoted => break,
_ => (),
}
}
end += c.len_utf8();
}
match MediaTypeBuf::from_str(s[..end].trim()) {
Ok(mt) => media_types.push(mt),
Err(_) => return Err(HeaderError::invalid()),
}
s = s[end..].trim_start_matches(',');
}
media_types.sort_by(|a, b| {
let spec_a = Self::parse_specificity(a);
let spec_b = Self::parse_specificity(b);
let q_a = Self::parse_q_value(a);
let q_b = Self::parse_q_value(b);
spec_b
.cmp(&spec_a)
.then_with(|| q_b.partial_cmp(&q_a).unwrap_or(Ordering::Equal))
});
Ok(Self(media_types))
}
fn parse_q_value(media_type: &MediaTypeBuf) -> f32 {
media_type
.get_param(names::Q)
.and_then(|v| v.as_str().parse().ok())
.unwrap_or(1.0)
}
fn parse_specificity(media_type: &MediaTypeBuf) -> usize {
let type_specificity = if media_type.ty() != "*" { 1 } else { 0 };
let subtype_specificity = if media_type.subty() != "*" { 1 } else { 0 };
let parameter_count = media_type
.params()
.filter(|&(name, _)| name != names::Q)
.count();
type_specificity + subtype_specificity + parameter_count
}
}
impl Header for Accept {
fn name() -> &'static HeaderName {
&http::header::ACCEPT
}
fn decode<'i, I>(values: &mut I) -> Result<Self, HeaderError>
where
I: Iterator<Item = &'i HeaderValue>,
{
let value = values.next().ok_or_else(HeaderError::invalid)?;
let value_str = value.to_str().map_err(|_| HeaderError::invalid())?;
Self::parse(value_str)
}
fn encode<E>(&self, values: &mut E)
where
E: Extend<HeaderValue>,
{
let value = HeaderValue::from_str(&self.to_string())
.expect("Header value should only contain visible ASCII characters (32-127)");
values.extend(std::iter::once(value));
}
}
impl FromStr for Accept {
type Err = HeaderError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::parse(s).map_err(|_| HeaderError::invalid())
}
}
impl Display for Accept {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let media_types = self
.0
.iter()
.map(|mt| mt.to_string())
.collect::<Vec<_>>()
.join(", ");
write!(f, "{}", media_types)
}
}
const fn is_ows(c: char) -> bool {
c == ' ' || c == '\t'
}
struct MediaRange<'a>(&'a MediaTypeBuf);
impl PartialEq<MediaType<'_>> for MediaRange<'_> {
fn eq(&self, other: &MediaType<'_>) -> bool {
let (type_match, subtype_match, suffix_match) = (
self.0.ty() == other.ty,
self.0.subty() == other.subty,
self.0.suffix() == other.suffix,
);
let wildcard_type = self.0.ty() == names::_STAR;
let wildcard_subtype = self.0.subty() == names::_STAR && type_match;
let exact_match =
type_match && subtype_match && suffix_match && self.0.params().count() == 0;
let params_match = type_match && subtype_match && suffix_match && {
let self_params = self
.0
.params()
.filter(|&(name, _)| name != names::Q)
.collect::<BTreeMap<_, _>>();
let other_params = other
.params()
.filter(|&(name, _)| name != names::Q)
.collect::<BTreeMap<_, _>>();
self_params == other_params
};
wildcard_type || wildcard_subtype || exact_match || params_match
}
}
#[derive(Default)]
struct BestMediaType<'ty> {
quality: f32,
parsed_priority: usize,
given_priority: usize,
ty: Option<&'ty MediaType<'ty>>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reordering() {
let accept = Accept::from_str("audio/*; q=0.2, audio/basic").unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("audio/basic").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("audio/*; q=0.2").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn reordering_elaborate() {
let accept =
Accept::from_str("text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c").unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/html").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/x-c").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/x-dvi; q=0.8").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain; q=0.5").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn preserve_ordering() {
let accept = Accept::from_str("x/y, a/b").unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("x/y").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("a/b").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn params() {
let accept =
Accept::from_str("text/html, application/xhtml+xml, application/xml;q=0.9, */*;q=0.8")
.unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/html").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("application/xhtml+xml").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("application/xml;q=0.9").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("*/*;q=0.8").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn quoted_params() {
let accept = Accept::from_str(
"text/html; message=\"Hello, world!\", application/xhtml+xml; message=\"Hello, \
world?\"",
)
.unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/html; message=\"Hello, world!\"").unwrap())
);
assert_eq!(
media_types.next(),
Some(
&MediaTypeBuf::from_str("application/xhtml+xml; message=\"Hello, world?\"")
.unwrap()
)
);
assert_eq!(media_types.next(), None);
}
#[test]
fn more_specifics() {
let accept = Accept::from_str("text/*, text/plain, text/plain;format=flowed, */*").unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain;format=flowed").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/*").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("*/*").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn variable_quality_more_specifics() {
let accept = Accept::from_str(
"text/*;q=0.3, text/plain;q=0.7, text/csv;q=0, text/plain;format=flowed, \
text/plain;format=fixed;q=0.4, */*;q=0.5",
)
.unwrap();
let mut media_types = accept.media_types();
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain;format=flowed").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain;format=fixed;q=0.4").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/plain;q=0.7").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/csv;q=0").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("text/*;q=0.3").unwrap())
);
assert_eq!(
media_types.next(),
Some(&MediaTypeBuf::from_str("*/*;q=0.5").unwrap())
);
assert_eq!(media_types.next(), None);
}
#[test]
fn negotiate() {
let accept = Accept::from_str(
"text/html, application/xhtml+xml, application/xml;q=0.9, text/*;q=0.7, text/csv;q=0",
)
.unwrap();
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/html").unwrap(),
MediaType::parse("application/json").unwrap()
])
.unwrap(),
&MediaType::parse("text/html").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("application/xhtml+xml").unwrap(),
MediaType::parse("text/html").unwrap()
])
.unwrap(),
&MediaType::parse("text/html").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/plain").unwrap(),
MediaType::parse("image/gif").unwrap()
])
.unwrap(),
&MediaType::parse("text/plain").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("image/gif").unwrap(),
MediaType::parse("text/plain").unwrap(),
MediaType::parse("text/troff").unwrap(),
])
.unwrap(),
&MediaType::parse("text/plain").unwrap()
);
assert_eq!(
accept.negotiate(&vec![
MediaType::parse("image/gif").unwrap(),
MediaType::parse("image/png").unwrap()
]),
None
);
assert_eq!(
accept.negotiate(&vec![
MediaType::parse("image/gif").unwrap(),
MediaType::parse("text/csv").unwrap()
]),
None
);
}
#[test]
fn negotiate_with_full_wildcard() {
let accept =
Accept::from_str("text/html, text/*;q=0.7, */*;q=0.1, text/csv;q=0.0").unwrap();
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/html").unwrap(),
MediaType::parse("application/json").unwrap()
])
.unwrap(),
&MediaType::parse("text/html").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/plain").unwrap(),
MediaType::parse("image/gif").unwrap()
])
.unwrap(),
&MediaType::parse("text/plain").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/javascript").unwrap(),
MediaType::parse("text/plain").unwrap()
])
.unwrap(),
&MediaType::parse("text/javascript").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("image/gif").unwrap(),
MediaType::parse("image/png").unwrap()
])
.unwrap(),
&MediaType::parse("image/gif").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/csv").unwrap(),
MediaType::parse("text/javascript").unwrap()
])
.unwrap(),
&MediaType::parse("text/javascript").unwrap()
);
}
#[test]
fn negotiate_diabolically() {
let accept = Accept::from_str(
"text/*;q=0.3, text/csv;q=0.2, text/plain;q=0.7, text/plain;format=rot13;q=0.7, \
text/plain;format=flowed, text/plain;format=fixed;q=0.4, */*;q=0.5",
)
.unwrap();
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/html").unwrap(),
MediaType::parse("text/plain").unwrap()
])
.unwrap(),
&MediaType::parse("text/plain").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/plain").unwrap(),
MediaType::parse("text/plain;format=rot13").unwrap(),
])
.unwrap(),
&MediaType::parse("text/plain;format=rot13").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/plain").unwrap(),
MediaType::parse("text/plain;format=fixed").unwrap()
])
.unwrap(),
&MediaType::parse("text/plain").unwrap()
);
assert_eq!(
accept
.negotiate(&vec![
MediaType::parse("text/html").unwrap(),
MediaType::parse("image/gif").unwrap()
])
.unwrap(),
&MediaType::parse("image/gif").unwrap()
);
}
}