1pub mod auth;
2pub mod connack;
3pub mod connect;
4pub mod disconnect;
5pub mod pingreq;
6pub mod pingresp;
7pub mod puback;
8pub mod pubcomp;
9pub mod publish;
10pub mod pubrec;
11pub mod pubrel;
12pub mod suback;
13pub mod subscribe;
14pub mod unsuback;
15pub mod unsubscribe;
16
17#[cfg(test)]
18mod property_tests;
19
20#[cfg(test)]
21mod bebytes_tests {
22 use super::*;
23 use proptest::prelude::*;
24
25 proptest! {
26 #[test]
27 fn prop_mqtt_type_and_flags_round_trip(
28 message_type in 1u8..=15,
29 dup in 0u8..=1,
30 qos in 0u8..=3,
31 retain in 0u8..=1
32 ) {
33 let original = MqttTypeAndFlags {
34 message_type,
35 dup,
36 qos,
37 retain,
38 };
39
40 let bytes = original.to_be_bytes();
41 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
42
43 prop_assert_eq!(original, decoded);
44 }
45
46 #[test]
47 fn prop_packet_type_round_trip(packet_type in 1u8..=15) {
48 if let Some(pt) = PacketType::from_u8(packet_type) {
49 let type_and_flags = MqttTypeAndFlags::for_packet_type(pt);
50 let bytes = type_and_flags.to_be_bytes();
51 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
52
53 prop_assert_eq!(type_and_flags, decoded);
54 prop_assert_eq!(decoded.packet_type(), Some(pt));
55 }
56 }
57
58 #[test]
59 fn prop_publish_flags_round_trip(
60 qos in 0u8..=3,
61 dup: bool,
62 retain: bool
63 ) {
64 let type_and_flags = MqttTypeAndFlags::for_publish(qos, dup, retain);
65 let bytes = type_and_flags.to_be_bytes();
66 let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
67
68 prop_assert_eq!(type_and_flags, decoded);
69 prop_assert_eq!(decoded.packet_type(), Some(PacketType::Publish));
70 prop_assert_eq!(decoded.qos, qos);
71 prop_assert_eq!(decoded.is_dup(), dup);
72 prop_assert_eq!(decoded.is_retain(), retain);
73 }
74 }
75}
76
77use crate::encoding::{decode_variable_int, encode_variable_int};
78use crate::error::{MqttError, Result};
79use bebytes::BeBytes;
80use bytes::{Buf, BufMut};
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
85pub struct AckPacketHeader {
86 pub packet_id: u16,
88 pub reason_code: u8,
90}
91
92impl AckPacketHeader {
93 #[must_use]
95 pub fn create(packet_id: u16, reason_code: crate::types::ReasonCode) -> Self {
96 Self {
97 packet_id,
98 reason_code: u8::from(reason_code),
99 }
100 }
101
102 #[must_use]
104 pub fn get_reason_code(&self) -> Option<crate::types::ReasonCode> {
105 crate::types::ReasonCode::from_u8(self.reason_code)
106 }
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
111pub struct MqttTypeAndFlags {
112 #[bits(4)]
114 pub message_type: u8,
115 #[bits(1)]
117 pub dup: u8,
118 #[bits(2)]
120 pub qos: u8,
121 #[bits(1)]
123 pub retain: u8,
124}
125
126impl MqttTypeAndFlags {
127 #[must_use]
129 pub fn for_packet_type(packet_type: PacketType) -> Self {
130 Self {
131 message_type: packet_type as u8,
132 dup: 0,
133 qos: 0,
134 retain: 0,
135 }
136 }
137
138 #[must_use]
140 pub fn for_publish(qos: u8, dup: bool, retain: bool) -> Self {
141 Self {
142 message_type: PacketType::Publish as u8,
143 dup: u8::from(dup),
144 qos,
145 retain: u8::from(retain),
146 }
147 }
148
149 #[must_use]
151 pub fn packet_type(&self) -> Option<PacketType> {
152 PacketType::from_u8(self.message_type)
153 }
154
155 #[must_use]
157 pub fn is_dup(&self) -> bool {
158 self.dup != 0
159 }
160
161 #[must_use]
163 pub fn is_retain(&self) -> bool {
164 self.retain != 0
165 }
166}
167
168#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
169pub enum PacketType {
170 Connect = 1,
171 ConnAck = 2,
172 Publish = 3,
173 PubAck = 4,
174 PubRec = 5,
175 PubRel = 6,
176 PubComp = 7,
177 Subscribe = 8,
178 SubAck = 9,
179 Unsubscribe = 10,
180 UnsubAck = 11,
181 PingReq = 12,
182 PingResp = 13,
183 Disconnect = 14,
184 Auth = 15,
185}
186
187impl PacketType {
188 #[must_use]
190 pub fn from_u8(value: u8) -> Option<Self> {
191 Self::try_from(value).ok()
193 }
194}
195
196impl From<PacketType> for u8 {
197 fn from(packet_type: PacketType) -> Self {
198 packet_type as u8
199 }
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq)]
204pub struct FixedHeader {
205 pub packet_type: PacketType,
206 pub flags: u8,
207 pub remaining_length: u32,
208}
209
210impl FixedHeader {
211 #[must_use]
213 pub fn new(packet_type: PacketType, flags: u8, remaining_length: u32) -> Self {
214 Self {
215 packet_type,
216 flags,
217 remaining_length,
218 }
219 }
220
221 pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
231 let byte1 =
232 (u8::from(self.packet_type) << 4) | (self.flags & crate::constants::masks::FLAGS);
233 buf.put_u8(byte1);
234 encode_variable_int(buf, self.remaining_length)?;
235 Ok(())
236 }
237
238 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
251 if !buf.has_remaining() {
252 return Err(MqttError::MalformedPacket(
253 "No data for fixed header".to_string(),
254 ));
255 }
256
257 let byte1 = buf.get_u8();
258 let packet_type_val = (byte1 >> 4) & crate::constants::masks::FLAGS;
259 let flags = byte1 & crate::constants::masks::FLAGS;
260
261 let packet_type = PacketType::from_u8(packet_type_val)
262 .ok_or(MqttError::InvalidPacketType(packet_type_val))?;
263
264 let remaining_length = decode_variable_int(buf)?;
265
266 Ok(Self {
267 packet_type,
268 flags,
269 remaining_length,
270 })
271 }
272
273 #[must_use]
275 pub fn validate_flags(&self) -> bool {
276 match self.packet_type {
277 PacketType::Publish => true, PacketType::PubRel | PacketType::Subscribe | PacketType::Unsubscribe => {
279 self.flags == 0x02 }
281 _ => self.flags == 0,
282 }
283 }
284
285 #[must_use]
287 pub fn encoded_len(&self) -> usize {
288 1 + crate::encoding::encoded_variable_int_len(self.remaining_length)
290 }
291}
292
293#[derive(Debug, Clone)]
295pub enum Packet {
296 Connect(Box<connect::ConnectPacket>),
297 ConnAck(connack::ConnAckPacket),
298 Publish(publish::PublishPacket),
299 PubAck(puback::PubAckPacket),
300 PubRec(pubrec::PubRecPacket),
301 PubRel(pubrel::PubRelPacket),
302 PubComp(pubcomp::PubCompPacket),
303 Subscribe(subscribe::SubscribePacket),
304 SubAck(suback::SubAckPacket),
305 Unsubscribe(unsubscribe::UnsubscribePacket),
306 UnsubAck(unsuback::UnsubAckPacket),
307 PingReq,
308 PingResp,
309 Disconnect(disconnect::DisconnectPacket),
310 Auth(auth::AuthPacket),
311}
312
313impl Packet {
314 pub fn decode_from_body<B: Buf>(
320 packet_type: PacketType,
321 fixed_header: &FixedHeader,
322 buf: &mut B,
323 ) -> Result<Self> {
324 match packet_type {
325 PacketType::Connect => {
326 let packet = connect::ConnectPacket::decode_body(buf, fixed_header)?;
327 Ok(Packet::Connect(Box::new(packet)))
328 }
329 PacketType::ConnAck => {
330 let packet = connack::ConnAckPacket::decode_body(buf, fixed_header)?;
331 Ok(Packet::ConnAck(packet))
332 }
333 PacketType::Publish => {
334 let packet = publish::PublishPacket::decode_body(buf, fixed_header)?;
335 Ok(Packet::Publish(packet))
336 }
337 PacketType::PubAck => {
338 let packet = puback::PubAckPacket::decode_body(buf, fixed_header)?;
339 Ok(Packet::PubAck(packet))
340 }
341 PacketType::PubRec => {
342 let packet = pubrec::PubRecPacket::decode_body(buf, fixed_header)?;
343 Ok(Packet::PubRec(packet))
344 }
345 PacketType::PubRel => {
346 let packet = pubrel::PubRelPacket::decode_body(buf, fixed_header)?;
347 Ok(Packet::PubRel(packet))
348 }
349 PacketType::PubComp => {
350 let packet = pubcomp::PubCompPacket::decode_body(buf, fixed_header)?;
351 Ok(Packet::PubComp(packet))
352 }
353 PacketType::Subscribe => {
354 let packet = subscribe::SubscribePacket::decode_body(buf, fixed_header)?;
355 Ok(Packet::Subscribe(packet))
356 }
357 PacketType::SubAck => {
358 let packet = suback::SubAckPacket::decode_body(buf, fixed_header)?;
359 Ok(Packet::SubAck(packet))
360 }
361 PacketType::Unsubscribe => {
362 let packet = unsubscribe::UnsubscribePacket::decode_body(buf, fixed_header)?;
363 Ok(Packet::Unsubscribe(packet))
364 }
365 PacketType::UnsubAck => {
366 let packet = unsuback::UnsubAckPacket::decode_body(buf, fixed_header)?;
367 Ok(Packet::UnsubAck(packet))
368 }
369 PacketType::PingReq => Ok(Packet::PingReq),
370 PacketType::PingResp => Ok(Packet::PingResp),
371 PacketType::Disconnect => {
372 let packet = disconnect::DisconnectPacket::decode_body(buf, fixed_header)?;
373 Ok(Packet::Disconnect(packet))
374 }
375 PacketType::Auth => {
376 let packet = auth::AuthPacket::decode_body(buf, fixed_header)?;
377 Ok(Packet::Auth(packet))
378 }
379 }
380 }
381}
382
383pub trait MqttPacket: Sized {
385 fn packet_type(&self) -> PacketType;
387
388 fn flags(&self) -> u8 {
390 0
391 }
392
393 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()>;
403
404 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self>;
410
411 fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
417 let mut body = Vec::new();
419 self.encode_body(&mut body)?;
420
421 let fixed_header = FixedHeader::new(
422 self.packet_type(),
423 self.flags(),
424 body.len().try_into().unwrap_or(u32::MAX),
425 );
426
427 fixed_header.encode(buf)?;
428 buf.put_slice(&body);
429 Ok(())
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use bytes::BytesMut;
437
438 #[test]
439 fn test_packet_type_from_u8() {
440 assert_eq!(PacketType::from_u8(1), Some(PacketType::Connect));
441 assert_eq!(PacketType::from_u8(2), Some(PacketType::ConnAck));
442 assert_eq!(PacketType::from_u8(15), Some(PacketType::Auth));
443 assert_eq!(PacketType::from_u8(0), None);
444 assert_eq!(PacketType::from_u8(16), None);
445 }
446
447 #[test]
448 fn test_fixed_header_encode_decode() {
449 let mut buf = BytesMut::new();
450
451 let header = FixedHeader::new(PacketType::Connect, 0, 100);
452 header.encode(&mut buf).unwrap();
453
454 let decoded = FixedHeader::decode(&mut buf).unwrap();
455 assert_eq!(decoded.packet_type, PacketType::Connect);
456 assert_eq!(decoded.flags, 0);
457 assert_eq!(decoded.remaining_length, 100);
458 }
459
460 #[test]
461 fn test_fixed_header_with_flags() {
462 let mut buf = BytesMut::new();
463
464 let header = FixedHeader::new(PacketType::Publish, 0x0D, 50);
465 header.encode(&mut buf).unwrap();
466
467 let decoded = FixedHeader::decode(&mut buf).unwrap();
468 assert_eq!(decoded.packet_type, PacketType::Publish);
469 assert_eq!(decoded.flags, 0x0D);
470 assert_eq!(decoded.remaining_length, 50);
471 }
472
473 #[test]
474 fn test_validate_flags() {
475 let header = FixedHeader::new(PacketType::Connect, 0, 0);
476 assert!(header.validate_flags());
477
478 let header = FixedHeader::new(PacketType::Connect, 1, 0);
479 assert!(!header.validate_flags());
480
481 let header = FixedHeader::new(PacketType::Subscribe, 0x02, 0);
482 assert!(header.validate_flags());
483
484 let header = FixedHeader::new(PacketType::Subscribe, 0x00, 0);
485 assert!(!header.validate_flags());
486
487 let header = FixedHeader::new(PacketType::Publish, 0x0F, 0);
488 assert!(header.validate_flags());
489 }
490
491 #[test]
492 fn test_decode_insufficient_data() {
493 let mut buf = BytesMut::new();
494 let result = FixedHeader::decode(&mut buf);
495 assert!(result.is_err());
496 }
497
498 #[test]
499 fn test_decode_invalid_packet_type() {
500 let mut buf = BytesMut::new();
501 buf.put_u8(0x00); buf.put_u8(0x00); let result = FixedHeader::decode(&mut buf);
505 assert!(result.is_err());
506 }
507
508 #[test]
509 fn test_packet_type_bebytes_serialization() {
510 let packet_type = PacketType::Publish;
512 let bytes = packet_type.to_be_bytes();
513 assert_eq!(bytes, vec![3]);
514
515 let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
516 assert_eq!(decoded, PacketType::Publish);
517 assert_eq!(consumed, 1);
518
519 let packet_type = PacketType::Connect;
521 let bytes = packet_type.to_be_bytes();
522 assert_eq!(bytes, vec![1]);
523
524 let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
525 assert_eq!(decoded, PacketType::Connect);
526 assert_eq!(consumed, 1);
527 }
528}