1use std::{cmp, collections::VecDeque, io::Cursor, str::Utf8Error};
2
3use ntex_bytes::{Buf, ByteString, Bytes, BytesMut};
4use ntex_http::{Method, StatusCode, compat, header};
5
6use super::{Header, huffman};
7
8#[derive(Debug)]
10pub struct Decoder {
11 max_size_update: Option<usize>,
13 last_max_update: usize,
14 table: Table,
15 buffer: BytesMut,
16}
17
18#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
21pub enum DecoderError {
22 #[error("InvalidRepresentation")]
23 InvalidRepresentation,
24 #[error("InvalidIntegerPrefix")]
25 InvalidIntegerPrefix,
26 #[error("InvalidTableIndex")]
27 InvalidTableIndex,
28 #[error("InvalidHuffmanCode")]
29 InvalidHuffmanCode,
30 #[error("InvalidUtf8")]
31 InvalidUtf8,
32 #[error("InvalidStatusCode")]
33 InvalidStatusCode,
34 #[error("InvalidPseudoheader")]
35 InvalidPseudoheader,
36 #[error("InvalidMaxDynamicSize")]
37 InvalidMaxDynamicSize,
38 #[error("IntegerOverflow")]
39 IntegerOverflow,
40 #[error("{0}")]
41 NeedMore(NeedMore),
42}
43
44#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
45pub enum NeedMore {
46 #[error("Unexpected end of stream")]
47 UnexpectedEndOfStream,
48 #[error("Integer underflow")]
49 IntegerUnderflow,
50 #[error("String underflow")]
51 StringUnderflow,
52}
53
54enum Representation {
55 Indexed,
69
70 LiteralWithIndexing,
89
90 LiteralWithoutIndexing,
109
110 LiteralNeverIndexed,
128
129 SizeUpdate,
143}
144
145#[derive(Debug)]
146struct Table {
147 entries: VecDeque<Header>,
148 size: usize,
149 max_size: usize,
150}
151
152struct StringMarker {
153 offset: usize,
154 len: usize,
155 string: Option<Bytes>,
156}
157
158impl Decoder {
161 pub fn new(size: usize) -> Decoder {
163 Decoder {
164 max_size_update: None,
165 last_max_update: size,
166 table: Table::new(size),
167 buffer: BytesMut::with_capacity(4096),
168 }
169 }
170
171 #[allow(dead_code)]
173 pub fn queue_size_update(&mut self, size: usize) {
174 let size = match self.max_size_update {
175 Some(v) => cmp::max(v, size),
176 None => size,
177 };
178
179 self.max_size_update = Some(size);
180 }
181
182 pub fn decode<F>(&mut self, src: &mut Cursor<&mut Bytes>, mut f: F) -> Result<(), DecoderError>
184 where
185 F: FnMut(Header),
186 {
187 let mut can_resize = true;
188
189 if let Some(size) = self.max_size_update.take() {
190 self.last_max_update = size;
191 }
192
193 while let Some(ty) = peek_u8(src) {
194 match Representation::load(ty)? {
198 Representation::Indexed => {
199 can_resize = false;
201 let entry = self.decode_indexed(src)?;
202 consume(src);
203 f(entry);
204 }
205 Representation::LiteralWithIndexing => {
206 can_resize = false;
208 let entry = self.decode_literal(src, true)?;
209
210 self.table.insert(entry.clone());
212 consume(src);
213
214 f(entry);
215 }
216 Representation::LiteralWithoutIndexing => {
217 can_resize = false;
219 let entry = self.decode_literal(src, false)?;
220 consume(src);
221 f(entry);
222 }
223 Representation::LiteralNeverIndexed => {
224 can_resize = false;
226 let entry = self.decode_literal(src, false)?;
227 consume(src);
228
229 f(entry);
232 }
233 Representation::SizeUpdate => {
234 if !can_resize {
236 return Err(DecoderError::InvalidMaxDynamicSize);
237 }
238
239 self.process_size_update(src)?;
241 consume(src);
242 }
243 }
244 }
245
246 Ok(())
247 }
248
249 fn process_size_update(&mut self, buf: &mut Cursor<&mut Bytes>) -> Result<(), DecoderError> {
250 let new_size = decode_int(buf, 5)?;
251
252 if new_size > self.last_max_update {
253 return Err(DecoderError::InvalidMaxDynamicSize);
254 }
255
256 log::debug!(
257 "Decoder changed max table size, from {} to {}",
258 self.table.size(),
259 new_size,
260 );
261
262 self.table.set_max_size(new_size);
263
264 Ok(())
265 }
266
267 fn decode_indexed(&self, buf: &mut Cursor<&mut Bytes>) -> Result<Header, DecoderError> {
268 let index = decode_int(buf, 7)?;
269 self.table.get(index)
270 }
271
272 fn decode_literal(
273 &mut self,
274 buf: &mut Cursor<&mut Bytes>,
275 index: bool,
276 ) -> Result<Header, DecoderError> {
277 let prefix = if index { 6 } else { 4 };
278
279 let table_idx = decode_int(buf, prefix)?;
281
282 if table_idx == 0 {
284 let old_pos = buf.position();
285 let name_marker = self.try_decode_string(buf)?;
286 let value_marker = self.try_decode_string(buf)?;
287 buf.set_position(old_pos);
288 let name = name_marker.consume(buf);
290 let value = value_marker.consume(buf);
291 Header::new(&name, value)
292 } else {
293 let e = self.table.get(table_idx)?;
294 let value = self.decode_string(buf)?;
295
296 e.name().into_entry(value)
297 }
298 }
299
300 fn try_decode_string(
301 &mut self,
302 buf: &mut Cursor<&mut Bytes>,
303 ) -> Result<StringMarker, DecoderError> {
304 const HUFF_FLAG: u8 = 0b1000_0000;
305
306 let old_pos = buf.position();
307
308 let huff = match peek_u8(buf) {
310 Some(hdr) => (hdr & HUFF_FLAG) == HUFF_FLAG,
311 None => return Err(DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream)),
312 };
313
314 let len = decode_int(buf, 7)?;
316
317 if len > buf.remaining() {
318 log::trace!("decode_string underflow {:?} {:?}", len, buf.remaining());
319 return Err(DecoderError::NeedMore(NeedMore::StringUnderflow));
320 }
321
322 let offset = (buf.position() - old_pos) as usize;
323 if huff {
324 let ret = {
325 let raw = &buf.chunk()[..len];
326 huffman::decode(raw, &mut self.buffer).map(|buf| StringMarker {
327 offset,
328 len,
329 string: Some(buf),
330 })
331 };
332
333 buf.advance(len);
334 ret
335 } else {
336 buf.advance(len);
337 Ok(StringMarker {
338 offset,
339 len,
340 string: None,
341 })
342 }
343 }
344
345 fn decode_string(&mut self, buf: &mut Cursor<&mut Bytes>) -> Result<Bytes, DecoderError> {
346 let old_pos = buf.position();
347 let marker = self.try_decode_string(buf)?;
348 buf.set_position(old_pos);
349 Ok(marker.consume(buf))
350 }
351}
352
353impl Default for Decoder {
354 fn default() -> Decoder {
355 Decoder::new(4096)
356 }
357}
358
359impl Representation {
362 pub(super) fn load(byte: u8) -> Result<Representation, DecoderError> {
363 const INDEXED: u8 = 0b1000_0000;
364 const LITERAL_WITH_INDEXING: u8 = 0b0100_0000;
365 const LITERAL_WITHOUT_INDEXING: u8 = 0b1111_0000;
366 const LITERAL_NEVER_INDEXED: u8 = 0b0001_0000;
367 const SIZE_UPDATE_MASK: u8 = 0b1110_0000;
368 const SIZE_UPDATE: u8 = 0b0010_0000;
369
370 if byte & INDEXED == INDEXED {
373 Ok(Representation::Indexed)
374 } else if byte & LITERAL_WITH_INDEXING == LITERAL_WITH_INDEXING {
375 Ok(Representation::LiteralWithIndexing)
376 } else if byte & LITERAL_WITHOUT_INDEXING == 0 {
377 Ok(Representation::LiteralWithoutIndexing)
378 } else if byte & LITERAL_WITHOUT_INDEXING == LITERAL_NEVER_INDEXED {
379 Ok(Representation::LiteralNeverIndexed)
380 } else if byte & SIZE_UPDATE_MASK == SIZE_UPDATE {
381 Ok(Representation::SizeUpdate)
382 } else {
383 Err(DecoderError::InvalidRepresentation)
384 }
385 }
386}
387
388fn decode_int<B: Buf>(buf: &mut B, prefix_size: u8) -> Result<usize, DecoderError> {
389 const MAX_BYTES: usize = 5;
393 const VARINT_MASK: u8 = 0b0111_1111;
394 const VARINT_FLAG: u8 = 0b1000_0000;
395
396 if !(1..=8).contains(&prefix_size) {
397 return Err(DecoderError::InvalidIntegerPrefix);
398 }
399
400 if !buf.has_remaining() {
401 return Err(DecoderError::NeedMore(NeedMore::IntegerUnderflow));
402 }
403
404 let mask = if prefix_size == 8 {
405 0xFF
406 } else {
407 (1u8 << prefix_size).wrapping_sub(1)
408 };
409
410 let mut ret = (buf.get_u8() & mask) as usize;
411
412 if ret < mask as usize {
413 return Ok(ret);
415 }
416
417 let mut bytes = 1;
422
423 let mut shift = 0;
426
427 while buf.has_remaining() {
428 let b = buf.get_u8();
429
430 bytes += 1;
431 ret += ((b & VARINT_MASK) as usize) << shift;
432 shift += 7;
433
434 if b & VARINT_FLAG == 0 {
435 return Ok(ret);
436 }
437
438 if bytes == MAX_BYTES {
439 return Err(DecoderError::IntegerOverflow);
441 }
442 }
443
444 Err(DecoderError::NeedMore(NeedMore::IntegerUnderflow))
445}
446
447fn peek_u8<B: Buf>(buf: &B) -> Option<u8> {
448 if buf.has_remaining() {
449 Some(buf.chunk()[0])
450 } else {
451 None
452 }
453}
454
455fn take(buf: &mut Cursor<&mut Bytes>, n: usize) -> Bytes {
456 let pos = buf.position() as usize;
457 let mut head = buf.get_mut().split_to(pos + n);
458 buf.set_position(0);
459 head.advance(pos);
460 head
461}
462
463impl StringMarker {
464 fn consume(self, buf: &mut Cursor<&mut Bytes>) -> Bytes {
465 buf.advance(self.offset);
466 match self.string {
467 Some(string) => {
468 buf.advance(self.len);
469 string
470 }
471 None => take(buf, self.len),
472 }
473 }
474}
475
476fn consume(buf: &mut Cursor<&mut Bytes>) {
477 take(buf, 0);
481}
482
483impl Table {
486 fn new(max_size: usize) -> Table {
487 Table {
488 entries: VecDeque::new(),
489 size: 0,
490 max_size,
491 }
492 }
493
494 fn size(&self) -> usize {
495 self.size
496 }
497
498 pub(super) fn get(&self, index: usize) -> Result<Header, DecoderError> {
507 if index == 0 {
508 return Err(DecoderError::InvalidTableIndex);
509 }
510
511 if index <= 61 {
512 return Ok(get_static(index));
513 }
514
515 match self.entries.get(index - 62) {
517 Some(e) => Ok(e.clone()),
518 None => Err(DecoderError::InvalidTableIndex),
519 }
520 }
521
522 fn insert(&mut self, entry: Header) {
523 let len = entry.len();
524
525 self.reserve(len);
526
527 if self.size + len <= self.max_size {
528 self.size += len;
529
530 self.entries.push_front(entry);
532 }
533 }
534
535 fn set_max_size(&mut self, size: usize) {
536 self.max_size = size;
537 self.consolidate();
539 }
540
541 fn reserve(&mut self, size: usize) {
542 while self.size + size > self.max_size {
543 match self.entries.pop_back() {
544 Some(last) => {
545 self.size -= last.len();
546 }
547 None => return,
548 }
549 }
550 }
551
552 fn consolidate(&mut self) {
553 while self.size > self.max_size {
554 {
555 let Some(last) = self.entries.back() else {
556 panic!("Size of table != 0, but no headers left!");
559 };
560
561 self.size -= last.len();
562 }
563
564 self.entries.pop_back();
565 }
566 }
567}
568
569impl From<Utf8Error> for DecoderError {
572 fn from(_: Utf8Error) -> DecoderError {
573 DecoderError::InvalidUtf8
574 }
575}
576
577impl From<()> for DecoderError {
578 fn from(_e: ()) -> DecoderError {
579 DecoderError::InvalidUtf8
580 }
581}
582
583impl From<header::InvalidHeaderValue> for DecoderError {
584 fn from(_: header::InvalidHeaderValue) -> DecoderError {
585 DecoderError::InvalidUtf8
586 }
587}
588
589impl From<header::InvalidHeaderName> for DecoderError {
590 fn from(_: header::InvalidHeaderName) -> DecoderError {
591 DecoderError::InvalidUtf8
592 }
593}
594
595impl From<compat::InvalidMethod> for DecoderError {
596 fn from(_: compat::InvalidMethod) -> DecoderError {
597 DecoderError::InvalidUtf8
598 }
599}
600
601impl From<compat::InvalidStatusCode> for DecoderError {
602 fn from(_: compat::InvalidStatusCode) -> DecoderError {
603 DecoderError::InvalidUtf8
605 }
606}
607
608pub(super) fn get_static(idx: usize) -> Header {
610 use ntex_http::header::HeaderValue;
611
612 match idx {
613 1 => Header::Authority(ByteString::from_static("")),
614 2 => Header::Method(Method::GET),
615 3 => Header::Method(Method::POST),
616 4 => Header::Path(ByteString::from_static("/")),
617 5 => Header::Path(ByteString::from_static("/index.html")),
618 6 => Header::Scheme(ByteString::from_static("http")),
619 7 => Header::Scheme(ByteString::from_static("https")),
620 8 => Header::Status(StatusCode::OK),
621 9 => Header::Status(StatusCode::NO_CONTENT),
622 10 => Header::Status(StatusCode::PARTIAL_CONTENT),
623 11 => Header::Status(StatusCode::NOT_MODIFIED),
624 12 => Header::Status(StatusCode::BAD_REQUEST),
625 13 => Header::Status(StatusCode::NOT_FOUND),
626 14 => Header::Status(StatusCode::INTERNAL_SERVER_ERROR),
627 15 => Header::Field {
628 name: header::ACCEPT_CHARSET,
629 value: HeaderValue::from_static(""),
630 },
631 16 => Header::Field {
632 name: header::ACCEPT_ENCODING,
633 value: HeaderValue::from_static("gzip, deflate"),
634 },
635 17 => Header::Field {
636 name: header::ACCEPT_LANGUAGE,
637 value: HeaderValue::from_static(""),
638 },
639 18 => Header::Field {
640 name: header::ACCEPT_RANGES,
641 value: HeaderValue::from_static(""),
642 },
643 19 => Header::Field {
644 name: header::ACCEPT,
645 value: HeaderValue::from_static(""),
646 },
647 20 => Header::Field {
648 name: header::ACCESS_CONTROL_ALLOW_ORIGIN,
649 value: HeaderValue::from_static(""),
650 },
651 21 => Header::Field {
652 name: header::AGE,
653 value: HeaderValue::from_static(""),
654 },
655 22 => Header::Field {
656 name: header::ALLOW,
657 value: HeaderValue::from_static(""),
658 },
659 23 => Header::Field {
660 name: header::AUTHORIZATION,
661 value: HeaderValue::from_static(""),
662 },
663 24 => Header::Field {
664 name: header::CACHE_CONTROL,
665 value: HeaderValue::from_static(""),
666 },
667 25 => Header::Field {
668 name: header::CONTENT_DISPOSITION,
669 value: HeaderValue::from_static(""),
670 },
671 26 => Header::Field {
672 name: header::CONTENT_ENCODING,
673 value: HeaderValue::from_static(""),
674 },
675 27 => Header::Field {
676 name: header::CONTENT_LANGUAGE,
677 value: HeaderValue::from_static(""),
678 },
679 28 => Header::Field {
680 name: header::CONTENT_LENGTH,
681 value: HeaderValue::from_static(""),
682 },
683 29 => Header::Field {
684 name: header::CONTENT_LOCATION,
685 value: HeaderValue::from_static(""),
686 },
687 30 => Header::Field {
688 name: header::CONTENT_RANGE,
689 value: HeaderValue::from_static(""),
690 },
691 31 => Header::Field {
692 name: header::CONTENT_TYPE,
693 value: HeaderValue::from_static(""),
694 },
695 32 => Header::Field {
696 name: header::COOKIE,
697 value: HeaderValue::from_static(""),
698 },
699 33 => Header::Field {
700 name: header::DATE,
701 value: HeaderValue::from_static(""),
702 },
703 34 => Header::Field {
704 name: header::ETAG,
705 value: HeaderValue::from_static(""),
706 },
707 35 => Header::Field {
708 name: header::EXPECT,
709 value: HeaderValue::from_static(""),
710 },
711 36 => Header::Field {
712 name: header::EXPIRES,
713 value: HeaderValue::from_static(""),
714 },
715 37 => Header::Field {
716 name: header::FROM,
717 value: HeaderValue::from_static(""),
718 },
719 38 => Header::Field {
720 name: header::HOST,
721 value: HeaderValue::from_static(""),
722 },
723 39 => Header::Field {
724 name: header::IF_MATCH,
725 value: HeaderValue::from_static(""),
726 },
727 40 => Header::Field {
728 name: header::IF_MODIFIED_SINCE,
729 value: HeaderValue::from_static(""),
730 },
731 41 => Header::Field {
732 name: header::IF_NONE_MATCH,
733 value: HeaderValue::from_static(""),
734 },
735 42 => Header::Field {
736 name: header::IF_RANGE,
737 value: HeaderValue::from_static(""),
738 },
739 43 => Header::Field {
740 name: header::IF_UNMODIFIED_SINCE,
741 value: HeaderValue::from_static(""),
742 },
743 44 => Header::Field {
744 name: header::LAST_MODIFIED,
745 value: HeaderValue::from_static(""),
746 },
747 45 => Header::Field {
748 name: header::LINK,
749 value: HeaderValue::from_static(""),
750 },
751 46 => Header::Field {
752 name: header::LOCATION,
753 value: HeaderValue::from_static(""),
754 },
755 47 => Header::Field {
756 name: header::MAX_FORWARDS,
757 value: HeaderValue::from_static(""),
758 },
759 48 => Header::Field {
760 name: header::PROXY_AUTHENTICATE,
761 value: HeaderValue::from_static(""),
762 },
763 49 => Header::Field {
764 name: header::PROXY_AUTHORIZATION,
765 value: HeaderValue::from_static(""),
766 },
767 50 => Header::Field {
768 name: header::RANGE,
769 value: HeaderValue::from_static(""),
770 },
771 51 => Header::Field {
772 name: header::REFERER,
773 value: HeaderValue::from_static(""),
774 },
775 52 => Header::Field {
776 name: header::REFRESH,
777 value: HeaderValue::from_static(""),
778 },
779 53 => Header::Field {
780 name: header::RETRY_AFTER,
781 value: HeaderValue::from_static(""),
782 },
783 54 => Header::Field {
784 name: header::SERVER,
785 value: HeaderValue::from_static(""),
786 },
787 55 => Header::Field {
788 name: header::SET_COOKIE,
789 value: HeaderValue::from_static(""),
790 },
791 56 => Header::Field {
792 name: header::STRICT_TRANSPORT_SECURITY,
793 value: HeaderValue::from_static(""),
794 },
795 57 => Header::Field {
796 name: header::TRANSFER_ENCODING,
797 value: HeaderValue::from_static(""),
798 },
799 58 => Header::Field {
800 name: header::USER_AGENT,
801 value: HeaderValue::from_static(""),
802 },
803 59 => Header::Field {
804 name: header::VARY,
805 value: HeaderValue::from_static(""),
806 },
807 60 => Header::Field {
808 name: header::VIA,
809 value: HeaderValue::from_static(""),
810 },
811 61 => Header::Field {
812 name: header::WWW_AUTHENTICATE,
813 value: HeaderValue::from_static(""),
814 },
815 _ => unreachable!(),
816 }
817}
818
819#[cfg(test)]
820mod test {
821 use super::*;
822
823 #[test]
824 fn test_peek_u8() {
825 let b = 0xff;
826 let mut buf = Cursor::new(vec![b]);
827 assert_eq!(peek_u8(&buf), Some(b));
828 assert_eq!(buf.get_u8(), b);
829 assert_eq!(peek_u8(&buf), None);
830 }
831
832 #[test]
833 fn test_decode_string_empty() {
834 let mut de = Decoder::new(0);
835 let mut buf = Bytes::new();
836 let err = de.decode_string(&mut Cursor::new(&mut buf)).unwrap_err();
837 assert_eq!(err, DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream));
838 }
839
840 #[test]
841 #[allow(clippy::unit_cmp, clippy::let_unit_value)]
842 fn test_decode_empty() {
843 let mut de = Decoder::new(0);
844 let mut buf = Bytes::new();
845 let empty = de.decode(&mut Cursor::new(&mut buf), |_| {}).unwrap();
846 assert_eq!(empty, ());
847 }
848
849 #[test]
850 fn test_decode_indexed_larger_than_table() {
851 let mut de = Decoder::new(0);
852
853 let mut buf = BytesMut::new();
854 buf.extend(&[0b0100_0000, 0x80 | 2]);
855 buf.extend(huff_encode(b"foo"));
856 buf.extend(&[0x80 | 3]);
857 buf.extend(huff_encode(b"bar"));
858 let mut buf = buf.freeze();
859
860 let mut res = vec![];
861 de.decode(&mut Cursor::new(&mut buf), |h| {
862 res.push(h);
863 })
864 .unwrap();
865
866 assert_eq!(res.len(), 1);
867 assert_eq!(de.table.size(), 0);
868
869 match res[0] {
870 Header::Field {
871 ref name,
872 ref value,
873 } => {
874 assert_eq!(name, "foo");
875 assert_eq!(value, "bar");
876 }
877 _ => panic!(),
878 }
879 }
880
881 fn huff_encode(src: &[u8]) -> BytesMut {
882 let mut buf = BytesMut::new();
883 huffman::encode(src, &mut buf);
884 buf
885 }
886
887 #[test]
888 fn test_decode_continuation_header_with_non_huff_encoded_name() {
889 let mut de = Decoder::new(0);
890 let value = huff_encode(b"bar");
891 let mut buf = BytesMut::new();
892 buf.extend(&[0b0100_0000, 3]);
894 buf.extend(b"foo");
895 buf.extend(&[0x80 | 3]);
897 buf.extend(&value[0..1]);
898 let mut buf = buf.freeze();
899
900 let mut res = vec![];
901 let e = de
902 .decode(&mut Cursor::new(&mut buf), |h| {
903 res.push(h);
904 })
905 .unwrap_err();
906 assert_eq!(e, DecoderError::NeedMore(NeedMore::StringUnderflow));
908
909 let mut buf = BytesMut::from(buf);
911 buf.extend(&value[1..]);
912 let mut buf = buf.freeze();
913 de.decode(&mut Cursor::new(&mut buf), |h| {
914 res.push(h);
915 })
916 .unwrap();
917
918 assert_eq!(res.len(), 1);
919 assert_eq!(de.table.size(), 0);
920
921 match res[0] {
922 Header::Field {
923 ref name,
924 ref value,
925 } => {
926 assert_eq!(name, "foo");
927 assert_eq!(value, "bar");
928 }
929 _ => panic!(),
930 }
931 }
932}