1extern crate alloc;
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use core::fmt::{self, Display, Formatter};
5use std::slice::Iter;
6
7pub mod v4;
8pub mod v5;
9mod topic;
10
11pub use topic::*;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum Error {
16 NotConnect(PacketType),
17 UnexpectedConnect,
18 InvalidConnectReturnCode(u8),
19 InvalidReason(u8),
20 InvalidProtocol,
21 InvalidProtocolLevel(u8),
22 IncorrectPacketFormat,
23 InvalidPacketType(u8),
24 InvalidPropertyType(u8),
25 InvalidRetainForwardRule(u8),
26 InvalidQoS(u8),
27 InvalidSubscribeReasonCode(u8),
28 PacketIdZero,
29 SubscriptionIdZero,
30 PayloadSizeIncorrect,
31 PayloadTooLong,
32 PayloadSizeLimitExceeded(usize),
33 PayloadRequired,
34 TopicNotUtf8,
35 BoundaryCrossed(usize),
36 MalformedPacket,
37 MalformedRemainingLength,
38 InsufficientBytes(usize),
42}
43
44#[repr(u8)]
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum PacketType {
48 Connect = 1,
49 ConnAck,
50 Publish,
51 PubAck,
52 PubRec,
53 PubRel,
54 PubComp,
55 Subscribe,
56 SubAck,
57 Unsubscribe,
58 UnsubAck,
59 PingReq,
60 PingResp,
61 Disconnect,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum Protocol {
67 V4,
68 V5,
69}
70
71#[repr(u8)]
73#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
74pub enum QoS {
75 AtMostOnce = 0,
76 AtLeastOnce = 1,
77 ExactlyOnce = 2,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
93pub struct FixedHeader {
94 byte1: u8,
97 fixed_header_len: usize,
101 remaining_len: usize,
104}
105
106impl FixedHeader {
107 pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader {
108 FixedHeader {
109 byte1,
110 fixed_header_len: remaining_len_len + 1,
111 remaining_len,
112 }
113 }
114
115 pub fn packet_type(&self) -> Result<PacketType, Error> {
116 let num = self.byte1 >> 4;
117 match num {
118 1 => Ok(PacketType::Connect),
119 2 => Ok(PacketType::ConnAck),
120 3 => Ok(PacketType::Publish),
121 4 => Ok(PacketType::PubAck),
122 5 => Ok(PacketType::PubRec),
123 6 => Ok(PacketType::PubRel),
124 7 => Ok(PacketType::PubComp),
125 8 => Ok(PacketType::Subscribe),
126 9 => Ok(PacketType::SubAck),
127 10 => Ok(PacketType::Unsubscribe),
128 11 => Ok(PacketType::UnsubAck),
129 12 => Ok(PacketType::PingReq),
130 13 => Ok(PacketType::PingResp),
131 14 => Ok(PacketType::Disconnect),
132 _ => Err(Error::InvalidPacketType(num)),
133 }
134 }
135
136 pub fn frame_length(&self) -> usize {
139 self.fixed_header_len + self.remaining_len
140 }
141}
142
143pub fn check(stream: Iter<u8>, max_packet_size: usize) -> Result<FixedHeader, Error> {
149 let stream_len = stream.len();
152 let fixed_header = parse_fixed_header(stream)?;
153
154 if fixed_header.remaining_len > max_packet_size {
157 return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len));
158 }
159
160 let frame_length = fixed_header.frame_length();
163 if stream_len < frame_length {
164 return Err(Error::InsufficientBytes(frame_length - stream_len));
165 }
166
167 Ok(fixed_header)
168}
169
170fn parse_fixed_header(mut stream: Iter<u8>) -> Result<FixedHeader, Error> {
172 let stream_len = stream.len();
174 if stream_len < 2 {
175 return Err(Error::InsufficientBytes(2 - stream_len));
176 }
177
178 let byte1 = stream.next().unwrap();
179 let (len_len, len) = length(stream)?;
180
181 Ok(FixedHeader::new(*byte1, len_len, len))
182}
183
184fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
188 let mut len: usize = 0;
189 let mut len_len = 0;
190 let mut done = false;
191 let mut shift = 0;
192
193 for byte in stream {
198 len_len += 1;
199 let byte = *byte as usize;
200 len += (byte & 0x7F) << shift;
201
202 done = (byte & 0x80) == 0;
204 if done {
205 break;
206 }
207
208 shift += 7;
209
210 if shift > 21 {
213 return Err(Error::MalformedRemainingLength);
214 }
215 }
216
217 if !done {
220 return Err(Error::InsufficientBytes(1));
221 }
222
223 Ok((len_len, len))
224}
225
226fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
228 let len = read_u16(stream)? as usize;
229
230 if len > stream.len() {
235 return Err(Error::BoundaryCrossed(len));
236 }
237
238 Ok(stream.split_to(len))
239}
240
241fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
243 let s = read_mqtt_bytes(stream)?;
244 match String::from_utf8(s.to_vec()) {
245 Ok(v) => Ok(v),
246 Err(_e) => Err(Error::TopicNotUtf8),
247 }
248}
249
250fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
252 stream.put_u16(bytes.len() as u16);
253 stream.extend_from_slice(bytes);
254}
255
256fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
258 write_mqtt_bytes(stream, string.as_bytes());
259}
260
261fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
263 if len > 268_435_455 {
264 return Err(Error::PayloadTooLong);
265 }
266
267 let mut done = false;
268 let mut x = len;
269 let mut count = 0;
270
271 while !done {
272 let mut byte = (x % 128) as u8;
273 x /= 128;
274 if x > 0 {
275 byte |= 128;
276 }
277
278 stream.put_u8(byte);
279 count += 1;
280 done = x == 0;
281 }
282
283 Ok(count)
284}
285
286fn len_len(len: usize) -> usize {
288 if len >= 2_097_152 {
289 4
290 } else if len >= 16_384 {
291 3
292 } else if len >= 128 {
293 2
294 } else {
295 1
296 }
297}
298
299pub fn qos(num: u8) -> Result<QoS, Error> {
301 match num {
302 0 => Ok(QoS::AtMostOnce),
303 1 => Ok(QoS::AtLeastOnce),
304 2 => Ok(QoS::ExactlyOnce),
305 qos => Err(Error::InvalidQoS(qos)),
306 }
307}
308
309fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
315 if stream.len() < 2 {
316 return Err(Error::MalformedPacket);
317 }
318
319 Ok(stream.get_u16())
320}
321
322fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
323 if stream.len() < 1 {
324 return Err(Error::MalformedPacket);
325 }
326
327 Ok(stream.get_u8())
328}
329
330fn read_u32(stream: &mut Bytes) -> Result<u32, Error> {
331 if stream.len() < 4 {
332 return Err(Error::MalformedPacket);
333 }
334
335 Ok(stream.get_u32())
336}
337
338impl Display for Error {
339 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
340 write!(f, "Error = {:?}", self)
341 }
342}