Skip to main content

rumqttc/mqttbytes/
mod.rs

1//! # mqttbytes
2//!
3//! This module contains the low level struct definitions required to assemble and disassemble MQTT 3.1.1 packets in rumqttc.
4//! The [`bytes`](https://docs.rs/bytes) crate is used internally.
5
6use bytes::{BufMut, Bytes, BytesMut};
7use core::fmt;
8use mqttbytes_core::primitives::{self as core_primitives, Error as PrimitiveError};
9use std::slice::Iter;
10
11pub mod v4;
12
13pub use mqttbytes_core::{QoS, has_wildcards, matches, valid_filter, valid_topic};
14
15/// Error during serialization and deserialization
16#[derive(Debug, thiserror::Error)]
17pub enum Error {
18    #[error("Expected Connect, received: {0:?}")]
19    NotConnect(PacketType),
20    #[error("Unexpected Connect")]
21    UnexpectedConnect,
22    #[error("Invalid Connect return code: {0}")]
23    InvalidConnectReturnCode(u8),
24    #[error("Invalid protocol")]
25    InvalidProtocol,
26    #[error("Invalid protocol level: {0}")]
27    InvalidProtocolLevel(u8),
28    #[error("Incorrect packet format")]
29    IncorrectPacketFormat,
30    #[error("Invalid packet type: {0}")]
31    InvalidPacketType(u8),
32    #[error("Invalid property type: {0}")]
33    InvalidPropertyType(u8),
34    #[error("Invalid QoS level: {0}")]
35    InvalidQoS(u8),
36    #[error("Invalid subscribe reason code: {0}")]
37    InvalidSubscribeReasonCode(u8),
38    #[error("Packet id Zero")]
39    PacketIdZero,
40    #[error("Payload size is incorrect")]
41    PayloadSizeIncorrect,
42    #[error("payload is too long")]
43    PayloadTooLong,
44    #[error("payload size limit exceeded: {0}")]
45    PayloadSizeLimitExceeded(usize),
46    #[error("Payload required")]
47    PayloadRequired,
48    #[error("Topic is not UTF-8")]
49    TopicNotUtf8,
50    #[error("Promised boundary crossed: {0}")]
51    BoundaryCrossed(usize),
52    #[error("Malformed packet")]
53    MalformedPacket,
54    #[error("Malformed remaining length")]
55    MalformedRemainingLength,
56    #[error("A Subscribe packet must contain atleast one filter")]
57    EmptySubscription,
58    /// More bytes required to frame packet. Argument
59    /// implies minimum additional bytes required to
60    /// proceed further
61    #[error("At least {0} more bytes required to frame packet")]
62    InsufficientBytes(usize),
63    #[error("IO: {0}")]
64    Io(#[from] std::io::Error),
65    #[error(
66        "Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'"
67    )]
68    OutgoingPacketTooLarge { pkt_size: usize, max: usize },
69}
70
71impl From<PrimitiveError> for Error {
72    fn from(error: PrimitiveError) -> Self {
73        match error {
74            PrimitiveError::PayloadTooLong => Self::PayloadTooLong,
75            PrimitiveError::BoundaryCrossed(len) => Self::BoundaryCrossed(len),
76            PrimitiveError::MalformedPacket => Self::MalformedPacket,
77            PrimitiveError::MalformedRemainingLength => Self::MalformedRemainingLength,
78            PrimitiveError::TopicNotUtf8 => Self::TopicNotUtf8,
79            PrimitiveError::InsufficientBytes(required) => Self::InsufficientBytes(required),
80        }
81    }
82}
83
84/// MQTT packet type
85#[repr(u8)]
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum PacketType {
88    Connect = 1,
89    ConnAck,
90    Publish,
91    PubAck,
92    PubRec,
93    PubRel,
94    PubComp,
95    Subscribe,
96    SubAck,
97    Unsubscribe,
98    UnsubAck,
99    PingReq,
100    PingResp,
101    Disconnect,
102}
103
104/// Protocol type
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum Protocol {
107    V4,
108    V5,
109}
110
111/// Packet type from a byte
112///
113/// ```text
114///          7                          3                          0
115///          +--------------------------+--------------------------+
116/// byte 1   | MQTT Control Packet Type | Flags for each type      |
117///          +--------------------------+--------------------------+
118///          |         Remaining Bytes Len  (1/2/3/4 bytes)        |
119///          +-----------------------------------------------------+
120///
121/// <https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349207>
122/// ```
123#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
124pub struct FixedHeader {
125    /// First byte of the stream. Used to identify packet types and
126    /// several flags
127    byte1: u8,
128    /// Length of fixed header. Byte 1 + (1..4) bytes. So fixed header
129    /// len can vary from 2 bytes to 5 bytes
130    /// 1..4 bytes are variable length encoded to represent remaining length
131    header_len: usize,
132    /// Remaining length of the packet. Doesn't include fixed header bytes
133    /// Represents variable header + payload size
134    remaining_len: usize,
135}
136
137impl FixedHeader {
138    #[must_use]
139    pub const fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> Self {
140        Self {
141            byte1,
142            header_len: remaining_len_len + 1,
143            remaining_len,
144        }
145    }
146
147    /// Returns the MQTT packet type represented by this fixed header.
148    ///
149    /// # Errors
150    ///
151    /// Returns an error if the fixed-header flags are invalid for the decoded
152    /// packet type.
153    pub const fn packet_type(&self) -> Result<PacketType, Error> {
154        let num = self.byte1 >> 4;
155        match num {
156            1 => Ok(PacketType::Connect),
157            2 => Ok(PacketType::ConnAck),
158            3 => Ok(PacketType::Publish),
159            4 => Ok(PacketType::PubAck),
160            5 => Ok(PacketType::PubRec),
161            6 => Ok(PacketType::PubRel),
162            7 => Ok(PacketType::PubComp),
163            8 => Ok(PacketType::Subscribe),
164            9 => Ok(PacketType::SubAck),
165            10 => Ok(PacketType::Unsubscribe),
166            11 => Ok(PacketType::UnsubAck),
167            12 => Ok(PacketType::PingReq),
168            13 => Ok(PacketType::PingResp),
169            14 => Ok(PacketType::Disconnect),
170            _ => Err(Error::InvalidPacketType(num)),
171        }
172    }
173
174    /// Returns the size of full packet (fixed header + variable header + payload)
175    /// Fixed header is enough to get the size of a frame in the stream
176    #[must_use]
177    pub const fn frame_length(&self) -> usize {
178        self.header_len + self.remaining_len
179    }
180}
181
182/// Checks whether the stream contains a complete MQTT packet within the
183/// configured size limit.
184///
185/// The fixed header is returned only if the existing bytes are enough to frame
186/// the packet. The passed stream does not modify the parent stream's cursor. If
187/// this function returns an error, the next `check` on the same parent stream
188/// starts again with the cursor at `0`.
189///
190/// # Errors
191///
192/// Returns an error if the frame is incomplete, malformed, or exceeds
193/// `max_packet_size`.
194pub fn check(stream: Iter<u8>, max_packet_size: usize) -> Result<FixedHeader, Error> {
195    let stream_len = stream.len();
196    let fixed_header = parse_fixed_header(stream)?;
197
198    // Don't let rogue connections attack with huge payloads.
199    // Disconnect them before reading all that data
200    if fixed_header.remaining_len > max_packet_size {
201        return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len));
202    }
203
204    let frame_length = fixed_header.frame_length();
205    if stream_len < frame_length {
206        return Err(Error::InsufficientBytes(frame_length - stream_len));
207    }
208
209    Ok(fixed_header)
210}
211
212fn parse_fixed_header(stream: Iter<u8>) -> Result<FixedHeader, Error> {
213    let fixed_header = core_primitives::parse_fixed_header(stream).map_err(Error::from)?;
214    Ok(FixedHeader::new(
215        fixed_header.byte1,
216        fixed_header.remaining_len_len,
217        fixed_header.remaining_len,
218    ))
219}
220
221/// Reads a series of bytes with a length from a byte stream
222fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
223    core_primitives::read_mqtt_bytes(stream).map_err(Error::from)
224}
225
226/// Reads a string from bytes stream
227fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
228    core_primitives::read_mqtt_string(stream).map_err(Error::from)
229}
230
231/// Serializes bytes to stream (including length)
232fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
233    core_primitives::write_mqtt_bytes(stream, bytes);
234}
235
236/// Serializes a string to stream
237fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
238    core_primitives::write_mqtt_string(stream, string);
239}
240
241/// Writes remaining length to stream and returns number of bytes for remaining length
242fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
243    core_primitives::write_remaining_length(stream, len).map_err(Error::from)
244}
245
246/// Maps a number to `QoS`
247/// Decodes a `QoS` value from its wire representation.
248///
249/// # Errors
250///
251/// Returns an error if `num` does not encode a valid MQTT `QoS` level.
252pub fn qos(num: u8) -> Result<QoS, Error> {
253    mqttbytes_core::qos(num).ok_or(Error::InvalidQoS(num))
254}
255
256/// After collecting enough bytes to frame a packet (packet's `frame()`)
257/// , It's possible that content itself in the stream is wrong. Like expected
258/// packet id or qos not being present. In cases where `read_mqtt_string` or
259/// `read_mqtt_bytes` exhausted remaining length but packet framing expects to
260/// parse qos next, these pre checks will prevent `bytes` crashes
261fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
262    core_primitives::read_u16(stream).map_err(Error::from)
263}
264
265fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
266    core_primitives::read_u8(stream).map_err(Error::from)
267}
268
269#[cfg(test)]
270mod tests {
271    use super::{Error, check};
272
273    #[test]
274    fn check_rejects_oversized_packet_on_partial_frame() {
275        let stream = [0x30, 0x14];
276        let result = check(stream.iter(), 10);
277
278        assert!(matches!(result, Err(Error::PayloadSizeLimitExceeded(20))));
279    }
280}