minimq 0.11.0

An MQTT5 client
Documentation
use super::received_packet::ReceivedPacket;
use crate::ProtocolError as Error;
use crate::{trace, warn};

#[derive(Debug)]
pub(crate) struct PacketReader<'a> {
    pub buffer: &'a mut [u8],
    read_bytes: usize,
    packet_length: Option<usize>,
}

impl<'a> PacketReader<'a> {
    pub fn new(buffer: &'a mut [u8]) -> PacketReader<'a> {
        PacketReader {
            buffer,
            read_bytes: 0,
            packet_length: None,
        }
    }

    pub fn receive_buffer(&mut self) -> Result<&mut [u8], Error> {
        if self.packet_length.is_none() {
            self.probe_fixed_header()?;
        }

        let end = if let Some(packet_length) = &self.packet_length {
            *packet_length
        } else {
            self.read_bytes + 1
        };

        if end <= self.buffer.len() {
            trace!(
                "PacketReader receive window: read_bytes={=usize}, target_end={=usize}, packet_length={=?}",
                self.read_bytes, end, self.packet_length
            );
            Ok(&mut self.buffer[self.read_bytes..end])
        } else {
            warn!(
                "PacketReader target packet length {=usize} exceeds buffer length {=usize}",
                end,
                self.buffer.len()
            );
            Err(Error::MalformedPacket)
        }
    }

    pub fn commit(&mut self, count: usize) {
        self.read_bytes += count;
        trace!(
            "PacketReader committed {=usize} bytes, total {=usize}",
            count, self.read_bytes
        );
    }

    fn probe_fixed_header(&mut self) -> Result<(), Error> {
        if self.read_bytes <= 1 {
            return Ok(());
        }

        self.packet_length = None;

        let mut packet_length = 0;
        for (index, value) in self.buffer[1..self.read_bytes].iter().take(4).enumerate() {
            packet_length += ((value & 0x7F) as usize) << (index * 7);
            if (value & 0x80) == 0 {
                let length_size_bytes = 1 + index;

                // MQTT headers encode the packet type in the first byte followed by the packet
                // length as a varint
                let header_size_bytes = 1 + length_size_bytes;
                self.packet_length = Some(header_size_bytes + packet_length);
                trace!(
                    "PacketReader fixed header resolved packet_length={=usize} (header={=usize} payload={=usize})",
                    header_size_bytes + packet_length,
                    header_size_bytes,
                    packet_length
                );
                break;
            }
        }

        // We should have found the packet length by now.
        if self.read_bytes >= 5 && self.packet_length.is_none() {
            warn!(
                "PacketReader failed to resolve MQTT remaining length after {=usize} bytes",
                self.read_bytes
            );
            return Err(Error::MalformedPacket);
        }

        Ok(())
    }

    pub fn packet_available(&self) -> bool {
        match self.packet_length {
            Some(length) => self.read_bytes >= length,
            None => false,
        }
    }

    pub fn reset(&mut self) {
        trace!(
            "PacketReader reset (read_bytes={=usize}, packet_length={=?})",
            self.read_bytes, self.packet_length
        );
        self.read_bytes = 0;
        self.packet_length = None;
    }

    pub fn received_packet(&mut self) -> Result<ReceivedPacket<'_>, Error> {
        self.take_packet().map(|(_, packet)| packet)
    }

    pub fn take_packet(&mut self) -> Result<(usize, ReceivedPacket<'_>), Error> {
        let packet_length = *self.packet_length.as_ref().ok_or(Error::MalformedPacket)?;
        trace!(
            "PacketReader handing off complete packet of {=usize} bytes",
            packet_length
        );

        // Reset the buffer now. Once the user drops the `ReceivedPacket`, this reader will then be
        // immediately ready to begin receiving a new packet.
        self.reset();

        Ok((
            packet_length,
            ReceivedPacket::from_buffer(&self.buffer[..packet_length])?,
        ))
    }
}

#[cfg(test)]
mod test {
    use super::PacketReader;
    #[test]
    fn dont_panic_on_bad_data() {
        let mut buffer: [u8; 4] = [0x20, 0x99, 0x00, 0x00];
        let mut packet_reader = PacketReader::new(&mut buffer);
        packet_reader.commit(4);
        packet_reader
            .receive_buffer()
            .expect_err("parsed packet with invalid length");
    }
}