1use std::slice::Iter;
2
3pub use self::{
4 auth::{Auth, AuthProperties, AuthReasonCode},
5 codec::Codec,
6 connack::{ConnAck, ConnAckProperties, ConnectReturnCode},
7 connect::{Connect, ConnectAuth, ConnectProperties, LastWill, LastWillProperties},
8 disconnect::{Disconnect, DisconnectProperties, DisconnectReasonCode},
9 ping::{PingReq, PingResp},
10 puback::{PubAck, PubAckProperties, PubAckReason},
11 pubcomp::{PubComp, PubCompProperties, PubCompReason},
12 publish::{Publish, PublishProperties},
13 pubrec::{PubRec, PubRecProperties, PubRecReason},
14 pubrel::{PubRel, PubRelProperties, PubRelReason},
15 suback::{SubAck, SubAckProperties, SubscribeReasonCode},
16 subscribe::{Filter, RetainForwardRule, Subscribe, SubscribeProperties},
17 unsuback::{UnsubAck, UnsubAckProperties, UnsubAckReason},
18 unsubscribe::{Unsubscribe, UnsubscribeProperties},
19};
20
21use super::{Error, QoS, qos};
22use bytes::{Buf, BufMut, Bytes, BytesMut};
23use mqttbytes_core::primitives::{self as core_primitives, Error as PrimitiveError};
24
25#[allow(clippy::missing_errors_doc)]
26mod auth;
27#[allow(clippy::missing_errors_doc)]
28mod codec;
29#[allow(clippy::missing_errors_doc)]
30mod connack;
31#[allow(clippy::missing_errors_doc)]
32mod connect;
33#[allow(clippy::missing_errors_doc)]
34mod disconnect;
35#[allow(clippy::missing_errors_doc)]
36mod ping;
37#[allow(clippy::missing_errors_doc)]
38mod puback;
39#[allow(clippy::missing_errors_doc)]
40mod pubcomp;
41#[allow(clippy::missing_errors_doc)]
42mod publish;
43#[allow(clippy::missing_errors_doc)]
44mod pubrec;
45#[allow(clippy::missing_errors_doc)]
46mod pubrel;
47#[allow(clippy::missing_errors_doc)]
48mod suback;
49#[allow(clippy::missing_errors_doc)]
50mod subscribe;
51#[allow(clippy::missing_errors_doc)]
52mod unsuback;
53#[allow(clippy::missing_errors_doc)]
54mod unsubscribe;
55
56#[derive(Clone, Debug, PartialEq, Eq)]
57pub enum Packet {
58 Auth(Auth),
59 Connect(Connect, Option<LastWill>, ConnectAuth),
60 ConnAck(ConnAck),
61 Publish(Publish),
62 PubAck(PubAck),
63 PingReq(PingReq),
64 PingResp(PingResp),
65 Subscribe(Subscribe),
66 SubAck(SubAck),
67 PubRec(PubRec),
68 PubRel(PubRel),
69 PubComp(PubComp),
70 Unsubscribe(Unsubscribe),
71 UnsubAck(UnsubAck),
72 Disconnect(Disconnect),
73}
74
75impl From<PrimitiveError> for Error {
76 fn from(error: PrimitiveError) -> Self {
77 match error {
78 PrimitiveError::PayloadTooLong => Self::PayloadTooLong,
79 PrimitiveError::BoundaryCrossed(len) => Self::BoundaryCrossed(len),
80 PrimitiveError::MalformedPacket => Self::MalformedPacket,
81 PrimitiveError::MalformedRemainingLength => Self::MalformedRemainingLength,
82 PrimitiveError::TopicNotUtf8 => Self::TopicNotUtf8,
83 PrimitiveError::InsufficientBytes(required) => Self::InsufficientBytes(required),
84 }
85 }
86}
87
88impl Packet {
89 pub fn read(stream: &mut BytesMut, max_size: Option<u32>) -> Result<Self, Error> {
96 let fixed_header = check(stream.iter(), max_size)?;
97
98 let packet = stream.split_to(fixed_header.frame_length());
100 let packet_type = fixed_header.packet_type()?;
101
102 if fixed_header.remaining_len == 0 {
103 return match packet_type {
105 PacketType::PingReq => Ok(Self::PingReq(PingReq)),
106 PacketType::PingResp => Ok(Self::PingResp(PingResp)),
107 PacketType::Disconnect => {
108 Disconnect::read(fixed_header, packet.freeze()).map(Self::Disconnect)
109 }
110 _ => Err(Error::PayloadRequired),
111 };
112 }
113
114 let packet = packet.freeze();
115 let packet = match packet_type {
116 PacketType::Connect => {
117 let (connect, will, auth) = Connect::read(fixed_header, packet)?;
118 Self::Connect(connect, will, auth)
119 }
120 PacketType::Publish => {
121 let publish = Publish::read(fixed_header, packet)?;
122 Self::Publish(publish)
123 }
124 PacketType::Subscribe => {
125 let subscribe = Subscribe::read(fixed_header, packet)?;
126 Self::Subscribe(subscribe)
127 }
128 PacketType::Unsubscribe => {
129 let unsubscribe = Unsubscribe::read(fixed_header, packet)?;
130 Self::Unsubscribe(unsubscribe)
131 }
132 PacketType::ConnAck => {
133 let connack = ConnAck::read(fixed_header, packet)?;
134 Self::ConnAck(connack)
135 }
136 PacketType::PubAck => {
137 let puback = PubAck::read(fixed_header, packet)?;
138 Self::PubAck(puback)
139 }
140 PacketType::PubRec => {
141 let pubrec = PubRec::read(fixed_header, packet)?;
142 Self::PubRec(pubrec)
143 }
144 PacketType::PubRel => {
145 let pubrel = PubRel::read(fixed_header, packet)?;
146 Self::PubRel(pubrel)
147 }
148 PacketType::PubComp => {
149 let pubcomp = PubComp::read(fixed_header, packet)?;
150 Self::PubComp(pubcomp)
151 }
152 PacketType::SubAck => {
153 let suback = SubAck::read(fixed_header, packet)?;
154 Self::SubAck(suback)
155 }
156 PacketType::UnsubAck => {
157 let unsuback = UnsubAck::read(fixed_header, packet)?;
158 Self::UnsubAck(unsuback)
159 }
160 PacketType::PingReq => Self::PingReq(PingReq),
161 PacketType::PingResp => Self::PingResp(PingResp),
162 PacketType::Disconnect => {
163 let disconnect = Disconnect::read(fixed_header, packet)?;
164 Self::Disconnect(disconnect)
165 }
166 PacketType::Auth => {
167 let auth = Auth::read(fixed_header, packet)?;
168 Self::Auth(auth)
169 }
170 };
171
172 Ok(packet)
173 }
174
175 pub fn write(&self, write: &mut BytesMut, max_size: Option<u32>) -> Result<usize, Error> {
182 if let Some(max_size) = max_size
183 && self.size() > max_size as usize
184 {
185 return Err(Error::OutgoingPacketTooLarge {
186 pkt_size: u32::try_from(self.size()).unwrap_or(u32::MAX),
187 max: max_size,
188 });
189 }
190
191 match self {
192 Self::Auth(auth) => auth.write(write),
193 Self::Publish(publish) => publish.write(write),
194 Self::Subscribe(subscription) => subscription.write(write),
195 Self::Unsubscribe(unsubscribe) => unsubscribe.write(write),
196 Self::ConnAck(ack) => ack.write(write),
197 Self::PubAck(ack) => ack.write(write),
198 Self::SubAck(ack) => ack.write(write),
199 Self::UnsubAck(unsuback) => unsuback.write(write),
200 Self::PubRec(pubrec) => pubrec.write(write),
201 Self::PubRel(pubrel) => pubrel.write(write),
202 Self::PubComp(pubcomp) => pubcomp.write(write),
203 Self::Connect(connect, will, auth) => connect.write(will, auth, write),
204 Self::PingReq(_) => PingReq::write(write),
205 Self::PingResp(_) => PingResp::write(write),
206 Self::Disconnect(disconnect) => disconnect.write(write),
207 }
208 }
209
210 pub fn size(&self) -> usize {
211 match self {
212 Self::Auth(auth) => auth.size(),
213 Self::Publish(publish) => publish.size(),
214 Self::Subscribe(subscription) => subscription.size(),
215 Self::Unsubscribe(unsubscribe) => unsubscribe.size(),
216 Self::ConnAck(ack) => ack.size(),
217 Self::PubAck(ack) => ack.size(),
218 Self::SubAck(ack) => ack.size(),
219 Self::UnsubAck(unsuback) => unsuback.size(),
220 Self::PubRec(pubrec) => pubrec.size(),
221 Self::PubRel(pubrel) => pubrel.size(),
222 Self::PubComp(pubcomp) => pubcomp.size(),
223 Self::Connect(connect, will, auth) => connect.size(will, auth),
224 Self::PingReq(req) => req.size(),
225 Self::PingResp(resp) => resp.size(),
226 Self::Disconnect(disconnect) => disconnect.size(),
227 }
228 }
229}
230
231#[repr(u8)]
233#[derive(Debug, Clone, Copy, PartialEq, Eq)]
234pub enum PacketType {
235 Connect = 1,
236 ConnAck,
237 Publish,
238 PubAck,
239 PubRec,
240 PubRel,
241 PubComp,
242 Subscribe,
243 SubAck,
244 Unsubscribe,
245 UnsubAck,
246 PingReq,
247 PingResp,
248 Disconnect,
249 Auth,
250}
251
252#[repr(u8)]
253#[derive(Debug, Clone, Copy, PartialEq, Eq)]
254enum PropertyType {
255 PayloadFormatIndicator = 1,
256 MessageExpiryInterval = 2,
257 ContentType = 3,
258 ResponseTopic = 8,
259 CorrelationData = 9,
260 SubscriptionIdentifier = 11,
261 SessionExpiryInterval = 17,
262 AssignedClientIdentifier = 18,
263 ServerKeepAlive = 19,
264 AuthenticationMethod = 21,
265 AuthenticationData = 22,
266 RequestProblemInformation = 23,
267 WillDelayInterval = 24,
268 RequestResponseInformation = 25,
269 ResponseInformation = 26,
270 ServerReference = 28,
271 ReasonString = 31,
272 ReceiveMaximum = 33,
273 TopicAliasMaximum = 34,
274 TopicAlias = 35,
275 MaximumQos = 36,
276 RetainAvailable = 37,
277 UserProperty = 38,
278 MaximumPacketSize = 39,
279 WildcardSubscriptionAvailable = 40,
280 SubscriptionIdentifierAvailable = 41,
281 SharedSubscriptionAvailable = 42,
282}
283
284#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
297pub struct FixedHeader {
298 byte1: u8,
301 header_len: usize,
305 remaining_len: usize,
308}
309
310impl FixedHeader {
311 #[must_use]
312 pub const fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> Self {
313 Self {
314 byte1,
315 header_len: remaining_len_len + 1,
316 remaining_len,
317 }
318 }
319
320 pub const fn packet_type(&self) -> Result<PacketType, Error> {
327 let num = self.byte1 >> 4;
328 match num {
329 1 => Ok(PacketType::Connect),
330 2 => Ok(PacketType::ConnAck),
331 3 => Ok(PacketType::Publish),
332 4 => Ok(PacketType::PubAck),
333 5 => Ok(PacketType::PubRec),
334 6 => Ok(PacketType::PubRel),
335 7 => Ok(PacketType::PubComp),
336 8 => Ok(PacketType::Subscribe),
337 9 => Ok(PacketType::SubAck),
338 10 => Ok(PacketType::Unsubscribe),
339 11 => Ok(PacketType::UnsubAck),
340 12 => Ok(PacketType::PingReq),
341 13 => Ok(PacketType::PingResp),
342 14 => Ok(PacketType::Disconnect),
343 15 => Ok(PacketType::Auth),
344 _ => Err(Error::InvalidPacketType(num)),
345 }
346 }
347
348 #[must_use]
351 pub const fn frame_length(&self) -> usize {
352 self.header_len + self.remaining_len
353 }
354}
355
356const fn property(num: u8) -> Result<PropertyType, Error> {
357 let property = match num {
358 1 => PropertyType::PayloadFormatIndicator,
359 2 => PropertyType::MessageExpiryInterval,
360 3 => PropertyType::ContentType,
361 8 => PropertyType::ResponseTopic,
362 9 => PropertyType::CorrelationData,
363 11 => PropertyType::SubscriptionIdentifier,
364 17 => PropertyType::SessionExpiryInterval,
365 18 => PropertyType::AssignedClientIdentifier,
366 19 => PropertyType::ServerKeepAlive,
367 21 => PropertyType::AuthenticationMethod,
368 22 => PropertyType::AuthenticationData,
369 23 => PropertyType::RequestProblemInformation,
370 24 => PropertyType::WillDelayInterval,
371 25 => PropertyType::RequestResponseInformation,
372 26 => PropertyType::ResponseInformation,
373 28 => PropertyType::ServerReference,
374 31 => PropertyType::ReasonString,
375 33 => PropertyType::ReceiveMaximum,
376 34 => PropertyType::TopicAliasMaximum,
377 35 => PropertyType::TopicAlias,
378 36 => PropertyType::MaximumQos,
379 37 => PropertyType::RetainAvailable,
380 38 => PropertyType::UserProperty,
381 39 => PropertyType::MaximumPacketSize,
382 40 => PropertyType::WildcardSubscriptionAvailable,
383 41 => PropertyType::SubscriptionIdentifierAvailable,
384 42 => PropertyType::SharedSubscriptionAvailable,
385 num => return Err(Error::InvalidPropertyType(num)),
386 };
387
388 Ok(property)
389}
390
391pub fn check(stream: Iter<u8>, max_packet_size: Option<u32>) -> Result<FixedHeader, Error> {
404 let stream_len = stream.len();
405 let fixed_header = parse_fixed_header(stream)?;
406
407 if let Some(max_size) = max_packet_size
410 && fixed_header.remaining_len > max_size as usize
411 {
412 return Err(Error::PayloadSizeLimitExceeded {
413 pkt_size: fixed_header.remaining_len,
414 max: max_size,
415 });
416 }
417
418 let frame_length = fixed_header.frame_length();
419 if stream_len < frame_length {
420 return Err(Error::InsufficientBytes(frame_length - stream_len));
421 }
422
423 Ok(fixed_header)
424}
425
426fn parse_fixed_header(stream: Iter<u8>) -> Result<FixedHeader, Error> {
427 let fixed_header = core_primitives::parse_fixed_header(stream).map_err(Error::from)?;
428 Ok(FixedHeader::new(
429 fixed_header.byte1,
430 fixed_header.remaining_len_len,
431 fixed_header.remaining_len,
432 ))
433}
434
435fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
439 core_primitives::length(stream).map_err(Error::from)
440}
441
442fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
444 core_primitives::read_mqtt_bytes(stream).map_err(Error::from)
445}
446
447fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
449 core_primitives::read_mqtt_string(stream).map_err(Error::from)
450}
451
452fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
454 core_primitives::write_mqtt_bytes(stream, bytes);
455}
456
457fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
459 core_primitives::write_mqtt_string(stream, string);
460}
461
462fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
464 core_primitives::write_remaining_length(stream, len).map_err(Error::from)
465}
466
467const fn len_len(len: usize) -> usize {
469 core_primitives::len_len(len)
470}
471
472fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
478 core_primitives::read_u16(stream).map_err(Error::from)
479}
480
481fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
482 core_primitives::read_u8(stream).map_err(Error::from)
483}
484
485fn read_u32(stream: &mut Bytes) -> Result<u32, Error> {
486 core_primitives::read_u32(stream).map_err(Error::from)
487}
488
489mod test {
490 #[allow(dead_code)]
492 pub const USER_PROP_KEY: &str = "property";
493 #[allow(dead_code)]
494 pub const USER_PROP_VAL: &str = "a value thats really long............................................................................................................";
495}
496
497#[cfg(test)]
498mod tests {
499 use super::{Error, check};
500
501 #[test]
502 fn check_rejects_oversized_packet_on_partial_frame() {
503 let stream = [0x30, 0x14];
504 let result = check(stream.iter(), Some(10));
505
506 assert!(matches!(
507 result,
508 Err(Error::PayloadSizeLimitExceeded {
509 pkt_size: 20,
510 max: 10,
511 })
512 ));
513 }
514}