1use bytes::{BufMut, Bytes, BytesMut};
7use core::fmt;
8use mqttbytes_core::primitives::{self as core_primitives, Error as PrimitiveError};
9use std::slice::Iter;
10
11pub mod v4;
12
13pub use mqttbytes_core::{QoS, has_wildcards, matches, valid_filter, valid_topic};
14
15#[derive(Debug, thiserror::Error)]
17pub enum Error {
18 #[error("Expected Connect, received: {0:?}")]
19 NotConnect(PacketType),
20 #[error("Unexpected Connect")]
21 UnexpectedConnect,
22 #[error("Invalid Connect return code: {0}")]
23 InvalidConnectReturnCode(u8),
24 #[error("Invalid protocol")]
25 InvalidProtocol,
26 #[error("Invalid protocol level: {0}")]
27 InvalidProtocolLevel(u8),
28 #[error("Incorrect packet format")]
29 IncorrectPacketFormat,
30 #[error("Invalid packet type: {0}")]
31 InvalidPacketType(u8),
32 #[error("Invalid property type: {0}")]
33 InvalidPropertyType(u8),
34 #[error("Invalid QoS level: {0}")]
35 InvalidQoS(u8),
36 #[error("Invalid subscribe reason code: {0}")]
37 InvalidSubscribeReasonCode(u8),
38 #[error("Packet id Zero")]
39 PacketIdZero,
40 #[error("Payload size is incorrect")]
41 PayloadSizeIncorrect,
42 #[error("payload is too long")]
43 PayloadTooLong,
44 #[error("payload size limit exceeded: {0}")]
45 PayloadSizeLimitExceeded(usize),
46 #[error("Payload required")]
47 PayloadRequired,
48 #[error("Topic is not UTF-8")]
49 TopicNotUtf8,
50 #[error("Promised boundary crossed: {0}")]
51 BoundaryCrossed(usize),
52 #[error("Malformed packet")]
53 MalformedPacket,
54 #[error("Malformed remaining length")]
55 MalformedRemainingLength,
56 #[error("A Subscribe packet must contain atleast one filter")]
57 EmptySubscription,
58 #[error("At least {0} more bytes required to frame packet")]
62 InsufficientBytes(usize),
63 #[error("IO: {0}")]
64 Io(#[from] std::io::Error),
65 #[error(
66 "Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'"
67 )]
68 OutgoingPacketTooLarge { pkt_size: usize, max: usize },
69}
70
71impl From<PrimitiveError> for Error {
72 fn from(error: PrimitiveError) -> Self {
73 match error {
74 PrimitiveError::PayloadTooLong => Self::PayloadTooLong,
75 PrimitiveError::BoundaryCrossed(len) => Self::BoundaryCrossed(len),
76 PrimitiveError::MalformedPacket => Self::MalformedPacket,
77 PrimitiveError::MalformedRemainingLength => Self::MalformedRemainingLength,
78 PrimitiveError::TopicNotUtf8 => Self::TopicNotUtf8,
79 PrimitiveError::InsufficientBytes(required) => Self::InsufficientBytes(required),
80 }
81 }
82}
83
84#[repr(u8)]
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum PacketType {
88 Connect = 1,
89 ConnAck,
90 Publish,
91 PubAck,
92 PubRec,
93 PubRel,
94 PubComp,
95 Subscribe,
96 SubAck,
97 Unsubscribe,
98 UnsubAck,
99 PingReq,
100 PingResp,
101 Disconnect,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum Protocol {
107 V4,
108 V5,
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
124pub struct FixedHeader {
125 byte1: u8,
128 header_len: usize,
132 remaining_len: usize,
135}
136
137impl FixedHeader {
138 #[must_use]
139 pub const fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> Self {
140 Self {
141 byte1,
142 header_len: remaining_len_len + 1,
143 remaining_len,
144 }
145 }
146
147 pub const fn packet_type(&self) -> Result<PacketType, Error> {
154 let num = self.byte1 >> 4;
155 match num {
156 1 => Ok(PacketType::Connect),
157 2 => Ok(PacketType::ConnAck),
158 3 => Ok(PacketType::Publish),
159 4 => Ok(PacketType::PubAck),
160 5 => Ok(PacketType::PubRec),
161 6 => Ok(PacketType::PubRel),
162 7 => Ok(PacketType::PubComp),
163 8 => Ok(PacketType::Subscribe),
164 9 => Ok(PacketType::SubAck),
165 10 => Ok(PacketType::Unsubscribe),
166 11 => Ok(PacketType::UnsubAck),
167 12 => Ok(PacketType::PingReq),
168 13 => Ok(PacketType::PingResp),
169 14 => Ok(PacketType::Disconnect),
170 _ => Err(Error::InvalidPacketType(num)),
171 }
172 }
173
174 #[must_use]
177 pub const fn frame_length(&self) -> usize {
178 self.header_len + self.remaining_len
179 }
180}
181
182pub fn check(stream: Iter<u8>, max_packet_size: usize) -> Result<FixedHeader, Error> {
195 let stream_len = stream.len();
196 let fixed_header = parse_fixed_header(stream)?;
197
198 if fixed_header.remaining_len > max_packet_size {
201 return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len));
202 }
203
204 let frame_length = fixed_header.frame_length();
205 if stream_len < frame_length {
206 return Err(Error::InsufficientBytes(frame_length - stream_len));
207 }
208
209 Ok(fixed_header)
210}
211
212fn parse_fixed_header(stream: Iter<u8>) -> Result<FixedHeader, Error> {
213 let fixed_header = core_primitives::parse_fixed_header(stream).map_err(Error::from)?;
214 Ok(FixedHeader::new(
215 fixed_header.byte1,
216 fixed_header.remaining_len_len,
217 fixed_header.remaining_len,
218 ))
219}
220
221fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
223 core_primitives::read_mqtt_bytes(stream).map_err(Error::from)
224}
225
226fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
228 core_primitives::read_mqtt_string(stream).map_err(Error::from)
229}
230
231fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
233 core_primitives::write_mqtt_bytes(stream, bytes);
234}
235
236fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
238 core_primitives::write_mqtt_string(stream, string);
239}
240
241fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
243 core_primitives::write_remaining_length(stream, len).map_err(Error::from)
244}
245
246pub fn qos(num: u8) -> Result<QoS, Error> {
253 mqttbytes_core::qos(num).ok_or(Error::InvalidQoS(num))
254}
255
256fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
262 core_primitives::read_u16(stream).map_err(Error::from)
263}
264
265fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
266 core_primitives::read_u8(stream).map_err(Error::from)
267}
268
269#[cfg(test)]
270mod tests {
271 use super::{Error, check};
272
273 #[test]
274 fn check_rejects_oversized_packet_on_partial_frame() {
275 let stream = [0x30, 0x14];
276 let result = check(stream.iter(), 10);
277
278 assert!(matches!(result, Err(Error::PayloadSizeLimitExceeded(20))));
279 }
280}