mqttbytes_core/
primitives.rs1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use std::slice::Iter;
3
4#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
5pub enum Error {
6 #[error("Payload is too long")]
7 PayloadTooLong,
8 #[error("Promised boundary crossed, contains {0} bytes")]
9 BoundaryCrossed(usize),
10 #[error("Packet is malformed")]
11 MalformedPacket,
12 #[error("Remaining length is malformed")]
13 MalformedRemainingLength,
14 #[error("Topic not utf-8")]
15 TopicNotUtf8,
16 #[error("Insufficient number of bytes to frame packet, {0} more bytes required")]
17 InsufficientBytes(usize),
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
21pub struct ParsedFixedHeader {
22 pub byte1: u8,
23 pub remaining_len_len: usize,
24 pub remaining_len: usize,
25}
26
27impl ParsedFixedHeader {
28 #[must_use]
29 pub const fn frame_length(self) -> usize {
30 1 + self.remaining_len_len + self.remaining_len
31 }
32}
33
34pub fn check(stream: Iter<u8>) -> Result<ParsedFixedHeader, Error> {
41 let stream_len = stream.len();
42 let fixed_header = parse_fixed_header(stream)?;
43
44 let frame_length = fixed_header.frame_length();
45 if stream_len < frame_length {
46 return Err(Error::InsufficientBytes(frame_length - stream_len));
47 }
48
49 Ok(fixed_header)
50}
51
52pub fn parse_fixed_header(mut stream: Iter<u8>) -> Result<ParsedFixedHeader, Error> {
64 let stream_len = stream.len();
65 if stream_len < 2 {
66 return Err(Error::InsufficientBytes(2 - stream_len));
67 }
68
69 let byte1 = *stream.next().unwrap();
70 let (remaining_len_len, remaining_len) = length(stream)?;
71
72 Ok(ParsedFixedHeader {
73 byte1,
74 remaining_len_len,
75 remaining_len,
76 })
77}
78
79pub fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
88 let mut len: usize = 0;
89 let mut len_len = 0;
90 let mut done = false;
91 let mut shift = 0;
92
93 for byte in stream {
98 len_len += 1;
99 let byte = *byte as usize;
100 len += (byte & 0x7F) << shift;
101
102 done = (byte & 0x80) == 0;
104 if done {
105 break;
106 }
107
108 shift += 7;
109
110 if shift > 21 {
113 return Err(Error::MalformedRemainingLength);
114 }
115 }
116
117 if !done {
120 return Err(Error::InsufficientBytes(1));
121 }
122
123 Ok((len_len, len))
124}
125
126pub fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
133 let len = read_u16(stream)? as usize;
134
135 if len > stream.len() {
140 return Err(Error::BoundaryCrossed(len));
141 }
142
143 Ok(stream.split_to(len))
144}
145
146pub fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
153 let s = read_mqtt_bytes(stream)?;
154 let s = std::str::from_utf8(&s).map_err(|_| Error::TopicNotUtf8)?;
155 Ok(s.to_owned())
156}
157
158pub fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
165 let len = u16::try_from(bytes.len()).expect("MQTT string/bytes length must fit in u16");
166 stream.put_u16(len);
167 stream.extend_from_slice(bytes);
168}
169
170pub fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
172 write_mqtt_bytes(stream, string.as_bytes());
173}
174
175pub fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
187 if len > 268_435_455 {
188 return Err(Error::PayloadTooLong);
189 }
190
191 let mut done = false;
192 let mut x = len;
193 let mut count = 0;
194
195 while !done {
196 let mut byte = u8::try_from(x % 128).expect("remainder in 0..=127 always fits in u8");
197 x /= 128;
198 if x > 0 {
199 byte |= 128;
200 }
201
202 stream.put_u8(byte);
203 count += 1;
204 done = x == 0;
205 }
206
207 Ok(count)
208}
209
210#[must_use]
212pub const fn len_len(len: usize) -> usize {
213 if len >= 2_097_152 {
214 4
215 } else if len >= 16_384 {
216 3
217 } else if len >= 128 {
218 2
219 } else {
220 1
221 }
222}
223
224pub fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
236 if stream.len() < 2 {
237 return Err(Error::MalformedPacket);
238 }
239
240 Ok(stream.get_u16())
241}
242
243pub fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
249 if stream.is_empty() {
250 return Err(Error::MalformedPacket);
251 }
252
253 Ok(stream.get_u8())
254}
255
256pub fn read_u32(stream: &mut Bytes) -> Result<u32, Error> {
262 if stream.len() < 4 {
263 return Err(Error::MalformedPacket);
264 }
265
266 Ok(stream.get_u32())
267}
268
269#[cfg(test)]
270mod tests {
271 use bytes::BytesMut;
272
273 use super::*;
274
275 #[test]
276 fn len_len_matches_expected_thresholds() {
277 assert_eq!(len_len(0), 1);
278 assert_eq!(len_len(127), 1);
279 assert_eq!(len_len(128), 2);
280 assert_eq!(len_len(16_383), 2);
281 assert_eq!(len_len(16_384), 3);
282 assert_eq!(len_len(2_097_151), 3);
283 assert_eq!(len_len(2_097_152), 4);
284 }
285
286 #[test]
287 fn write_remaining_length_round_trip() {
288 for len in [0usize, 127, 128, 321, 16_384, 268_435_455] {
289 let mut b = BytesMut::new();
290 let count = write_remaining_length(&mut b, len).unwrap();
291 let (decoded_count, decoded) = length(b.iter()).unwrap();
292 assert_eq!(count, decoded_count);
293 assert_eq!(decoded, len);
294 }
295 }
296
297 #[test]
298 fn check_reports_missing_bytes() {
299 let b = [0x30u8, 0x05, 1, 2];
300 let result = check(b.iter());
301 assert_eq!(result, Err(Error::InsufficientBytes(3)));
302 }
303}