mqttbytes/
lib.rs

1extern crate alloc;
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use core::fmt::{self, Display, Formatter};
5use std::slice::Iter;
6
7pub mod v4;
8pub mod v5;
9mod topic;
10
11pub use topic::*;
12
13/// Error during serialization and deserialization
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum Error {
16    NotConnect(PacketType),
17    UnexpectedConnect,
18    InvalidConnectReturnCode(u8),
19    InvalidReason(u8),
20    InvalidProtocol,
21    InvalidProtocolLevel(u8),
22    IncorrectPacketFormat,
23    InvalidPacketType(u8),
24    InvalidPropertyType(u8),
25    InvalidRetainForwardRule(u8),
26    InvalidQoS(u8),
27    InvalidSubscribeReasonCode(u8),
28    PacketIdZero,
29    SubscriptionIdZero,
30    PayloadSizeIncorrect,
31    PayloadTooLong,
32    PayloadSizeLimitExceeded(usize),
33    PayloadRequired,
34    TopicNotUtf8,
35    BoundaryCrossed(usize),
36    MalformedPacket,
37    MalformedRemainingLength,
38    /// More bytes required to frame packet. Argument
39    /// implies minimum additional bytes required to
40    /// proceed further
41    InsufficientBytes(usize),
42}
43
44/// MQTT packet type
45#[repr(u8)]
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum PacketType {
48    Connect = 1,
49    ConnAck,
50    Publish,
51    PubAck,
52    PubRec,
53    PubRel,
54    PubComp,
55    Subscribe,
56    SubAck,
57    Unsubscribe,
58    UnsubAck,
59    PingReq,
60    PingResp,
61    Disconnect,
62}
63
64/// Protocol type
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum Protocol {
67    V4,
68    V5,
69}
70
71/// Quality of service
72#[repr(u8)]
73#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
74pub enum QoS {
75    AtMostOnce = 0,
76    AtLeastOnce = 1,
77    ExactlyOnce = 2,
78}
79
80/// Packet type from a byte
81///
82/// ```ignore
83///          7                          3                          0
84///          +--------------------------+--------------------------+
85/// byte 1   | MQTT Control Packet Type | Flags for each type      |
86///          +--------------------------+--------------------------+
87///          |         Remaining Bytes Len  (1/2/3/4 bytes)        |
88///          +-----------------------------------------------------+
89///
90/// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Figure_2.2_-
91/// ```
92#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
93pub struct FixedHeader {
94    /// First byte of the stream. Used to identify packet types and
95    /// several flags
96    byte1: u8,
97    /// Length of fixed header. Byte 1 + (1..4) bytes. So fixed header
98    /// len can vary from 2 bytes to 5 bytes
99    /// 1..4 bytes are variable length encoded to represent remaining length
100    fixed_header_len: usize,
101    /// Remaining length of the packet. Doesn't include fixed header bytes
102    /// Represents variable header + payload size
103    remaining_len: usize,
104}
105
106impl FixedHeader {
107    pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader {
108        FixedHeader {
109            byte1,
110            fixed_header_len: remaining_len_len + 1,
111            remaining_len,
112        }
113    }
114
115    pub fn packet_type(&self) -> Result<PacketType, Error> {
116        let num = self.byte1 >> 4;
117        match num {
118            1 => Ok(PacketType::Connect),
119            2 => Ok(PacketType::ConnAck),
120            3 => Ok(PacketType::Publish),
121            4 => Ok(PacketType::PubAck),
122            5 => Ok(PacketType::PubRec),
123            6 => Ok(PacketType::PubRel),
124            7 => Ok(PacketType::PubComp),
125            8 => Ok(PacketType::Subscribe),
126            9 => Ok(PacketType::SubAck),
127            10 => Ok(PacketType::Unsubscribe),
128            11 => Ok(PacketType::UnsubAck),
129            12 => Ok(PacketType::PingReq),
130            13 => Ok(PacketType::PingResp),
131            14 => Ok(PacketType::Disconnect),
132            _ => Err(Error::InvalidPacketType(num)),
133        }
134    }
135
136    /// Returns the size of full packet (fixed header + variable header + payload)
137    /// Fixed header is enough to get the size of a frame in the stream
138    pub fn frame_length(&self) -> usize {
139        self.fixed_header_len + self.remaining_len
140    }
141}
142
143/// Checks if the stream has enough bytes to frame a packet and returns fixed header
144/// only if a packet can be framed with existing bytes in the `stream`.
145/// The passed stream doesn't modify parent stream's cursor. If this function
146/// returned an error, next `check` on the same parent stream is forced start
147/// with cursor at 0 again (Iter is owned. Only Iter's cursor is changed internally)
148pub fn check(stream: Iter<u8>, max_packet_size: usize) -> Result<FixedHeader, Error> {
149    // Create fixed header if there are enough bytes in the stream
150    // to frame full packet
151    let stream_len = stream.len();
152    let fixed_header = parse_fixed_header(stream)?;
153
154    // Don't let rogue connections attack with huge payloads.
155    // Disconnect them before reading all that data
156    if fixed_header.remaining_len > max_packet_size {
157        return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len));
158    }
159
160    // If the current call fails due to insufficient bytes in the stream,
161    // after calculating remaining length, we extend the stream
162    let frame_length = fixed_header.frame_length();
163    if stream_len < frame_length {
164        return Err(Error::InsufficientBytes(frame_length - stream_len));
165    }
166
167    Ok(fixed_header)
168}
169
170/// Parses fixed header
171fn parse_fixed_header(mut stream: Iter<u8>) -> Result<FixedHeader, Error> {
172    // At least 2 bytes are necessary to frame a packet
173    let stream_len = stream.len();
174    if stream_len < 2 {
175        return Err(Error::InsufficientBytes(2 - stream_len));
176    }
177
178    let byte1 = stream.next().unwrap();
179    let (len_len, len) = length(stream)?;
180
181    Ok(FixedHeader::new(*byte1, len_len, len))
182}
183
184/// Parses variable byte integer in the stream and returns the length
185/// and number of bytes that make it. Used for remaining length calculation
186/// as well as for calculating property lengths
187fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
188    let mut len: usize = 0;
189    let mut len_len = 0;
190    let mut done = false;
191    let mut shift = 0;
192
193    // Use continuation bit at position 7 to continue reading next
194    // byte to frame 'length'.
195    // Stream 0b1xxx_xxxx 0b1yyy_yyyy 0b1zzz_zzzz 0b0www_wwww will
196    // be framed as number 0bwww_wwww_zzz_zzzz_yyy_yyyy_xxx_xxxx
197    for byte in stream {
198        len_len += 1;
199        let byte = *byte as usize;
200        len += (byte & 0x7F) << shift;
201
202        // stop when continue bit is 0
203        done = (byte & 0x80) == 0;
204        if done {
205            break;
206        }
207
208        shift += 7;
209
210        // Only a max of 4 bytes allowed for remaining length
211        // more than 4 shifts (0, 7, 14, 21) implies bad length
212        if shift > 21 {
213            return Err(Error::MalformedRemainingLength);
214        }
215    }
216
217    // Not enough bytes to frame remaining length. wait for
218    // one more byte
219    if !done {
220        return Err(Error::InsufficientBytes(1));
221    }
222
223    Ok((len_len, len))
224}
225
226/// Reads a series of bytes with a length from a byte stream
227fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
228    let len = read_u16(stream)? as usize;
229
230    // Prevent attacks with wrong remaining length. This method is used in
231    // `packet.assembly()` with (enough) bytes to frame packet. Ensures that
232    // reading variable len string or bytes doesn't cross promised boundary
233    // with `read_fixed_header()`
234    if len > stream.len() {
235        return Err(Error::BoundaryCrossed(len));
236    }
237
238    Ok(stream.split_to(len))
239}
240
241/// Reads a string from bytes stream
242fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
243    let s = read_mqtt_bytes(stream)?;
244    match String::from_utf8(s.to_vec()) {
245        Ok(v) => Ok(v),
246        Err(_e) => Err(Error::TopicNotUtf8),
247    }
248}
249
250/// Serializes bytes to stream (including length)
251fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
252    stream.put_u16(bytes.len() as u16);
253    stream.extend_from_slice(bytes);
254}
255
256/// Serializes a string to stream
257fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
258    write_mqtt_bytes(stream, string.as_bytes());
259}
260
261/// Writes remaining length to stream and returns number of bytes for remaining length
262fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
263    if len > 268_435_455 {
264        return Err(Error::PayloadTooLong);
265    }
266
267    let mut done = false;
268    let mut x = len;
269    let mut count = 0;
270
271    while !done {
272        let mut byte = (x % 128) as u8;
273        x /= 128;
274        if x > 0 {
275            byte |= 128;
276        }
277
278        stream.put_u8(byte);
279        count += 1;
280        done = x == 0;
281    }
282
283    Ok(count)
284}
285
286/// Return number of remaining length bytes required for encoding length
287fn len_len(len: usize) -> usize {
288    if len >= 2_097_152 {
289        4
290    } else if len >= 16_384 {
291        3
292    } else if len >= 128 {
293        2
294    } else {
295        1
296    }
297}
298
299/// Maps a number to QoS
300pub fn qos(num: u8) -> Result<QoS, Error> {
301    match num {
302        0 => Ok(QoS::AtMostOnce),
303        1 => Ok(QoS::AtLeastOnce),
304        2 => Ok(QoS::ExactlyOnce),
305        qos => Err(Error::InvalidQoS(qos)),
306    }
307}
308
309/// After collecting enough bytes to frame a packet (packet's frame())
310/// , It's possible that content itself in the stream is wrong. Like expected
311/// packet id or qos not being present. In cases where `read_mqtt_string` or
312/// `read_mqtt_bytes` exhausted remaining length but packet framing expects to
313/// parse qos next, these pre checks will prevent `bytes` crashes
314fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
315    if stream.len() < 2 {
316        return Err(Error::MalformedPacket);
317    }
318
319    Ok(stream.get_u16())
320}
321
322fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
323    if stream.len() < 1 {
324        return Err(Error::MalformedPacket);
325    }
326
327    Ok(stream.get_u8())
328}
329
330fn read_u32(stream: &mut Bytes) -> Result<u32, Error> {
331    if stream.len() < 4 {
332        return Err(Error::MalformedPacket);
333    }
334
335    Ok(stream.get_u32())
336}
337
338impl Display for Error {
339    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
340        write!(f, "Error = {:?}", self)
341    }
342}