1mod ack_common;
2pub mod auth;
3pub mod connack;
4pub mod connect;
5pub mod disconnect;
6pub mod pingreq;
7pub mod pingresp;
8pub mod puback;
9pub mod pubcomp;
10pub mod publish;
11pub mod pubrec;
12pub mod pubrel;
13pub mod suback;
14pub mod subscribe;
15pub mod subscribe_options;
16pub mod unsuback;
17pub mod unsubscribe;
18
19pub use ack_common::{is_valid_publish_ack_reason_code, is_valid_pubrel_reason_code};
20
21#[cfg(test)]
22mod property_tests;
23
24#[cfg(test)]
25mod bebytes_tests {
26 use super::*;
27 use proptest::prelude::*;
28
29 proptest! {
30 #[test]
31 fn prop_mqtt_type_and_flags_round_trip(
32 message_type in 1u8..=15,
33 dup in 0u8..=1,
34 qos in 0u8..=3,
35 retain in 0u8..=1
36 ) {
37 let original = MqttTypeAndFlags {
38 message_type,
39 dup,
40 qos,
41 retain,
42 };
43
44 let bytes = original.to_be_bytes();
45 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
46
47 prop_assert_eq!(original, decoded);
48 }
49
50 #[test]
51 fn prop_packet_type_round_trip(packet_type in 1u8..=15) {
52 if let Some(pt) = PacketType::from_u8(packet_type) {
53 let type_and_flags = MqttTypeAndFlags::for_packet_type(pt);
54 let bytes = type_and_flags.to_be_bytes();
55 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
56
57 prop_assert_eq!(type_and_flags, decoded);
58 prop_assert_eq!(decoded.packet_type(), Some(pt));
59 }
60 }
61
62 #[test]
63 fn prop_publish_flags_round_trip(
64 qos in 0u8..=3,
65 dup: bool,
66 retain: bool
67 ) {
68 let type_and_flags = MqttTypeAndFlags::for_publish(qos, dup, retain);
69 let bytes = type_and_flags.to_be_bytes();
70 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
71
72 prop_assert_eq!(type_and_flags, decoded);
73 prop_assert_eq!(decoded.packet_type(), Some(PacketType::Publish));
74 prop_assert_eq!(decoded.qos, qos);
75 prop_assert_eq!(decoded.is_dup(), dup);
76 prop_assert_eq!(decoded.is_retain(), retain);
77 }
78 }
79}
80
81use crate::encoding::{decode_variable_int, encode_variable_int};
82use crate::error::{MqttError, Result};
83use crate::prelude::{format, Box, ToString, Vec};
84use bebytes::BeBytes;
85use bytes::{Buf, BufMut};
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
90pub struct AckPacketHeader {
91 pub packet_id: u16,
93 pub reason_code: u8,
95}
96
97impl AckPacketHeader {
98 #[must_use]
100 pub fn create(packet_id: u16, reason_code: crate::types::ReasonCode) -> Self {
101 Self {
102 packet_id,
103 reason_code: u8::from(reason_code),
104 }
105 }
106
107 #[must_use]
109 pub fn get_reason_code(&self) -> Option<crate::types::ReasonCode> {
110 crate::types::ReasonCode::from_u8(self.reason_code)
111 }
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
116pub struct MqttTypeAndFlags {
117 #[bits(4)]
119 pub message_type: u8,
120 #[bits(1)]
122 pub dup: u8,
123 #[bits(2)]
125 pub qos: u8,
126 #[bits(1)]
128 pub retain: u8,
129}
130
131impl MqttTypeAndFlags {
132 #[must_use]
134 pub fn for_packet_type(packet_type: PacketType) -> Self {
135 Self {
136 message_type: packet_type as u8,
137 dup: 0,
138 qos: 0,
139 retain: 0,
140 }
141 }
142
143 #[must_use]
145 pub fn for_publish(qos: u8, dup: bool, retain: bool) -> Self {
146 Self {
147 message_type: PacketType::Publish as u8,
148 dup: u8::from(dup),
149 qos,
150 retain: u8::from(retain),
151 }
152 }
153
154 #[must_use]
156 pub fn packet_type(&self) -> Option<PacketType> {
157 PacketType::from_u8(self.message_type)
158 }
159
160 #[must_use]
162 pub fn is_dup(&self) -> bool {
163 self.dup != 0
164 }
165
166 #[must_use]
168 pub fn is_retain(&self) -> bool {
169 self.retain != 0
170 }
171}
172
173#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
174pub enum PacketType {
175 Connect = 1,
176 ConnAck = 2,
177 Publish = 3,
178 PubAck = 4,
179 PubRec = 5,
180 PubRel = 6,
181 PubComp = 7,
182 Subscribe = 8,
183 SubAck = 9,
184 Unsubscribe = 10,
185 UnsubAck = 11,
186 PingReq = 12,
187 PingResp = 13,
188 Disconnect = 14,
189 Auth = 15,
190}
191
192impl PacketType {
193 #[must_use]
195 pub fn from_u8(value: u8) -> Option<Self> {
196 Self::try_from(value).ok()
198 }
199}
200
201impl From<PacketType> for u8 {
202 fn from(packet_type: PacketType) -> Self {
203 packet_type as u8
204 }
205}
206
207#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub struct FixedHeader {
210 pub packet_type: PacketType,
211 pub flags: u8,
212 pub remaining_length: u32,
213}
214
215impl FixedHeader {
216 #[must_use]
218 pub fn new(packet_type: PacketType, flags: u8, remaining_length: u32) -> Self {
219 Self {
220 packet_type,
221 flags,
222 remaining_length,
223 }
224 }
225
226 pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
232 let byte1 =
233 (u8::from(self.packet_type) << 4) | (self.flags & crate::constants::masks::FLAGS);
234 buf.put_u8(byte1);
235 encode_variable_int(buf, self.remaining_length)?;
236 Ok(())
237 }
238
239 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
248 if !buf.has_remaining() {
249 return Err(MqttError::MalformedPacket(
250 "No data for fixed header".to_string(),
251 ));
252 }
253
254 let byte1 = buf.get_u8();
255 let packet_type_val = (byte1 >> 4) & crate::constants::masks::FLAGS;
256 let flags = byte1 & crate::constants::masks::FLAGS;
257
258 let packet_type = PacketType::from_u8(packet_type_val)
259 .ok_or(MqttError::InvalidPacketType(packet_type_val))?;
260
261 let remaining_length = decode_variable_int(buf)?;
262
263 Ok(Self {
264 packet_type,
265 flags,
266 remaining_length,
267 })
268 }
269
270 #[must_use]
272 pub fn validate_flags(&self) -> bool {
273 match self.packet_type {
274 PacketType::Publish => true, PacketType::PubRel | PacketType::Subscribe | PacketType::Unsubscribe => {
276 self.flags == 0x02 }
278 _ => self.flags == 0,
279 }
280 }
281
282 #[must_use]
284 pub fn encoded_len(&self) -> usize {
285 1 + crate::encoding::encoded_variable_int_len(self.remaining_length)
287 }
288}
289
290#[derive(Debug, Clone)]
292pub enum Packet {
293 Connect(Box<connect::ConnectPacket>),
294 ConnAck(connack::ConnAckPacket),
295 Publish(publish::PublishPacket),
296 PubAck(puback::PubAckPacket),
297 PubRec(pubrec::PubRecPacket),
298 PubRel(pubrel::PubRelPacket),
299 PubComp(pubcomp::PubCompPacket),
300 Subscribe(subscribe::SubscribePacket),
301 SubAck(suback::SubAckPacket),
302 Unsubscribe(unsubscribe::UnsubscribePacket),
303 UnsubAck(unsuback::UnsubAckPacket),
304 PingReq,
305 PingResp,
306 Disconnect(disconnect::DisconnectPacket),
307 Auth(auth::AuthPacket),
308}
309
310impl Packet {
311 #[must_use]
312 pub fn packet_type_name(&self) -> &'static str {
313 match self {
314 Self::Connect(_) => "CONNECT",
315 Self::ConnAck(_) => "CONNACK",
316 Self::Publish(_) => "PUBLISH",
317 Self::PubAck(_) => "PUBACK",
318 Self::PubRec(_) => "PUBREC",
319 Self::PubRel(_) => "PUBREL",
320 Self::PubComp(_) => "PUBCOMP",
321 Self::Subscribe(_) => "SUBSCRIBE",
322 Self::SubAck(_) => "SUBACK",
323 Self::Unsubscribe(_) => "UNSUBSCRIBE",
324 Self::UnsubAck(_) => "UNSUBACK",
325 Self::PingReq => "PINGREQ",
326 Self::PingResp => "PINGRESP",
327 Self::Disconnect(_) => "DISCONNECT",
328 Self::Auth(_) => "AUTH",
329 }
330 }
331
332 pub fn decode_from_body<B: Buf>(
336 packet_type: PacketType,
337 fixed_header: &FixedHeader,
338 buf: &mut B,
339 ) -> Result<Self> {
340 if !fixed_header.validate_flags() {
341 return Err(MqttError::MalformedPacket(format!(
342 "Invalid fixed header flags 0x{:02X} for {:?}",
343 fixed_header.flags, packet_type
344 )));
345 }
346
347 match packet_type {
348 PacketType::Connect => {
349 let packet = connect::ConnectPacket::decode_body(buf, fixed_header)?;
350 Ok(Packet::Connect(Box::new(packet)))
351 }
352 PacketType::ConnAck => {
353 let packet = connack::ConnAckPacket::decode_body(buf, fixed_header)?;
354 Ok(Packet::ConnAck(packet))
355 }
356 PacketType::Publish => {
357 let packet = publish::PublishPacket::decode_body(buf, fixed_header)?;
358 Ok(Packet::Publish(packet))
359 }
360 PacketType::PubAck => {
361 let packet = puback::PubAckPacket::decode_body(buf, fixed_header)?;
362 Ok(Packet::PubAck(packet))
363 }
364 PacketType::PubRec => {
365 let packet = pubrec::PubRecPacket::decode_body(buf, fixed_header)?;
366 Ok(Packet::PubRec(packet))
367 }
368 PacketType::PubRel => {
369 let packet = pubrel::PubRelPacket::decode_body(buf, fixed_header)?;
370 Ok(Packet::PubRel(packet))
371 }
372 PacketType::PubComp => {
373 let packet = pubcomp::PubCompPacket::decode_body(buf, fixed_header)?;
374 Ok(Packet::PubComp(packet))
375 }
376 PacketType::Subscribe => {
377 let packet = subscribe::SubscribePacket::decode_body(buf, fixed_header)?;
378 Ok(Packet::Subscribe(packet))
379 }
380 PacketType::SubAck => {
381 let packet = suback::SubAckPacket::decode_body(buf, fixed_header)?;
382 Ok(Packet::SubAck(packet))
383 }
384 PacketType::Unsubscribe => {
385 let packet = unsubscribe::UnsubscribePacket::decode_body(buf, fixed_header)?;
386 Ok(Packet::Unsubscribe(packet))
387 }
388 PacketType::UnsubAck => {
389 let packet = unsuback::UnsubAckPacket::decode_body(buf, fixed_header)?;
390 Ok(Packet::UnsubAck(packet))
391 }
392 PacketType::PingReq => Ok(Packet::PingReq),
393 PacketType::PingResp => Ok(Packet::PingResp),
394 PacketType::Disconnect => {
395 let packet = disconnect::DisconnectPacket::decode_body(buf, fixed_header)?;
396 Ok(Packet::Disconnect(packet))
397 }
398 PacketType::Auth => {
399 let packet = auth::AuthPacket::decode_body(buf, fixed_header)?;
400 Ok(Packet::Auth(packet))
401 }
402 }
403 }
404
405 pub fn decode_from_body_with_version<B: Buf>(
411 packet_type: PacketType,
412 fixed_header: &FixedHeader,
413 buf: &mut B,
414 protocol_version: u8,
415 ) -> Result<Self> {
416 match packet_type {
417 PacketType::Publish => {
418 let packet = publish::PublishPacket::decode_body_with_version(
419 buf,
420 fixed_header,
421 protocol_version,
422 )?;
423 Ok(Packet::Publish(packet))
424 }
425 PacketType::Subscribe => {
426 let packet = subscribe::SubscribePacket::decode_body_with_version(
427 buf,
428 fixed_header,
429 protocol_version,
430 )?;
431 Ok(Packet::Subscribe(packet))
432 }
433 PacketType::SubAck => {
434 let packet = suback::SubAckPacket::decode_body_with_version(
435 buf,
436 fixed_header,
437 protocol_version,
438 )?;
439 Ok(Packet::SubAck(packet))
440 }
441 PacketType::Unsubscribe => {
442 let packet = unsubscribe::UnsubscribePacket::decode_body_with_version(
443 buf,
444 fixed_header,
445 protocol_version,
446 )?;
447 Ok(Packet::Unsubscribe(packet))
448 }
449 _ => Self::decode_from_body(packet_type, fixed_header, buf),
450 }
451 }
452}
453
454pub trait MqttPacket: Sized {
456 fn packet_type(&self) -> PacketType;
458
459 fn flags(&self) -> u8 {
461 0
462 }
463
464 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()>;
470
471 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self>;
477
478 fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
484 let mut body = Vec::new();
486 self.encode_body(&mut body)?;
487
488 let fixed_header = FixedHeader::new(
489 self.packet_type(),
490 self.flags(),
491 body.len().try_into().unwrap_or(u32::MAX),
492 );
493
494 fixed_header.encode(buf)?;
495 buf.put_slice(&body);
496 Ok(())
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use bytes::BytesMut;
504
505 #[test]
506 fn test_packet_type_from_u8() {
507 assert_eq!(PacketType::from_u8(1), Some(PacketType::Connect));
508 assert_eq!(PacketType::from_u8(2), Some(PacketType::ConnAck));
509 assert_eq!(PacketType::from_u8(15), Some(PacketType::Auth));
510 assert_eq!(PacketType::from_u8(0), None);
511 assert_eq!(PacketType::from_u8(16), None);
512 }
513
514 #[test]
515 fn test_fixed_header_encode_decode() {
516 let mut buf = BytesMut::new();
517
518 let header = FixedHeader::new(PacketType::Connect, 0, 100);
519 header.encode(&mut buf).unwrap();
520
521 let decoded = FixedHeader::decode(&mut buf).unwrap();
522 assert_eq!(decoded.packet_type, PacketType::Connect);
523 assert_eq!(decoded.flags, 0);
524 assert_eq!(decoded.remaining_length, 100);
525 }
526
527 #[test]
528 fn test_fixed_header_with_flags() {
529 let mut buf = BytesMut::new();
530
531 let header = FixedHeader::new(PacketType::Publish, 0x0D, 50);
532 header.encode(&mut buf).unwrap();
533
534 let decoded = FixedHeader::decode(&mut buf).unwrap();
535 assert_eq!(decoded.packet_type, PacketType::Publish);
536 assert_eq!(decoded.flags, 0x0D);
537 assert_eq!(decoded.remaining_length, 50);
538 }
539
540 #[test]
541 fn test_validate_flags() {
542 let header = FixedHeader::new(PacketType::Connect, 0, 0);
543 assert!(header.validate_flags());
544
545 let header = FixedHeader::new(PacketType::Connect, 1, 0);
546 assert!(!header.validate_flags());
547
548 let header = FixedHeader::new(PacketType::Subscribe, 0x02, 0);
549 assert!(header.validate_flags());
550
551 let header = FixedHeader::new(PacketType::Subscribe, 0x00, 0);
552 assert!(!header.validate_flags());
553
554 let header = FixedHeader::new(PacketType::Publish, 0x0F, 0);
555 assert!(header.validate_flags());
556 }
557
558 #[test]
559 fn test_decode_insufficient_data() {
560 let mut buf = BytesMut::new();
561 let result = FixedHeader::decode(&mut buf);
562 assert!(result.is_err());
563 }
564
565 #[test]
566 fn test_decode_invalid_packet_type() {
567 let mut buf = BytesMut::new();
568 buf.put_u8(0x00); buf.put_u8(0x00); let result = FixedHeader::decode(&mut buf);
572 assert!(result.is_err());
573 }
574
575 #[test]
576 fn test_packet_type_bebytes_serialization() {
577 let packet_type = PacketType::Publish;
579 let bytes = packet_type.to_be_bytes();
580 assert_eq!(bytes, vec![3]);
581
582 let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
583 assert_eq!(decoded, PacketType::Publish);
584 assert_eq!(consumed, 1);
585
586 let packet_type = PacketType::Connect;
588 let bytes = packet_type.to_be_bytes();
589 assert_eq!(bytes, vec![1]);
590
591 let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
592 assert_eq!(decoded, PacketType::Connect);
593 assert_eq!(consumed, 1);
594 }
595}