1use std::{cmp, collections::VecDeque, io::Cursor, str::Utf8Error};
2
3use ntex_bytes::{Buf, ByteString, Bytes, BytesMut};
4use ntex_http::{Method, StatusCode, compat, header};
5
6use super::{Header, huffman};
7
8#[derive(Debug)]
10pub struct Decoder {
11 max_size_update: Option<usize>,
13 last_max_update: usize,
14 table: Table,
15 buffer: BytesMut,
16}
17
18#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
21pub enum DecoderError {
22 #[error("InvalidRepresentation")]
23 InvalidRepresentation,
24 #[error("InvalidIntegerPrefix")]
25 InvalidIntegerPrefix,
26 #[error("InvalidTableIndex")]
27 InvalidTableIndex,
28 #[error("InvalidHuffmanCode")]
29 InvalidHuffmanCode,
30 #[error("InvalidUtf8")]
31 InvalidUtf8,
32 #[error("InvalidStatusCode")]
33 InvalidStatusCode,
34 #[error("InvalidPseudoheader")]
35 InvalidPseudoheader,
36 #[error("InvalidMaxDynamicSize")]
37 InvalidMaxDynamicSize,
38 #[error("IntegerOverflow")]
39 IntegerOverflow,
40 #[error("{0}")]
41 NeedMore(NeedMore),
42}
43
44#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
45pub enum NeedMore {
46 #[error("Unexpected end of stream")]
47 UnexpectedEndOfStream,
48 #[error("Integer underflow")]
49 IntegerUnderflow,
50 #[error("String underflow")]
51 StringUnderflow,
52}
53
54enum Representation {
55 Indexed,
69
70 LiteralWithIndexing,
89
90 LiteralWithoutIndexing,
109
110 LiteralNeverIndexed,
128
129 SizeUpdate,
143}
144
145#[derive(Debug)]
146struct Table {
147 entries: VecDeque<Header>,
148 size: usize,
149 max_size: usize,
150}
151
152struct StringMarker {
153 offset: usize,
154 len: usize,
155 string: Option<Bytes>,
156}
157
158impl Decoder {
161 pub fn new(size: usize) -> Decoder {
163 Decoder {
164 max_size_update: None,
165 last_max_update: size,
166 table: Table::new(size),
167 buffer: BytesMut::with_capacity(4096),
168 }
169 }
170
171 #[allow(dead_code)]
173 pub fn queue_size_update(&mut self, size: usize) {
174 let size = match self.max_size_update {
175 Some(v) => cmp::max(v, size),
176 None => size,
177 };
178
179 self.max_size_update = Some(size);
180 }
181
182 pub fn decode<F>(&mut self, src: &mut Cursor<&mut Bytes>, mut f: F) -> Result<(), DecoderError>
184 where
185 F: FnMut(Header),
186 {
187 use self::Representation::*;
188
189 let mut can_resize = true;
190
191 if let Some(size) = self.max_size_update.take() {
192 self.last_max_update = size;
193 }
194
195 while let Some(ty) = peek_u8(src) {
196 match Representation::load(ty)? {
200 Indexed => {
201 can_resize = false;
203 let entry = self.decode_indexed(src)?;
204 consume(src);
205 f(entry);
206 }
207 LiteralWithIndexing => {
208 can_resize = false;
210 let entry = self.decode_literal(src, true)?;
211
212 self.table.insert(entry.clone());
214 consume(src);
215
216 f(entry);
217 }
218 LiteralWithoutIndexing => {
219 can_resize = false;
221 let entry = self.decode_literal(src, false)?;
222 consume(src);
223 f(entry);
224 }
225 LiteralNeverIndexed => {
226 can_resize = false;
228 let entry = self.decode_literal(src, false)?;
229 consume(src);
230
231 f(entry);
234 }
235 SizeUpdate => {
236 if !can_resize {
238 return Err(DecoderError::InvalidMaxDynamicSize);
239 }
240
241 self.process_size_update(src)?;
243 consume(src);
244 }
245 }
246 }
247
248 Ok(())
249 }
250
251 fn process_size_update(&mut self, buf: &mut Cursor<&mut Bytes>) -> Result<(), DecoderError> {
252 let new_size = decode_int(buf, 5)?;
253
254 if new_size > self.last_max_update {
255 return Err(DecoderError::InvalidMaxDynamicSize);
256 }
257
258 log::debug!(
259 "Decoder changed max table size, from {} to {}",
260 self.table.size(),
261 new_size,
262 );
263
264 self.table.set_max_size(new_size);
265
266 Ok(())
267 }
268
269 fn decode_indexed(&self, buf: &mut Cursor<&mut Bytes>) -> Result<Header, DecoderError> {
270 let index = decode_int(buf, 7)?;
271 self.table.get(index)
272 }
273
274 fn decode_literal(
275 &mut self,
276 buf: &mut Cursor<&mut Bytes>,
277 index: bool,
278 ) -> Result<Header, DecoderError> {
279 let prefix = if index { 6 } else { 4 };
280
281 let table_idx = decode_int(buf, prefix)?;
283
284 if table_idx == 0 {
286 let old_pos = buf.position();
287 let name_marker = self.try_decode_string(buf)?;
288 let value_marker = self.try_decode_string(buf)?;
289 buf.set_position(old_pos);
290 let name = name_marker.consume(buf);
292 let value = value_marker.consume(buf);
293 Header::new(name, value)
294 } else {
295 let e = self.table.get(table_idx)?;
296 let value = self.decode_string(buf)?;
297
298 e.name().into_entry(value)
299 }
300 }
301
302 fn try_decode_string(
303 &mut self,
304 buf: &mut Cursor<&mut Bytes>,
305 ) -> Result<StringMarker, DecoderError> {
306 let old_pos = buf.position();
307 const HUFF_FLAG: u8 = 0b1000_0000;
308
309 let huff = match peek_u8(buf) {
311 Some(hdr) => (hdr & HUFF_FLAG) == HUFF_FLAG,
312 None => return Err(DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream)),
313 };
314
315 let len = decode_int(buf, 7)?;
317
318 if len > buf.remaining() {
319 log::trace!("decode_string underflow {:?} {:?}", len, buf.remaining());
320 return Err(DecoderError::NeedMore(NeedMore::StringUnderflow));
321 }
322
323 let offset = (buf.position() - old_pos) as usize;
324 if huff {
325 let ret = {
326 let raw = &buf.chunk()[..len];
327 huffman::decode(raw, &mut self.buffer).map(|buf| StringMarker {
328 offset,
329 len,
330 string: Some(buf),
331 })
332 };
333
334 buf.advance(len);
335 ret
336 } else {
337 buf.advance(len);
338 Ok(StringMarker {
339 offset,
340 len,
341 string: None,
342 })
343 }
344 }
345
346 fn decode_string(&mut self, buf: &mut Cursor<&mut Bytes>) -> Result<Bytes, DecoderError> {
347 let old_pos = buf.position();
348 let marker = self.try_decode_string(buf)?;
349 buf.set_position(old_pos);
350 Ok(marker.consume(buf))
351 }
352}
353
354impl Default for Decoder {
355 fn default() -> Decoder {
356 Decoder::new(4096)
357 }
358}
359
360impl Representation {
363 pub fn load(byte: u8) -> Result<Representation, DecoderError> {
364 const INDEXED: u8 = 0b1000_0000;
365 const LITERAL_WITH_INDEXING: u8 = 0b0100_0000;
366 const LITERAL_WITHOUT_INDEXING: u8 = 0b1111_0000;
367 const LITERAL_NEVER_INDEXED: u8 = 0b0001_0000;
368 const SIZE_UPDATE_MASK: u8 = 0b1110_0000;
369 const SIZE_UPDATE: u8 = 0b0010_0000;
370
371 if byte & INDEXED == INDEXED {
374 Ok(Representation::Indexed)
375 } else if byte & LITERAL_WITH_INDEXING == LITERAL_WITH_INDEXING {
376 Ok(Representation::LiteralWithIndexing)
377 } else if byte & LITERAL_WITHOUT_INDEXING == 0 {
378 Ok(Representation::LiteralWithoutIndexing)
379 } else if byte & LITERAL_WITHOUT_INDEXING == LITERAL_NEVER_INDEXED {
380 Ok(Representation::LiteralNeverIndexed)
381 } else if byte & SIZE_UPDATE_MASK == SIZE_UPDATE {
382 Ok(Representation::SizeUpdate)
383 } else {
384 Err(DecoderError::InvalidRepresentation)
385 }
386 }
387}
388
389fn decode_int<B: Buf>(buf: &mut B, prefix_size: u8) -> Result<usize, DecoderError> {
390 const MAX_BYTES: usize = 5;
394 const VARINT_MASK: u8 = 0b0111_1111;
395 const VARINT_FLAG: u8 = 0b1000_0000;
396
397 if !(1..=8).contains(&prefix_size) {
398 return Err(DecoderError::InvalidIntegerPrefix);
399 }
400
401 if !buf.has_remaining() {
402 return Err(DecoderError::NeedMore(NeedMore::IntegerUnderflow));
403 }
404
405 let mask = if prefix_size == 8 {
406 0xFF
407 } else {
408 (1u8 << prefix_size).wrapping_sub(1)
409 };
410
411 let mut ret = (buf.get_u8() & mask) as usize;
412
413 if ret < mask as usize {
414 return Ok(ret);
416 }
417
418 let mut bytes = 1;
423
424 let mut shift = 0;
427
428 while buf.has_remaining() {
429 let b = buf.get_u8();
430
431 bytes += 1;
432 ret += ((b & VARINT_MASK) as usize) << shift;
433 shift += 7;
434
435 if b & VARINT_FLAG == 0 {
436 return Ok(ret);
437 }
438
439 if bytes == MAX_BYTES {
440 return Err(DecoderError::IntegerOverflow);
442 }
443 }
444
445 Err(DecoderError::NeedMore(NeedMore::IntegerUnderflow))
446}
447
448fn peek_u8<B: Buf>(buf: &B) -> Option<u8> {
449 if buf.has_remaining() {
450 Some(buf.chunk()[0])
451 } else {
452 None
453 }
454}
455
456fn take(buf: &mut Cursor<&mut Bytes>, n: usize) -> Bytes {
457 let pos = buf.position() as usize;
458 let mut head = buf.get_mut().split_to(pos + n);
459 buf.set_position(0);
460 head.advance(pos);
461 head
462}
463
464impl StringMarker {
465 fn consume(self, buf: &mut Cursor<&mut Bytes>) -> Bytes {
466 buf.advance(self.offset);
467 match self.string {
468 Some(string) => {
469 buf.advance(self.len);
470 string
471 }
472 None => take(buf, self.len),
473 }
474 }
475}
476
477fn consume(buf: &mut Cursor<&mut Bytes>) {
478 take(buf, 0);
482}
483
484impl Table {
487 fn new(max_size: usize) -> Table {
488 Table {
489 entries: VecDeque::new(),
490 size: 0,
491 max_size,
492 }
493 }
494
495 fn size(&self) -> usize {
496 self.size
497 }
498
499 pub fn get(&self, index: usize) -> Result<Header, DecoderError> {
508 if index == 0 {
509 return Err(DecoderError::InvalidTableIndex);
510 }
511
512 if index <= 61 {
513 return Ok(get_static(index));
514 }
515
516 match self.entries.get(index - 62) {
518 Some(e) => Ok(e.clone()),
519 None => Err(DecoderError::InvalidTableIndex),
520 }
521 }
522
523 fn insert(&mut self, entry: Header) {
524 let len = entry.len();
525
526 self.reserve(len);
527
528 if self.size + len <= self.max_size {
529 self.size += len;
530
531 self.entries.push_front(entry);
533 }
534 }
535
536 fn set_max_size(&mut self, size: usize) {
537 self.max_size = size;
538 self.consolidate();
540 }
541
542 fn reserve(&mut self, size: usize) {
543 while self.size + size > self.max_size {
544 match self.entries.pop_back() {
545 Some(last) => {
546 self.size -= last.len();
547 }
548 None => return,
549 }
550 }
551 }
552
553 fn consolidate(&mut self) {
554 while self.size > self.max_size {
555 {
556 let last = match self.entries.back() {
557 Some(x) => x,
558 None => {
559 panic!("Size of table != 0, but no headers left!");
562 }
563 };
564
565 self.size -= last.len();
566 }
567
568 self.entries.pop_back();
569 }
570 }
571}
572
573impl From<Utf8Error> for DecoderError {
576 fn from(_: Utf8Error) -> DecoderError {
577 DecoderError::InvalidUtf8
578 }
579}
580
581impl From<()> for DecoderError {
582 fn from(_: ()) -> DecoderError {
583 DecoderError::InvalidUtf8
584 }
585}
586
587impl From<header::InvalidHeaderValue> for DecoderError {
588 fn from(_: header::InvalidHeaderValue) -> DecoderError {
589 DecoderError::InvalidUtf8
590 }
591}
592
593impl From<header::InvalidHeaderName> for DecoderError {
594 fn from(_: header::InvalidHeaderName) -> DecoderError {
595 DecoderError::InvalidUtf8
596 }
597}
598
599impl From<compat::InvalidMethod> for DecoderError {
600 fn from(_: compat::InvalidMethod) -> DecoderError {
601 DecoderError::InvalidUtf8
602 }
603}
604
605impl From<compat::InvalidStatusCode> for DecoderError {
606 fn from(_: compat::InvalidStatusCode) -> DecoderError {
607 DecoderError::InvalidUtf8
609 }
610}
611
612pub fn get_static(idx: usize) -> Header {
614 use ntex_http::header::HeaderValue;
615
616 match idx {
617 1 => Header::Authority(ByteString::from_static("")),
618 2 => Header::Method(Method::GET),
619 3 => Header::Method(Method::POST),
620 4 => Header::Path(ByteString::from_static("/")),
621 5 => Header::Path(ByteString::from_static("/index.html")),
622 6 => Header::Scheme(ByteString::from_static("http")),
623 7 => Header::Scheme(ByteString::from_static("https")),
624 8 => Header::Status(StatusCode::OK),
625 9 => Header::Status(StatusCode::NO_CONTENT),
626 10 => Header::Status(StatusCode::PARTIAL_CONTENT),
627 11 => Header::Status(StatusCode::NOT_MODIFIED),
628 12 => Header::Status(StatusCode::BAD_REQUEST),
629 13 => Header::Status(StatusCode::NOT_FOUND),
630 14 => Header::Status(StatusCode::INTERNAL_SERVER_ERROR),
631 15 => Header::Field {
632 name: header::ACCEPT_CHARSET,
633 value: HeaderValue::from_static(""),
634 },
635 16 => Header::Field {
636 name: header::ACCEPT_ENCODING,
637 value: HeaderValue::from_static("gzip, deflate"),
638 },
639 17 => Header::Field {
640 name: header::ACCEPT_LANGUAGE,
641 value: HeaderValue::from_static(""),
642 },
643 18 => Header::Field {
644 name: header::ACCEPT_RANGES,
645 value: HeaderValue::from_static(""),
646 },
647 19 => Header::Field {
648 name: header::ACCEPT,
649 value: HeaderValue::from_static(""),
650 },
651 20 => Header::Field {
652 name: header::ACCESS_CONTROL_ALLOW_ORIGIN,
653 value: HeaderValue::from_static(""),
654 },
655 21 => Header::Field {
656 name: header::AGE,
657 value: HeaderValue::from_static(""),
658 },
659 22 => Header::Field {
660 name: header::ALLOW,
661 value: HeaderValue::from_static(""),
662 },
663 23 => Header::Field {
664 name: header::AUTHORIZATION,
665 value: HeaderValue::from_static(""),
666 },
667 24 => Header::Field {
668 name: header::CACHE_CONTROL,
669 value: HeaderValue::from_static(""),
670 },
671 25 => Header::Field {
672 name: header::CONTENT_DISPOSITION,
673 value: HeaderValue::from_static(""),
674 },
675 26 => Header::Field {
676 name: header::CONTENT_ENCODING,
677 value: HeaderValue::from_static(""),
678 },
679 27 => Header::Field {
680 name: header::CONTENT_LANGUAGE,
681 value: HeaderValue::from_static(""),
682 },
683 28 => Header::Field {
684 name: header::CONTENT_LENGTH,
685 value: HeaderValue::from_static(""),
686 },
687 29 => Header::Field {
688 name: header::CONTENT_LOCATION,
689 value: HeaderValue::from_static(""),
690 },
691 30 => Header::Field {
692 name: header::CONTENT_RANGE,
693 value: HeaderValue::from_static(""),
694 },
695 31 => Header::Field {
696 name: header::CONTENT_TYPE,
697 value: HeaderValue::from_static(""),
698 },
699 32 => Header::Field {
700 name: header::COOKIE,
701 value: HeaderValue::from_static(""),
702 },
703 33 => Header::Field {
704 name: header::DATE,
705 value: HeaderValue::from_static(""),
706 },
707 34 => Header::Field {
708 name: header::ETAG,
709 value: HeaderValue::from_static(""),
710 },
711 35 => Header::Field {
712 name: header::EXPECT,
713 value: HeaderValue::from_static(""),
714 },
715 36 => Header::Field {
716 name: header::EXPIRES,
717 value: HeaderValue::from_static(""),
718 },
719 37 => Header::Field {
720 name: header::FROM,
721 value: HeaderValue::from_static(""),
722 },
723 38 => Header::Field {
724 name: header::HOST,
725 value: HeaderValue::from_static(""),
726 },
727 39 => Header::Field {
728 name: header::IF_MATCH,
729 value: HeaderValue::from_static(""),
730 },
731 40 => Header::Field {
732 name: header::IF_MODIFIED_SINCE,
733 value: HeaderValue::from_static(""),
734 },
735 41 => Header::Field {
736 name: header::IF_NONE_MATCH,
737 value: HeaderValue::from_static(""),
738 },
739 42 => Header::Field {
740 name: header::IF_RANGE,
741 value: HeaderValue::from_static(""),
742 },
743 43 => Header::Field {
744 name: header::IF_UNMODIFIED_SINCE,
745 value: HeaderValue::from_static(""),
746 },
747 44 => Header::Field {
748 name: header::LAST_MODIFIED,
749 value: HeaderValue::from_static(""),
750 },
751 45 => Header::Field {
752 name: header::LINK,
753 value: HeaderValue::from_static(""),
754 },
755 46 => Header::Field {
756 name: header::LOCATION,
757 value: HeaderValue::from_static(""),
758 },
759 47 => Header::Field {
760 name: header::MAX_FORWARDS,
761 value: HeaderValue::from_static(""),
762 },
763 48 => Header::Field {
764 name: header::PROXY_AUTHENTICATE,
765 value: HeaderValue::from_static(""),
766 },
767 49 => Header::Field {
768 name: header::PROXY_AUTHORIZATION,
769 value: HeaderValue::from_static(""),
770 },
771 50 => Header::Field {
772 name: header::RANGE,
773 value: HeaderValue::from_static(""),
774 },
775 51 => Header::Field {
776 name: header::REFERER,
777 value: HeaderValue::from_static(""),
778 },
779 52 => Header::Field {
780 name: header::REFRESH,
781 value: HeaderValue::from_static(""),
782 },
783 53 => Header::Field {
784 name: header::RETRY_AFTER,
785 value: HeaderValue::from_static(""),
786 },
787 54 => Header::Field {
788 name: header::SERVER,
789 value: HeaderValue::from_static(""),
790 },
791 55 => Header::Field {
792 name: header::SET_COOKIE,
793 value: HeaderValue::from_static(""),
794 },
795 56 => Header::Field {
796 name: header::STRICT_TRANSPORT_SECURITY,
797 value: HeaderValue::from_static(""),
798 },
799 57 => Header::Field {
800 name: header::TRANSFER_ENCODING,
801 value: HeaderValue::from_static(""),
802 },
803 58 => Header::Field {
804 name: header::USER_AGENT,
805 value: HeaderValue::from_static(""),
806 },
807 59 => Header::Field {
808 name: header::VARY,
809 value: HeaderValue::from_static(""),
810 },
811 60 => Header::Field {
812 name: header::VIA,
813 value: HeaderValue::from_static(""),
814 },
815 61 => Header::Field {
816 name: header::WWW_AUTHENTICATE,
817 value: HeaderValue::from_static(""),
818 },
819 _ => unreachable!(),
820 }
821}
822
823#[cfg(test)]
824mod test {
825 use super::*;
826
827 #[test]
828 fn test_peek_u8() {
829 let b = 0xff;
830 let mut buf = Cursor::new(vec![b]);
831 assert_eq!(peek_u8(&buf), Some(b));
832 assert_eq!(buf.get_u8(), b);
833 assert_eq!(peek_u8(&buf), None);
834 }
835
836 #[test]
837 fn test_decode_string_empty() {
838 let mut de = Decoder::new(0);
839 let mut buf = Bytes::new();
840 let err = de.decode_string(&mut Cursor::new(&mut buf)).unwrap_err();
841 assert_eq!(err, DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream));
842 }
843
844 #[test]
845 #[allow(clippy::unit_cmp, clippy::let_unit_value)]
846 fn test_decode_empty() {
847 let mut de = Decoder::new(0);
848 let mut buf = Bytes::new();
849 let empty = de.decode(&mut Cursor::new(&mut buf), |_| {}).unwrap();
850 assert_eq!(empty, ());
851 }
852
853 #[test]
854 fn test_decode_indexed_larger_than_table() {
855 let mut de = Decoder::new(0);
856
857 let mut buf = BytesMut::new();
858 buf.extend(&[0b01000000, 0x80 | 2]);
859 buf.extend(huff_encode(b"foo"));
860 buf.extend(&[0x80 | 3]);
861 buf.extend(huff_encode(b"bar"));
862 let mut buf = buf.freeze();
863
864 let mut res = vec![];
865 de.decode(&mut Cursor::new(&mut buf), |h| {
866 res.push(h);
867 })
868 .unwrap();
869
870 assert_eq!(res.len(), 1);
871 assert_eq!(de.table.size(), 0);
872
873 match res[0] {
874 Header::Field {
875 ref name,
876 ref value,
877 } => {
878 assert_eq!(name, "foo");
879 assert_eq!(value, "bar");
880 }
881 _ => panic!(),
882 }
883 }
884
885 fn huff_encode(src: &[u8]) -> BytesMut {
886 let mut buf = BytesMut::new();
887 huffman::encode(src, &mut buf);
888 buf
889 }
890
891 #[test]
892 fn test_decode_continuation_header_with_non_huff_encoded_name() {
893 let mut de = Decoder::new(0);
894 let value = huff_encode(b"bar");
895 let mut buf = BytesMut::new();
896 buf.extend(&[0b01000000, 3]);
898 buf.extend(b"foo");
899 buf.extend(&[0x80 | 3]);
901 buf.extend(&value[0..1]);
902 let mut buf = buf.freeze();
903
904 let mut res = vec![];
905 let e = de
906 .decode(&mut Cursor::new(&mut buf), |h| {
907 res.push(h);
908 })
909 .unwrap_err();
910 assert_eq!(e, DecoderError::NeedMore(NeedMore::StringUnderflow));
912
913 let mut buf = BytesMut::from(buf);
915 buf.extend(&value[1..]);
916 let mut buf = buf.freeze();
917 de.decode(&mut Cursor::new(&mut buf), |h| {
918 res.push(h);
919 })
920 .unwrap();
921
922 assert_eq!(res.len(), 1);
923 assert_eq!(de.table.size(), 0);
924
925 match res[0] {
926 Header::Field {
927 ref name,
928 ref value,
929 } => {
930 assert_eq!(name, "foo");
931 assert_eq!(value, "bar");
932 }
933 _ => panic!(),
934 }
935 }
936}