rustls 0.20.6

Rustls is a modern TLS library written in Rust.
Documentation
use std::collections::VecDeque;

use crate::msgs::base::Payload;
use crate::msgs::codec;
use crate::msgs::enums::{ContentType, ProtocolVersion};
use crate::msgs::handshake::HandshakeMessagePayload;
use crate::msgs::message::{Message, MessagePayload, PlainMessage};

const HEADER_SIZE: usize = 1 + 3;

/// TLS allows for handshake messages of up to 16MB.  We
/// restrict that to 64KB to limit potential for denial-of-
/// service.
const MAX_HANDSHAKE_SIZE: u32 = 0xffff;

/// This works to reconstruct TLS handshake messages
/// from individual TLS messages.  It's guaranteed that
/// TLS messages output from this layer contain precisely
/// one handshake payload.
pub struct HandshakeJoiner {
    /// Completed handshake frames for output.
    pub frames: VecDeque<Message>,

    /// The message payload we're currently accumulating.
    buf: Vec<u8>,
}

impl Default for HandshakeJoiner {
    fn default() -> Self {
        Self::new()
    }
}

enum BufferState {
    /// Buffer contains a header that introduces a message that is too long.
    MessageTooLarge,

    /// Buffer contains a full header and body.
    OneMessage,

    /// We need more data to see a header and complete body.
    NeedsMoreData,
}

impl HandshakeJoiner {
    /// Make a new HandshakeJoiner.
    pub fn new() -> Self {
        Self {
            frames: VecDeque::new(),
            buf: Vec::new(),
        }
    }

    /// Do we want to process this message?
    pub fn want_message(&self, msg: &PlainMessage) -> bool {
        msg.typ == ContentType::Handshake
    }

    /// Do we have any buffered data?
    pub fn is_empty(&self) -> bool {
        self.buf.is_empty()
    }

    /// Take the message, and join/split it as needed.
    /// Return the number of new messages added to the
    /// output deque as a result of this message.
    ///
    /// Returns None if msg or a preceding message was corrupt.
    /// You cannot recover from this situation.  Otherwise returns
    /// a count of how many messages we queued.
    pub fn take_message(&mut self, msg: PlainMessage) -> Option<usize> {
        // The vast majority of the time `self.buf` will be empty since most
        // handshake messages arrive in a single fragment. Avoid allocating and
        // copying in that common case.
        if self.buf.is_empty() {
            self.buf = msg.payload.0;
        } else {
            self.buf
                .extend_from_slice(&msg.payload.0[..]);
        }

        let mut count = 0;
        loop {
            match self.buf_contains_message() {
                BufferState::MessageTooLarge => return None,
                BufferState::NeedsMoreData => break,
                BufferState::OneMessage => {
                    if !self.deframe_one(msg.version) {
                        return None;
                    }

                    count += 1;
                }
            }
        }

        Some(count)
    }

    /// Does our `buf` contain a full handshake payload?  It does if it is big
    /// enough to contain a header, and that header has a length which falls
    /// within `buf`.
    fn buf_contains_message(&self) -> BufferState {
        if self.buf.len() < HEADER_SIZE {
            return BufferState::NeedsMoreData;
        }

        let (header, rest) = self.buf.split_at(HEADER_SIZE);
        match codec::u24::decode(&header[1..]) {
            Some(len) if len.0 > MAX_HANDSHAKE_SIZE => BufferState::MessageTooLarge,
            Some(len) if rest.get(..len.into()).is_some() => BufferState::OneMessage,
            _ => BufferState::NeedsMoreData,
        }
    }

    /// Take a TLS handshake payload off the front of `buf`, and put it onto
    /// the back of our `frames` deque inside a normal `Message`.
    ///
    /// Returns false if the stream is desynchronised beyond repair.
    fn deframe_one(&mut self, version: ProtocolVersion) -> bool {
        let used = {
            let mut rd = codec::Reader::init(&self.buf);
            let parsed = match HandshakeMessagePayload::read_version(&mut rd, version) {
                Some(p) => p,
                None => return false,
            };

            let m = Message {
                version,
                payload: MessagePayload::Handshake {
                    parsed,
                    encoded: Payload::new(&self.buf[..rd.used()]),
                },
            };

            self.frames.push_back(m);
            rd.used()
        };
        self.buf = self.buf.split_off(used);
        true
    }
}

#[cfg(test)]
mod tests {
    use super::HandshakeJoiner;
    use crate::msgs::base::Payload;
    use crate::msgs::codec::Codec;
    use crate::msgs::enums::{ContentType, HandshakeType, ProtocolVersion};
    use crate::msgs::handshake::{HandshakeMessagePayload, HandshakePayload};
    use crate::msgs::message::{Message, MessagePayload, PlainMessage};

    #[test]
    fn want() {
        let hj = HandshakeJoiner::new();
        assert!(hj.is_empty());

        let wanted = PlainMessage {
            typ: ContentType::Handshake,
            version: ProtocolVersion::TLSv1_2,
            payload: Payload::new(b"hello world".to_vec()),
        };

        let unwanted = PlainMessage {
            typ: ContentType::Alert,
            version: ProtocolVersion::TLSv1_2,
            payload: Payload::new(b"ponytown".to_vec()),
        };

        assert!(hj.want_message(&wanted));
        assert!(!hj.want_message(&unwanted));
    }

    fn pop_eq(expect: &PlainMessage, hj: &mut HandshakeJoiner) {
        let got = hj.frames.pop_front().unwrap();
        assert_eq!(got.payload.content_type(), expect.typ);
        assert_eq!(got.version, expect.version);

        let (mut left, mut right) = (Vec::new(), Vec::new());
        got.payload.encode(&mut left);
        expect.payload.encode(&mut right);

        assert_eq!(left, right);
    }

    #[test]
    fn split() {
        // Check we split two handshake messages within one PDU.
        let mut hj = HandshakeJoiner::new();

        // two HelloRequests
        let msg = PlainMessage {
            typ: ContentType::Handshake,
            version: ProtocolVersion::TLSv1_2,
            payload: Payload::new(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()),
        };

        assert!(hj.want_message(&msg));
        assert_eq!(hj.take_message(msg), Some(2));
        assert!(hj.is_empty());

        let expect = Message {
            version: ProtocolVersion::TLSv1_2,
            payload: MessagePayload::handshake(HandshakeMessagePayload {
                typ: HandshakeType::HelloRequest,
                payload: HandshakePayload::HelloRequest,
            }),
        }
        .into();

        pop_eq(&expect, &mut hj);
        pop_eq(&expect, &mut hj);
    }

    #[test]
    fn broken() {
        // Check obvious crap payloads are reported as errors, not panics.
        let mut hj = HandshakeJoiner::new();

        // short ClientHello
        let msg = PlainMessage {
            typ: ContentType::Handshake,
            version: ProtocolVersion::TLSv1_2,
            payload: Payload::new(b"\x01\x00\x00\x02\xff\xff".to_vec()),
        };

        assert!(hj.want_message(&msg));
        assert_eq!(hj.take_message(msg), None);
    }

    #[test]
    fn join() {
        // Check we join one handshake message split over two PDUs.
        let mut hj = HandshakeJoiner::new();
        assert!(hj.is_empty());

        // Introduce Finished of 16 bytes, providing 4.
        let mut msg = PlainMessage {
            typ: ContentType::Handshake,
            version: ProtocolVersion::TLSv1_2,
            payload: Payload::new(b"\x14\x00\x00\x10\x00\x01\x02\x03\x04".to_vec()),
        };

        assert!(hj.want_message(&msg));
        assert_eq!(hj.take_message(msg), Some(0));
        assert!(!hj.is_empty());

        // 11 more bytes.
        msg = PlainMessage {
            typ: ContentType::Handshake,
            version: ProtocolVersion::TLSv1_2,
            payload: Payload::new(b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec()),
        };

        assert!(hj.want_message(&msg));
        assert_eq!(hj.take_message(msg), Some(0));
        assert!(!hj.is_empty());

        // Final 1 byte.
        msg = PlainMessage {
            typ: ContentType::Handshake,
            version: ProtocolVersion::TLSv1_2,
            payload: Payload::new(b"\x0f".to_vec()),
        };

        assert!(hj.want_message(&msg));
        assert_eq!(hj.take_message(msg), Some(1));
        assert!(hj.is_empty());

        let payload = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec();
        let expect = Message {
            version: ProtocolVersion::TLSv1_2,
            payload: MessagePayload::handshake(HandshakeMessagePayload {
                typ: HandshakeType::Finished,
                payload: HandshakePayload::Finished(Payload::new(payload)),
            }),
        }
        .into();

        pop_eq(&expect, &mut hj);
    }

    #[test]
    fn test_rejects_giant_certs() {
        let mut hj = HandshakeJoiner::new();
        let msg = PlainMessage {
            typ: ContentType::Handshake,
            version: ProtocolVersion::TLSv1_2,
            payload: Payload::new(b"\x0b\x01\x00\x04\x01\x00\x01\x00\xff\xfe".to_vec()),
        };

        assert!(hj.want_message(&msg));
        assert_eq!(hj.take_message(msg), None);
        assert!(!hj.is_empty());
    }
}