1#![warn(
48 clippy::all,
49 nonstandard_style,
50 future_incompatible,
51 missing_debug_implementations
52)]
53#![deny(missing_docs)]
54#![forbid(unsafe_code)]
55
56use std::{
57 cmp::{Ordering, Reverse},
58 collections::BTreeMap,
59 fmt::{self, Display},
60 str::FromStr,
61};
62
63use headers_core::{Error as HeaderError, Header, HeaderName, HeaderValue};
64use mediatype::{names, MediaType, MediaTypeBuf, ReadParams};
65
66#[derive(Debug)]
74pub struct Accept(Vec<MediaTypeBuf>);
75
76impl Accept {
77 pub fn media_types(&self) -> impl Iterator<Item = &MediaTypeBuf> {
85 self.0.iter()
86 }
87
88 pub fn negotiate<'a, 'mt: 'a, Available>(
132 &self,
133 available: Available,
134 ) -> Option<&'a MediaType<'mt>>
135 where
136 Available: IntoIterator<Item = &'a MediaType<'mt>>,
137 {
138 struct BestMediaType<'a, 'mt: 'a> {
139 quality: QValue,
140 parsed_priority: usize,
141 given_priority: usize,
142 media_type: &'a MediaType<'mt>,
143 }
144
145 available
146 .into_iter()
147 .enumerate()
148 .filter_map(|(given_priority, available_type)| {
149 if let Some(matched_range) = self
150 .0
151 .iter()
152 .enumerate()
153 .find(|(_, available_range)| MediaRange(available_range) == *available_type)
154 {
155 let quality = Self::parse_q_value(matched_range.1);
156 if quality.is_zero() {
157 return None;
158 }
159 Some(BestMediaType {
160 quality,
161 parsed_priority: matched_range.0,
162 given_priority,
163 media_type: available_type,
164 })
165 } else {
166 None
167 }
168 })
169 .max_by_key(|x| (x.quality, Reverse((x.parsed_priority, x.given_priority))))
170 .map(|best| best.media_type)
171 }
172
173 fn parse(mut s: &str) -> Result<Self, HeaderError> {
174 let mut media_types = Vec::new();
175
176 while !s.is_empty() {
180 if let Some(index) = s.find(|c: char| !is_ows(c)) {
182 s = &s[index..];
183 } else {
184 break;
185 }
186
187 let mut end = 0;
188 let mut quoted = false;
189 let mut escaped = false;
190 for c in s.chars() {
191 if escaped {
192 escaped = false;
193 } else {
194 match c {
195 '"' => quoted = !quoted,
196 '\\' if quoted => escaped = true,
197 ',' if !quoted => break,
198 _ => (),
199 }
200 }
201 end += c.len_utf8();
202 }
203
204 match MediaTypeBuf::from_str(s[..end].trim()) {
206 Ok(mt) => media_types.push(mt),
207 Err(_) => return Err(HeaderError::invalid()),
208 }
209
210 s = s[end..].trim_start_matches(',');
212 }
213
214 media_types.sort_by_key(|x| {
216 let spec = Self::parse_specificity(x);
217 let q = Self::parse_q_value(x);
218 Reverse((spec, q))
219 });
220
221 Ok(Self(media_types))
222 }
223
224 fn parse_q_value(media_type: &MediaTypeBuf) -> QValue {
225 media_type
226 .get_param(names::Q)
227 .and_then(|v| v.as_str().parse().ok())
228 .unwrap_or_default()
229 }
230
231 fn parse_specificity(media_type: &MediaTypeBuf) -> usize {
232 let type_specificity = if media_type.ty() != names::_STAR {
233 1
234 } else {
235 0
236 };
237 let subtype_specificity = if media_type.subty() != names::_STAR {
238 1
239 } else {
240 0
241 };
242
243 let parameter_count = media_type
244 .params()
245 .filter(|&(name, _)| name != names::Q)
246 .count();
247
248 type_specificity + subtype_specificity + parameter_count
249 }
250}
251
252impl Header for Accept {
254 fn name() -> &'static HeaderName {
255 &http::header::ACCEPT
256 }
257
258 fn decode<'i, I>(values: &mut I) -> Result<Self, HeaderError>
259 where
260 I: Iterator<Item = &'i HeaderValue>,
261 {
262 let value = values.next().ok_or_else(HeaderError::invalid)?;
263 let value_str = value.to_str().map_err(|_| HeaderError::invalid())?;
264 Self::parse(value_str)
265 }
266
267 fn encode<E>(&self, values: &mut E)
268 where
269 E: Extend<HeaderValue>,
270 {
271 let value = HeaderValue::from_str(&self.to_string())
272 .expect("Header value should only contain visible ASCII characters (32-127)");
273 values.extend(std::iter::once(value));
274 }
275}
276
277impl FromStr for Accept {
278 type Err = HeaderError;
279
280 fn from_str(s: &str) -> Result<Self, Self::Err> {
281 Self::parse(s).map_err(|_| HeaderError::invalid())
282 }
283}
284
285impl TryFrom<&HeaderValue> for Accept {
286 type Error = HeaderError;
287
288 fn try_from(value: &HeaderValue) -> Result<Self, Self::Error> {
289 let s = value.to_str().map_err(|_| HeaderError::invalid())?;
290 s.parse().map_err(|_| HeaderError::invalid())
291 }
292}
293
294impl Display for Accept {
295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296 let media_types = self
297 .0
298 .iter()
299 .map(|mt| mt.to_string())
300 .collect::<Vec<_>>()
301 .join(", ");
302 write!(f, "{media_types}")
303 }
304}
305
306impl<'a> FromIterator<MediaType<'a>> for Accept {
307 fn from_iter<T: IntoIterator<Item = MediaType<'a>>>(iter: T) -> Self {
308 iter.into_iter().map(MediaTypeBuf::from).collect()
309 }
310}
311
312impl FromIterator<MediaTypeBuf> for Accept {
313 fn from_iter<T: IntoIterator<Item = MediaTypeBuf>>(iter: T) -> Self {
314 Self(iter.into_iter().collect())
315 }
316}
317
318const fn is_ows(c: char) -> bool {
322 c == ' ' || c == '\t'
323}
324
325struct MediaRange<'a>(&'a MediaTypeBuf);
326
327impl PartialEq<MediaType<'_>> for MediaRange<'_> {
328 fn eq(&self, other: &MediaType<'_>) -> bool {
329 let (type_match, subtype_match, suffix_match) = (
330 self.0.ty() == other.ty,
331 self.0.subty() == other.subty,
332 self.0.suffix() == other.suffix,
333 );
334
335 let wildcard_type = self.0.ty() == names::_STAR;
336 let wildcard_subtype = self.0.subty() == names::_STAR && type_match;
337
338 let exact_match =
339 type_match && subtype_match && suffix_match && self.0.params().count() == 0;
340
341 let params_match = type_match && subtype_match && suffix_match && {
342 let self_params = self
343 .0
344 .params()
345 .filter(|&(name, _)| name != names::Q)
346 .collect::<BTreeMap<_, _>>();
347
348 let other_params = other
349 .params()
350 .filter(|&(name, _)| name != names::Q)
351 .collect::<BTreeMap<_, _>>();
352
353 self_params == other_params
354 };
355
356 wildcard_type || wildcard_subtype || exact_match || params_match
357 }
358}
359
360#[derive(Debug, Clone, Copy, PartialEq, Eq)]
361struct QValue(
362 u16,
364);
365
366impl Default for QValue {
367 fn default() -> Self {
368 QValue(1000)
369 }
370}
371
372impl QValue {
373 pub fn is_zero(&self) -> bool {
375 self.0 == 0
376 }
377}
378
379impl FromStr for QValue {
380 type Err = HeaderError;
381
382 fn from_str(s: &str) -> Result<Self, Self::Err> {
383 fn parse_fractional(digits: &[u8]) -> Result<u16, HeaderError> {
386 digits
387 .iter()
388 .try_fold(0u16, |acc, &c| {
389 if c.is_ascii_digit() {
390 Some(acc * 10 + (c - b'0') as u16)
391 } else {
392 None
393 }
394 })
395 .map(|num| match digits.len() {
396 1 => num * 100,
397 2 => num * 10,
398 _ => num,
399 })
400 .ok_or_else(HeaderError::invalid)
401 }
402
403 match s.as_bytes() {
404 b"0" => Ok(QValue(0)),
405 b"1" => Ok(QValue(1000)),
406 [b'1', b'.', zeros @ ..] if zeros.len() <= 3 && zeros.iter().all(|d| *d == b'0') => {
407 Ok(QValue(1000))
408 }
409 [b'0', b'.', fractional @ ..] if fractional.len() <= 3 => {
410 parse_fractional(fractional).map(QValue)
411 }
412 _ => Err(HeaderError::invalid()),
413 }
414 }
415}
416
417impl Ord for QValue {
418 fn cmp(&self, other: &Self) -> Ordering {
419 self.0.cmp(&other.0)
420 }
421}
422
423impl PartialOrd for QValue {
424 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
425 Some(self.cmp(other))
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn reordering() {
435 let accept = Accept::from_str("audio/*; q=0.2, audio/basic").unwrap();
436 let mut media_types = accept.media_types();
437 assert_eq!(
438 media_types.next(),
439 Some(&MediaTypeBuf::from_str("audio/basic").unwrap())
440 );
441 assert_eq!(
442 media_types.next(),
443 Some(&MediaTypeBuf::from_str("audio/*; q=0.2").unwrap())
444 );
445 assert_eq!(media_types.next(), None);
446 }
447
448 #[test]
449 fn reordering_elaborate() {
450 let accept =
451 Accept::from_str("text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c").unwrap();
452 let mut media_types = accept.media_types();
453 assert_eq!(
454 media_types.next(),
455 Some(&MediaTypeBuf::from_str("text/html").unwrap())
456 );
457 assert_eq!(
458 media_types.next(),
459 Some(&MediaTypeBuf::from_str("text/x-c").unwrap())
460 );
461 assert_eq!(
462 media_types.next(),
463 Some(&MediaTypeBuf::from_str("text/x-dvi; q=0.8").unwrap())
464 );
465 assert_eq!(
466 media_types.next(),
467 Some(&MediaTypeBuf::from_str("text/plain; q=0.5").unwrap())
468 );
469 assert_eq!(media_types.next(), None);
470 }
471
472 #[test]
473 fn preserve_ordering() {
474 let accept = Accept::from_str("x/y, a/b").unwrap();
475 let mut media_types = accept.media_types();
476 assert_eq!(
477 media_types.next(),
478 Some(&MediaTypeBuf::from_str("x/y").unwrap())
479 );
480 assert_eq!(
481 media_types.next(),
482 Some(&MediaTypeBuf::from_str("a/b").unwrap())
483 );
484 assert_eq!(media_types.next(), None);
485 }
486
487 #[test]
488 fn params() {
489 let accept =
490 Accept::from_str("text/html, application/xhtml+xml, application/xml;q=0.9, */*;q=0.8")
491 .unwrap();
492 let mut media_types = accept.media_types();
493 assert_eq!(
494 media_types.next(),
495 Some(&MediaTypeBuf::from_str("text/html").unwrap())
496 );
497 assert_eq!(
498 media_types.next(),
499 Some(&MediaTypeBuf::from_str("application/xhtml+xml").unwrap())
500 );
501 assert_eq!(
502 media_types.next(),
503 Some(&MediaTypeBuf::from_str("application/xml;q=0.9").unwrap())
504 );
505 assert_eq!(
506 media_types.next(),
507 Some(&MediaTypeBuf::from_str("*/*;q=0.8").unwrap())
508 );
509 assert_eq!(media_types.next(), None);
510 }
511
512 #[test]
513 fn quoted_params() {
514 let accept = Accept::from_str(
515 "text/html; message=\"Hello, world!\", application/xhtml+xml; message=\"Hello, \
516 world?\"",
517 )
518 .unwrap();
519 let mut media_types = accept.media_types();
520 assert_eq!(
521 media_types.next(),
522 Some(&MediaTypeBuf::from_str("text/html; message=\"Hello, world!\"").unwrap())
523 );
524 assert_eq!(
525 media_types.next(),
526 Some(
527 &MediaTypeBuf::from_str("application/xhtml+xml; message=\"Hello, world?\"")
528 .unwrap()
529 )
530 );
531 assert_eq!(media_types.next(), None);
532 }
533
534 #[test]
535 fn more_specifics() {
536 let accept = Accept::from_str("text/*, text/plain, text/plain;format=flowed, */*").unwrap();
537 let mut media_types = accept.media_types();
538 assert_eq!(
539 media_types.next(),
540 Some(&MediaTypeBuf::from_str("text/plain;format=flowed").unwrap())
541 );
542 assert_eq!(
543 media_types.next(),
544 Some(&MediaTypeBuf::from_str("text/plain").unwrap())
545 );
546 assert_eq!(
547 media_types.next(),
548 Some(&MediaTypeBuf::from_str("text/*").unwrap())
549 );
550 assert_eq!(
551 media_types.next(),
552 Some(&MediaTypeBuf::from_str("*/*").unwrap())
553 );
554 assert_eq!(media_types.next(), None);
555 }
556
557 #[test]
558 fn variable_quality_more_specifics() {
559 let accept = Accept::from_str(
560 "text/*;q=0.3, text/plain;q=0.7, text/csv;q=0, text/plain;format=flowed, \
561 text/plain;format=fixed;q=0.4, */*;q=0.5",
562 )
563 .unwrap();
564 let mut media_types = accept.media_types();
565 assert_eq!(
566 media_types.next(),
567 Some(&MediaTypeBuf::from_str("text/plain;format=flowed").unwrap())
568 );
569 assert_eq!(
570 media_types.next(),
571 Some(&MediaTypeBuf::from_str("text/plain;format=fixed;q=0.4").unwrap())
572 );
573 assert_eq!(
574 media_types.next(),
575 Some(&MediaTypeBuf::from_str("text/plain;q=0.7").unwrap())
576 );
577 assert_eq!(
578 media_types.next(),
579 Some(&MediaTypeBuf::from_str("text/csv;q=0").unwrap())
580 );
581 assert_eq!(
582 media_types.next(),
583 Some(&MediaTypeBuf::from_str("text/*;q=0.3").unwrap())
584 );
585 assert_eq!(
586 media_types.next(),
587 Some(&MediaTypeBuf::from_str("*/*;q=0.5").unwrap())
588 );
589 assert_eq!(media_types.next(), None);
590 }
591
592 #[test]
593 fn negotiate() {
594 let accept = Accept::from_str(
595 "text/html, application/xhtml+xml, application/xml;q=0.9, text/*;q=0.7, text/csv;q=0",
596 )
597 .unwrap();
598
599 assert_eq!(
601 accept
602 .negotiate(&vec![
603 MediaType::parse("text/html").unwrap(),
604 MediaType::parse("application/json").unwrap()
605 ])
606 .unwrap(),
607 &MediaType::parse("text/html").unwrap()
608 );
609 assert_eq!(
611 accept
612 .negotiate(&vec![
613 MediaType::parse("application/xhtml+xml").unwrap(),
614 MediaType::parse("text/html").unwrap()
615 ])
616 .unwrap(),
617 &MediaType::parse("text/html").unwrap()
618 );
619 assert_eq!(
621 accept
622 .negotiate(&vec![
623 MediaType::parse("text/plain").unwrap(),
624 MediaType::parse("image/gif").unwrap()
625 ])
626 .unwrap(),
627 &MediaType::parse("text/plain").unwrap()
628 );
629 assert_eq!(
631 accept
632 .negotiate(&vec![
633 MediaType::parse("image/gif").unwrap(),
634 MediaType::parse("text/plain").unwrap(),
635 MediaType::parse("text/troff").unwrap(),
636 ])
637 .unwrap(),
638 &MediaType::parse("text/plain").unwrap()
639 );
640 assert_eq!(
642 accept.negotiate(&vec![
643 MediaType::parse("image/gif").unwrap(),
644 MediaType::parse("image/png").unwrap()
645 ]),
646 None
647 );
648 assert_eq!(
650 accept.negotiate(&vec![
651 MediaType::parse("image/gif").unwrap(),
652 MediaType::parse("text/csv").unwrap()
653 ]),
654 None
655 );
656 }
657
658 #[test]
659 fn negotiate_with_full_wildcard() {
660 let accept =
661 Accept::from_str("text/html, text/*;q=0.7, */*;q=0.1, text/csv;q=0.0").unwrap();
662
663 assert_eq!(
665 accept
666 .negotiate(&vec![
667 MediaType::parse("text/html").unwrap(),
668 MediaType::parse("application/json").unwrap()
669 ])
670 .unwrap(),
671 &MediaType::parse("text/html").unwrap()
672 );
673 assert_eq!(
675 accept
676 .negotiate(&vec![
677 MediaType::parse("text/plain").unwrap(),
678 MediaType::parse("image/gif").unwrap()
679 ])
680 .unwrap(),
681 &MediaType::parse("text/plain").unwrap()
682 );
683 assert_eq!(
685 accept
686 .negotiate(&vec![
687 MediaType::parse("text/javascript").unwrap(),
688 MediaType::parse("text/plain").unwrap()
689 ])
690 .unwrap(),
691 &MediaType::parse("text/javascript").unwrap()
692 );
693 assert_eq!(
695 accept
696 .negotiate(&vec![
697 MediaType::parse("image/gif").unwrap(),
698 MediaType::parse("image/png").unwrap()
699 ])
700 .unwrap(),
701 &MediaType::parse("image/gif").unwrap()
702 );
703 assert_eq!(
705 accept
706 .negotiate(&vec![
707 MediaType::parse("text/csv").unwrap(),
708 MediaType::parse("text/javascript").unwrap()
709 ])
710 .unwrap(),
711 &MediaType::parse("text/javascript").unwrap()
712 );
713 }
714
715 #[test]
716 fn negotiate_diabolically() {
717 let accept = Accept::from_str(
718 "text/*;q=0.3, text/csv;q=0.2, text/plain;q=0.7, text/plain;format=rot13;q=0.7, \
719 text/plain;format=flowed, text/plain;format=fixed;q=0.4, */*;q=0.5",
720 )
721 .unwrap();
722
723 assert_eq!(
725 accept
726 .negotiate(&vec![
727 MediaType::parse("text/html").unwrap(),
728 MediaType::parse("text/plain").unwrap()
729 ])
730 .unwrap(),
731 &MediaType::parse("text/plain").unwrap()
732 );
733 assert_eq!(
735 accept
736 .negotiate(&vec![
737 MediaType::parse("text/plain").unwrap(),
738 MediaType::parse("text/plain;format=rot13").unwrap(),
739 ])
740 .unwrap(),
741 &MediaType::parse("text/plain;format=rot13").unwrap()
742 );
743 assert_eq!(
745 accept
746 .negotiate(&vec![
747 MediaType::parse("text/plain").unwrap(),
748 MediaType::parse("text/plain;format=fixed").unwrap()
749 ])
750 .unwrap(),
751 &MediaType::parse("text/plain").unwrap()
752 );
753 assert_eq!(
756 accept
757 .negotiate(&vec![
758 MediaType::parse("text/html").unwrap(),
759 MediaType::parse("image/gif").unwrap()
760 ])
761 .unwrap(),
762 &MediaType::parse("image/gif").unwrap()
763 );
764 }
765
766 #[test]
767 fn try_from_header_value() {
768 let header_value = &HeaderValue::from_static("audio/*; q=0.2, audio/basic");
769 let accept: Accept = header_value.try_into().unwrap();
770
771 let mut media_types = accept.media_types();
772 assert_eq!(
773 media_types.next(),
774 Some(&MediaTypeBuf::from_str("audio/basic").unwrap())
775 );
776 assert_eq!(
777 media_types.next(),
778 Some(&MediaTypeBuf::from_str("audio/*; q=0.2").unwrap())
779 );
780 assert_eq!(media_types.next(), None);
781 }
782
783 #[test]
784 fn mixed_lifetime_from_iter() {
785 #[allow(unused)]
787 fn best<'a>(available: &'a [MediaType<'static>]) -> Option<&'a MediaType<'static>> {
788 let accept = Accept::from_str("*/*").unwrap();
789 accept.negotiate(available.iter())
790 }
791 }
792
793 #[test]
794 fn from_iterator() {
795 let accept = Accept::from_iter([
797 MediaType::parse("text/html").unwrap(),
798 MediaType::parse("image/gif").unwrap(),
799 ]);
800
801 assert_eq!(
802 accept.media_types().collect::<Vec<_>>(),
803 vec![
804 MediaType::parse("text/html").unwrap(),
805 MediaType::parse("image/gif").unwrap(),
806 ]
807 );
808
809 let accept = Accept::from_iter([
811 MediaTypeBuf::from_str("text/html").unwrap(),
812 MediaTypeBuf::from_str("image/gif").unwrap(),
813 ]);
814
815 assert_eq!(
816 accept.media_types().collect::<Vec<_>>(),
817 vec![
818 MediaType::parse("text/html").unwrap(),
819 MediaType::parse("image/gif").unwrap(),
820 ]
821 );
822 }
823
824 #[test]
825 fn test_qvalue_parsing_one() {
826 assert_eq!(QValue(1000), "1".parse().unwrap());
827 assert_eq!(QValue(1000), "1.".parse().unwrap());
828 assert_eq!(QValue(1000), "1.0".parse().unwrap());
829 assert_eq!(QValue(1000), "1.00".parse().unwrap());
830 assert_eq!(QValue(1000), "1.000".parse().unwrap());
831 }
832
833 #[test]
834 fn test_qvalue_parsing_partial() {
835 assert_eq!(QValue(0), "0".parse().unwrap());
836 assert_eq!(QValue(0), "0.".parse().unwrap());
837 assert_eq!(QValue(0), "0.0".parse().unwrap());
838 assert_eq!(QValue(0), "0.00".parse().unwrap());
839 assert_eq!(QValue(0), "0.000".parse().unwrap());
840 assert_eq!(QValue(100), "0.1".parse().unwrap());
841 assert_eq!(QValue(120), "0.12".parse().unwrap());
842 assert_eq!(QValue(123), "0.123".parse().unwrap());
843 assert_eq!(QValue(23), "0.023".parse().unwrap());
844 assert_eq!(QValue(3), "0.003".parse().unwrap());
845 }
846
847 #[test]
848 fn qvalue_parsing_invalid() {
849 assert!("0.0000".parse::<QValue>().is_err());
850 assert!("0.1.".parse::<QValue>().is_err());
851 assert!("0.12.".parse::<QValue>().is_err());
852 assert!("0.123.".parse::<QValue>().is_err());
853 assert!("0.1234".parse::<QValue>().is_err());
854 assert!("1.123".parse::<QValue>().is_err());
855 assert!("1.1234".parse::<QValue>().is_err());
856 assert!("1.12345".parse::<QValue>().is_err());
857 assert!("2.0".parse::<QValue>().is_err());
858 assert!("-0.0".parse::<QValue>().is_err());
859 assert!("1.0000".parse::<QValue>().is_err());
860 }
861
862 #[test]
863 fn qvalue_ordering() {
864 assert!(QValue(1000) > QValue(0));
865 assert!(QValue(1000) > QValue(100));
866 assert!(QValue(100) > QValue(0));
867 assert!(QValue(120) > QValue(100));
868 assert!(QValue(123) > QValue(120));
869 assert!(QValue(23) < QValue(100));
870 assert!(QValue(3) < QValue(23));
871 }
872
873 #[test]
874 fn qvalue_default() {
875 let q: QValue = Default::default();
876 assert_eq!(q, QValue(1000));
877 }
878
879 #[test]
880 fn qvalue_is_zero() {
881 assert!("0.".parse::<QValue>().unwrap().is_zero());
882 }
883}