1pub mod auth;
2pub mod connack;
3pub mod connect;
4pub mod disconnect;
5pub mod pingreq;
6pub mod pingresp;
7pub mod puback;
8pub mod pubcomp;
9pub mod publish;
10pub mod pubrec;
11pub mod pubrel;
12pub mod suback;
13pub mod subscribe;
14pub mod unsuback;
15pub mod unsubscribe;
16
17#[cfg(test)]
18mod property_tests;
19
20#[cfg(test)]
21mod bebytes_tests {
22 use super::*;
23 use proptest::prelude::*;
24
25 proptest! {
26 #[test]
27 fn prop_mqtt_type_and_flags_round_trip(
28 message_type in 1u8..=15,
29 dup in 0u8..=1,
30 qos in 0u8..=3,
31 retain in 0u8..=1
32 ) {
33 let original = MqttTypeAndFlags {
34 message_type,
35 dup,
36 qos,
37 retain,
38 };
39
40 let bytes = original.to_be_bytes();
41 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
42
43 prop_assert_eq!(original, decoded);
44 }
45
46 #[test]
47 fn prop_packet_type_round_trip(packet_type in 1u8..=15) {
48 if let Some(pt) = PacketType::from_u8(packet_type) {
49 let type_and_flags = MqttTypeAndFlags::for_packet_type(pt);
50 let bytes = type_and_flags.to_be_bytes();
51 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
52
53 prop_assert_eq!(type_and_flags, decoded);
54 prop_assert_eq!(decoded.packet_type(), Some(pt));
55 }
56 }
57
58 #[test]
59 fn prop_publish_flags_round_trip(
60 qos in 0u8..=3,
61 dup: bool,
62 retain: bool
63 ) {
64 let type_and_flags = MqttTypeAndFlags::for_publish(qos, dup, retain);
65 let bytes = type_and_flags.to_be_bytes();
66 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
67
68 prop_assert_eq!(type_and_flags, decoded);
69 prop_assert_eq!(decoded.packet_type(), Some(PacketType::Publish));
70 prop_assert_eq!(decoded.qos, qos);
71 prop_assert_eq!(decoded.is_dup(), dup);
72 prop_assert_eq!(decoded.is_retain(), retain);
73 }
74 }
75}
76
77use crate::encoding::{decode_variable_int, encode_variable_int};
78use crate::error::{MqttError, Result};
79use bebytes::BeBytes;
80use bytes::{Buf, BufMut};
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
85pub struct AckPacketHeader {
86 #[bebytes(big_endian)]
88 pub packet_id: u16,
89 pub reason_code: u8,
91}
92
93impl AckPacketHeader {
94 #[must_use]
96 pub fn create(packet_id: u16, reason_code: crate::types::ReasonCode) -> Self {
97 Self {
98 packet_id,
99 reason_code: u8::from(reason_code),
100 }
101 }
102
103 #[must_use]
105 pub fn get_reason_code(&self) -> Option<crate::types::ReasonCode> {
106 crate::types::ReasonCode::from_u8(self.reason_code)
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
112pub struct MqttTypeAndFlags {
113 #[bits(4)]
115 pub message_type: u8,
116 #[bits(1)]
118 pub dup: u8,
119 #[bits(2)]
121 pub qos: u8,
122 #[bits(1)]
124 pub retain: u8,
125}
126
127impl MqttTypeAndFlags {
128 #[must_use]
130 pub fn for_packet_type(packet_type: PacketType) -> Self {
131 Self {
132 message_type: packet_type as u8,
133 dup: 0,
134 qos: 0,
135 retain: 0,
136 }
137 }
138
139 #[must_use]
141 pub fn for_publish(qos: u8, dup: bool, retain: bool) -> Self {
142 Self {
143 message_type: PacketType::Publish as u8,
144 dup: u8::from(dup),
145 qos,
146 retain: u8::from(retain),
147 }
148 }
149
150 #[must_use]
152 pub fn packet_type(&self) -> Option<PacketType> {
153 PacketType::from_u8(self.message_type)
154 }
155
156 #[must_use]
158 pub fn is_dup(&self) -> bool {
159 self.dup != 0
160 }
161
162 #[must_use]
164 pub fn is_retain(&self) -> bool {
165 self.retain != 0
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
170pub enum PacketType {
171 Connect = 1,
172 ConnAck = 2,
173 Publish = 3,
174 PubAck = 4,
175 PubRec = 5,
176 PubRel = 6,
177 PubComp = 7,
178 Subscribe = 8,
179 SubAck = 9,
180 Unsubscribe = 10,
181 UnsubAck = 11,
182 PingReq = 12,
183 PingResp = 13,
184 Disconnect = 14,
185 Auth = 15,
186}
187
188impl PacketType {
189 #[must_use]
191 pub fn from_u8(value: u8) -> Option<Self> {
192 Self::try_from(value).ok()
194 }
195}
196
197impl From<PacketType> for u8 {
198 fn from(packet_type: PacketType) -> Self {
199 packet_type as u8
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq)]
205pub struct FixedHeader {
206 pub packet_type: PacketType,
207 pub flags: u8,
208 pub remaining_length: u32,
209}
210
211impl FixedHeader {
212 #[must_use]
214 pub fn new(packet_type: PacketType, flags: u8, remaining_length: u32) -> Self {
215 Self {
216 packet_type,
217 flags,
218 remaining_length,
219 }
220 }
221
222 pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
232 let byte1 =
233 (u8::from(self.packet_type) << 4) | (self.flags & crate::constants::masks::FLAGS);
234 buf.put_u8(byte1);
235 encode_variable_int(buf, self.remaining_length)?;
236 Ok(())
237 }
238
239 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
252 if !buf.has_remaining() {
253 return Err(MqttError::MalformedPacket(
254 "No data for fixed header".to_string(),
255 ));
256 }
257
258 let byte1 = buf.get_u8();
259 let packet_type_val = (byte1 >> 4) & crate::constants::masks::FLAGS;
260 let flags = byte1 & crate::constants::masks::FLAGS;
261
262 let packet_type = PacketType::from_u8(packet_type_val)
263 .ok_or(MqttError::InvalidPacketType(packet_type_val))?;
264
265 let remaining_length = decode_variable_int(buf)?;
266
267 Ok(Self {
268 packet_type,
269 flags,
270 remaining_length,
271 })
272 }
273
274 #[must_use]
276 pub fn validate_flags(&self) -> bool {
277 match self.packet_type {
278 PacketType::Publish => true, PacketType::PubRel | PacketType::Subscribe | PacketType::Unsubscribe => {
280 self.flags == 0x02 }
282 _ => self.flags == 0,
283 }
284 }
285
286 #[must_use]
288 pub fn encoded_len(&self) -> usize {
289 1 + crate::encoding::encoded_variable_int_len(self.remaining_length)
291 }
292}
293
294#[derive(Debug, Clone)]
296pub enum Packet {
297 Connect(Box<connect::ConnectPacket>),
298 ConnAck(connack::ConnAckPacket),
299 Publish(publish::PublishPacket),
300 PubAck(puback::PubAckPacket),
301 PubRec(pubrec::PubRecPacket),
302 PubRel(pubrel::PubRelPacket),
303 PubComp(pubcomp::PubCompPacket),
304 Subscribe(subscribe::SubscribePacket),
305 SubAck(suback::SubAckPacket),
306 Unsubscribe(unsubscribe::UnsubscribePacket),
307 UnsubAck(unsuback::UnsubAckPacket),
308 PingReq,
309 PingResp,
310 Disconnect(disconnect::DisconnectPacket),
311 Auth(auth::AuthPacket),
312}
313
314impl Packet {
315 pub fn decode_from_body<B: Buf>(
321 packet_type: PacketType,
322 fixed_header: &FixedHeader,
323 buf: &mut B,
324 ) -> Result<Self> {
325 match packet_type {
326 PacketType::Connect => {
327 let packet = connect::ConnectPacket::decode_body(buf, fixed_header)?;
328 Ok(Packet::Connect(Box::new(packet)))
329 }
330 PacketType::ConnAck => {
331 let packet = connack::ConnAckPacket::decode_body(buf, fixed_header)?;
332 Ok(Packet::ConnAck(packet))
333 }
334 PacketType::Publish => {
335 let packet = publish::PublishPacket::decode_body(buf, fixed_header)?;
336 Ok(Packet::Publish(packet))
337 }
338 PacketType::PubAck => {
339 let packet = puback::PubAckPacket::decode_body(buf, fixed_header)?;
340 Ok(Packet::PubAck(packet))
341 }
342 PacketType::PubRec => {
343 let packet = pubrec::PubRecPacket::decode_body(buf, fixed_header)?;
344 Ok(Packet::PubRec(packet))
345 }
346 PacketType::PubRel => {
347 let packet = pubrel::PubRelPacket::decode_body(buf, fixed_header)?;
348 Ok(Packet::PubRel(packet))
349 }
350 PacketType::PubComp => {
351 let packet = pubcomp::PubCompPacket::decode_body(buf, fixed_header)?;
352 Ok(Packet::PubComp(packet))
353 }
354 PacketType::Subscribe => {
355 let packet = subscribe::SubscribePacket::decode_body(buf, fixed_header)?;
356 Ok(Packet::Subscribe(packet))
357 }
358 PacketType::SubAck => {
359 let packet = suback::SubAckPacket::decode_body(buf, fixed_header)?;
360 Ok(Packet::SubAck(packet))
361 }
362 PacketType::Unsubscribe => {
363 let packet = unsubscribe::UnsubscribePacket::decode_body(buf, fixed_header)?;
364 Ok(Packet::Unsubscribe(packet))
365 }
366 PacketType::UnsubAck => {
367 let packet = unsuback::UnsubAckPacket::decode_body(buf, fixed_header)?;
368 Ok(Packet::UnsubAck(packet))
369 }
370 PacketType::PingReq => Ok(Packet::PingReq),
371 PacketType::PingResp => Ok(Packet::PingResp),
372 PacketType::Disconnect => {
373 let packet = disconnect::DisconnectPacket::decode_body(buf, fixed_header)?;
374 Ok(Packet::Disconnect(packet))
375 }
376 PacketType::Auth => {
377 let packet = auth::AuthPacket::decode_body(buf, fixed_header)?;
378 Ok(Packet::Auth(packet))
379 }
380 }
381 }
382}
383
384pub trait MqttPacket: Sized {
386 fn packet_type(&self) -> PacketType;
388
389 fn flags(&self) -> u8 {
391 0
392 }
393
394 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()>;
404
405 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self>;
411
412 fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
418 let mut body = Vec::new();
420 self.encode_body(&mut body)?;
421
422 let fixed_header = FixedHeader::new(
423 self.packet_type(),
424 self.flags(),
425 body.len().try_into().unwrap_or(u32::MAX),
426 );
427
428 fixed_header.encode(buf)?;
429 buf.put_slice(&body);
430 Ok(())
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use bytes::BytesMut;
438
439 #[test]
440 fn test_packet_type_from_u8() {
441 assert_eq!(PacketType::from_u8(1), Some(PacketType::Connect));
442 assert_eq!(PacketType::from_u8(2), Some(PacketType::ConnAck));
443 assert_eq!(PacketType::from_u8(15), Some(PacketType::Auth));
444 assert_eq!(PacketType::from_u8(0), None);
445 assert_eq!(PacketType::from_u8(16), None);
446 }
447
448 #[test]
449 fn test_fixed_header_encode_decode() {
450 let mut buf = BytesMut::new();
451
452 let header = FixedHeader::new(PacketType::Connect, 0, 100);
453 header.encode(&mut buf).unwrap();
454
455 let decoded = FixedHeader::decode(&mut buf).unwrap();
456 assert_eq!(decoded.packet_type, PacketType::Connect);
457 assert_eq!(decoded.flags, 0);
458 assert_eq!(decoded.remaining_length, 100);
459 }
460
461 #[test]
462 fn test_fixed_header_with_flags() {
463 let mut buf = BytesMut::new();
464
465 let header = FixedHeader::new(PacketType::Publish, 0x0D, 50);
466 header.encode(&mut buf).unwrap();
467
468 let decoded = FixedHeader::decode(&mut buf).unwrap();
469 assert_eq!(decoded.packet_type, PacketType::Publish);
470 assert_eq!(decoded.flags, 0x0D);
471 assert_eq!(decoded.remaining_length, 50);
472 }
473
474 #[test]
475 fn test_validate_flags() {
476 let header = FixedHeader::new(PacketType::Connect, 0, 0);
477 assert!(header.validate_flags());
478
479 let header = FixedHeader::new(PacketType::Connect, 1, 0);
480 assert!(!header.validate_flags());
481
482 let header = FixedHeader::new(PacketType::Subscribe, 0x02, 0);
483 assert!(header.validate_flags());
484
485 let header = FixedHeader::new(PacketType::Subscribe, 0x00, 0);
486 assert!(!header.validate_flags());
487
488 let header = FixedHeader::new(PacketType::Publish, 0x0F, 0);
489 assert!(header.validate_flags());
490 }
491
492 #[test]
493 fn test_decode_insufficient_data() {
494 let mut buf = BytesMut::new();
495 let result = FixedHeader::decode(&mut buf);
496 assert!(result.is_err());
497 }
498
499 #[test]
500 fn test_decode_invalid_packet_type() {
501 let mut buf = BytesMut::new();
502 buf.put_u8(0x00); buf.put_u8(0x00); let result = FixedHeader::decode(&mut buf);
506 assert!(result.is_err());
507 }
508
509 #[test]
510 fn test_packet_type_bebytes_serialization() {
511 let packet_type = PacketType::Publish;
513 let bytes = packet_type.to_be_bytes();
514 assert_eq!(bytes, vec![3]);
515
516 let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
517 assert_eq!(decoded, PacketType::Publish);
518 assert_eq!(consumed, 1);
519
520 let packet_type = PacketType::Connect;
522 let bytes = packet_type.to_be_bytes();
523 assert_eq!(bytes, vec![1]);
524
525 let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
526 assert_eq!(decoded, PacketType::Connect);
527 assert_eq!(consumed, 1);
528 }
529}