Skip to main content

mqttbytes_core/
primitives.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use std::slice::Iter;
3
4#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
5pub enum Error {
6    #[error("Payload is too long")]
7    PayloadTooLong,
8    #[error("Promised boundary crossed, contains {0} bytes")]
9    BoundaryCrossed(usize),
10    #[error("Packet is malformed")]
11    MalformedPacket,
12    #[error("Remaining length is malformed")]
13    MalformedRemainingLength,
14    #[error("Topic not utf-8")]
15    TopicNotUtf8,
16    #[error("Insufficient number of bytes to frame packet, {0} more bytes required")]
17    InsufficientBytes(usize),
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
21pub struct ParsedFixedHeader {
22    pub byte1: u8,
23    pub remaining_len_len: usize,
24    pub remaining_len: usize,
25}
26
27impl ParsedFixedHeader {
28    #[must_use]
29    pub const fn frame_length(self) -> usize {
30        1 + self.remaining_len_len + self.remaining_len
31    }
32}
33
34/// Checks whether the buffer contains a complete MQTT frame header and payload.
35///
36/// # Errors
37///
38/// Returns [`Error::InsufficientBytes`] when the full frame has not arrived yet,
39/// or another framing error if the fixed header is malformed.
40pub fn check(stream: Iter<u8>) -> Result<ParsedFixedHeader, Error> {
41    let stream_len = stream.len();
42    let fixed_header = parse_fixed_header(stream)?;
43
44    let frame_length = fixed_header.frame_length();
45    if stream_len < frame_length {
46        return Err(Error::InsufficientBytes(frame_length - stream_len));
47    }
48
49    Ok(fixed_header)
50}
51
52/// Parses the fixed header from the provided iterator.
53///
54/// # Errors
55///
56/// Returns an error when the header is incomplete or the remaining length field
57/// is malformed.
58///
59/// # Panics
60///
61/// Panics only if the iterator yields fewer than two bytes after the explicit
62/// length check above, which would indicate a broken iterator implementation.
63pub fn parse_fixed_header(mut stream: Iter<u8>) -> Result<ParsedFixedHeader, Error> {
64    let stream_len = stream.len();
65    if stream_len < 2 {
66        return Err(Error::InsufficientBytes(2 - stream_len));
67    }
68
69    let byte1 = *stream.next().unwrap();
70    let (remaining_len_len, remaining_len) = length(stream)?;
71
72    Ok(ParsedFixedHeader {
73        byte1,
74        remaining_len_len,
75        remaining_len,
76    })
77}
78
79/// Parses variable byte integer in the stream and returns the length
80/// and number of bytes that make it. Used for remaining length calculation
81/// as well as for calculating property lengths
82///
83/// # Errors
84///
85/// Returns an error when the variable-length integer is incomplete or exceeds
86/// the MQTT maximum encoding width.
87pub fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
88    let mut len: usize = 0;
89    let mut len_len = 0;
90    let mut done = false;
91    let mut shift = 0;
92
93    // Use continuation bit at position 7 to continue reading next
94    // byte to frame 'length'.
95    // Stream 0b1xxx_xxxx 0b1yyy_yyyy 0b1zzz_zzzz 0b0www_wwww will
96    // be framed as number 0bwww_wwww_zzz_zzzz_yyy_yyyy_xxx_xxxx
97    for byte in stream {
98        len_len += 1;
99        let byte = *byte as usize;
100        len += (byte & 0x7F) << shift;
101
102        // stop when continue bit is 0
103        done = (byte & 0x80) == 0;
104        if done {
105            break;
106        }
107
108        shift += 7;
109
110        // Only a max of 4 bytes allowed for remaining length
111        // more than 4 shifts (0, 7, 14, 21) implies bad length
112        if shift > 21 {
113            return Err(Error::MalformedRemainingLength);
114        }
115    }
116
117    // Not enough bytes to frame remaining length. wait for
118    // one more byte
119    if !done {
120        return Err(Error::InsufficientBytes(1));
121    }
122
123    Ok((len_len, len))
124}
125
126/// Reads a series of bytes with a length from a byte stream
127///
128/// # Errors
129///
130/// Returns an error when the stream does not contain enough bytes for the
131/// length prefix or the declared payload length crosses the packet boundary.
132pub fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
133    let len = read_u16(stream)? as usize;
134
135    // Prevent attacks with wrong remaining length. This method is used in
136    // `packet.assembly()` with (enough) bytes to frame packet. Ensures that
137    // reading variable len string or bytes doesn't cross promised boundary
138    // with `read_fixed_header()`
139    if len > stream.len() {
140        return Err(Error::BoundaryCrossed(len));
141    }
142
143    Ok(stream.split_to(len))
144}
145
146/// Reads a string from bytes stream
147///
148/// # Errors
149///
150/// Returns an error when the stream does not contain a complete MQTT string or
151/// when the bytes are not valid UTF-8.
152pub fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
153    let s = read_mqtt_bytes(stream)?;
154    let s = std::str::from_utf8(&s).map_err(|_| Error::TopicNotUtf8)?;
155    Ok(s.to_owned())
156}
157
158/// Serializes bytes to stream (including length)
159///
160/// # Panics
161///
162/// Panics if `bytes.len()` exceeds the MQTT maximum encoded string length of
163/// `u16::MAX`.
164pub fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
165    let len = u16::try_from(bytes.len()).expect("MQTT string/bytes length must fit in u16");
166    stream.put_u16(len);
167    stream.extend_from_slice(bytes);
168}
169
170/// Serializes a string to stream
171pub fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
172    write_mqtt_bytes(stream, string.as_bytes());
173}
174
175/// Writes remaining length to stream and returns number of bytes for remaining length
176///
177/// # Errors
178///
179/// Returns [`Error::PayloadTooLong`] when `len` exceeds the MQTT remaining
180/// length limit.
181///
182/// # Panics
183///
184/// Panics only if converting a remainder in `0..=127` to `u8` fails, which
185/// cannot happen for valid Rust integer conversions.
186pub fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
187    if len > 268_435_455 {
188        return Err(Error::PayloadTooLong);
189    }
190
191    let mut done = false;
192    let mut x = len;
193    let mut count = 0;
194
195    while !done {
196        let mut byte = u8::try_from(x % 128).expect("remainder in 0..=127 always fits in u8");
197        x /= 128;
198        if x > 0 {
199            byte |= 128;
200        }
201
202        stream.put_u8(byte);
203        count += 1;
204        done = x == 0;
205    }
206
207    Ok(count)
208}
209
210/// Return number of remaining length bytes required for encoding length
211#[must_use]
212pub const fn len_len(len: usize) -> usize {
213    if len >= 2_097_152 {
214        4
215    } else if len >= 16_384 {
216        3
217    } else if len >= 128 {
218        2
219    } else {
220        1
221    }
222}
223
224/// After collecting enough bytes to frame a packet, the packet payload itself
225/// can still be malformed.
226///
227/// For example, a packet may be missing an expected packet identifier or `QoS`
228/// field. These pre-checks prevent `bytes` panics when `read_mqtt_string` or
229/// `read_mqtt_bytes` exhaust the remaining length before the packet parser
230/// reaches the next expected field.
231///
232/// # Errors
233///
234/// Returns [`Error::MalformedPacket`] when fewer than two bytes remain.
235pub fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
236    if stream.len() < 2 {
237        return Err(Error::MalformedPacket);
238    }
239
240    Ok(stream.get_u16())
241}
242
243/// Reads the next byte from the stream.
244///
245/// # Errors
246///
247/// Returns [`Error::MalformedPacket`] when the stream is empty.
248pub fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
249    if stream.is_empty() {
250        return Err(Error::MalformedPacket);
251    }
252
253    Ok(stream.get_u8())
254}
255
256/// Reads the next big-endian `u32` from the stream.
257///
258/// # Errors
259///
260/// Returns [`Error::MalformedPacket`] when fewer than four bytes remain.
261pub fn read_u32(stream: &mut Bytes) -> Result<u32, Error> {
262    if stream.len() < 4 {
263        return Err(Error::MalformedPacket);
264    }
265
266    Ok(stream.get_u32())
267}
268
269#[cfg(test)]
270mod tests {
271    use bytes::BytesMut;
272
273    use super::*;
274
275    #[test]
276    fn len_len_matches_expected_thresholds() {
277        assert_eq!(len_len(0), 1);
278        assert_eq!(len_len(127), 1);
279        assert_eq!(len_len(128), 2);
280        assert_eq!(len_len(16_383), 2);
281        assert_eq!(len_len(16_384), 3);
282        assert_eq!(len_len(2_097_151), 3);
283        assert_eq!(len_len(2_097_152), 4);
284    }
285
286    #[test]
287    fn write_remaining_length_round_trip() {
288        for len in [0usize, 127, 128, 321, 16_384, 268_435_455] {
289            let mut b = BytesMut::new();
290            let count = write_remaining_length(&mut b, len).unwrap();
291            let (decoded_count, decoded) = length(b.iter()).unwrap();
292            assert_eq!(count, decoded_count);
293            assert_eq!(decoded, len);
294        }
295    }
296
297    #[test]
298    fn check_reports_missing_bytes() {
299        let b = [0x30u8, 0x05, 1, 2];
300        let result = check(b.iter());
301        assert_eq!(result, Err(Error::InsufficientBytes(3)));
302    }
303}