1use alloc::vec::Vec;
2use core::ops::Range;
3use std::io;
4
5use super::base::Payload;
6use super::codec::Codec;
7use super::message::PlainMessage;
8use crate::enums::{ContentType, ProtocolVersion};
9use crate::error::{Error, InvalidMessage, PeerMisbehaved};
10use crate::msgs::codec;
11use crate::msgs::message::{MessageError, OpaqueMessage};
12use crate::record_layer::{Decrypted, RecordLayer};
13
14#[derive(Default)]
19pub struct MessageDeframer {
20 last_error: Option<Error>,
24
25 buf: Vec<u8>,
29
30 joining_hs: Option<HandshakePayloadMeta>,
32
33 used: usize,
35}
36
37impl MessageDeframer {
38 pub fn pop(
44 &mut self,
45 record_layer: &mut RecordLayer,
46 negotiated_version: Option<ProtocolVersion>,
47 ) -> Result<Option<Deframed>, Error> {
48 if let Some(last_err) = self.last_error.clone() {
49 return Err(last_err);
50 } else if self.used == 0 {
51 return Ok(None);
52 }
53
54 let expected_len = loop {
58 let start = match &self.joining_hs {
59 Some(meta) => {
60 match meta.expected_len {
61 Some(len) if len <= meta.payload.len() => break len,
63 _ if meta.quic => return Ok(None),
65 _ => meta.message.end,
67 }
68 }
69 None => 0,
70 };
71
72 let mut rd = codec::Reader::init(&self.buf[start..self.used]);
76 let m = match OpaqueMessage::read(&mut rd) {
77 Ok(m) => m,
78 Err(msg_err) => {
79 let err_kind = match msg_err {
80 MessageError::TooShortForHeader | MessageError::TooShortForLength => {
81 return Ok(None)
82 }
83 MessageError::InvalidEmptyPayload => InvalidMessage::InvalidEmptyPayload,
84 MessageError::MessageTooLarge => InvalidMessage::MessageTooLarge,
85 MessageError::InvalidContentType => InvalidMessage::InvalidContentType,
86 MessageError::UnknownProtocolVersion => {
87 InvalidMessage::UnknownProtocolVersion
88 }
89 };
90
91 return Err(self.set_err(err_kind));
92 }
93 };
94
95 let end = start + rd.used();
97 let version_is_tls13 = matches!(negotiated_version, Some(ProtocolVersion::TLSv1_3));
98 let allowed_plaintext = match m.typ {
99 ContentType::ChangeCipherSpec => true,
101 ContentType::Alert
108 if version_is_tls13
109 && !record_layer.has_decrypted()
110 && m.payload().len() <= 2 =>
111 {
112 true
113 }
114 _ => false,
116 };
117 if self.joining_hs.is_none() && allowed_plaintext {
118 self.discard(end);
120 return Ok(Some(Deframed {
121 want_close_before_decrypt: false,
122 aligned: true,
123 trial_decryption_finished: false,
124 message: m.into_plain_message(),
125 }));
126 }
127
128 let msg = match record_layer.decrypt_incoming(m) {
130 Ok(Some(decrypted)) => {
131 let Decrypted {
132 want_close_before_decrypt,
133 plaintext,
134 } = decrypted;
135 debug_assert!(!want_close_before_decrypt);
136 plaintext
137 }
138 Ok(None) if self.joining_hs.is_some() => {
141 return Err(self.set_err(
142 PeerMisbehaved::RejectedEarlyDataInterleavedWithHandshakeMessage,
143 ));
144 }
145 Ok(None) => {
146 self.discard(end);
147 continue;
148 }
149 Err(e) => return Err(e),
150 };
151
152 if self.joining_hs.is_some() && msg.typ != ContentType::Handshake {
153 return Err(self.set_err(PeerMisbehaved::MessageInterleavedWithHandshakeMessage));
158 }
159
160 if msg.typ != ContentType::Handshake {
162 let end = start + rd.used();
163 self.discard(end);
164 return Ok(Some(Deframed {
165 want_close_before_decrypt: false,
166 aligned: true,
167 trial_decryption_finished: false,
168 message: msg,
169 }));
170 }
171
172 match self.append_hs(msg.version, &msg.payload.0, end, false)? {
175 HandshakePayloadState::Blocked => return Ok(None),
176 HandshakePayloadState::Complete(len) => break len,
177 HandshakePayloadState::Continue => continue,
178 }
179 };
180
181 let meta = self.joining_hs.as_mut().unwrap(); let message = PlainMessage {
185 typ: ContentType::Handshake,
186 version: meta.version,
187 payload: Payload::new(&self.buf[meta.payload.start..meta.payload.start + expected_len]),
188 };
189
190 if meta.payload.len() > expected_len {
192 meta.payload.start += expected_len;
196 meta.expected_len = payload_size(&self.buf[meta.payload.start..meta.payload.end])?;
197 } else {
198 let end = meta.message.end;
201 self.joining_hs = None;
202 self.discard(end);
203 }
204
205 Ok(Some(Deframed {
206 want_close_before_decrypt: false,
207 aligned: self.joining_hs.is_none(),
208 trial_decryption_finished: true,
209 message,
210 }))
211 }
212
213 fn set_err(&mut self, err: impl Into<Error>) -> Error {
217 let err = err.into();
218 self.last_error = Some(err.clone());
219 err
220 }
221
222 pub(crate) fn push(&mut self, version: ProtocolVersion, payload: &[u8]) -> Result<(), Error> {
224 if self.used > 0 && self.joining_hs.is_none() {
225 return Err(Error::General(
226 "cannot push QUIC messages into unrelated connection".into(),
227 ));
228 } else if let Err(err) = self.prepare_read() {
229 return Err(Error::General(err.into()));
230 }
231
232 let end = self.used + payload.len();
233 self.append_hs(version, payload, end, true)?;
234 self.used = end;
235 Ok(())
236 }
237
238 fn append_hs(
242 &mut self,
243 version: ProtocolVersion,
244 payload: &[u8],
245 end: usize,
246 quic: bool,
247 ) -> Result<HandshakePayloadState, Error> {
248 let meta = match &mut self.joining_hs {
249 Some(meta) => {
250 debug_assert_eq!(meta.quic, quic);
251
252 let dst = &mut self.buf[meta.payload.end..meta.payload.end + payload.len()];
256 dst.copy_from_slice(payload);
257 meta.message.end = end;
258 meta.payload.end += payload.len();
259
260 if meta.expected_len.is_none() {
262 meta.expected_len =
263 payload_size(&self.buf[meta.payload.start..meta.payload.end])?;
264 }
265
266 meta
267 }
268 None => {
269 let expected_len = payload_size(payload)?;
273 let dst = &mut self.buf[..payload.len()];
274 dst.copy_from_slice(payload);
275 self.joining_hs
276 .insert(HandshakePayloadMeta {
277 message: Range { start: 0, end },
278 payload: Range {
279 start: 0,
280 end: payload.len(),
281 },
282 version,
283 expected_len,
284 quic,
285 })
286 }
287 };
288
289 Ok(match meta.expected_len {
290 Some(len) if len <= meta.payload.len() => HandshakePayloadState::Complete(len),
291 _ => match self.used > meta.message.end {
292 true => HandshakePayloadState::Continue,
293 false => HandshakePayloadState::Blocked,
294 },
295 })
296 }
297
298 #[allow(clippy::comparison_chain)]
300 pub fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
301 if let Err(err) = self.prepare_read() {
302 return Err(io::Error::new(io::ErrorKind::InvalidData, err));
303 }
304
305 let new_bytes = rd.read(&mut self.buf[self.used..])?;
310 self.used += new_bytes;
311 Ok(new_bytes)
312 }
313
314 fn prepare_read(&mut self) -> Result<(), &'static str> {
316 let allow_max = match self.joining_hs {
324 Some(_) => MAX_HANDSHAKE_SIZE as usize,
325 None => OpaqueMessage::MAX_WIRE_SIZE,
326 };
327
328 if self.used >= allow_max {
329 return Err("message buffer full");
330 }
331
332 let need_capacity = Ord::min(allow_max, self.used + READ_SIZE);
338 if need_capacity > self.buf.len() {
339 self.buf.resize(need_capacity, 0);
340 } else if self.used == 0 || self.buf.len() > allow_max {
341 self.buf.resize(need_capacity, 0);
342 self.buf.shrink_to(need_capacity);
343 }
344
345 Ok(())
346 }
347
348 pub fn has_pending(&self) -> bool {
352 self.used > 0
353 }
354
355 fn discard(&mut self, taken: usize) {
357 #[allow(clippy::comparison_chain)]
358 if taken < self.used {
359 self.buf
373 .copy_within(taken..self.used, 0);
374 self.used -= taken;
375 } else if taken == self.used {
376 self.used = 0;
377 }
378 }
379}
380
381enum HandshakePayloadState {
382 Blocked,
384 Complete(usize),
386 Continue,
388}
389
390struct HandshakePayloadMeta {
391 message: Range<usize>,
395 payload: Range<usize>,
397 version: ProtocolVersion,
399 expected_len: Option<usize>,
404 quic: bool,
410}
411
412fn payload_size(buf: &[u8]) -> Result<Option<usize>, Error> {
418 if buf.len() < HEADER_SIZE {
419 return Ok(None);
420 }
421
422 let (header, _) = buf.split_at(HEADER_SIZE);
423 match codec::u24::read_bytes(&header[1..]) {
424 Ok(len) if len.0 > MAX_HANDSHAKE_SIZE => Err(Error::InvalidMessage(
425 InvalidMessage::HandshakePayloadTooLarge,
426 )),
427 Ok(len) => Ok(Some(HEADER_SIZE + usize::from(len))),
428 _ => Ok(None),
429 }
430}
431
432#[derive(Debug)]
433pub struct Deframed {
434 pub(crate) want_close_before_decrypt: bool,
435 pub(crate) aligned: bool,
436 pub(crate) trial_decryption_finished: bool,
437 pub message: PlainMessage,
438}
439
440const HEADER_SIZE: usize = 1 + 3;
441
442const MAX_HANDSHAKE_SIZE: u32 = 0xffff;
446
447const READ_SIZE: usize = 4096;
448
449#[cfg(test)]
450mod tests {
451 use super::MessageDeframer;
452 use crate::msgs::message::{Message, OpaqueMessage};
453 use crate::record_layer::RecordLayer;
454 use crate::{ContentType, Error, InvalidMessage};
455
456 use std::io;
457
458 const FIRST_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.1.bin");
459 const SECOND_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.2.bin");
460
461 const EMPTY_APPLICATIONDATA_MESSAGE: &[u8] =
462 include_bytes!("../testdata/deframer-empty-applicationdata.bin");
463
464 const INVALID_EMPTY_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-empty.bin");
465 const INVALID_CONTENTTYPE_MESSAGE: &[u8] =
466 include_bytes!("../testdata/deframer-invalid-contenttype.bin");
467 const INVALID_VERSION_MESSAGE: &[u8] =
468 include_bytes!("../testdata/deframer-invalid-version.bin");
469 const INVALID_LENGTH_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-length.bin");
470
471 fn input_bytes(d: &mut MessageDeframer, bytes: &[u8]) -> io::Result<usize> {
472 let mut rd = io::Cursor::new(bytes);
473 d.read(&mut rd)
474 }
475
476 fn input_bytes_concat(
477 d: &mut MessageDeframer,
478 bytes1: &[u8],
479 bytes2: &[u8],
480 ) -> io::Result<usize> {
481 let mut bytes = vec![0u8; bytes1.len() + bytes2.len()];
482 bytes[..bytes1.len()].clone_from_slice(bytes1);
483 bytes[bytes1.len()..].clone_from_slice(bytes2);
484 let mut rd = io::Cursor::new(&bytes);
485 d.read(&mut rd)
486 }
487
488 struct ErrorRead {
489 error: Option<io::Error>,
490 }
491
492 impl ErrorRead {
493 fn new(error: io::Error) -> Self {
494 Self { error: Some(error) }
495 }
496 }
497
498 impl io::Read for ErrorRead {
499 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
500 for (i, b) in buf.iter_mut().enumerate() {
501 *b = i as u8;
502 }
503
504 let error = self.error.take().unwrap();
505 Err(error)
506 }
507 }
508
509 fn input_error(d: &mut MessageDeframer) {
510 let error = io::Error::from(io::ErrorKind::TimedOut);
511 let mut rd = ErrorRead::new(error);
512 d.read(&mut rd)
513 .expect_err("error not propagated");
514 }
515
516 fn input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8]) {
517 let before = d.used;
518
519 for i in 0..bytes.len() {
520 assert_len(1, input_bytes(d, &bytes[i..i + 1]));
521 assert!(d.has_pending());
522 }
523
524 assert_eq!(before + bytes.len(), d.used);
525 }
526
527 fn assert_len(want: usize, got: io::Result<usize>) {
528 if let Ok(gotval) = got {
529 assert_eq!(gotval, want);
530 } else {
531 panic!("read failed, expected {:?} bytes", want);
532 }
533 }
534
535 fn pop_first(d: &mut MessageDeframer, rl: &mut RecordLayer) {
536 let m = d
537 .pop(rl, None)
538 .unwrap()
539 .unwrap()
540 .message;
541 assert_eq!(m.typ, ContentType::Handshake);
542 Message::try_from(m).unwrap();
543 }
544
545 fn pop_second(d: &mut MessageDeframer, rl: &mut RecordLayer) {
546 let m = d
547 .pop(rl, None)
548 .unwrap()
549 .unwrap()
550 .message;
551 assert_eq!(m.typ, ContentType::Alert);
552 Message::try_from(m).unwrap();
553 }
554
555 #[test]
556 fn check_incremental() {
557 let mut d = MessageDeframer::default();
558 assert!(!d.has_pending());
559 input_whole_incremental(&mut d, FIRST_MESSAGE);
560 assert!(d.has_pending());
561
562 let mut rl = RecordLayer::new();
563 pop_first(&mut d, &mut rl);
564 assert!(!d.has_pending());
565 assert!(d.last_error.is_none());
566 }
567
568 #[test]
569 fn check_incremental_2() {
570 let mut d = MessageDeframer::default();
571 assert!(!d.has_pending());
572 input_whole_incremental(&mut d, FIRST_MESSAGE);
573 assert!(d.has_pending());
574 input_whole_incremental(&mut d, SECOND_MESSAGE);
575 assert!(d.has_pending());
576
577 let mut rl = RecordLayer::new();
578 pop_first(&mut d, &mut rl);
579 assert!(d.has_pending());
580 pop_second(&mut d, &mut rl);
581 assert!(!d.has_pending());
582 assert!(d.last_error.is_none());
583 }
584
585 #[test]
586 fn check_whole() {
587 let mut d = MessageDeframer::default();
588 assert!(!d.has_pending());
589 assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
590 assert!(d.has_pending());
591
592 let mut rl = RecordLayer::new();
593 pop_first(&mut d, &mut rl);
594 assert!(!d.has_pending());
595 assert!(d.last_error.is_none());
596 }
597
598 #[test]
599 fn check_whole_2() {
600 let mut d = MessageDeframer::default();
601 assert!(!d.has_pending());
602 assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
603 assert_len(SECOND_MESSAGE.len(), input_bytes(&mut d, SECOND_MESSAGE));
604
605 let mut rl = RecordLayer::new();
606 pop_first(&mut d, &mut rl);
607 pop_second(&mut d, &mut rl);
608 assert!(!d.has_pending());
609 assert!(d.last_error.is_none());
610 }
611
612 #[test]
613 fn test_two_in_one_read() {
614 let mut d = MessageDeframer::default();
615 assert!(!d.has_pending());
616 assert_len(
617 FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
618 input_bytes_concat(&mut d, FIRST_MESSAGE, SECOND_MESSAGE),
619 );
620
621 let mut rl = RecordLayer::new();
622 pop_first(&mut d, &mut rl);
623 pop_second(&mut d, &mut rl);
624 assert!(!d.has_pending());
625 assert!(d.last_error.is_none());
626 }
627
628 #[test]
629 fn test_two_in_one_read_shortest_first() {
630 let mut d = MessageDeframer::default();
631 assert!(!d.has_pending());
632 assert_len(
633 FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
634 input_bytes_concat(&mut d, SECOND_MESSAGE, FIRST_MESSAGE),
635 );
636
637 let mut rl = RecordLayer::new();
638 pop_second(&mut d, &mut rl);
639 pop_first(&mut d, &mut rl);
640 assert!(!d.has_pending());
641 assert!(d.last_error.is_none());
642 }
643
644 #[test]
645 fn test_incremental_with_nonfatal_read_error() {
646 let mut d = MessageDeframer::default();
647 assert_len(3, input_bytes(&mut d, &FIRST_MESSAGE[..3]));
648 input_error(&mut d);
649 assert_len(
650 FIRST_MESSAGE.len() - 3,
651 input_bytes(&mut d, &FIRST_MESSAGE[3..]),
652 );
653
654 let mut rl = RecordLayer::new();
655 pop_first(&mut d, &mut rl);
656 assert!(!d.has_pending());
657 assert!(d.last_error.is_none());
658 }
659
660 #[test]
661 fn test_invalid_contenttype_errors() {
662 let mut d = MessageDeframer::default();
663 assert_len(
664 INVALID_CONTENTTYPE_MESSAGE.len(),
665 input_bytes(&mut d, INVALID_CONTENTTYPE_MESSAGE),
666 );
667
668 let mut rl = RecordLayer::new();
669 assert_eq!(
670 d.pop(&mut rl, None).unwrap_err(),
671 Error::InvalidMessage(InvalidMessage::InvalidContentType)
672 );
673 }
674
675 #[test]
676 fn test_invalid_version_errors() {
677 let mut d = MessageDeframer::default();
678 assert_len(
679 INVALID_VERSION_MESSAGE.len(),
680 input_bytes(&mut d, INVALID_VERSION_MESSAGE),
681 );
682
683 let mut rl = RecordLayer::new();
684 assert_eq!(
685 d.pop(&mut rl, None).unwrap_err(),
686 Error::InvalidMessage(InvalidMessage::UnknownProtocolVersion)
687 );
688 }
689
690 #[test]
691 fn test_invalid_length_errors() {
692 let mut d = MessageDeframer::default();
693 assert_len(
694 INVALID_LENGTH_MESSAGE.len(),
695 input_bytes(&mut d, INVALID_LENGTH_MESSAGE),
696 );
697
698 let mut rl = RecordLayer::new();
699 assert_eq!(
700 d.pop(&mut rl, None).unwrap_err(),
701 Error::InvalidMessage(InvalidMessage::MessageTooLarge)
702 );
703 }
704
705 #[test]
706 fn test_empty_applicationdata() {
707 let mut d = MessageDeframer::default();
708 assert_len(
709 EMPTY_APPLICATIONDATA_MESSAGE.len(),
710 input_bytes(&mut d, EMPTY_APPLICATIONDATA_MESSAGE),
711 );
712
713 let mut rl = RecordLayer::new();
714 let m = d
715 .pop(&mut rl, None)
716 .unwrap()
717 .unwrap()
718 .message;
719 assert_eq!(m.typ, ContentType::ApplicationData);
720 assert_eq!(m.payload.0.len(), 0);
721 assert!(!d.has_pending());
722 assert!(d.last_error.is_none());
723 }
724
725 #[test]
726 fn test_invalid_empty_errors() {
727 let mut d = MessageDeframer::default();
728 assert_len(
729 INVALID_EMPTY_MESSAGE.len(),
730 input_bytes(&mut d, INVALID_EMPTY_MESSAGE),
731 );
732
733 let mut rl = RecordLayer::new();
734 assert_eq!(
735 d.pop(&mut rl, None).unwrap_err(),
736 Error::InvalidMessage(InvalidMessage::InvalidEmptyPayload)
737 );
738 assert_eq!(
740 d.pop(&mut rl, None).unwrap_err(),
741 Error::InvalidMessage(InvalidMessage::InvalidEmptyPayload)
742 );
743 }
744
745 #[test]
746 fn test_limited_buffer() {
747 const PAYLOAD_LEN: usize = 16_384;
748 let mut message = Vec::with_capacity(16_389);
749 message.push(0x17); message.extend(&[0x03, 0x04]); message.extend((PAYLOAD_LEN as u16).to_be_bytes()); message.extend(&[0; PAYLOAD_LEN]);
753
754 let mut d = MessageDeframer::default();
755 assert_len(4096, input_bytes(&mut d, &message));
756 assert_len(4096, input_bytes(&mut d, &message));
757 assert_len(4096, input_bytes(&mut d, &message));
758 assert_len(4096, input_bytes(&mut d, &message));
759 assert_len(
760 OpaqueMessage::MAX_WIRE_SIZE - 16_384,
761 input_bytes(&mut d, &message),
762 );
763 assert!(input_bytes(&mut d, &message).is_err());
764 }
765}