1use std::error::Error;
4use std::fmt::{self, Debug};
5use std::io::{self, Read, Write};
6
7#[cfg(feature = "tokio")]
8use tokio::io::{AsyncRead, AsyncReadExt};
9
10use crate::control::fixed_header::FixedHeaderError;
11use crate::control::variable_header::VariableHeaderError;
12use crate::control::ControlType;
13use crate::control::FixedHeader;
14use crate::topic_name::{TopicNameDecodeError, TopicNameError};
15use crate::{Decodable, Encodable};
16
17macro_rules! encodable_packet {
18 ($typ:ident($($field:ident),* $(,)?)) => {
19 impl $crate::packet::EncodablePacket for $typ {
20 fn fixed_header(&self) -> &$crate::control::fixed_header::FixedHeader {
21 &self.fixed_header
22 }
23
24 #[allow(unused)]
25 fn encode_packet<W: ::std::io::Write>(&self, writer: &mut W) -> ::std::io::Result<()> {
26 $($crate::encodable::Encodable::encode(&self.$field, writer)?;)*
27 Ok(())
28 }
29
30 fn encoded_packet_length(&self) -> u32 {
31 $($crate::encodable::Encodable::encoded_length(&self.$field) +)*
32 0
33 }
34 }
35
36 impl $typ {
37 #[allow(unused)]
38 #[inline(always)]
39 fn fix_header_remaining_len(&mut self) {
40 self.fixed_header.remaining_length = $crate::packet::EncodablePacket::encoded_packet_length(self);
41 }
42 }
43 };
44}
45
46pub use self::connack::ConnackPacket;
47pub use self::connect::ConnectPacket;
48pub use self::disconnect::DisconnectPacket;
49pub use self::pingreq::PingreqPacket;
50pub use self::pingresp::PingrespPacket;
51pub use self::puback::PubackPacket;
52pub use self::pubcomp::PubcompPacket;
53pub use self::publish::{PublishPacket, PublishPacketRef};
54pub use self::pubrec::PubrecPacket;
55pub use self::pubrel::PubrelPacket;
56pub use self::suback::SubackPacket;
57pub use self::subscribe::SubscribePacket;
58pub use self::unsuback::UnsubackPacket;
59pub use self::unsubscribe::UnsubscribePacket;
60
61pub use self::publish::QoSWithPacketIdentifier;
62
63pub mod connack;
64pub mod connect;
65pub mod disconnect;
66pub mod pingreq;
67pub mod pingresp;
68pub mod puback;
69pub mod pubcomp;
70pub mod publish;
71pub mod pubrec;
72pub mod pubrel;
73pub mod suback;
74pub mod subscribe;
75pub mod unsuback;
76pub mod unsubscribe;
77
78pub trait EncodablePacket {
83 fn fixed_header(&self) -> &FixedHeader;
85
86 fn encode_packet<W: Write>(&self, _writer: &mut W) -> io::Result<()> {
88 Ok(())
89 }
90
91 fn encoded_packet_length(&self) -> u32 {
93 0
94 }
95}
96
97impl<T: EncodablePacket> Encodable for T {
98 fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
99 self.fixed_header().encode(writer)?;
100 self.encode_packet(writer)
101 }
102
103 fn encoded_length(&self) -> u32 {
104 self.fixed_header().encoded_length() + self.encoded_packet_length()
105 }
106}
107
108pub trait DecodablePacket: EncodablePacket + Sized {
109 type DecodePacketError: Error + 'static;
110
111 fn decode_packet<R: Read>(reader: &mut R, fixed_header: FixedHeader) -> Result<Self, PacketError<Self>>;
113}
114
115impl<T: DecodablePacket> Decodable for T {
116 type Error = PacketError<T>;
117 type Cond = Option<FixedHeader>;
118
119 fn decode_with<R: Read>(reader: &mut R, fixed_header: Self::Cond) -> Result<Self, Self::Error> {
120 let fixed_header: FixedHeader = if let Some(hdr) = fixed_header {
121 hdr
122 } else {
123 Decodable::decode(reader)?
124 };
125
126 <Self as DecodablePacket>::decode_packet(reader, fixed_header)
127 }
128}
129
130#[derive(thiserror::Error)]
132#[error(transparent)]
133pub enum PacketError<P>
134where
135 P: DecodablePacket,
136{
137 FixedHeaderError(#[from] FixedHeaderError),
138 VariableHeaderError(#[from] VariableHeaderError),
139 PayloadError(<P as DecodablePacket>::DecodePacketError),
140 IoError(#[from] io::Error),
141 TopicNameError(#[from] TopicNameError),
142}
143
144impl<P> Debug for PacketError<P>
145where
146 P: DecodablePacket,
147{
148 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
149 match *self {
150 PacketError::FixedHeaderError(ref e) => f.debug_tuple("FixedHeaderError").field(e).finish(),
151 PacketError::VariableHeaderError(ref e) => f.debug_tuple("VariableHeaderError").field(e).finish(),
152 PacketError::PayloadError(ref e) => f.debug_tuple("PayloadError").field(e).finish(),
153 PacketError::IoError(ref e) => f.debug_tuple("IoError").field(e).finish(),
154 PacketError::TopicNameError(ref e) => f.debug_tuple("TopicNameError").field(e).finish(),
155 }
156 }
157}
158
159impl<P: DecodablePacket> From<TopicNameDecodeError> for PacketError<P> {
160 fn from(e: TopicNameDecodeError) -> Self {
161 match e {
162 TopicNameDecodeError::IoError(e) => e.into(),
163 TopicNameDecodeError::InvalidTopicName(e) => e.into(),
164 }
165 }
166}
167
168macro_rules! impl_variable_packet {
169 ($($name:ident & $errname:ident => $hdr:ident,)+) => {
170 #[derive(Debug, Eq, PartialEq, Clone)]
172 pub enum VariablePacket {
173 $(
174 $name($name),
175 )+
176 }
177
178 #[cfg(feature = "tokio")]
179 impl VariablePacket {
180 pub async fn parse<A: AsyncRead + Unpin>(rdr: &mut A) -> Result<Self, VariablePacketError> {
184 use std::io::Cursor;
185 let fixed_header = FixedHeader::parse(rdr).await?;
186
187 let mut buffer = vec![0u8; fixed_header.remaining_length as usize];
188 rdr.read_exact(&mut buffer).await?;
189
190 decode_with_header(&mut Cursor::new(buffer), fixed_header)
191 }
192 }
193
194 #[inline]
195 fn decode_with_header<R: io::Read>(rdr: &mut R, fixed_header: FixedHeader) -> Result<VariablePacket, VariablePacketError> {
196 match fixed_header.packet_type.control_type() {
197 $(
198 ControlType::$hdr => {
199 let pk = <$name as DecodablePacket>::decode_packet(rdr, fixed_header)?;
200 Ok(VariablePacket::$name(pk))
201 }
202 )+
203 }
204 }
205
206 $(
207 impl From<$name> for VariablePacket {
208 fn from(pk: $name) -> VariablePacket {
209 VariablePacket::$name(pk)
210 }
211 }
212 )+
213
214 impl EncodablePacket for VariablePacket {
233 fn fixed_header(&self) -> &FixedHeader {
234 match *self {
235 $(
236 VariablePacket::$name(ref pk) => pk.fixed_header(),
237 )+
238 }
239 }
240
241 fn encode_packet<W: Write>(&self, writer: &mut W) -> io::Result<()> {
242 match *self {
243 $(
244 VariablePacket::$name(ref pk) => pk.encode_packet(writer),
245 )+
246 }
247 }
248
249 fn encoded_packet_length(&self) -> u32 {
250 match *self {
251 $(
252 VariablePacket::$name(ref pk) => pk.encoded_packet_length(),
253 )+
254 }
255 }
256 }
257
258 impl Decodable for VariablePacket {
259 type Error = VariablePacketError;
260 type Cond = Option<FixedHeader>;
261
262 fn decode_with<R: Read>(reader: &mut R, fixed_header: Self::Cond)
263 -> Result<VariablePacket, Self::Error> {
264 let fixed_header = match fixed_header {
265 Some(fh) => fh,
266 None => {
267 match FixedHeader::decode(reader) {
268 Ok(header) => header,
269 Err(FixedHeaderError::ReservedType(code, length)) => {
270 let reader = &mut reader.take(length as u64);
271 let mut buf = Vec::with_capacity(length as usize);
272 reader.read_to_end(&mut buf)?;
273 return Err(VariablePacketError::ReservedPacket(code, buf));
274 },
275 Err(err) => return Err(From::from(err))
276 }
277 }
278 };
279 let reader = &mut reader.take(fixed_header.remaining_length as u64);
280
281 decode_with_header(reader, fixed_header)
282 }
283 }
284
285 #[derive(Debug, thiserror::Error)]
287 pub enum VariablePacketError {
288 #[error(transparent)]
289 FixedHeaderError(#[from] FixedHeaderError),
290 #[error("reserved packet type ({0}), [u8, ..{}]", .1.len())]
291 ReservedPacket(u8, Vec<u8>),
292 #[error(transparent)]
293 IoError(#[from] io::Error),
294 $(
295 #[error(transparent)]
296 $errname(#[from] PacketError<$name>),
297 )+
298 }
299 }
300}
301
302impl_variable_packet! {
303 ConnectPacket & ConnectPacketError => Connect,
304 ConnackPacket & ConnackPacketError => ConnectAcknowledgement,
305
306 PublishPacket & PublishPacketError => Publish,
307 PubackPacket & PubackPacketError => PublishAcknowledgement,
308 PubrecPacket & PubrecPacketError => PublishReceived,
309 PubrelPacket & PubrelPacketError => PublishRelease,
310 PubcompPacket & PubcompPacketError => PublishComplete,
311
312 PingreqPacket & PingreqPacketError => PingRequest,
313 PingrespPacket & PingrespPacketError => PingResponse,
314
315 SubscribePacket & SubscribePacketError => Subscribe,
316 SubackPacket & SubackPacketError => SubscribeAcknowledgement,
317
318 UnsubscribePacket & UnsubscribePacketError => Unsubscribe,
319 UnsubackPacket & UnsubackPacketError => UnsubscribeAcknowledgement,
320
321 DisconnectPacket & DisconnectPacketError => Disconnect,
322}
323
324impl VariablePacket {
325 pub fn new<T>(t: T) -> VariablePacket
326 where
327 VariablePacket: From<T>,
328 {
329 From::from(t)
330 }
331}
332
333#[cfg(feature = "tokio-codec")]
334mod tokio_codec {
335 use super::*;
336 use crate::control::packet_type::{PacketType, PacketTypeError};
337 use bytes::{Buf, BufMut, BytesMut};
338 use tokio_util::codec;
339
340 pub struct MqttDecoder {
341 state: DecodeState,
342 }
343
344 enum DecodeState {
345 Start,
346 Packet { length: u32, typ: DecodePacketType },
347 }
348
349 #[derive(Copy, Clone)]
350 enum DecodePacketType {
351 Standard(PacketType),
352 Reserved(u8),
353 }
354
355 impl MqttDecoder {
356 pub const fn new() -> Self {
357 MqttDecoder {
358 state: DecodeState::Start,
359 }
360 }
361 }
362
363 #[inline]
366 fn decode_header(mut data: &[u8]) -> Option<Result<(DecodePacketType, u32, usize), FixedHeaderError>> {
367 let mut header_size = 0;
368 macro_rules! read_u8 {
369 () => {{
370 let (&x, rest) = data.split_first()?;
371 data = rest;
372 header_size += 1;
373 x
374 }};
375 }
376
377 let type_val = read_u8!();
378 let remaining_len = {
379 let mut cur = 0u32;
380 for i in 0.. {
381 let byte = read_u8!();
382 cur |= ((byte as u32) & 0x7F) << (7 * i);
383
384 if i >= 4 {
385 return Some(Err(FixedHeaderError::MalformedRemainingLength));
386 }
387
388 if byte & 0x80 == 0 {
389 break;
390 }
391 }
392
393 cur
394 };
395
396 let packet_type = match PacketType::from_u8(type_val) {
397 Ok(ty) => DecodePacketType::Standard(ty),
398 Err(PacketTypeError::ReservedType(ty, _)) => DecodePacketType::Reserved(ty),
399 Err(err) => return Some(Err(err.into())),
400 };
401 Some(Ok((packet_type, remaining_len, header_size)))
402 }
403
404 impl codec::Decoder for MqttDecoder {
405 type Item = VariablePacket;
406 type Error = VariablePacketError;
407 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<VariablePacket>, VariablePacketError> {
408 loop {
409 match &mut self.state {
410 DecodeState::Start => match decode_header(&src[..]) {
411 Some(Ok((typ, length, header_size))) => {
412 src.advance(header_size);
413 self.state = DecodeState::Packet { length, typ };
414 continue;
415 }
416 Some(Err(e)) => return Err(e.into()),
417 None => return Ok(None),
418 },
419 DecodeState::Packet { length, typ } => {
420 let length = *length;
421 if src.remaining() < length as usize {
422 return Ok(None);
423 }
424 let typ = *typ;
425
426 self.state = DecodeState::Start;
427
428 match typ {
429 DecodePacketType::Standard(typ) => {
430 let header = FixedHeader {
431 packet_type: typ,
432 remaining_length: length,
433 };
434 return decode_with_header(&mut src.reader(), header).map(Some);
435 }
436 DecodePacketType::Reserved(code) => {
437 let data = src[..length as usize].to_vec();
438 src.advance(length as usize);
439 return Err(VariablePacketError::ReservedPacket(code, data));
440 }
441 }
442 }
443 }
444 }
445 }
446 }
447
448 pub struct MqttEncoder {
449 _priv: (),
450 }
451
452 impl MqttEncoder {
453 pub const fn new() -> Self {
454 MqttEncoder { _priv: () }
455 }
456 }
457
458 impl<T: EncodablePacket> codec::Encoder<T> for MqttEncoder {
459 type Error = io::Error;
460 fn encode(&mut self, packet: T, dst: &mut BytesMut) -> Result<(), io::Error> {
461 dst.reserve(packet.encoded_length() as usize);
462 packet.encode(&mut dst.writer())
463 }
464 }
465
466 pub struct MqttCodec {
467 decode: MqttDecoder,
468 encode: MqttEncoder,
469 }
470
471 impl MqttCodec {
472 pub const fn new() -> Self {
473 MqttCodec {
474 decode: MqttDecoder::new(),
475 encode: MqttEncoder::new(),
476 }
477 }
478 }
479
480 impl codec::Decoder for MqttCodec {
481 type Item = VariablePacket;
482 type Error = VariablePacketError;
483 #[inline]
484 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<VariablePacket>, VariablePacketError> {
485 self.decode.decode(src)
486 }
487 }
488
489 impl<T: EncodablePacket> codec::Encoder<T> for MqttCodec {
490 type Error = io::Error;
491 #[inline]
492 fn encode(&mut self, packet: T, dst: &mut BytesMut) -> Result<(), io::Error> {
493 self.encode.encode(packet, dst)
494 }
495 }
496}
497
498#[cfg(feature = "tokio-codec")]
499pub use tokio_codec::{MqttCodec, MqttDecoder, MqttEncoder};
500
501#[cfg(test)]
502mod test {
503 use super::*;
504
505 use std::io::Cursor;
506
507 use crate::{Decodable, Encodable};
508
509 #[test]
510 fn test_variable_packet_basic() {
511 let packet = ConnectPacket::new("1234".to_owned());
512
513 let var_packet = VariablePacket::new(packet);
515
516 let mut buf = Vec::new();
518 var_packet.encode(&mut buf).unwrap();
519
520 let mut decode_buf = Cursor::new(buf);
522 let decoded_packet = VariablePacket::decode(&mut decode_buf).unwrap();
523
524 assert_eq!(var_packet, decoded_packet);
525 }
526
527 #[cfg(feature = "tokio")]
528 #[tokio::test]
529 async fn test_variable_packet_async_parse() {
530 let packet = ConnectPacket::new("1234".to_owned());
531
532 let var_packet = VariablePacket::new(packet);
534
535 let mut buf = Vec::new();
537 var_packet.encode(&mut buf).unwrap();
538
539 let mut async_buf = buf.as_slice();
541 let decoded_packet = VariablePacket::parse(&mut async_buf).await.unwrap();
542
543 assert_eq!(var_packet, decoded_packet);
544 }
545
546 #[cfg(feature = "tokio-codec")]
547 #[tokio::test]
548 async fn test_variable_packet_framed() {
549 use crate::{QualityOfService, TopicFilter};
550 use futures::{SinkExt, StreamExt};
551 use tokio_util::codec::{FramedRead, FramedWrite};
552
553 let conn_packet = ConnectPacket::new("1234".to_owned());
554 let sub_packet = SubscribePacket::new(1, vec![(TopicFilter::new("foo/#").unwrap(), QualityOfService::Level0)]);
555
556 let (reader, writer) = tokio::io::duplex(8);
558
559 let task = tokio::spawn({
560 let (conn_packet, sub_packet) = (conn_packet.clone(), sub_packet.clone());
561 async move {
562 let mut sink = FramedWrite::new(writer, MqttEncoder::new());
563 sink.send(conn_packet).await.unwrap();
564 sink.send(sub_packet).await.unwrap();
565 SinkExt::<VariablePacket>::flush(&mut sink).await.unwrap();
566 }
567 });
568
569 let mut stream = FramedRead::new(reader, MqttDecoder::new());
570 let decoded_conn = stream.next().await.unwrap().unwrap();
571 let decoded_sub = stream.next().await.unwrap().unwrap();
572
573 task.await.unwrap();
574
575 assert!(stream.next().await.is_none());
576
577 assert_eq!(decoded_conn, conn_packet.into());
578 assert_eq!(decoded_sub, sub_packet.into());
579 }
580}