clia_rustls_mod/msgs/
deframer.rs

1use alloc::vec::Vec;
2use core::ops::Range;
3use core::slice::SliceIndex;
4#[cfg(feature = "std")]
5use std::io;
6
7use super::codec::Codec;
8use crate::enums::{ContentType, ProtocolVersion};
9use crate::error::{Error, InvalidMessage, PeerMisbehaved};
10use crate::msgs::codec;
11#[cfg(feature = "std")]
12use crate::msgs::message::MAX_WIRE_SIZE;
13use crate::msgs::message::{InboundOpaqueMessage, InboundPlainMessage, MessageError};
14use crate::record_layer::{Decrypted, RecordLayer};
15
16/// This deframer works to reconstruct TLS messages from a stream of arbitrary-sized reads.
17///
18/// It buffers incoming data into a `Vec` through `read()`, and returns messages through `pop()`.
19/// QUIC connections will call `push()` to append handshake payload data directly.
20#[derive(Default)]
21pub struct MessageDeframer {
22    /// Set if the peer is not talking TLS, but some other
23    /// protocol.  The caller should abort the connection, because
24    /// the deframer cannot recover.
25    last_error: Option<Error>,
26
27    /// If we're in the middle of joining a handshake payload, this is the metadata.
28    joining_hs: Option<HandshakePayloadMeta>,
29}
30
31impl MessageDeframer {
32    /// Return any decrypted messages that the deframer has been able to parse.
33    ///
34    /// Returns an `Error` if the deframer failed to parse some message contents or if decryption
35    /// failed, `Ok(None)` if no full message is buffered or if trial decryption failed, and
36    /// `Ok(Some(_))` if a valid message was found and decrypted successfully.
37    pub fn pop<'b>(
38        &mut self,
39        record_layer: &mut RecordLayer,
40        negotiated_version: Option<ProtocolVersion>,
41        buffer: &mut DeframerSliceBuffer<'b>,
42    ) -> Result<Option<Deframed<'b>>, Error> {
43        if let Some(last_err) = self.last_error.clone() {
44            return Err(last_err);
45        } else if buffer.is_empty() {
46            return Ok(None);
47        }
48
49        // We loop over records we've received but not processed yet.
50        // For records that decrypt as `Handshake`, we keep the current state of the joined
51        // handshake message payload in `self.joining_hs`, appending to it as we see records.
52        let expected_len = loop {
53            let start = match &self.joining_hs {
54                Some(meta) => {
55                    match meta.expected_len {
56                        // We're joining a handshake payload, and we've seen the full payload.
57                        Some(len) if len <= meta.payload.len() => break len,
58                        // Not enough data, and we can't parse any more out of the buffer (QUIC).
59                        _ if meta.quic => return Ok(None),
60                        // Try parsing some more of the encrypted buffered data.
61                        _ => meta.message.end,
62                    }
63                }
64                None => 0,
65            };
66
67            // Does our `buf` contain a full message?  It does if it is big enough to
68            // contain a header, and that header has a length which falls within `buf`.
69            // If so, deframe it and place the message onto the frames output queue.
70            let mut rd = codec::ReaderMut::init(buffer.filled_get_mut(start..));
71            let m = match InboundOpaqueMessage::read(&mut rd) {
72                Ok(m) => m,
73                Err(msg_err) => {
74                    let err_kind = match msg_err {
75                        MessageError::TooShortForHeader | MessageError::TooShortForLength => {
76                            return Ok(None)
77                        }
78                        MessageError::InvalidEmptyPayload => InvalidMessage::InvalidEmptyPayload,
79                        MessageError::MessageTooLarge => InvalidMessage::MessageTooLarge,
80                        MessageError::InvalidContentType => InvalidMessage::InvalidContentType,
81                        MessageError::UnknownProtocolVersion => {
82                            InvalidMessage::UnknownProtocolVersion
83                        }
84                    };
85
86                    return Err(self.set_err(err_kind));
87                }
88            };
89
90            // Return CCS messages and early plaintext alerts immediately without decrypting.
91            let end = start + rd.used();
92            let version_is_tls13 = matches!(negotiated_version, Some(ProtocolVersion::TLSv1_3));
93            let allowed_plaintext = match m.typ {
94                // CCS messages are always plaintext.
95                ContentType::ChangeCipherSpec => true,
96                // Alerts are allowed to be plaintext if-and-only-if:
97                // * The negotiated protocol version is TLS 1.3. - In TLS 1.2 it is unambiguous when
98                //   keying changes based on the CCS message. Only TLS 1.3 requires these heuristics.
99                // * We have not yet decrypted any messages from the peer - if we have we don't
100                //   expect any plaintext.
101                // * The payload size is indicative of a plaintext alert message.
102                ContentType::Alert
103                    if version_is_tls13
104                        && !record_layer.has_decrypted()
105                        && m.payload.len() <= 2 =>
106                {
107                    true
108                }
109                // In other circumstances, we expect all messages to be encrypted.
110                _ => false,
111            };
112            if self.joining_hs.is_none() && allowed_plaintext {
113                let InboundOpaqueMessage {
114                    typ,
115                    version,
116                    payload,
117                } = m;
118                let raw_payload_slice = RawSlice::from(&*payload);
119                // This is unencrypted. We check the contents later.
120                buffer.queue_discard(end);
121                let message = InboundPlainMessage {
122                    typ,
123                    version,
124                    payload: buffer.take(raw_payload_slice),
125                };
126                return Ok(Some(Deframed {
127                    want_close_before_decrypt: false,
128                    aligned: true,
129                    trial_decryption_finished: false,
130                    message,
131                }));
132            }
133
134            // Decrypt the encrypted message (if necessary).
135            let (typ, version, plain_payload_slice) = match record_layer.decrypt_incoming(m) {
136                Ok(Some(decrypted)) => {
137                    let Decrypted {
138                        want_close_before_decrypt,
139                        plaintext:
140                            InboundPlainMessage {
141                                typ,
142                                version,
143                                payload,
144                            },
145                    } = decrypted;
146                    debug_assert!(!want_close_before_decrypt);
147                    (typ, version, RawSlice::from(payload))
148                }
149                // This was rejected early data, discard it. If we currently have a handshake
150                // payload in progress, this counts as interleaved, so we error out.
151                Ok(None) if self.joining_hs.is_some() => {
152                    return Err(self.set_err(
153                        PeerMisbehaved::RejectedEarlyDataInterleavedWithHandshakeMessage,
154                    ));
155                }
156                Ok(None) => {
157                    buffer.queue_discard(end);
158                    continue;
159                }
160                Err(e) => return Err(e),
161            };
162
163            if self.joining_hs.is_some() && typ != ContentType::Handshake {
164                // "Handshake messages MUST NOT be interleaved with other record
165                // types.  That is, if a handshake message is split over two or more
166                // records, there MUST NOT be any other records between them."
167                // https://www.rfc-editor.org/rfc/rfc8446#section-5.1
168                return Err(self.set_err(PeerMisbehaved::MessageInterleavedWithHandshakeMessage));
169            }
170
171            // If it's not a handshake message, just return it -- no joining necessary.
172            if typ != ContentType::Handshake {
173                buffer.queue_discard(end);
174                let message = InboundPlainMessage {
175                    typ,
176                    version,
177                    payload: buffer.take(plain_payload_slice),
178                };
179                return Ok(Some(Deframed {
180                    want_close_before_decrypt: false,
181                    aligned: true,
182                    trial_decryption_finished: false,
183                    message,
184                }));
185            }
186
187            // If we don't know the payload size yet or if the payload size is larger
188            // than the currently buffered payload, we need to wait for more data.
189            let src = buffer.raw_slice_to_filled_range(plain_payload_slice);
190            match self.append_hs(version, InternalPayload(src), end, buffer)? {
191                HandshakePayloadState::Blocked => return Ok(None),
192                HandshakePayloadState::Complete(len) => break len,
193                HandshakePayloadState::Continue => continue,
194            }
195        };
196
197        let meta = self.joining_hs.as_mut().unwrap(); // safe after calling `append_hs()`
198
199        // We can now wrap the complete handshake payload in a `PlainMessage`, to be returned.
200        let typ = ContentType::Handshake;
201        let version = meta.version;
202        let raw_payload = RawSlice::from(
203            buffer.filled_get(meta.payload.start..meta.payload.start + expected_len),
204        );
205
206        // But before we return, update the `joining_hs` state to skip past this payload.
207        if meta.payload.len() > expected_len {
208            // If we have another (beginning of) a handshake payload left in the buffer, update
209            // the payload start to point past the payload we're about to yield, and update the
210            // `expected_len` to match the state of that remaining payload.
211            meta.payload.start += expected_len;
212            meta.expected_len =
213                payload_size(buffer.filled_get(meta.payload.start..meta.payload.end))?;
214        } else {
215            // Otherwise, we've yielded the last handshake payload in the buffer, so we can
216            // discard all of the bytes that we're previously buffered as handshake data.
217            let end = meta.message.end;
218            self.joining_hs = None;
219            buffer.queue_discard(end);
220        }
221
222        let message = InboundPlainMessage {
223            typ,
224            version,
225            payload: buffer.take(raw_payload),
226        };
227
228        Ok(Some(Deframed {
229            want_close_before_decrypt: false,
230            aligned: self.joining_hs.is_none(),
231            trial_decryption_finished: true,
232            message,
233        }))
234    }
235
236    /// Fuses this deframer's error and returns the set value.
237    ///
238    /// Any future calls to `pop` will return `err` again.
239    fn set_err(&mut self, err: impl Into<Error>) -> Error {
240        let err = err.into();
241        self.last_error = Some(err.clone());
242        err
243    }
244
245    /// Write the handshake message contents into the buffer and update the metadata.
246    ///
247    /// Returns true if a complete message is found.
248    fn append_hs<'a, P: AppendPayload<'a>, B: DeframerBuffer<'a, P>>(
249        &mut self,
250        version: ProtocolVersion,
251        payload: P,
252        end: usize,
253        buffer: &mut B,
254    ) -> Result<HandshakePayloadState, Error> {
255        let meta = match &mut self.joining_hs {
256            Some(meta) => {
257                debug_assert_eq!(meta.quic, P::QUIC);
258
259                // We're joining a handshake message to the previous one here.
260                // Write it into the buffer and update the metadata.
261
262                buffer.copy(&payload, meta.payload.end);
263                meta.message.end = end;
264                meta.payload.end += payload.len();
265
266                // If we haven't parsed the payload size yet, try to do so now.
267                if meta.expected_len.is_none() {
268                    meta.expected_len =
269                        payload_size(buffer.filled_get(meta.payload.start..meta.payload.end))?;
270                }
271
272                meta
273            }
274            None => {
275                // We've found a new handshake message here.
276                // Write it into the buffer and create the metadata.
277
278                let expected_len = payload.size(buffer)?;
279                buffer.copy(&payload, 0);
280                self.joining_hs
281                    .insert(HandshakePayloadMeta {
282                        message: Range { start: 0, end },
283                        payload: Range {
284                            start: 0,
285                            end: payload.len(),
286                        },
287                        version,
288                        expected_len,
289                        quic: P::QUIC,
290                    })
291            }
292        };
293
294        Ok(match meta.expected_len {
295            Some(len) if len <= meta.payload.len() => HandshakePayloadState::Complete(len),
296            _ => match buffer.len() > meta.message.end {
297                true => HandshakePayloadState::Continue,
298                false => HandshakePayloadState::Blocked,
299            },
300        })
301    }
302}
303
304#[cfg(feature = "std")]
305impl MessageDeframer {
306    /// Allow pushing handshake messages directly into the buffer.
307    pub(crate) fn push(
308        &mut self,
309        version: ProtocolVersion,
310        payload: &[u8],
311        buffer: &mut DeframerVecBuffer,
312    ) -> Result<(), Error> {
313        if !buffer.is_empty() && self.joining_hs.is_none() {
314            return Err(Error::General(
315                "cannot push QUIC messages into unrelated connection".into(),
316            ));
317        } else if let Err(err) = buffer.prepare_read(self.joining_hs.is_some()) {
318            return Err(Error::General(err.into()));
319        }
320
321        let end = buffer.len() + payload.len();
322        self.append_hs(version, ExternalPayload(payload), end, buffer)?;
323        Ok(())
324    }
325
326    /// Read some bytes from `rd`, and add them to our internal buffer.
327    #[allow(clippy::comparison_chain)]
328    pub fn read(
329        &mut self,
330        rd: &mut dyn io::Read,
331        buffer: &mut DeframerVecBuffer,
332    ) -> io::Result<usize> {
333        if let Err(err) = buffer.prepare_read(self.joining_hs.is_some()) {
334            return Err(io::Error::new(io::ErrorKind::InvalidData, err));
335        }
336
337        // Try to do the largest reads possible. Note that if
338        // we get a message with a length field out of range here,
339        // we do a zero length read.  That looks like an EOF to
340        // the next layer up, which is fine.
341        let new_bytes = rd.read(buffer.unfilled())?;
342        buffer.advance(new_bytes);
343        Ok(new_bytes)
344    }
345}
346
347trait AppendPayload<'a>: Sized {
348    const QUIC: bool;
349
350    fn len(&self) -> usize;
351
352    fn size<B: DeframerBuffer<'a, Self>>(
353        &self,
354        internal_buffer: &B,
355    ) -> Result<Option<usize>, Error>;
356}
357
358struct ExternalPayload<'a>(&'a [u8]);
359
360impl<'a> AppendPayload<'a> for ExternalPayload<'a> {
361    const QUIC: bool = true;
362
363    fn len(&self) -> usize {
364        self.0.len()
365    }
366
367    fn size<B: DeframerBuffer<'a, Self>>(&self, _: &B) -> Result<Option<usize>, Error> {
368        payload_size(self.0)
369    }
370}
371
372struct InternalPayload(Range<usize>);
373
374impl<'a> AppendPayload<'a> for InternalPayload {
375    const QUIC: bool = false;
376
377    fn len(&self) -> usize {
378        self.0.end - self.0.start
379    }
380
381    fn size<B: DeframerBuffer<'a, Self>>(
382        &self,
383        internal_buffer: &B,
384    ) -> Result<Option<usize>, Error> {
385        payload_size(internal_buffer.filled_get(self.0.clone()))
386    }
387}
388
389#[derive(Default, Debug)]
390pub struct DeframerVecBuffer {
391    /// Buffer of data read from the socket, in the process of being parsed into messages.
392    ///
393    /// For buffer size management, checkout out the [`DeframerVecBuffer::prepare_read()`] method.
394    buf: Vec<u8>,
395
396    /// What size prefix of `buf` is used.
397    used: usize,
398}
399
400impl DeframerVecBuffer {
401    /// Borrows the initialized contents of this buffer and tracks pending discard operations via
402    /// the `discard` reference
403    pub fn borrow(&mut self) -> DeframerSliceBuffer {
404        DeframerSliceBuffer::new(&mut self.buf[..self.used])
405    }
406
407    /// Discard `taken` bytes from the start of our buffer.
408    pub fn discard(&mut self, taken: usize) {
409        #[allow(clippy::comparison_chain)]
410        if taken < self.used {
411            /* Before:
412             * +----------+----------+----------+
413             * | taken    | pending  |xxxxxxxxxx|
414             * +----------+----------+----------+
415             * 0          ^ taken    ^ self.used
416             *
417             * After:
418             * +----------+----------+----------+
419             * | pending  |xxxxxxxxxxxxxxxxxxxxx|
420             * +----------+----------+----------+
421             * 0          ^ self.used
422             */
423
424            self.buf
425                .copy_within(taken..self.used, 0);
426            self.used -= taken;
427        } else if taken == self.used {
428            self.used = 0;
429        }
430    }
431}
432
433#[cfg(feature = "std")]
434impl DeframerVecBuffer {
435    /// Returns true if there are messages for the caller to process
436    pub fn has_pending(&self) -> bool {
437        !self.is_empty()
438    }
439
440    /// Resize the internal `buf` if necessary for reading more bytes.
441    fn prepare_read(&mut self, is_joining_hs: bool) -> Result<(), &'static str> {
442        // We allow a maximum of 64k of buffered data for handshake messages only. Enforce this
443        // by varying the maximum allowed buffer size here based on whether a prefix of a
444        // handshake payload is currently being buffered. Given that the first read of such a
445        // payload will only ever be 4k bytes, the next time we come around here we allow a
446        // larger buffer size. Once the large message and any following handshake messages in
447        // the same flight have been consumed, `pop()` will call `discard()` to reset `used`.
448        // At this point, the buffer resizing logic below should reduce the buffer size.
449        let allow_max = match is_joining_hs {
450            true => MAX_HANDSHAKE_SIZE as usize,
451            false => MAX_WIRE_SIZE,
452        };
453
454        if self.used >= allow_max {
455            return Err("message buffer full");
456        }
457
458        // If we can and need to increase the buffer size to allow a 4k read, do so. After
459        // dealing with a large handshake message (exceeding `OutboundOpaqueMessage::MAX_WIRE_SIZE`),
460        // make sure to reduce the buffer size again (large messages should be rare).
461        // Also, reduce the buffer size if there are neither full nor partial messages in it,
462        // which usually means that the other side suspended sending data.
463        let need_capacity = Ord::min(allow_max, self.used + READ_SIZE);
464        if need_capacity > self.buf.len() {
465            self.buf.resize(need_capacity, 0);
466        } else if self.used == 0 || self.buf.len() > allow_max {
467            self.buf.resize(need_capacity, 0);
468            self.buf.shrink_to(need_capacity);
469        }
470
471        Ok(())
472    }
473
474    fn is_empty(&self) -> bool {
475        self.len() == 0
476    }
477
478    fn advance(&mut self, num_bytes: usize) {
479        self.used += num_bytes;
480    }
481
482    fn unfilled(&mut self) -> &mut [u8] {
483        &mut self.buf[self.used..]
484    }
485}
486
487#[cfg(feature = "std")]
488impl FilledDeframerBuffer for DeframerVecBuffer {
489    fn filled_mut(&mut self) -> &mut [u8] {
490        &mut self.buf[..self.used]
491    }
492
493    fn filled(&self) -> &[u8] {
494        &self.buf[..self.used]
495    }
496}
497
498#[cfg(feature = "std")]
499impl DeframerBuffer<'_, InternalPayload> for DeframerVecBuffer {
500    fn copy(&mut self, payload: &InternalPayload, at: usize) {
501        self.borrow().copy(payload, at)
502    }
503}
504
505#[cfg(feature = "std")]
506impl<'a> DeframerBuffer<'a, ExternalPayload<'a>> for DeframerVecBuffer {
507    fn copy(&mut self, payload: &ExternalPayload<'a>, _at: usize) {
508        let len = payload.len();
509        self.unfilled()[..len].copy_from_slice(payload.0);
510        self.advance(len);
511    }
512}
513
514/// A borrowed version of [`DeframerVecBuffer`] that tracks discard operations
515#[derive(Debug)]
516pub struct DeframerSliceBuffer<'a> {
517    // a fully initialized buffer that will be deframed
518    buf: &'a mut [u8],
519    // number of bytes to discard from the front of `buf` at a later time
520    discard: usize,
521    taken: usize,
522}
523
524impl<'a> DeframerSliceBuffer<'a> {
525    pub fn new(buf: &'a mut [u8]) -> Self {
526        Self {
527            buf,
528            discard: 0,
529            taken: 0,
530        }
531    }
532
533    /// Tracks a pending discard operation of `num_bytes`
534    pub fn queue_discard(&mut self, num_bytes: usize) {
535        self.discard += num_bytes;
536    }
537
538    /// Returns the number of bytes that need to be discarded
539    pub fn pending_discard(&self) -> usize {
540        self.discard
541    }
542
543    pub fn is_empty(&self) -> bool {
544        self.len() == 0
545    }
546
547    /// Remove a `RawSlice` range from the deframer buffer, returning a mutable reference to the
548    /// removed portion.
549    ///
550    /// Safety: the caller *must* ensure that the `RawSlice` refers to a range from the same
551    /// allocation as the deframer's buffer.
552    fn take(&mut self, raw: RawSlice) -> &'a mut [u8] {
553        let start = (raw.ptr as usize)
554            .checked_sub(self.buf.as_ptr() as usize)
555            .unwrap();
556        let end = start + raw.len;
557
558        let (taken, rest) = core::mem::take(&mut self.buf).split_at_mut(end);
559        self.buf = rest;
560        self.taken += end;
561
562        &mut taken[start..]
563    }
564
565    /// Converts a raw slice to a filled range based on the offset and length.
566    ///
567    /// Safety: the caller *must* ensure that the `RawSlice` refers to a range from the same
568    /// allocation as the deframer's buffer.
569    fn raw_slice_to_filled_range(&self, raw: RawSlice) -> Range<usize> {
570        let adjust = self.discard - self.taken;
571        let start = ((raw.ptr as usize).checked_sub(self.buf.as_ptr() as usize)).unwrap() - adjust;
572        let end = start + raw.len;
573        start..end
574    }
575}
576
577impl FilledDeframerBuffer for DeframerSliceBuffer<'_> {
578    fn filled_mut(&mut self) -> &mut [u8] {
579        &mut self.buf[self.discard - self.taken..]
580    }
581
582    fn filled(&self) -> &[u8] {
583        &self.buf[self.discard - self.taken..]
584    }
585}
586
587impl DeframerBuffer<'_, InternalPayload> for DeframerSliceBuffer<'_> {
588    fn copy(&mut self, payload: &InternalPayload, at: usize) {
589        let buf = self.filled_mut();
590        buf.copy_within(payload.0.clone(), at)
591    }
592}
593
594pub(crate) struct RawSlice {
595    ptr: *const u8,
596    len: usize,
597}
598
599impl From<&'_ [u8]> for RawSlice {
600    fn from(value: &'_ [u8]) -> Self {
601        Self {
602            ptr: value.as_ptr(),
603            len: value.len(),
604        }
605    }
606}
607
608trait DeframerBuffer<'a, P: AppendPayload<'a>>: FilledDeframerBuffer {
609    /// Copies from the `src` buffer into this buffer at the requested index
610    ///
611    /// If `QUIC` is true the data will be copied into the *un*filled section of the buffer
612    ///
613    /// If `QUIC` is false the data will be copied into the filled section of the buffer
614    fn copy(&mut self, payload: &P, at: usize);
615}
616
617trait FilledDeframerBuffer {
618    fn filled_get_mut<I: SliceIndex<[u8]>>(&mut self, index: I) -> &mut I::Output {
619        self.filled_mut()
620            .get_mut(index)
621            .unwrap()
622    }
623
624    fn filled_mut(&mut self) -> &mut [u8];
625
626    fn filled_get<I>(&self, index: I) -> &I::Output
627    where
628        I: SliceIndex<[u8]>,
629    {
630        self.filled().get(index).unwrap()
631    }
632
633    fn len(&self) -> usize {
634        self.filled().len()
635    }
636
637    fn filled(&self) -> &[u8];
638}
639
640enum HandshakePayloadState {
641    /// Waiting for more data.
642    Blocked,
643    /// We have a complete handshake message.
644    Complete(usize),
645    /// More records available for processing.
646    Continue,
647}
648
649struct HandshakePayloadMeta {
650    /// The range of bytes from the deframer buffer that contains data processed so far.
651    ///
652    /// This will need to be discarded as the last of the handshake message is `pop()`ped.
653    message: Range<usize>,
654    /// The range of bytes from the deframer buffer that contains payload.
655    payload: Range<usize>,
656    /// The protocol version as found in the decrypted handshake message.
657    version: ProtocolVersion,
658    /// The expected size of the handshake payload, if available.
659    ///
660    /// If the received payload exceeds 4 bytes (the handshake payload header), we update
661    /// `expected_len` to contain the payload length as advertised (at most 16_777_215 bytes).
662    expected_len: Option<usize>,
663    /// True if this is a QUIC handshake message.
664    ///
665    /// In the case of QUIC, we get a plaintext handshake data directly from the CRYPTO stream,
666    /// so there's no need to unwrap and decrypt the outer TLS record. This is implemented
667    /// by directly calling `MessageDeframer::push()` from the connection.
668    quic: bool,
669}
670
671/// Determine the expected length of the payload as advertised in the header.
672///
673/// Returns `Err` if the advertised length is larger than what we want to accept
674/// (`MAX_HANDSHAKE_SIZE`), `Ok(None)` if the buffer is too small to contain a complete header,
675/// and `Ok(Some(len))` otherwise.
676fn payload_size(buf: &[u8]) -> Result<Option<usize>, Error> {
677    if buf.len() < HANDSHAKE_HEADER_SIZE {
678        return Ok(None);
679    }
680
681    let (header, _) = buf.split_at(HANDSHAKE_HEADER_SIZE);
682    match codec::u24::read_bytes(&header[1..]) {
683        Ok(len) if len.0 > MAX_HANDSHAKE_SIZE => Err(Error::InvalidMessage(
684            InvalidMessage::HandshakePayloadTooLarge,
685        )),
686        Ok(len) => Ok(Some(HANDSHAKE_HEADER_SIZE + usize::from(len))),
687        _ => Ok(None),
688    }
689}
690
691#[derive(Debug)]
692pub struct Deframed<'a> {
693    pub(crate) want_close_before_decrypt: bool,
694    pub(crate) aligned: bool,
695    pub(crate) trial_decryption_finished: bool,
696    pub message: InboundPlainMessage<'a>,
697}
698
699const HANDSHAKE_HEADER_SIZE: usize = 1 + 3;
700
701/// TLS allows for handshake messages of up to 16MB.  We
702/// restrict that to 64KB to limit potential for denial-of-
703/// service.
704const MAX_HANDSHAKE_SIZE: u32 = 0xffff;
705
706#[cfg(feature = "std")]
707const READ_SIZE: usize = 4096;
708
709#[cfg(feature = "std")]
710#[cfg(test)]
711mod tests {
712    use std::prelude::v1::*;
713    use std::vec;
714
715    use super::*;
716    use crate::crypto::cipher::PlainMessage;
717    use crate::msgs::message::Message;
718
719    #[test]
720    fn check_incremental() {
721        let mut d = BufferedDeframer::default();
722        assert!(!d.has_pending());
723        input_whole_incremental(&mut d, FIRST_MESSAGE);
724        assert!(d.has_pending());
725
726        let mut rl = RecordLayer::new();
727        pop_first(&mut d, &mut rl);
728        assert!(!d.has_pending());
729        assert!(d.last_error.is_none());
730    }
731
732    #[test]
733    fn check_incremental_2() {
734        let mut d = BufferedDeframer::default();
735        assert!(!d.has_pending());
736        input_whole_incremental(&mut d, FIRST_MESSAGE);
737        assert!(d.has_pending());
738        input_whole_incremental(&mut d, SECOND_MESSAGE);
739        assert!(d.has_pending());
740
741        let mut rl = RecordLayer::new();
742        pop_first(&mut d, &mut rl);
743        assert!(d.has_pending());
744        pop_second(&mut d, &mut rl);
745        assert!(!d.has_pending());
746        assert!(d.last_error.is_none());
747    }
748
749    #[test]
750    fn check_whole() {
751        let mut d = BufferedDeframer::default();
752        assert!(!d.has_pending());
753        assert_len(FIRST_MESSAGE.len(), d.input_bytes(FIRST_MESSAGE));
754        assert!(d.has_pending());
755
756        let mut rl = RecordLayer::new();
757        pop_first(&mut d, &mut rl);
758        assert!(!d.has_pending());
759        assert!(d.last_error.is_none());
760    }
761
762    #[test]
763    fn check_whole_2() {
764        let mut d = BufferedDeframer::default();
765        assert!(!d.has_pending());
766        assert_len(FIRST_MESSAGE.len(), d.input_bytes(FIRST_MESSAGE));
767        assert_len(SECOND_MESSAGE.len(), d.input_bytes(SECOND_MESSAGE));
768
769        let mut rl = RecordLayer::new();
770        pop_first(&mut d, &mut rl);
771        pop_second(&mut d, &mut rl);
772        assert!(!d.has_pending());
773        assert!(d.last_error.is_none());
774    }
775
776    #[test]
777    fn test_two_in_one_read() {
778        let mut d = BufferedDeframer::default();
779        assert!(!d.has_pending());
780        assert_len(
781            FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
782            d.input_bytes_concat(FIRST_MESSAGE, SECOND_MESSAGE),
783        );
784
785        let mut rl = RecordLayer::new();
786        pop_first(&mut d, &mut rl);
787        pop_second(&mut d, &mut rl);
788        assert!(!d.has_pending());
789        assert!(d.last_error.is_none());
790    }
791
792    #[test]
793    fn test_two_in_one_read_shortest_first() {
794        let mut d = BufferedDeframer::default();
795        assert!(!d.has_pending());
796        assert_len(
797            FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
798            d.input_bytes_concat(SECOND_MESSAGE, FIRST_MESSAGE),
799        );
800
801        let mut rl = RecordLayer::new();
802        pop_second(&mut d, &mut rl);
803        pop_first(&mut d, &mut rl);
804        assert!(!d.has_pending());
805        assert!(d.last_error.is_none());
806    }
807
808    #[test]
809    fn test_incremental_with_nonfatal_read_error() {
810        let mut d = BufferedDeframer::default();
811        assert_len(3, d.input_bytes(&FIRST_MESSAGE[..3]));
812        input_error(&mut d);
813        assert_len(FIRST_MESSAGE.len() - 3, d.input_bytes(&FIRST_MESSAGE[3..]));
814
815        let mut rl = RecordLayer::new();
816        pop_first(&mut d, &mut rl);
817        assert!(!d.has_pending());
818        assert!(d.last_error.is_none());
819    }
820
821    #[test]
822    fn test_invalid_contenttype_errors() {
823        let mut d = BufferedDeframer::default();
824        assert_len(
825            INVALID_CONTENTTYPE_MESSAGE.len(),
826            d.input_bytes(INVALID_CONTENTTYPE_MESSAGE),
827        );
828
829        let mut rl = RecordLayer::new();
830        assert_eq!(
831            d.pop_error(&mut rl, None),
832            Error::InvalidMessage(InvalidMessage::InvalidContentType)
833        );
834    }
835
836    #[test]
837    fn test_invalid_version_errors() {
838        let mut d = BufferedDeframer::default();
839        assert_len(
840            INVALID_VERSION_MESSAGE.len(),
841            d.input_bytes(INVALID_VERSION_MESSAGE),
842        );
843
844        let mut rl = RecordLayer::new();
845        assert_eq!(
846            d.pop_error(&mut rl, None),
847            Error::InvalidMessage(InvalidMessage::UnknownProtocolVersion)
848        );
849    }
850
851    #[test]
852    fn test_invalid_length_errors() {
853        let mut d = BufferedDeframer::default();
854        assert_len(
855            INVALID_LENGTH_MESSAGE.len(),
856            d.input_bytes(INVALID_LENGTH_MESSAGE),
857        );
858
859        let mut rl = RecordLayer::new();
860        assert_eq!(
861            d.pop_error(&mut rl, None),
862            Error::InvalidMessage(InvalidMessage::MessageTooLarge)
863        );
864    }
865
866    #[test]
867    fn test_empty_applicationdata() {
868        let mut d = BufferedDeframer::default();
869        assert_len(
870            EMPTY_APPLICATIONDATA_MESSAGE.len(),
871            d.input_bytes(EMPTY_APPLICATIONDATA_MESSAGE),
872        );
873
874        let mut rl = RecordLayer::new();
875        let m = d.pop_message(&mut rl, None);
876        assert_eq!(m.typ, ContentType::ApplicationData);
877        assert_eq!(m.payload.bytes().len(), 0);
878        assert!(!d.has_pending());
879        assert!(d.last_error.is_none());
880    }
881
882    #[test]
883    fn test_invalid_empty_errors() {
884        let mut d = BufferedDeframer::default();
885        assert_len(
886            INVALID_EMPTY_MESSAGE.len(),
887            d.input_bytes(INVALID_EMPTY_MESSAGE),
888        );
889
890        let mut rl = RecordLayer::new();
891        assert_eq!(
892            d.pop_error(&mut rl, None),
893            Error::InvalidMessage(InvalidMessage::InvalidEmptyPayload)
894        );
895        // CorruptMessage has been fused
896        assert_eq!(
897            d.pop_error(&mut rl, None),
898            Error::InvalidMessage(InvalidMessage::InvalidEmptyPayload)
899        );
900    }
901
902    #[test]
903    fn test_limited_buffer() {
904        const PAYLOAD_LEN: usize = 16_384;
905        let mut message = Vec::with_capacity(16_389);
906        message.push(0x17); // ApplicationData
907        message.extend(&[0x03, 0x04]); // ProtocolVersion
908        message.extend((PAYLOAD_LEN as u16).to_be_bytes()); // payload length
909        message.extend(&[0; PAYLOAD_LEN]);
910
911        let mut d = BufferedDeframer::default();
912        assert_len(4096, d.input_bytes(&message));
913        assert_len(4096, d.input_bytes(&message));
914        assert_len(4096, d.input_bytes(&message));
915        assert_len(4096, d.input_bytes(&message));
916        assert_len(MAX_WIRE_SIZE - 16_384, d.input_bytes(&message));
917        assert!(d.input_bytes(&message).is_err());
918    }
919
920    fn input_error(d: &mut BufferedDeframer) {
921        let error = io::Error::from(io::ErrorKind::TimedOut);
922        let mut rd = ErrorRead::new(error);
923        d.read(&mut rd)
924            .expect_err("error not propagated");
925    }
926
927    fn input_whole_incremental(d: &mut BufferedDeframer, bytes: &[u8]) {
928        let before = d.buffer.len();
929
930        for i in 0..bytes.len() {
931            assert_len(1, d.input_bytes(&bytes[i..i + 1]));
932            assert!(d.has_pending());
933        }
934
935        assert_eq!(before + bytes.len(), d.buffer.len());
936    }
937
938    fn pop_first(d: &mut BufferedDeframer, rl: &mut RecordLayer) {
939        let m = d.pop_message(rl, None);
940        assert_eq!(m.typ, ContentType::Handshake);
941        Message::try_from(m).unwrap();
942    }
943
944    fn pop_second(d: &mut BufferedDeframer, rl: &mut RecordLayer) {
945        let m = d.pop_message(rl, None);
946        assert_eq!(m.typ, ContentType::Alert);
947        Message::try_from(m).unwrap();
948    }
949
950    // buffered version to ease testing
951    #[derive(Default)]
952    struct BufferedDeframer {
953        inner: MessageDeframer,
954        buffer: DeframerVecBuffer,
955    }
956
957    impl BufferedDeframer {
958        fn input_bytes(&mut self, bytes: &[u8]) -> io::Result<usize> {
959            let mut rd = io::Cursor::new(bytes);
960            self.read(&mut rd)
961        }
962
963        fn input_bytes_concat(&mut self, bytes1: &[u8], bytes2: &[u8]) -> io::Result<usize> {
964            let mut bytes = vec![0u8; bytes1.len() + bytes2.len()];
965            bytes[..bytes1.len()].clone_from_slice(bytes1);
966            bytes[bytes1.len()..].clone_from_slice(bytes2);
967            let mut rd = io::Cursor::new(&bytes);
968            self.read(&mut rd)
969        }
970
971        fn pop_error(
972            &mut self,
973            record_layer: &mut RecordLayer,
974            negotiated_version: Option<ProtocolVersion>,
975        ) -> Error {
976            let mut deframer_buffer = self.buffer.borrow();
977            let err = self
978                .inner
979                .pop(record_layer, negotiated_version, &mut deframer_buffer)
980                .unwrap_err();
981            let discard = deframer_buffer.pending_discard();
982            self.buffer.discard(discard);
983            err
984        }
985
986        fn pop_message(
987            &mut self,
988            record_layer: &mut RecordLayer,
989            negotiated_version: Option<ProtocolVersion>,
990        ) -> PlainMessage {
991            let mut deframer_buffer = self.buffer.borrow();
992            let m = self
993                .inner
994                .pop(record_layer, negotiated_version, &mut deframer_buffer)
995                .unwrap()
996                .unwrap()
997                .message
998                .into_owned();
999            let discard = deframer_buffer.pending_discard();
1000            self.buffer.discard(discard);
1001            m
1002        }
1003
1004        fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
1005            self.inner.read(rd, &mut self.buffer)
1006        }
1007
1008        fn has_pending(&self) -> bool {
1009            self.buffer.has_pending()
1010        }
1011    }
1012
1013    // grant access to the `MessageDeframer.last_error` field
1014    impl core::ops::Deref for BufferedDeframer {
1015        type Target = MessageDeframer;
1016
1017        fn deref(&self) -> &Self::Target {
1018            &self.inner
1019        }
1020    }
1021
1022    struct ErrorRead {
1023        error: Option<io::Error>,
1024    }
1025
1026    impl ErrorRead {
1027        fn new(error: io::Error) -> Self {
1028            Self { error: Some(error) }
1029        }
1030    }
1031
1032    impl io::Read for ErrorRead {
1033        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1034            for (i, b) in buf.iter_mut().enumerate() {
1035                *b = i as u8;
1036            }
1037
1038            let error = self.error.take().unwrap();
1039            Err(error)
1040        }
1041    }
1042
1043    fn assert_len(want: usize, got: io::Result<usize>) {
1044        assert_eq!(Some(want), got.ok())
1045    }
1046
1047    const FIRST_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.1.bin");
1048    const SECOND_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.2.bin");
1049
1050    const EMPTY_APPLICATIONDATA_MESSAGE: &[u8] =
1051        include_bytes!("../testdata/deframer-empty-applicationdata.bin");
1052
1053    const INVALID_EMPTY_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-empty.bin");
1054    const INVALID_CONTENTTYPE_MESSAGE: &[u8] =
1055        include_bytes!("../testdata/deframer-invalid-contenttype.bin");
1056    const INVALID_VERSION_MESSAGE: &[u8] =
1057        include_bytes!("../testdata/deframer-invalid-version.bin");
1058    const INVALID_LENGTH_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-length.bin");
1059}