1mod ack_common;
2pub mod auth;
3pub mod connack;
4pub mod connect;
5pub mod disconnect;
6pub mod pingreq;
7pub mod pingresp;
8pub mod puback;
9pub mod pubcomp;
10pub mod publish;
11pub mod pubrec;
12pub mod pubrel;
13pub mod suback;
14pub mod subscribe;
15pub mod unsuback;
16pub mod unsubscribe;
17
18pub use ack_common::{is_valid_publish_ack_reason_code, is_valid_pubrel_reason_code};
19
20#[cfg(test)]
21mod property_tests;
22
23#[cfg(test)]
24mod bebytes_tests {
25 use super::*;
26 use proptest::prelude::*;
27
28 proptest! {
29 #[test]
30 fn prop_mqtt_type_and_flags_round_trip(
31 message_type in 1u8..=15,
32 dup in 0u8..=1,
33 qos in 0u8..=3,
34 retain in 0u8..=1
35 ) {
36 let original = MqttTypeAndFlags {
37 message_type,
38 dup,
39 qos,
40 retain,
41 };
42
43 let bytes = original.to_be_bytes();
44 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
45
46 prop_assert_eq!(original, decoded);
47 }
48
49 #[test]
50 fn prop_packet_type_round_trip(packet_type in 1u8..=15) {
51 if let Some(pt) = PacketType::from_u8(packet_type) {
52 let type_and_flags = MqttTypeAndFlags::for_packet_type(pt);
53 let bytes = type_and_flags.to_be_bytes();
54 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
55
56 prop_assert_eq!(type_and_flags, decoded);
57 prop_assert_eq!(decoded.packet_type(), Some(pt));
58 }
59 }
60
61 #[test]
62 fn prop_publish_flags_round_trip(
63 qos in 0u8..=3,
64 dup: bool,
65 retain: bool
66 ) {
67 let type_and_flags = MqttTypeAndFlags::for_publish(qos, dup, retain);
68 let bytes = type_and_flags.to_be_bytes();
69 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
70
71 prop_assert_eq!(type_and_flags, decoded);
72 prop_assert_eq!(decoded.packet_type(), Some(PacketType::Publish));
73 prop_assert_eq!(decoded.qos, qos);
74 prop_assert_eq!(decoded.is_dup(), dup);
75 prop_assert_eq!(decoded.is_retain(), retain);
76 }
77 }
78}
79
80use crate::encoding::{decode_variable_int, encode_variable_int};
81use crate::error::{MqttError, Result};
82use crate::prelude::{Box, ToString, Vec};
83use bebytes::BeBytes;
84use bytes::{Buf, BufMut};
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
89pub struct AckPacketHeader {
90 pub packet_id: u16,
92 pub reason_code: u8,
94}
95
96impl AckPacketHeader {
97 #[must_use]
99 pub fn create(packet_id: u16, reason_code: crate::types::ReasonCode) -> Self {
100 Self {
101 packet_id,
102 reason_code: u8::from(reason_code),
103 }
104 }
105
106 #[must_use]
108 pub fn get_reason_code(&self) -> Option<crate::types::ReasonCode> {
109 crate::types::ReasonCode::from_u8(self.reason_code)
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
115pub struct MqttTypeAndFlags {
116 #[bits(4)]
118 pub message_type: u8,
119 #[bits(1)]
121 pub dup: u8,
122 #[bits(2)]
124 pub qos: u8,
125 #[bits(1)]
127 pub retain: u8,
128}
129
130impl MqttTypeAndFlags {
131 #[must_use]
133 pub fn for_packet_type(packet_type: PacketType) -> Self {
134 Self {
135 message_type: packet_type as u8,
136 dup: 0,
137 qos: 0,
138 retain: 0,
139 }
140 }
141
142 #[must_use]
144 pub fn for_publish(qos: u8, dup: bool, retain: bool) -> Self {
145 Self {
146 message_type: PacketType::Publish as u8,
147 dup: u8::from(dup),
148 qos,
149 retain: u8::from(retain),
150 }
151 }
152
153 #[must_use]
155 pub fn packet_type(&self) -> Option<PacketType> {
156 PacketType::from_u8(self.message_type)
157 }
158
159 #[must_use]
161 pub fn is_dup(&self) -> bool {
162 self.dup != 0
163 }
164
165 #[must_use]
167 pub fn is_retain(&self) -> bool {
168 self.retain != 0
169 }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
173pub enum PacketType {
174 Connect = 1,
175 ConnAck = 2,
176 Publish = 3,
177 PubAck = 4,
178 PubRec = 5,
179 PubRel = 6,
180 PubComp = 7,
181 Subscribe = 8,
182 SubAck = 9,
183 Unsubscribe = 10,
184 UnsubAck = 11,
185 PingReq = 12,
186 PingResp = 13,
187 Disconnect = 14,
188 Auth = 15,
189}
190
191impl PacketType {
192 #[must_use]
194 pub fn from_u8(value: u8) -> Option<Self> {
195 Self::try_from(value).ok()
197 }
198}
199
200impl From<PacketType> for u8 {
201 fn from(packet_type: PacketType) -> Self {
202 packet_type as u8
203 }
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
208pub struct FixedHeader {
209 pub packet_type: PacketType,
210 pub flags: u8,
211 pub remaining_length: u32,
212}
213
214impl FixedHeader {
215 #[must_use]
217 pub fn new(packet_type: PacketType, flags: u8, remaining_length: u32) -> Self {
218 Self {
219 packet_type,
220 flags,
221 remaining_length,
222 }
223 }
224
225 pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
231 let byte1 =
232 (u8::from(self.packet_type) << 4) | (self.flags & crate::constants::masks::FLAGS);
233 buf.put_u8(byte1);
234 encode_variable_int(buf, self.remaining_length)?;
235 Ok(())
236 }
237
238 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
247 if !buf.has_remaining() {
248 return Err(MqttError::MalformedPacket(
249 "No data for fixed header".to_string(),
250 ));
251 }
252
253 let byte1 = buf.get_u8();
254 let packet_type_val = (byte1 >> 4) & crate::constants::masks::FLAGS;
255 let flags = byte1 & crate::constants::masks::FLAGS;
256
257 let packet_type = PacketType::from_u8(packet_type_val)
258 .ok_or(MqttError::InvalidPacketType(packet_type_val))?;
259
260 let remaining_length = decode_variable_int(buf)?;
261
262 Ok(Self {
263 packet_type,
264 flags,
265 remaining_length,
266 })
267 }
268
269 #[must_use]
271 pub fn validate_flags(&self) -> bool {
272 match self.packet_type {
273 PacketType::Publish => true, PacketType::PubRel | PacketType::Subscribe | PacketType::Unsubscribe => {
275 self.flags == 0x02 }
277 _ => self.flags == 0,
278 }
279 }
280
281 #[must_use]
283 pub fn encoded_len(&self) -> usize {
284 1 + crate::encoding::encoded_variable_int_len(self.remaining_length)
286 }
287}
288
289#[derive(Debug, Clone)]
291pub enum Packet {
292 Connect(Box<connect::ConnectPacket>),
293 ConnAck(connack::ConnAckPacket),
294 Publish(publish::PublishPacket),
295 PubAck(puback::PubAckPacket),
296 PubRec(pubrec::PubRecPacket),
297 PubRel(pubrel::PubRelPacket),
298 PubComp(pubcomp::PubCompPacket),
299 Subscribe(subscribe::SubscribePacket),
300 SubAck(suback::SubAckPacket),
301 Unsubscribe(unsubscribe::UnsubscribePacket),
302 UnsubAck(unsuback::UnsubAckPacket),
303 PingReq,
304 PingResp,
305 Disconnect(disconnect::DisconnectPacket),
306 Auth(auth::AuthPacket),
307}
308
309impl Packet {
310 pub fn decode_from_body<B: Buf>(
316 packet_type: PacketType,
317 fixed_header: &FixedHeader,
318 buf: &mut B,
319 ) -> Result<Self> {
320 match packet_type {
321 PacketType::Connect => {
322 let packet = connect::ConnectPacket::decode_body(buf, fixed_header)?;
323 Ok(Packet::Connect(Box::new(packet)))
324 }
325 PacketType::ConnAck => {
326 let packet = connack::ConnAckPacket::decode_body(buf, fixed_header)?;
327 Ok(Packet::ConnAck(packet))
328 }
329 PacketType::Publish => {
330 let packet = publish::PublishPacket::decode_body(buf, fixed_header)?;
331 Ok(Packet::Publish(packet))
332 }
333 PacketType::PubAck => {
334 let packet = puback::PubAckPacket::decode_body(buf, fixed_header)?;
335 Ok(Packet::PubAck(packet))
336 }
337 PacketType::PubRec => {
338 let packet = pubrec::PubRecPacket::decode_body(buf, fixed_header)?;
339 Ok(Packet::PubRec(packet))
340 }
341 PacketType::PubRel => {
342 let packet = pubrel::PubRelPacket::decode_body(buf, fixed_header)?;
343 Ok(Packet::PubRel(packet))
344 }
345 PacketType::PubComp => {
346 let packet = pubcomp::PubCompPacket::decode_body(buf, fixed_header)?;
347 Ok(Packet::PubComp(packet))
348 }
349 PacketType::Subscribe => {
350 let packet = subscribe::SubscribePacket::decode_body(buf, fixed_header)?;
351 Ok(Packet::Subscribe(packet))
352 }
353 PacketType::SubAck => {
354 let packet = suback::SubAckPacket::decode_body(buf, fixed_header)?;
355 Ok(Packet::SubAck(packet))
356 }
357 PacketType::Unsubscribe => {
358 let packet = unsubscribe::UnsubscribePacket::decode_body(buf, fixed_header)?;
359 Ok(Packet::Unsubscribe(packet))
360 }
361 PacketType::UnsubAck => {
362 let packet = unsuback::UnsubAckPacket::decode_body(buf, fixed_header)?;
363 Ok(Packet::UnsubAck(packet))
364 }
365 PacketType::PingReq => Ok(Packet::PingReq),
366 PacketType::PingResp => Ok(Packet::PingResp),
367 PacketType::Disconnect => {
368 let packet = disconnect::DisconnectPacket::decode_body(buf, fixed_header)?;
369 Ok(Packet::Disconnect(packet))
370 }
371 PacketType::Auth => {
372 let packet = auth::AuthPacket::decode_body(buf, fixed_header)?;
373 Ok(Packet::Auth(packet))
374 }
375 }
376 }
377
378 pub fn decode_from_body_with_version<B: Buf>(
384 packet_type: PacketType,
385 fixed_header: &FixedHeader,
386 buf: &mut B,
387 protocol_version: u8,
388 ) -> Result<Self> {
389 match packet_type {
390 PacketType::Publish => {
391 let packet = publish::PublishPacket::decode_body_with_version(
392 buf,
393 fixed_header,
394 protocol_version,
395 )?;
396 Ok(Packet::Publish(packet))
397 }
398 PacketType::Subscribe => {
399 let packet = subscribe::SubscribePacket::decode_body_with_version(
400 buf,
401 fixed_header,
402 protocol_version,
403 )?;
404 Ok(Packet::Subscribe(packet))
405 }
406 PacketType::SubAck => {
407 let packet = suback::SubAckPacket::decode_body_with_version(
408 buf,
409 fixed_header,
410 protocol_version,
411 )?;
412 Ok(Packet::SubAck(packet))
413 }
414 PacketType::Unsubscribe => {
415 let packet = unsubscribe::UnsubscribePacket::decode_body_with_version(
416 buf,
417 fixed_header,
418 protocol_version,
419 )?;
420 Ok(Packet::Unsubscribe(packet))
421 }
422 _ => Self::decode_from_body(packet_type, fixed_header, buf),
423 }
424 }
425}
426
427pub trait MqttPacket: Sized {
429 fn packet_type(&self) -> PacketType;
431
432 fn flags(&self) -> u8 {
434 0
435 }
436
437 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()>;
443
444 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self>;
450
451 fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
457 let mut body = Vec::new();
459 self.encode_body(&mut body)?;
460
461 let fixed_header = FixedHeader::new(
462 self.packet_type(),
463 self.flags(),
464 body.len().try_into().unwrap_or(u32::MAX),
465 );
466
467 fixed_header.encode(buf)?;
468 buf.put_slice(&body);
469 Ok(())
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476 use bytes::BytesMut;
477
478 #[test]
479 fn test_packet_type_from_u8() {
480 assert_eq!(PacketType::from_u8(1), Some(PacketType::Connect));
481 assert_eq!(PacketType::from_u8(2), Some(PacketType::ConnAck));
482 assert_eq!(PacketType::from_u8(15), Some(PacketType::Auth));
483 assert_eq!(PacketType::from_u8(0), None);
484 assert_eq!(PacketType::from_u8(16), None);
485 }
486
487 #[test]
488 fn test_fixed_header_encode_decode() {
489 let mut buf = BytesMut::new();
490
491 let header = FixedHeader::new(PacketType::Connect, 0, 100);
492 header.encode(&mut buf).unwrap();
493
494 let decoded = FixedHeader::decode(&mut buf).unwrap();
495 assert_eq!(decoded.packet_type, PacketType::Connect);
496 assert_eq!(decoded.flags, 0);
497 assert_eq!(decoded.remaining_length, 100);
498 }
499
500 #[test]
501 fn test_fixed_header_with_flags() {
502 let mut buf = BytesMut::new();
503
504 let header = FixedHeader::new(PacketType::Publish, 0x0D, 50);
505 header.encode(&mut buf).unwrap();
506
507 let decoded = FixedHeader::decode(&mut buf).unwrap();
508 assert_eq!(decoded.packet_type, PacketType::Publish);
509 assert_eq!(decoded.flags, 0x0D);
510 assert_eq!(decoded.remaining_length, 50);
511 }
512
513 #[test]
514 fn test_validate_flags() {
515 let header = FixedHeader::new(PacketType::Connect, 0, 0);
516 assert!(header.validate_flags());
517
518 let header = FixedHeader::new(PacketType::Connect, 1, 0);
519 assert!(!header.validate_flags());
520
521 let header = FixedHeader::new(PacketType::Subscribe, 0x02, 0);
522 assert!(header.validate_flags());
523
524 let header = FixedHeader::new(PacketType::Subscribe, 0x00, 0);
525 assert!(!header.validate_flags());
526
527 let header = FixedHeader::new(PacketType::Publish, 0x0F, 0);
528 assert!(header.validate_flags());
529 }
530
531 #[test]
532 fn test_decode_insufficient_data() {
533 let mut buf = BytesMut::new();
534 let result = FixedHeader::decode(&mut buf);
535 assert!(result.is_err());
536 }
537
538 #[test]
539 fn test_decode_invalid_packet_type() {
540 let mut buf = BytesMut::new();
541 buf.put_u8(0x00); buf.put_u8(0x00); let result = FixedHeader::decode(&mut buf);
545 assert!(result.is_err());
546 }
547
548 #[test]
549 fn test_packet_type_bebytes_serialization() {
550 let packet_type = PacketType::Publish;
552 let bytes = packet_type.to_be_bytes();
553 assert_eq!(bytes, vec![3]);
554
555 let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
556 assert_eq!(decoded, PacketType::Publish);
557 assert_eq!(consumed, 1);
558
559 let packet_type = PacketType::Connect;
561 let bytes = packet_type.to_be_bytes();
562 assert_eq!(bytes, vec![1]);
563
564 let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
565 assert_eq!(decoded, PacketType::Connect);
566 assert_eq!(consumed, 1);
567 }
568}