1use std::io::{self, Cursor, Read};
8use thiserror::Error;
9
10#[derive(Debug, Error)]
12pub enum ProtocolError {
13 #[error("Invalid packet type: {0}")]
14 InvalidPacketType(u8),
15
16 #[error("Invalid remaining length encoding")]
17 InvalidRemainingLength,
18
19 #[error("Invalid protocol name: {0}")]
20 InvalidProtocolName(String),
21
22 #[error("Invalid protocol version: {0}")]
23 InvalidProtocolVersion(u8),
24
25 #[error("Invalid QoS level: {0}")]
26 InvalidQoS(u8),
27
28 #[error("Invalid UTF-8 string")]
29 InvalidUtf8,
30
31 #[error("Packet too large: {0} bytes")]
32 PacketTooLarge(usize),
33
34 #[error("Incomplete packet: expected {expected} bytes, got {got}")]
35 IncompletePacket { expected: usize, got: usize },
36
37 #[error("IO error: {0}")]
38 Io(#[from] io::Error),
39
40 #[error("Invalid connect flags")]
41 InvalidConnectFlags,
42
43 #[error("Malformed packet")]
44 MalformedPacket,
45}
46
47pub type ProtocolResult<T> = Result<T, ProtocolError>;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52#[repr(u8)]
53pub enum PacketType {
54 Connect = 1,
55 Connack = 2,
56 Publish = 3,
57 Puback = 4,
58 Pubrec = 5,
59 Pubrel = 6,
60 Pubcomp = 7,
61 Subscribe = 8,
62 Suback = 9,
63 Unsubscribe = 10,
64 Unsuback = 11,
65 Pingreq = 12,
66 Pingresp = 13,
67 Disconnect = 14,
68}
69
70impl TryFrom<u8> for PacketType {
71 type Error = ProtocolError;
72
73 fn try_from(value: u8) -> ProtocolResult<Self> {
74 match value {
75 1 => Ok(PacketType::Connect),
76 2 => Ok(PacketType::Connack),
77 3 => Ok(PacketType::Publish),
78 4 => Ok(PacketType::Puback),
79 5 => Ok(PacketType::Pubrec),
80 6 => Ok(PacketType::Pubrel),
81 7 => Ok(PacketType::Pubcomp),
82 8 => Ok(PacketType::Subscribe),
83 9 => Ok(PacketType::Suback),
84 10 => Ok(PacketType::Unsubscribe),
85 11 => Ok(PacketType::Unsuback),
86 12 => Ok(PacketType::Pingreq),
87 13 => Ok(PacketType::Pingresp),
88 14 => Ok(PacketType::Disconnect),
89 _ => Err(ProtocolError::InvalidPacketType(value)),
90 }
91 }
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
96#[repr(u8)]
97pub enum QoS {
98 #[default]
99 AtMostOnce = 0,
100 AtLeastOnce = 1,
101 ExactlyOnce = 2,
102}
103
104impl TryFrom<u8> for QoS {
105 type Error = ProtocolError;
106
107 fn try_from(value: u8) -> ProtocolResult<Self> {
108 match value {
109 0 => Ok(QoS::AtMostOnce),
110 1 => Ok(QoS::AtLeastOnce),
111 2 => Ok(QoS::ExactlyOnce),
112 _ => Err(ProtocolError::InvalidQoS(value)),
113 }
114 }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119#[repr(u8)]
120pub enum ConnackCode {
121 Accepted = 0,
122 UnacceptableProtocolVersion = 1,
123 IdentifierRejected = 2,
124 ServerUnavailable = 3,
125 BadUsernamePassword = 4,
126 NotAuthorized = 5,
127}
128
129impl From<ConnackCode> for u8 {
130 fn from(code: ConnackCode) -> u8 {
131 code as u8
132 }
133}
134
135#[derive(Debug, Clone)]
137pub enum Packet {
138 Connect(ConnectPacket),
139 Connack(ConnackPacket),
140 Publish(PublishPacket),
141 Puback(PubackPacket),
142 Pubrec(PubrecPacket),
143 Pubrel(PubrelPacket),
144 Pubcomp(PubcompPacket),
145 Subscribe(SubscribePacket),
146 Suback(SubackPacket),
147 Unsubscribe(UnsubscribePacket),
148 Unsuback(UnsubackPacket),
149 Pingreq,
150 Pingresp,
151 Disconnect,
152}
153
154#[derive(Debug, Clone)]
156pub struct ConnectPacket {
157 pub protocol_name: String,
158 pub protocol_level: u8,
159 pub clean_session: bool,
160 pub keep_alive: u16,
161 pub client_id: String,
162 pub will: Option<Will>,
163 pub username: Option<String>,
164 pub password: Option<Vec<u8>>,
165}
166
167#[derive(Debug, Clone)]
169pub struct Will {
170 pub topic: String,
171 pub message: Vec<u8>,
172 pub qos: QoS,
173 pub retain: bool,
174}
175
176#[derive(Debug, Clone)]
178pub struct ConnackPacket {
179 pub session_present: bool,
180 pub return_code: ConnackCode,
181}
182
183#[derive(Debug, Clone)]
185pub struct PublishPacket {
186 pub dup: bool,
187 pub qos: QoS,
188 pub retain: bool,
189 pub topic: String,
190 pub packet_id: Option<u16>,
191 pub payload: Vec<u8>,
192}
193
194#[derive(Debug, Clone)]
196pub struct PubackPacket {
197 pub packet_id: u16,
198}
199
200#[derive(Debug, Clone)]
202pub struct PubrecPacket {
203 pub packet_id: u16,
204}
205
206#[derive(Debug, Clone)]
208pub struct PubrelPacket {
209 pub packet_id: u16,
210}
211
212#[derive(Debug, Clone)]
214pub struct PubcompPacket {
215 pub packet_id: u16,
216}
217
218#[derive(Debug, Clone)]
220pub struct SubscribePacket {
221 pub packet_id: u16,
222 pub subscriptions: Vec<(String, QoS)>,
223}
224
225#[derive(Debug, Clone)]
227pub struct SubackPacket {
228 pub packet_id: u16,
229 pub return_codes: Vec<SubackReturnCode>,
230}
231
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
234pub enum SubackReturnCode {
235 SuccessQoS0,
236 SuccessQoS1,
237 SuccessQoS2,
238 Failure,
239}
240
241impl From<SubackReturnCode> for u8 {
242 fn from(code: SubackReturnCode) -> u8 {
243 match code {
244 SubackReturnCode::SuccessQoS0 => 0x00,
245 SubackReturnCode::SuccessQoS1 => 0x01,
246 SubackReturnCode::SuccessQoS2 => 0x02,
247 SubackReturnCode::Failure => 0x80,
248 }
249 }
250}
251
252impl SubackReturnCode {
253 pub fn success(qos: QoS) -> Self {
255 match qos {
256 QoS::AtMostOnce => SubackReturnCode::SuccessQoS0,
257 QoS::AtLeastOnce => SubackReturnCode::SuccessQoS1,
258 QoS::ExactlyOnce => SubackReturnCode::SuccessQoS2,
259 }
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct UnsubscribePacket {
266 pub packet_id: u16,
267 pub topics: Vec<String>,
268}
269
270#[derive(Debug, Clone)]
272pub struct UnsubackPacket {
273 pub packet_id: u16,
274}
275
276pub struct PacketDecoder;
278
279impl PacketDecoder {
280 pub fn decode(buffer: &[u8]) -> ProtocolResult<Option<(Packet, usize)>> {
284 if buffer.is_empty() {
285 return Ok(None);
286 }
287
288 let first_byte = buffer[0];
290 let packet_type = PacketType::try_from(first_byte >> 4)?;
291 let flags = first_byte & 0x0F;
292
293 let (remaining_length, header_len) = Self::decode_remaining_length(&buffer[1..])?;
295
296 let total_len = 1 + header_len + remaining_length;
297 if buffer.len() < total_len {
298 return Ok(None); }
300
301 let payload = &buffer[1 + header_len..total_len];
302
303 let packet = match packet_type {
304 PacketType::Connect => Packet::Connect(Self::decode_connect(payload)?),
305 PacketType::Connack => Packet::Connack(Self::decode_connack(payload)?),
306 PacketType::Publish => Packet::Publish(Self::decode_publish(flags, payload)?),
307 PacketType::Puback => Packet::Puback(Self::decode_puback(payload)?),
308 PacketType::Pubrec => Packet::Pubrec(Self::decode_pubrec(payload)?),
309 PacketType::Pubrel => Packet::Pubrel(Self::decode_pubrel(payload)?),
310 PacketType::Pubcomp => Packet::Pubcomp(Self::decode_pubcomp(payload)?),
311 PacketType::Subscribe => Packet::Subscribe(Self::decode_subscribe(payload)?),
312 PacketType::Suback => Packet::Suback(Self::decode_suback(payload)?),
313 PacketType::Unsubscribe => Packet::Unsubscribe(Self::decode_unsubscribe(payload)?),
314 PacketType::Unsuback => Packet::Unsuback(Self::decode_unsuback(payload)?),
315 PacketType::Pingreq => Packet::Pingreq,
316 PacketType::Pingresp => Packet::Pingresp,
317 PacketType::Disconnect => Packet::Disconnect,
318 };
319
320 Ok(Some((packet, total_len)))
321 }
322
323 fn decode_remaining_length(buffer: &[u8]) -> ProtocolResult<(usize, usize)> {
325 let mut multiplier = 1usize;
326 let mut value = 0usize;
327 let mut pos = 0;
328
329 loop {
330 if pos >= buffer.len() {
331 return Err(ProtocolError::IncompletePacket {
332 expected: pos + 1,
333 got: buffer.len(),
334 });
335 }
336
337 let byte = buffer[pos];
338 value += (byte & 0x7F) as usize * multiplier;
339
340 if multiplier > 128 * 128 * 128 {
341 return Err(ProtocolError::InvalidRemainingLength);
342 }
343
344 multiplier *= 128;
345 pos += 1;
346
347 if byte & 0x80 == 0 {
348 break;
349 }
350 }
351
352 Ok((value, pos))
353 }
354
355 fn decode_string(cursor: &mut Cursor<&[u8]>) -> ProtocolResult<String> {
357 let mut len_buf = [0u8; 2];
358 cursor.read_exact(&mut len_buf)?;
359 let len = u16::from_be_bytes(len_buf) as usize;
360
361 let mut str_buf = vec![0u8; len];
362 cursor.read_exact(&mut str_buf)?;
363
364 String::from_utf8(str_buf).map_err(|_| ProtocolError::InvalidUtf8)
365 }
366
367 fn decode_binary(cursor: &mut Cursor<&[u8]>) -> ProtocolResult<Vec<u8>> {
369 let mut len_buf = [0u8; 2];
370 cursor.read_exact(&mut len_buf)?;
371 let len = u16::from_be_bytes(len_buf) as usize;
372
373 let mut data = vec![0u8; len];
374 cursor.read_exact(&mut data)?;
375 Ok(data)
376 }
377
378 fn decode_u16(cursor: &mut Cursor<&[u8]>) -> ProtocolResult<u16> {
380 let mut buf = [0u8; 2];
381 cursor.read_exact(&mut buf)?;
382 Ok(u16::from_be_bytes(buf))
383 }
384
385 fn decode_connect(payload: &[u8]) -> ProtocolResult<ConnectPacket> {
386 let mut cursor = Cursor::new(payload);
387
388 let protocol_name = Self::decode_string(&mut cursor)?;
390 if protocol_name != "MQTT" && protocol_name != "MQIsdp" {
391 return Err(ProtocolError::InvalidProtocolName(protocol_name));
392 }
393
394 let mut level_buf = [0u8; 1];
396 cursor.read_exact(&mut level_buf)?;
397 let protocol_level = level_buf[0];
398
399 if protocol_level != 4 && protocol_level != 3 {
402 return Err(ProtocolError::InvalidProtocolVersion(protocol_level));
403 }
404
405 let mut flags_buf = [0u8; 1];
407 cursor.read_exact(&mut flags_buf)?;
408 let flags = flags_buf[0];
409
410 let clean_session = (flags & 0x02) != 0;
411 let will_flag = (flags & 0x04) != 0;
412 let will_qos = QoS::try_from((flags >> 3) & 0x03)?;
413 let will_retain = (flags & 0x20) != 0;
414 let password_flag = (flags & 0x40) != 0;
415 let username_flag = (flags & 0x80) != 0;
416
417 if flags & 0x01 != 0 {
419 return Err(ProtocolError::InvalidConnectFlags);
420 }
421
422 let keep_alive = Self::decode_u16(&mut cursor)?;
424
425 let client_id = Self::decode_string(&mut cursor)?;
427
428 let will = if will_flag {
430 let topic = Self::decode_string(&mut cursor)?;
431 let message = Self::decode_binary(&mut cursor)?;
432 Some(Will {
433 topic,
434 message,
435 qos: will_qos,
436 retain: will_retain,
437 })
438 } else {
439 None
440 };
441
442 let username = if username_flag {
444 Some(Self::decode_string(&mut cursor)?)
445 } else {
446 None
447 };
448
449 let password = if password_flag {
451 Some(Self::decode_binary(&mut cursor)?)
452 } else {
453 None
454 };
455
456 Ok(ConnectPacket {
457 protocol_name,
458 protocol_level,
459 clean_session,
460 keep_alive,
461 client_id,
462 will,
463 username,
464 password,
465 })
466 }
467
468 fn decode_connack(payload: &[u8]) -> ProtocolResult<ConnackPacket> {
469 if payload.len() < 2 {
470 return Err(ProtocolError::MalformedPacket);
471 }
472
473 let session_present = (payload[0] & 0x01) != 0;
474 let return_code = match payload[1] {
475 0 => ConnackCode::Accepted,
476 1 => ConnackCode::UnacceptableProtocolVersion,
477 2 => ConnackCode::IdentifierRejected,
478 3 => ConnackCode::ServerUnavailable,
479 4 => ConnackCode::BadUsernamePassword,
480 5 => ConnackCode::NotAuthorized,
481 _ => return Err(ProtocolError::MalformedPacket),
482 };
483
484 Ok(ConnackPacket {
485 session_present,
486 return_code,
487 })
488 }
489
490 fn decode_publish(flags: u8, payload: &[u8]) -> ProtocolResult<PublishPacket> {
491 let dup = (flags & 0x08) != 0;
492 let qos = QoS::try_from((flags >> 1) & 0x03)?;
493 let retain = (flags & 0x01) != 0;
494
495 let mut cursor = Cursor::new(payload);
496
497 let topic = Self::decode_string(&mut cursor)?;
499
500 let packet_id = if qos != QoS::AtMostOnce {
502 Some(Self::decode_u16(&mut cursor)?)
503 } else {
504 None
505 };
506
507 let pos = cursor.position() as usize;
509 let message_payload = payload[pos..].to_vec();
510
511 Ok(PublishPacket {
512 dup,
513 qos,
514 retain,
515 topic,
516 packet_id,
517 payload: message_payload,
518 })
519 }
520
521 fn decode_puback(payload: &[u8]) -> ProtocolResult<PubackPacket> {
522 if payload.len() < 2 {
523 return Err(ProtocolError::MalformedPacket);
524 }
525 Ok(PubackPacket {
526 packet_id: u16::from_be_bytes([payload[0], payload[1]]),
527 })
528 }
529
530 fn decode_pubrec(payload: &[u8]) -> ProtocolResult<PubrecPacket> {
531 if payload.len() < 2 {
532 return Err(ProtocolError::MalformedPacket);
533 }
534 Ok(PubrecPacket {
535 packet_id: u16::from_be_bytes([payload[0], payload[1]]),
536 })
537 }
538
539 fn decode_pubrel(payload: &[u8]) -> ProtocolResult<PubrelPacket> {
540 if payload.len() < 2 {
541 return Err(ProtocolError::MalformedPacket);
542 }
543 Ok(PubrelPacket {
544 packet_id: u16::from_be_bytes([payload[0], payload[1]]),
545 })
546 }
547
548 fn decode_pubcomp(payload: &[u8]) -> ProtocolResult<PubcompPacket> {
549 if payload.len() < 2 {
550 return Err(ProtocolError::MalformedPacket);
551 }
552 Ok(PubcompPacket {
553 packet_id: u16::from_be_bytes([payload[0], payload[1]]),
554 })
555 }
556
557 fn decode_subscribe(payload: &[u8]) -> ProtocolResult<SubscribePacket> {
558 let mut cursor = Cursor::new(payload);
559
560 let packet_id = Self::decode_u16(&mut cursor)?;
561 let mut subscriptions = Vec::new();
562
563 while (cursor.position() as usize) < payload.len() {
564 let topic = Self::decode_string(&mut cursor)?;
565 let mut qos_buf = [0u8; 1];
566 cursor.read_exact(&mut qos_buf)?;
567 let qos = QoS::try_from(qos_buf[0] & 0x03)?;
568 subscriptions.push((topic, qos));
569 }
570
571 if subscriptions.is_empty() {
572 return Err(ProtocolError::MalformedPacket);
573 }
574
575 Ok(SubscribePacket {
576 packet_id,
577 subscriptions,
578 })
579 }
580
581 fn decode_suback(payload: &[u8]) -> ProtocolResult<SubackPacket> {
582 if payload.len() < 3 {
583 return Err(ProtocolError::MalformedPacket);
584 }
585
586 let packet_id = u16::from_be_bytes([payload[0], payload[1]]);
587 let mut return_codes = Vec::new();
588
589 for &byte in &payload[2..] {
590 let code = match byte {
591 0x00 => SubackReturnCode::SuccessQoS0,
592 0x01 => SubackReturnCode::SuccessQoS1,
593 0x02 => SubackReturnCode::SuccessQoS2,
594 0x80 => SubackReturnCode::Failure,
595 _ => return Err(ProtocolError::MalformedPacket),
596 };
597 return_codes.push(code);
598 }
599
600 Ok(SubackPacket {
601 packet_id,
602 return_codes,
603 })
604 }
605
606 fn decode_unsubscribe(payload: &[u8]) -> ProtocolResult<UnsubscribePacket> {
607 let mut cursor = Cursor::new(payload);
608
609 let packet_id = Self::decode_u16(&mut cursor)?;
610 let mut topics = Vec::new();
611
612 while (cursor.position() as usize) < payload.len() {
613 let topic = Self::decode_string(&mut cursor)?;
614 topics.push(topic);
615 }
616
617 if topics.is_empty() {
618 return Err(ProtocolError::MalformedPacket);
619 }
620
621 Ok(UnsubscribePacket { packet_id, topics })
622 }
623
624 fn decode_unsuback(payload: &[u8]) -> ProtocolResult<UnsubackPacket> {
625 if payload.len() < 2 {
626 return Err(ProtocolError::MalformedPacket);
627 }
628 Ok(UnsubackPacket {
629 packet_id: u16::from_be_bytes([payload[0], payload[1]]),
630 })
631 }
632}
633
634pub struct PacketEncoder;
636
637impl PacketEncoder {
638 pub fn encode(packet: &Packet) -> ProtocolResult<Vec<u8>> {
640 match packet {
641 Packet::Connect(p) => Self::encode_connect(p),
642 Packet::Connack(p) => Self::encode_connack(p),
643 Packet::Publish(p) => Self::encode_publish(p),
644 Packet::Puback(p) => Self::encode_puback(p),
645 Packet::Pubrec(p) => Self::encode_pubrec(p),
646 Packet::Pubrel(p) => Self::encode_pubrel(p),
647 Packet::Pubcomp(p) => Self::encode_pubcomp(p),
648 Packet::Subscribe(p) => Self::encode_subscribe(p),
649 Packet::Suback(p) => Self::encode_suback(p),
650 Packet::Unsubscribe(p) => Self::encode_unsubscribe(p),
651 Packet::Unsuback(p) => Self::encode_unsuback(p),
652 Packet::Pingreq => Self::encode_pingreq(),
653 Packet::Pingresp => Self::encode_pingresp(),
654 Packet::Disconnect => Self::encode_disconnect(),
655 }
656 }
657
658 fn encode_remaining_length(length: usize) -> Vec<u8> {
660 let mut result = Vec::new();
661 let mut x = length;
662
663 loop {
664 let mut byte = (x % 128) as u8;
665 x /= 128;
666 if x > 0 {
667 byte |= 0x80;
668 }
669 result.push(byte);
670 if x == 0 {
671 break;
672 }
673 }
674
675 result
676 }
677
678 fn encode_string(s: &str) -> Vec<u8> {
680 let bytes = s.as_bytes();
681 let len = bytes.len() as u16;
682 let mut result = Vec::with_capacity(2 + bytes.len());
683 result.extend_from_slice(&len.to_be_bytes());
684 result.extend_from_slice(bytes);
685 result
686 }
687
688 fn encode_binary(data: &[u8]) -> Vec<u8> {
690 let len = data.len() as u16;
691 let mut result = Vec::with_capacity(2 + data.len());
692 result.extend_from_slice(&len.to_be_bytes());
693 result.extend_from_slice(data);
694 result
695 }
696
697 fn encode_connect(packet: &ConnectPacket) -> ProtocolResult<Vec<u8>> {
698 let mut payload = Vec::new();
699
700 payload.extend(Self::encode_string(&packet.protocol_name));
702
703 payload.push(packet.protocol_level);
705
706 let mut flags = 0u8;
708 if packet.clean_session {
709 flags |= 0x02;
710 }
711 if let Some(ref will) = packet.will {
712 flags |= 0x04; flags |= (will.qos as u8) << 3;
714 if will.retain {
715 flags |= 0x20;
716 }
717 }
718 if packet.password.is_some() {
719 flags |= 0x40;
720 }
721 if packet.username.is_some() {
722 flags |= 0x80;
723 }
724 payload.push(flags);
725
726 payload.extend_from_slice(&packet.keep_alive.to_be_bytes());
728
729 payload.extend(Self::encode_string(&packet.client_id));
731
732 if let Some(ref will) = packet.will {
734 payload.extend(Self::encode_string(&will.topic));
735 payload.extend(Self::encode_binary(&will.message));
736 }
737
738 if let Some(ref username) = packet.username {
740 payload.extend(Self::encode_string(username));
741 }
742
743 if let Some(ref password) = packet.password {
745 payload.extend(Self::encode_binary(password));
746 }
747
748 let mut result = Vec::new();
750 result.push(0x10); result.extend(Self::encode_remaining_length(payload.len()));
752 result.extend(payload);
753
754 Ok(result)
755 }
756
757 fn encode_connack(packet: &ConnackPacket) -> ProtocolResult<Vec<u8>> {
758 let mut result = Vec::new();
759 result.push(0x20); result.push(0x02); let ack_flags = if packet.session_present { 0x01 } else { 0x00 };
763 result.push(ack_flags);
764 result.push(packet.return_code as u8);
765
766 Ok(result)
767 }
768
769 fn encode_publish(packet: &PublishPacket) -> ProtocolResult<Vec<u8>> {
770 let mut payload = Vec::new();
771
772 payload.extend(Self::encode_string(&packet.topic));
774
775 if let Some(packet_id) = packet.packet_id {
777 payload.extend_from_slice(&packet_id.to_be_bytes());
778 }
779
780 payload.extend_from_slice(&packet.payload);
782
783 let mut first_byte = 0x30u8; if packet.dup {
786 first_byte |= 0x08;
787 }
788 first_byte |= (packet.qos as u8) << 1;
789 if packet.retain {
790 first_byte |= 0x01;
791 }
792
793 let mut result = Vec::new();
794 result.push(first_byte);
795 result.extend(Self::encode_remaining_length(payload.len()));
796 result.extend(payload);
797
798 Ok(result)
799 }
800
801 fn encode_puback(packet: &PubackPacket) -> ProtocolResult<Vec<u8>> {
802 let mut result = Vec::new();
803 result.push(0x40); result.push(0x02); result.extend_from_slice(&packet.packet_id.to_be_bytes());
806 Ok(result)
807 }
808
809 fn encode_pubrec(packet: &PubrecPacket) -> ProtocolResult<Vec<u8>> {
810 let mut result = Vec::new();
811 result.push(0x50); result.push(0x02); result.extend_from_slice(&packet.packet_id.to_be_bytes());
814 Ok(result)
815 }
816
817 fn encode_pubrel(packet: &PubrelPacket) -> ProtocolResult<Vec<u8>> {
818 let mut result = Vec::new();
819 result.push(0x62); result.push(0x02); result.extend_from_slice(&packet.packet_id.to_be_bytes());
822 Ok(result)
823 }
824
825 fn encode_pubcomp(packet: &PubcompPacket) -> ProtocolResult<Vec<u8>> {
826 let mut result = Vec::new();
827 result.push(0x70); result.push(0x02); result.extend_from_slice(&packet.packet_id.to_be_bytes());
830 Ok(result)
831 }
832
833 fn encode_subscribe(packet: &SubscribePacket) -> ProtocolResult<Vec<u8>> {
834 let mut payload = Vec::new();
835
836 payload.extend_from_slice(&packet.packet_id.to_be_bytes());
838
839 for (topic, qos) in &packet.subscriptions {
841 payload.extend(Self::encode_string(topic));
842 payload.push(*qos as u8);
843 }
844
845 let mut result = Vec::new();
846 result.push(0x82); result.extend(Self::encode_remaining_length(payload.len()));
848 result.extend(payload);
849
850 Ok(result)
851 }
852
853 fn encode_suback(packet: &SubackPacket) -> ProtocolResult<Vec<u8>> {
854 let mut payload = Vec::new();
855
856 payload.extend_from_slice(&packet.packet_id.to_be_bytes());
858
859 for code in &packet.return_codes {
861 payload.push((*code).into());
862 }
863
864 let mut result = Vec::new();
865 result.push(0x90); result.extend(Self::encode_remaining_length(payload.len()));
867 result.extend(payload);
868
869 Ok(result)
870 }
871
872 fn encode_unsubscribe(packet: &UnsubscribePacket) -> ProtocolResult<Vec<u8>> {
873 let mut payload = Vec::new();
874
875 payload.extend_from_slice(&packet.packet_id.to_be_bytes());
877
878 for topic in &packet.topics {
880 payload.extend(Self::encode_string(topic));
881 }
882
883 let mut result = Vec::new();
884 result.push(0xA2); result.extend(Self::encode_remaining_length(payload.len()));
886 result.extend(payload);
887
888 Ok(result)
889 }
890
891 fn encode_unsuback(packet: &UnsubackPacket) -> ProtocolResult<Vec<u8>> {
892 let mut result = Vec::new();
893 result.push(0xB0); result.push(0x02); result.extend_from_slice(&packet.packet_id.to_be_bytes());
896 Ok(result)
897 }
898
899 fn encode_pingreq() -> ProtocolResult<Vec<u8>> {
900 Ok(vec![0xC0, 0x00]) }
902
903 fn encode_pingresp() -> ProtocolResult<Vec<u8>> {
904 Ok(vec![0xD0, 0x00]) }
906
907 fn encode_disconnect() -> ProtocolResult<Vec<u8>> {
908 Ok(vec![0xE0, 0x00]) }
910}
911
912#[cfg(test)]
913mod tests {
914 use super::*;
915
916 #[test]
917 fn test_packet_type_from_u8() {
918 assert_eq!(PacketType::try_from(1).unwrap(), PacketType::Connect);
919 assert_eq!(PacketType::try_from(2).unwrap(), PacketType::Connack);
920 assert_eq!(PacketType::try_from(3).unwrap(), PacketType::Publish);
921 assert!(PacketType::try_from(0).is_err());
922 assert!(PacketType::try_from(15).is_err());
923 }
924
925 #[test]
926 fn test_qos_from_u8() {
927 assert_eq!(QoS::try_from(0).unwrap(), QoS::AtMostOnce);
928 assert_eq!(QoS::try_from(1).unwrap(), QoS::AtLeastOnce);
929 assert_eq!(QoS::try_from(2).unwrap(), QoS::ExactlyOnce);
930 assert!(QoS::try_from(3).is_err());
931 }
932
933 #[test]
934 fn test_encode_decode_connect() {
935 let connect = ConnectPacket {
936 protocol_name: "MQTT".to_string(),
937 protocol_level: 4,
938 clean_session: true,
939 keep_alive: 60,
940 client_id: "test-client".to_string(),
941 will: None,
942 username: None,
943 password: None,
944 };
945
946 let encoded = PacketEncoder::encode(&Packet::Connect(connect.clone())).unwrap();
947 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
948
949 if let Packet::Connect(decoded_connect) = decoded {
950 assert_eq!(decoded_connect.protocol_name, connect.protocol_name);
951 assert_eq!(decoded_connect.protocol_level, connect.protocol_level);
952 assert_eq!(decoded_connect.clean_session, connect.clean_session);
953 assert_eq!(decoded_connect.keep_alive, connect.keep_alive);
954 assert_eq!(decoded_connect.client_id, connect.client_id);
955 } else {
956 panic!("Expected Connect packet");
957 }
958 }
959
960 #[test]
961 fn test_encode_decode_connack() {
962 let connack = ConnackPacket {
963 session_present: false,
964 return_code: ConnackCode::Accepted,
965 };
966
967 let encoded = PacketEncoder::encode(&Packet::Connack(connack.clone())).unwrap();
968 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
969
970 if let Packet::Connack(decoded_connack) = decoded {
971 assert_eq!(decoded_connack.session_present, connack.session_present);
972 assert_eq!(decoded_connack.return_code, connack.return_code);
973 } else {
974 panic!("Expected Connack packet");
975 }
976 }
977
978 #[test]
979 fn test_encode_decode_publish_qos0() {
980 let publish = PublishPacket {
981 dup: false,
982 qos: QoS::AtMostOnce,
983 retain: false,
984 topic: "test/topic".to_string(),
985 packet_id: None,
986 payload: b"Hello, MQTT!".to_vec(),
987 };
988
989 let encoded = PacketEncoder::encode(&Packet::Publish(publish.clone())).unwrap();
990 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
991
992 if let Packet::Publish(decoded_publish) = decoded {
993 assert_eq!(decoded_publish.topic, publish.topic);
994 assert_eq!(decoded_publish.payload, publish.payload);
995 assert_eq!(decoded_publish.qos, publish.qos);
996 } else {
997 panic!("Expected Publish packet");
998 }
999 }
1000
1001 #[test]
1002 fn test_encode_decode_publish_qos1() {
1003 let publish = PublishPacket {
1004 dup: false,
1005 qos: QoS::AtLeastOnce,
1006 retain: true,
1007 topic: "sensor/temp".to_string(),
1008 packet_id: Some(1234),
1009 payload: b"25.5".to_vec(),
1010 };
1011
1012 let encoded = PacketEncoder::encode(&Packet::Publish(publish.clone())).unwrap();
1013 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1014
1015 if let Packet::Publish(decoded_publish) = decoded {
1016 assert_eq!(decoded_publish.topic, publish.topic);
1017 assert_eq!(decoded_publish.packet_id, publish.packet_id);
1018 assert_eq!(decoded_publish.retain, publish.retain);
1019 } else {
1020 panic!("Expected Publish packet");
1021 }
1022 }
1023
1024 #[test]
1025 fn test_encode_decode_subscribe() {
1026 let subscribe = SubscribePacket {
1027 packet_id: 100,
1028 subscriptions: vec![
1029 ("topic/a".to_string(), QoS::AtMostOnce),
1030 ("topic/b/#".to_string(), QoS::AtLeastOnce),
1031 ],
1032 };
1033
1034 let encoded = PacketEncoder::encode(&Packet::Subscribe(subscribe.clone())).unwrap();
1035 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1036
1037 if let Packet::Subscribe(decoded_sub) = decoded {
1038 assert_eq!(decoded_sub.packet_id, subscribe.packet_id);
1039 assert_eq!(decoded_sub.subscriptions.len(), 2);
1040 assert_eq!(decoded_sub.subscriptions[0].0, "topic/a");
1041 assert_eq!(decoded_sub.subscriptions[1].1, QoS::AtLeastOnce);
1042 } else {
1043 panic!("Expected Subscribe packet");
1044 }
1045 }
1046
1047 #[test]
1048 fn test_encode_decode_suback() {
1049 let suback = SubackPacket {
1050 packet_id: 100,
1051 return_codes: vec![SubackReturnCode::SuccessQoS0, SubackReturnCode::SuccessQoS1],
1052 };
1053
1054 let encoded = PacketEncoder::encode(&Packet::Suback(suback.clone())).unwrap();
1055 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1056
1057 if let Packet::Suback(decoded_suback) = decoded {
1058 assert_eq!(decoded_suback.packet_id, suback.packet_id);
1059 assert_eq!(decoded_suback.return_codes.len(), 2);
1060 } else {
1061 panic!("Expected Suback packet");
1062 }
1063 }
1064
1065 #[test]
1066 fn test_encode_decode_unsubscribe() {
1067 let unsubscribe = UnsubscribePacket {
1068 packet_id: 200,
1069 topics: vec!["topic/a".to_string(), "topic/b".to_string()],
1070 };
1071
1072 let encoded = PacketEncoder::encode(&Packet::Unsubscribe(unsubscribe.clone())).unwrap();
1073 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1074
1075 if let Packet::Unsubscribe(decoded_unsub) = decoded {
1076 assert_eq!(decoded_unsub.packet_id, unsubscribe.packet_id);
1077 assert_eq!(decoded_unsub.topics.len(), 2);
1078 } else {
1079 panic!("Expected Unsubscribe packet");
1080 }
1081 }
1082
1083 #[test]
1084 fn test_encode_decode_pingreq() {
1085 let encoded = PacketEncoder::encode(&Packet::Pingreq).unwrap();
1086 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1087 assert!(matches!(decoded, Packet::Pingreq));
1088 }
1089
1090 #[test]
1091 fn test_encode_decode_pingresp() {
1092 let encoded = PacketEncoder::encode(&Packet::Pingresp).unwrap();
1093 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1094 assert!(matches!(decoded, Packet::Pingresp));
1095 }
1096
1097 #[test]
1098 fn test_encode_decode_disconnect() {
1099 let encoded = PacketEncoder::encode(&Packet::Disconnect).unwrap();
1100 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1101 assert!(matches!(decoded, Packet::Disconnect));
1102 }
1103
1104 #[test]
1105 fn test_incomplete_packet() {
1106 let partial = vec![0x10, 0x0A]; let result = PacketDecoder::decode(&partial);
1108 assert!(result.unwrap().is_none());
1109 }
1110
1111 #[test]
1112 fn test_remaining_length_encoding() {
1113 assert_eq!(PacketEncoder::encode_remaining_length(0), vec![0x00]);
1115 assert_eq!(PacketEncoder::encode_remaining_length(127), vec![0x7F]);
1116 assert_eq!(PacketEncoder::encode_remaining_length(128), vec![0x80, 0x01]);
1117 assert_eq!(PacketEncoder::encode_remaining_length(16383), vec![0xFF, 0x7F]);
1118 assert_eq!(PacketEncoder::encode_remaining_length(16384), vec![0x80, 0x80, 0x01]);
1119 }
1120
1121 #[test]
1122 fn test_connect_with_credentials() {
1123 let connect = ConnectPacket {
1124 protocol_name: "MQTT".to_string(),
1125 protocol_level: 4,
1126 clean_session: false,
1127 keep_alive: 120,
1128 client_id: "secure-client".to_string(),
1129 will: None,
1130 username: Some("user".to_string()),
1131 password: Some(b"pass".to_vec()),
1132 };
1133
1134 let encoded = PacketEncoder::encode(&Packet::Connect(connect.clone())).unwrap();
1135 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1136
1137 if let Packet::Connect(decoded_connect) = decoded {
1138 assert_eq!(decoded_connect.username, Some("user".to_string()));
1139 assert_eq!(decoded_connect.password, Some(b"pass".to_vec()));
1140 assert!(!decoded_connect.clean_session);
1141 } else {
1142 panic!("Expected Connect packet");
1143 }
1144 }
1145
1146 #[test]
1147 fn test_connect_with_will() {
1148 let connect = ConnectPacket {
1149 protocol_name: "MQTT".to_string(),
1150 protocol_level: 4,
1151 clean_session: true,
1152 keep_alive: 60,
1153 client_id: "will-client".to_string(),
1154 will: Some(Will {
1155 topic: "last/will".to_string(),
1156 message: b"goodbye".to_vec(),
1157 qos: QoS::AtLeastOnce,
1158 retain: true,
1159 }),
1160 username: None,
1161 password: None,
1162 };
1163
1164 let encoded = PacketEncoder::encode(&Packet::Connect(connect.clone())).unwrap();
1165 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1166
1167 if let Packet::Connect(decoded_connect) = decoded {
1168 let will = decoded_connect.will.unwrap();
1169 assert_eq!(will.topic, "last/will");
1170 assert_eq!(will.message, b"goodbye");
1171 assert_eq!(will.qos, QoS::AtLeastOnce);
1172 assert!(will.retain);
1173 } else {
1174 panic!("Expected Connect packet");
1175 }
1176 }
1177
1178 #[test]
1179 fn test_puback_roundtrip() {
1180 let puback = PubackPacket { packet_id: 12345 };
1181 let encoded = PacketEncoder::encode(&Packet::Puback(puback.clone())).unwrap();
1182 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1183
1184 if let Packet::Puback(decoded_puback) = decoded {
1185 assert_eq!(decoded_puback.packet_id, puback.packet_id);
1186 } else {
1187 panic!("Expected Puback packet");
1188 }
1189 }
1190
1191 #[test]
1192 fn test_qos2_handshake_packets() {
1193 let pubrec = PubrecPacket { packet_id: 1000 };
1195 let encoded = PacketEncoder::encode(&Packet::Pubrec(pubrec.clone())).unwrap();
1196 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1197 if let Packet::Pubrec(d) = decoded {
1198 assert_eq!(d.packet_id, 1000);
1199 } else {
1200 panic!("Expected Pubrec");
1201 }
1202
1203 let pubrel = PubrelPacket { packet_id: 1000 };
1205 let encoded = PacketEncoder::encode(&Packet::Pubrel(pubrel.clone())).unwrap();
1206 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1207 if let Packet::Pubrel(d) = decoded {
1208 assert_eq!(d.packet_id, 1000);
1209 } else {
1210 panic!("Expected Pubrel");
1211 }
1212
1213 let pubcomp = PubcompPacket { packet_id: 1000 };
1215 let encoded = PacketEncoder::encode(&Packet::Pubcomp(pubcomp.clone())).unwrap();
1216 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1217 if let Packet::Pubcomp(d) = decoded {
1218 assert_eq!(d.packet_id, 1000);
1219 } else {
1220 panic!("Expected Pubcomp");
1221 }
1222 }
1223
1224 #[test]
1225 fn test_unsuback_roundtrip() {
1226 let unsuback = UnsubackPacket { packet_id: 999 };
1227 let encoded = PacketEncoder::encode(&Packet::Unsuback(unsuback.clone())).unwrap();
1228 let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1229
1230 if let Packet::Unsuback(decoded_unsuback) = decoded {
1231 assert_eq!(decoded_unsuback.packet_id, unsuback.packet_id);
1232 } else {
1233 panic!("Expected Unsuback packet");
1234 }
1235 }
1236}