1use std::{cmp, collections::VecDeque, io::Cursor, str::Utf8Error};
2
3use ntex_bytes::{Buf, ByteString, Bytes, BytesMut};
4use ntex_http::{error, header, Method, StatusCode};
5
6use super::{huffman, Header};
7
8#[derive(Debug)]
10pub struct Decoder {
11 max_size_update: Option<usize>,
13 last_max_update: usize,
14 table: Table,
15 buffer: BytesMut,
16}
17
18#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
21pub enum DecoderError {
22 #[error("InvalidRepresentation")]
23 InvalidRepresentation,
24 #[error("InvalidIntegerPrefix")]
25 InvalidIntegerPrefix,
26 #[error("InvalidTableIndex")]
27 InvalidTableIndex,
28 #[error("InvalidHuffmanCode")]
29 InvalidHuffmanCode,
30 #[error("InvalidUtf8")]
31 InvalidUtf8,
32 #[error("InvalidStatusCode")]
33 InvalidStatusCode,
34 #[error("InvalidPseudoheader")]
35 InvalidPseudoheader,
36 #[error("InvalidMaxDynamicSize")]
37 InvalidMaxDynamicSize,
38 #[error("IntegerOverflow")]
39 IntegerOverflow,
40 #[error("{0}")]
41 NeedMore(NeedMore),
42}
43
44#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
45pub enum NeedMore {
46 #[error("Unexpected end of stream")]
47 UnexpectedEndOfStream,
48 #[error("Integer underflow")]
49 IntegerUnderflow,
50 #[error("String underflow")]
51 StringUnderflow,
52}
53
54enum Representation {
55 Indexed,
69
70 LiteralWithIndexing,
89
90 LiteralWithoutIndexing,
109
110 LiteralNeverIndexed,
128
129 SizeUpdate,
143}
144
145#[derive(Debug)]
146struct Table {
147 entries: VecDeque<Header>,
148 size: usize,
149 max_size: usize,
150}
151
152struct StringMarker {
153 offset: usize,
154 len: usize,
155 string: Option<Bytes>,
156}
157
158impl Decoder {
161 pub fn new(size: usize) -> Decoder {
163 Decoder {
164 max_size_update: None,
165 last_max_update: size,
166 table: Table::new(size),
167 buffer: BytesMut::with_capacity(4096),
168 }
169 }
170
171 #[allow(dead_code)]
173 pub fn queue_size_update(&mut self, size: usize) {
174 let size = match self.max_size_update {
175 Some(v) => cmp::max(v, size),
176 None => size,
177 };
178
179 self.max_size_update = Some(size);
180 }
181
182 pub fn decode<F>(
184 &mut self,
185 src: &mut Cursor<&mut BytesMut>,
186 mut f: F,
187 ) -> Result<(), DecoderError>
188 where
189 F: FnMut(Header),
190 {
191 use self::Representation::*;
192
193 let mut can_resize = true;
194
195 if let Some(size) = self.max_size_update.take() {
196 self.last_max_update = size;
197 }
198
199 while let Some(ty) = peek_u8(src) {
200 match Representation::load(ty)? {
204 Indexed => {
205 can_resize = false;
207 let entry = self.decode_indexed(src)?;
208 consume(src);
209 f(entry);
210 }
211 LiteralWithIndexing => {
212 can_resize = false;
214 let entry = self.decode_literal(src, true)?;
215
216 self.table.insert(entry.clone());
218 consume(src);
219
220 f(entry);
221 }
222 LiteralWithoutIndexing => {
223 can_resize = false;
225 let entry = self.decode_literal(src, false)?;
226 consume(src);
227 f(entry);
228 }
229 LiteralNeverIndexed => {
230 can_resize = false;
232 let entry = self.decode_literal(src, false)?;
233 consume(src);
234
235 f(entry);
238 }
239 SizeUpdate => {
240 if !can_resize {
242 return Err(DecoderError::InvalidMaxDynamicSize);
243 }
244
245 self.process_size_update(src)?;
247 consume(src);
248 }
249 }
250 }
251
252 Ok(())
253 }
254
255 fn process_size_update(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result<(), DecoderError> {
256 let new_size = decode_int(buf, 5)?;
257
258 if new_size > self.last_max_update {
259 return Err(DecoderError::InvalidMaxDynamicSize);
260 }
261
262 log::debug!(
263 "Decoder changed max table size, from {} to {}",
264 self.table.size(),
265 new_size,
266 );
267
268 self.table.set_max_size(new_size);
269
270 Ok(())
271 }
272
273 fn decode_indexed(&self, buf: &mut Cursor<&mut BytesMut>) -> Result<Header, DecoderError> {
274 let index = decode_int(buf, 7)?;
275 self.table.get(index)
276 }
277
278 fn decode_literal(
279 &mut self,
280 buf: &mut Cursor<&mut BytesMut>,
281 index: bool,
282 ) -> Result<Header, DecoderError> {
283 let prefix = if index { 6 } else { 4 };
284
285 let table_idx = decode_int(buf, prefix)?;
287
288 if table_idx == 0 {
290 let old_pos = buf.position();
291 let name_marker = self.try_decode_string(buf)?;
292 let value_marker = self.try_decode_string(buf)?;
293 buf.set_position(old_pos);
294 let name = name_marker.consume(buf);
296 let value = value_marker.consume(buf);
297 Header::new(name, value)
298 } else {
299 let e = self.table.get(table_idx)?;
300 let value = self.decode_string(buf)?;
301
302 e.name().into_entry(value)
303 }
304 }
305
306 fn try_decode_string(
307 &mut self,
308 buf: &mut Cursor<&mut BytesMut>,
309 ) -> Result<StringMarker, DecoderError> {
310 let old_pos = buf.position();
311 const HUFF_FLAG: u8 = 0b1000_0000;
312
313 let huff = match peek_u8(buf) {
315 Some(hdr) => (hdr & HUFF_FLAG) == HUFF_FLAG,
316 None => return Err(DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream)),
317 };
318
319 let len = decode_int(buf, 7)?;
321
322 if len > buf.remaining() {
323 log::trace!("decode_string underflow {:?} {:?}", len, buf.remaining());
324 return Err(DecoderError::NeedMore(NeedMore::StringUnderflow));
325 }
326
327 let offset = (buf.position() - old_pos) as usize;
328 if huff {
329 let ret = {
330 let raw = &buf.chunk()[..len];
331 huffman::decode(raw, &mut self.buffer).map(|buf| StringMarker {
332 offset,
333 len,
334 string: Some(BytesMut::freeze(buf)),
335 })
336 };
337
338 buf.advance(len);
339 ret
340 } else {
341 buf.advance(len);
342 Ok(StringMarker {
343 offset,
344 len,
345 string: None,
346 })
347 }
348 }
349
350 fn decode_string(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result<Bytes, DecoderError> {
351 let old_pos = buf.position();
352 let marker = self.try_decode_string(buf)?;
353 buf.set_position(old_pos);
354 Ok(marker.consume(buf))
355 }
356}
357
358impl Default for Decoder {
359 fn default() -> Decoder {
360 Decoder::new(4096)
361 }
362}
363
364impl Representation {
367 pub fn load(byte: u8) -> Result<Representation, DecoderError> {
368 const INDEXED: u8 = 0b1000_0000;
369 const LITERAL_WITH_INDEXING: u8 = 0b0100_0000;
370 const LITERAL_WITHOUT_INDEXING: u8 = 0b1111_0000;
371 const LITERAL_NEVER_INDEXED: u8 = 0b0001_0000;
372 const SIZE_UPDATE_MASK: u8 = 0b1110_0000;
373 const SIZE_UPDATE: u8 = 0b0010_0000;
374
375 if byte & INDEXED == INDEXED {
378 Ok(Representation::Indexed)
379 } else if byte & LITERAL_WITH_INDEXING == LITERAL_WITH_INDEXING {
380 Ok(Representation::LiteralWithIndexing)
381 } else if byte & LITERAL_WITHOUT_INDEXING == 0 {
382 Ok(Representation::LiteralWithoutIndexing)
383 } else if byte & LITERAL_WITHOUT_INDEXING == LITERAL_NEVER_INDEXED {
384 Ok(Representation::LiteralNeverIndexed)
385 } else if byte & SIZE_UPDATE_MASK == SIZE_UPDATE {
386 Ok(Representation::SizeUpdate)
387 } else {
388 Err(DecoderError::InvalidRepresentation)
389 }
390 }
391}
392
393fn decode_int<B: Buf>(buf: &mut B, prefix_size: u8) -> Result<usize, DecoderError> {
394 const MAX_BYTES: usize = 5;
398 const VARINT_MASK: u8 = 0b0111_1111;
399 const VARINT_FLAG: u8 = 0b1000_0000;
400
401 if !(1..=8).contains(&prefix_size) {
402 return Err(DecoderError::InvalidIntegerPrefix);
403 }
404
405 if !buf.has_remaining() {
406 return Err(DecoderError::NeedMore(NeedMore::IntegerUnderflow));
407 }
408
409 let mask = if prefix_size == 8 {
410 0xFF
411 } else {
412 (1u8 << prefix_size).wrapping_sub(1)
413 };
414
415 let mut ret = (buf.get_u8() & mask) as usize;
416
417 if ret < mask as usize {
418 return Ok(ret);
420 }
421
422 let mut bytes = 1;
427
428 let mut shift = 0;
431
432 while buf.has_remaining() {
433 let b = buf.get_u8();
434
435 bytes += 1;
436 ret += ((b & VARINT_MASK) as usize) << shift;
437 shift += 7;
438
439 if b & VARINT_FLAG == 0 {
440 return Ok(ret);
441 }
442
443 if bytes == MAX_BYTES {
444 return Err(DecoderError::IntegerOverflow);
446 }
447 }
448
449 Err(DecoderError::NeedMore(NeedMore::IntegerUnderflow))
450}
451
452fn peek_u8<B: Buf>(buf: &B) -> Option<u8> {
453 if buf.has_remaining() {
454 Some(buf.chunk()[0])
455 } else {
456 None
457 }
458}
459
460fn take(buf: &mut Cursor<&mut BytesMut>, n: usize) -> Bytes {
461 let pos = buf.position() as usize;
462 let mut head = buf.get_mut().split_to(pos + n);
463 buf.set_position(0);
464 head.advance(pos);
465 head.freeze()
466}
467
468impl StringMarker {
469 fn consume(self, buf: &mut Cursor<&mut BytesMut>) -> Bytes {
470 buf.advance(self.offset);
471 match self.string {
472 Some(string) => {
473 buf.advance(self.len);
474 string
475 }
476 None => take(buf, self.len),
477 }
478 }
479}
480
481fn consume(buf: &mut Cursor<&mut BytesMut>) {
482 take(buf, 0);
486}
487
488impl Table {
491 fn new(max_size: usize) -> Table {
492 Table {
493 entries: VecDeque::new(),
494 size: 0,
495 max_size,
496 }
497 }
498
499 fn size(&self) -> usize {
500 self.size
501 }
502
503 pub fn get(&self, index: usize) -> Result<Header, DecoderError> {
512 if index == 0 {
513 return Err(DecoderError::InvalidTableIndex);
514 }
515
516 if index <= 61 {
517 return Ok(get_static(index));
518 }
519
520 match self.entries.get(index - 62) {
522 Some(e) => Ok(e.clone()),
523 None => Err(DecoderError::InvalidTableIndex),
524 }
525 }
526
527 fn insert(&mut self, entry: Header) {
528 let len = entry.len();
529
530 self.reserve(len);
531
532 if self.size + len <= self.max_size {
533 self.size += len;
534
535 self.entries.push_front(entry);
537 }
538 }
539
540 fn set_max_size(&mut self, size: usize) {
541 self.max_size = size;
542 self.consolidate();
544 }
545
546 fn reserve(&mut self, size: usize) {
547 while self.size + size > self.max_size {
548 match self.entries.pop_back() {
549 Some(last) => {
550 self.size -= last.len();
551 }
552 None => return,
553 }
554 }
555 }
556
557 fn consolidate(&mut self) {
558 while self.size > self.max_size {
559 {
560 let last = match self.entries.back() {
561 Some(x) => x,
562 None => {
563 panic!("Size of table != 0, but no headers left!");
566 }
567 };
568
569 self.size -= last.len();
570 }
571
572 self.entries.pop_back();
573 }
574 }
575}
576
577impl From<Utf8Error> for DecoderError {
580 fn from(_: Utf8Error) -> DecoderError {
581 DecoderError::InvalidUtf8
583 }
584}
585
586impl From<()> for DecoderError {
587 fn from(_: ()) -> DecoderError {
588 DecoderError::InvalidUtf8
590 }
591}
592
593impl From<header::InvalidHeaderValue> for DecoderError {
594 fn from(_: header::InvalidHeaderValue) -> DecoderError {
595 DecoderError::InvalidUtf8
597 }
598}
599
600impl From<header::InvalidHeaderName> for DecoderError {
601 fn from(_: header::InvalidHeaderName) -> DecoderError {
602 DecoderError::InvalidUtf8
604 }
605}
606
607impl From<error::InvalidMethod> for DecoderError {
608 fn from(_: error::InvalidMethod) -> DecoderError {
609 DecoderError::InvalidUtf8
611 }
612}
613
614impl From<error::InvalidStatusCode> for DecoderError {
615 fn from(_: error::InvalidStatusCode) -> DecoderError {
616 DecoderError::InvalidUtf8
618 }
619}
620
621pub fn get_static(idx: usize) -> Header {
623 use ntex_http::header::HeaderValue;
624
625 match idx {
626 1 => Header::Authority(ByteString::from_static("")),
627 2 => Header::Method(Method::GET),
628 3 => Header::Method(Method::POST),
629 4 => Header::Path(ByteString::from_static("/")),
630 5 => Header::Path(ByteString::from_static("/index.html")),
631 6 => Header::Scheme(ByteString::from_static("http")),
632 7 => Header::Scheme(ByteString::from_static("https")),
633 8 => Header::Status(StatusCode::OK),
634 9 => Header::Status(StatusCode::NO_CONTENT),
635 10 => Header::Status(StatusCode::PARTIAL_CONTENT),
636 11 => Header::Status(StatusCode::NOT_MODIFIED),
637 12 => Header::Status(StatusCode::BAD_REQUEST),
638 13 => Header::Status(StatusCode::NOT_FOUND),
639 14 => Header::Status(StatusCode::INTERNAL_SERVER_ERROR),
640 15 => Header::Field {
641 name: header::ACCEPT_CHARSET,
642 value: HeaderValue::from_static(""),
643 },
644 16 => Header::Field {
645 name: header::ACCEPT_ENCODING,
646 value: HeaderValue::from_static("gzip, deflate"),
647 },
648 17 => Header::Field {
649 name: header::ACCEPT_LANGUAGE,
650 value: HeaderValue::from_static(""),
651 },
652 18 => Header::Field {
653 name: header::ACCEPT_RANGES,
654 value: HeaderValue::from_static(""),
655 },
656 19 => Header::Field {
657 name: header::ACCEPT,
658 value: HeaderValue::from_static(""),
659 },
660 20 => Header::Field {
661 name: header::ACCESS_CONTROL_ALLOW_ORIGIN,
662 value: HeaderValue::from_static(""),
663 },
664 21 => Header::Field {
665 name: header::AGE,
666 value: HeaderValue::from_static(""),
667 },
668 22 => Header::Field {
669 name: header::ALLOW,
670 value: HeaderValue::from_static(""),
671 },
672 23 => Header::Field {
673 name: header::AUTHORIZATION,
674 value: HeaderValue::from_static(""),
675 },
676 24 => Header::Field {
677 name: header::CACHE_CONTROL,
678 value: HeaderValue::from_static(""),
679 },
680 25 => Header::Field {
681 name: header::CONTENT_DISPOSITION,
682 value: HeaderValue::from_static(""),
683 },
684 26 => Header::Field {
685 name: header::CONTENT_ENCODING,
686 value: HeaderValue::from_static(""),
687 },
688 27 => Header::Field {
689 name: header::CONTENT_LANGUAGE,
690 value: HeaderValue::from_static(""),
691 },
692 28 => Header::Field {
693 name: header::CONTENT_LENGTH,
694 value: HeaderValue::from_static(""),
695 },
696 29 => Header::Field {
697 name: header::CONTENT_LOCATION,
698 value: HeaderValue::from_static(""),
699 },
700 30 => Header::Field {
701 name: header::CONTENT_RANGE,
702 value: HeaderValue::from_static(""),
703 },
704 31 => Header::Field {
705 name: header::CONTENT_TYPE,
706 value: HeaderValue::from_static(""),
707 },
708 32 => Header::Field {
709 name: header::COOKIE,
710 value: HeaderValue::from_static(""),
711 },
712 33 => Header::Field {
713 name: header::DATE,
714 value: HeaderValue::from_static(""),
715 },
716 34 => Header::Field {
717 name: header::ETAG,
718 value: HeaderValue::from_static(""),
719 },
720 35 => Header::Field {
721 name: header::EXPECT,
722 value: HeaderValue::from_static(""),
723 },
724 36 => Header::Field {
725 name: header::EXPIRES,
726 value: HeaderValue::from_static(""),
727 },
728 37 => Header::Field {
729 name: header::FROM,
730 value: HeaderValue::from_static(""),
731 },
732 38 => Header::Field {
733 name: header::HOST,
734 value: HeaderValue::from_static(""),
735 },
736 39 => Header::Field {
737 name: header::IF_MATCH,
738 value: HeaderValue::from_static(""),
739 },
740 40 => Header::Field {
741 name: header::IF_MODIFIED_SINCE,
742 value: HeaderValue::from_static(""),
743 },
744 41 => Header::Field {
745 name: header::IF_NONE_MATCH,
746 value: HeaderValue::from_static(""),
747 },
748 42 => Header::Field {
749 name: header::IF_RANGE,
750 value: HeaderValue::from_static(""),
751 },
752 43 => Header::Field {
753 name: header::IF_UNMODIFIED_SINCE,
754 value: HeaderValue::from_static(""),
755 },
756 44 => Header::Field {
757 name: header::LAST_MODIFIED,
758 value: HeaderValue::from_static(""),
759 },
760 45 => Header::Field {
761 name: header::LINK,
762 value: HeaderValue::from_static(""),
763 },
764 46 => Header::Field {
765 name: header::LOCATION,
766 value: HeaderValue::from_static(""),
767 },
768 47 => Header::Field {
769 name: header::MAX_FORWARDS,
770 value: HeaderValue::from_static(""),
771 },
772 48 => Header::Field {
773 name: header::PROXY_AUTHENTICATE,
774 value: HeaderValue::from_static(""),
775 },
776 49 => Header::Field {
777 name: header::PROXY_AUTHORIZATION,
778 value: HeaderValue::from_static(""),
779 },
780 50 => Header::Field {
781 name: header::RANGE,
782 value: HeaderValue::from_static(""),
783 },
784 51 => Header::Field {
785 name: header::REFERER,
786 value: HeaderValue::from_static(""),
787 },
788 52 => Header::Field {
789 name: header::REFRESH,
790 value: HeaderValue::from_static(""),
791 },
792 53 => Header::Field {
793 name: header::RETRY_AFTER,
794 value: HeaderValue::from_static(""),
795 },
796 54 => Header::Field {
797 name: header::SERVER,
798 value: HeaderValue::from_static(""),
799 },
800 55 => Header::Field {
801 name: header::SET_COOKIE,
802 value: HeaderValue::from_static(""),
803 },
804 56 => Header::Field {
805 name: header::STRICT_TRANSPORT_SECURITY,
806 value: HeaderValue::from_static(""),
807 },
808 57 => Header::Field {
809 name: header::TRANSFER_ENCODING,
810 value: HeaderValue::from_static(""),
811 },
812 58 => Header::Field {
813 name: header::USER_AGENT,
814 value: HeaderValue::from_static(""),
815 },
816 59 => Header::Field {
817 name: header::VARY,
818 value: HeaderValue::from_static(""),
819 },
820 60 => Header::Field {
821 name: header::VIA,
822 value: HeaderValue::from_static(""),
823 },
824 61 => Header::Field {
825 name: header::WWW_AUTHENTICATE,
826 value: HeaderValue::from_static(""),
827 },
828 _ => unreachable!(),
829 }
830}
831
832#[cfg(test)]
833mod test {
834 use super::*;
835
836 #[test]
837 fn test_peek_u8() {
838 let b = 0xff;
839 let mut buf = Cursor::new(vec![b]);
840 assert_eq!(peek_u8(&buf), Some(b));
841 assert_eq!(buf.get_u8(), b);
842 assert_eq!(peek_u8(&buf), None);
843 }
844
845 #[test]
846 fn test_decode_string_empty() {
847 let mut de = Decoder::new(0);
848 let mut buf = BytesMut::new();
849 let err = de.decode_string(&mut Cursor::new(&mut buf)).unwrap_err();
850 assert_eq!(err, DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream));
851 }
852
853 #[test]
854 #[allow(clippy::unit_cmp, clippy::let_unit_value)]
855 fn test_decode_empty() {
856 let mut de = Decoder::new(0);
857 let mut buf = BytesMut::new();
858 let empty = de.decode(&mut Cursor::new(&mut buf), |_| {}).unwrap();
859 assert_eq!(empty, ());
860 }
861
862 #[test]
863 fn test_decode_indexed_larger_than_table() {
864 let mut de = Decoder::new(0);
865
866 let mut buf = BytesMut::new();
867 buf.extend(&[0b01000000, 0x80 | 2]);
868 buf.extend(huff_encode(b"foo"));
869 buf.extend(&[0x80 | 3]);
870 buf.extend(huff_encode(b"bar"));
871
872 let mut res = vec![];
873 de.decode(&mut Cursor::new(&mut buf), |h| {
874 res.push(h);
875 })
876 .unwrap();
877
878 assert_eq!(res.len(), 1);
879 assert_eq!(de.table.size(), 0);
880
881 match res[0] {
882 Header::Field {
883 ref name,
884 ref value,
885 } => {
886 assert_eq!(name, "foo");
887 assert_eq!(value, "bar");
888 }
889 _ => panic!(),
890 }
891 }
892
893 fn huff_encode(src: &[u8]) -> BytesMut {
894 let mut buf = BytesMut::new();
895 huffman::encode(src, &mut buf);
896 buf
897 }
898
899 #[test]
900 fn test_decode_continuation_header_with_non_huff_encoded_name() {
901 let mut de = Decoder::new(0);
902 let value = huff_encode(b"bar");
903 let mut buf = BytesMut::new();
904 buf.extend(&[0b01000000, 3]);
906 buf.extend(b"foo");
907 buf.extend(&[0x80 | 3]);
909 buf.extend(&value[0..1]);
910
911 let mut res = vec![];
912 let e = de
913 .decode(&mut Cursor::new(&mut buf), |h| {
914 res.push(h);
915 })
916 .unwrap_err();
917 assert_eq!(e, DecoderError::NeedMore(NeedMore::StringUnderflow));
919
920 buf.extend(&value[1..]);
922 de.decode(&mut Cursor::new(&mut buf), |h| {
923 res.push(h);
924 })
925 .unwrap();
926
927 assert_eq!(res.len(), 1);
928 assert_eq!(de.table.size(), 0);
929
930 match res[0] {
931 Header::Field {
932 ref name,
933 ref value,
934 } => {
935 assert_eq!(name, "foo");
936 assert_eq!(value, "bar");
937 }
938 _ => panic!(),
939 }
940 }
941}