mqtt/packet/
mod.rs

1//! Specific packets
2
3use std::error::Error;
4use std::fmt::{self, Debug};
5use std::io::{self, Read, Write};
6
7#[cfg(feature = "tokio")]
8use tokio::io::{AsyncRead, AsyncReadExt};
9
10use crate::control::fixed_header::FixedHeaderError;
11use crate::control::variable_header::VariableHeaderError;
12use crate::control::ControlType;
13use crate::control::FixedHeader;
14use crate::topic_name::{TopicNameDecodeError, TopicNameError};
15use crate::{Decodable, Encodable};
16
17macro_rules! encodable_packet {
18    ($typ:ident($($field:ident),* $(,)?)) => {
19        impl $crate::packet::EncodablePacket for $typ {
20            fn fixed_header(&self) -> &$crate::control::fixed_header::FixedHeader {
21                &self.fixed_header
22            }
23
24            #[allow(unused)]
25            fn encode_packet<W: ::std::io::Write>(&self, writer: &mut W) -> ::std::io::Result<()> {
26                $($crate::encodable::Encodable::encode(&self.$field, writer)?;)*
27                Ok(())
28            }
29
30            fn encoded_packet_length(&self) -> u32 {
31                $($crate::encodable::Encodable::encoded_length(&self.$field) +)*
32                    0
33            }
34        }
35
36        impl $typ {
37            #[allow(unused)]
38            #[inline(always)]
39            fn fix_header_remaining_len(&mut self) {
40                self.fixed_header.remaining_length = $crate::packet::EncodablePacket::encoded_packet_length(self);
41            }
42        }
43    };
44}
45
46pub use self::connack::ConnackPacket;
47pub use self::connect::ConnectPacket;
48pub use self::disconnect::DisconnectPacket;
49pub use self::pingreq::PingreqPacket;
50pub use self::pingresp::PingrespPacket;
51pub use self::puback::PubackPacket;
52pub use self::pubcomp::PubcompPacket;
53pub use self::publish::{PublishPacket, PublishPacketRef};
54pub use self::pubrec::PubrecPacket;
55pub use self::pubrel::PubrelPacket;
56pub use self::suback::SubackPacket;
57pub use self::subscribe::SubscribePacket;
58pub use self::unsuback::UnsubackPacket;
59pub use self::unsubscribe::UnsubscribePacket;
60
61pub use self::publish::QoSWithPacketIdentifier;
62
63pub mod connack;
64pub mod connect;
65pub mod disconnect;
66pub mod pingreq;
67pub mod pingresp;
68pub mod puback;
69pub mod pubcomp;
70pub mod publish;
71pub mod pubrec;
72pub mod pubrel;
73pub mod suback;
74pub mod subscribe;
75pub mod unsuback;
76pub mod unsubscribe;
77
78/// A trait representing a packet that can be encoded, when passed as `FooPacket` or as
79/// `&FooPacket`. Different from [`Encodable`] in that it prevents you from accidentally passing
80/// a type intended to be encoded only as a part of a packet and doesn't have a header, e.g.
81/// `Vec<u8>`.
82pub trait EncodablePacket {
83    /// Get a reference to `FixedHeader`. All MQTT packet must have a fixed header.
84    fn fixed_header(&self) -> &FixedHeader;
85
86    /// Encodes packet data after fixed header, including variable headers and payload
87    fn encode_packet<W: Write>(&self, _writer: &mut W) -> io::Result<()> {
88        Ok(())
89    }
90
91    /// Length in bytes for data after fixed header, including variable headers and payload
92    fn encoded_packet_length(&self) -> u32 {
93        0
94    }
95}
96
97impl<T: EncodablePacket> Encodable for T {
98    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
99        self.fixed_header().encode(writer)?;
100        self.encode_packet(writer)
101    }
102
103    fn encoded_length(&self) -> u32 {
104        self.fixed_header().encoded_length() + self.encoded_packet_length()
105    }
106}
107
108pub trait DecodablePacket: EncodablePacket + Sized {
109    type DecodePacketError: Error + 'static;
110
111    /// Decode packet given a `FixedHeader`
112    fn decode_packet<R: Read>(reader: &mut R, fixed_header: FixedHeader) -> Result<Self, PacketError<Self>>;
113}
114
115impl<T: DecodablePacket> Decodable for T {
116    type Error = PacketError<T>;
117    type Cond = Option<FixedHeader>;
118
119    fn decode_with<R: Read>(reader: &mut R, fixed_header: Self::Cond) -> Result<Self, Self::Error> {
120        let fixed_header: FixedHeader = if let Some(hdr) = fixed_header {
121            hdr
122        } else {
123            Decodable::decode(reader)?
124        };
125
126        <Self as DecodablePacket>::decode_packet(reader, fixed_header)
127    }
128}
129
130/// Parsing errors for packet
131#[derive(thiserror::Error)]
132#[error(transparent)]
133pub enum PacketError<P>
134where
135    P: DecodablePacket,
136{
137    FixedHeaderError(#[from] FixedHeaderError),
138    VariableHeaderError(#[from] VariableHeaderError),
139    PayloadError(<P as DecodablePacket>::DecodePacketError),
140    IoError(#[from] io::Error),
141    TopicNameError(#[from] TopicNameError),
142}
143
144impl<P> Debug for PacketError<P>
145where
146    P: DecodablePacket,
147{
148    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
149        match *self {
150            PacketError::FixedHeaderError(ref e) => f.debug_tuple("FixedHeaderError").field(e).finish(),
151            PacketError::VariableHeaderError(ref e) => f.debug_tuple("VariableHeaderError").field(e).finish(),
152            PacketError::PayloadError(ref e) => f.debug_tuple("PayloadError").field(e).finish(),
153            PacketError::IoError(ref e) => f.debug_tuple("IoError").field(e).finish(),
154            PacketError::TopicNameError(ref e) => f.debug_tuple("TopicNameError").field(e).finish(),
155        }
156    }
157}
158
159impl<P: DecodablePacket> From<TopicNameDecodeError> for PacketError<P> {
160    fn from(e: TopicNameDecodeError) -> Self {
161        match e {
162            TopicNameDecodeError::IoError(e) => e.into(),
163            TopicNameDecodeError::InvalidTopicName(e) => e.into(),
164        }
165    }
166}
167
168macro_rules! impl_variable_packet {
169    ($($name:ident & $errname:ident => $hdr:ident,)+) => {
170        /// Variable packet
171        #[derive(Debug, Eq, PartialEq, Clone)]
172        pub enum VariablePacket {
173            $(
174                $name($name),
175            )+
176        }
177
178        #[cfg(feature = "tokio")]
179        impl VariablePacket {
180            /// Asynchronously parse a packet from a `tokio::io::AsyncRead`
181            ///
182            /// This requires mqtt-rs to be built with `feature = "tokio"`
183            pub async fn parse<A: AsyncRead + Unpin>(rdr: &mut A) -> Result<Self, VariablePacketError> {
184                use std::io::Cursor;
185                let fixed_header = FixedHeader::parse(rdr).await?;
186
187                let mut buffer = vec![0u8; fixed_header.remaining_length as usize];
188                rdr.read_exact(&mut buffer).await?;
189
190                decode_with_header(&mut Cursor::new(buffer), fixed_header)
191            }
192        }
193
194        #[inline]
195        fn decode_with_header<R: io::Read>(rdr: &mut R, fixed_header: FixedHeader) -> Result<VariablePacket, VariablePacketError> {
196            match fixed_header.packet_type.control_type() {
197                $(
198                    ControlType::$hdr => {
199                        let pk = <$name as DecodablePacket>::decode_packet(rdr, fixed_header)?;
200                        Ok(VariablePacket::$name(pk))
201                    }
202                )+
203            }
204        }
205
206        $(
207            impl From<$name> for VariablePacket {
208                fn from(pk: $name) -> VariablePacket {
209                    VariablePacket::$name(pk)
210                }
211            }
212        )+
213
214        // impl Encodable for VariablePacket {
215        //     fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
216        //         match *self {
217        //             $(
218        //                 VariablePacket::$name(ref pk) => pk.encode(writer),
219        //             )+
220        //         }
221        //     }
222
223        //     fn encoded_length(&self) -> u32 {
224        //         match *self {
225        //             $(
226        //                 VariablePacket::$name(ref pk) => pk.encoded_length(),
227        //             )+
228        //         }
229        //     }
230        // }
231
232        impl EncodablePacket for VariablePacket {
233            fn fixed_header(&self) -> &FixedHeader {
234                match *self {
235                    $(
236                        VariablePacket::$name(ref pk) => pk.fixed_header(),
237                    )+
238                }
239            }
240
241            fn encode_packet<W: Write>(&self, writer: &mut W) -> io::Result<()> {
242                match *self {
243                    $(
244                        VariablePacket::$name(ref pk) => pk.encode_packet(writer),
245                    )+
246                }
247            }
248
249            fn encoded_packet_length(&self) -> u32 {
250                match *self {
251                    $(
252                        VariablePacket::$name(ref pk) => pk.encoded_packet_length(),
253                    )+
254                }
255            }
256        }
257
258        impl Decodable for VariablePacket {
259            type Error = VariablePacketError;
260            type Cond = Option<FixedHeader>;
261
262            fn decode_with<R: Read>(reader: &mut R, fixed_header: Self::Cond)
263                    -> Result<VariablePacket, Self::Error> {
264                let fixed_header = match fixed_header {
265                    Some(fh) => fh,
266                    None => {
267                        match FixedHeader::decode(reader) {
268                            Ok(header) => header,
269                            Err(FixedHeaderError::ReservedType(code, length)) => {
270                                let reader = &mut reader.take(length as u64);
271                                let mut buf = Vec::with_capacity(length as usize);
272                                reader.read_to_end(&mut buf)?;
273                                return Err(VariablePacketError::ReservedPacket(code, buf));
274                            },
275                            Err(err) => return Err(From::from(err))
276                        }
277                    }
278                };
279                let reader = &mut reader.take(fixed_header.remaining_length as u64);
280
281                decode_with_header(reader, fixed_header)
282            }
283        }
284
285        /// Parsing errors for variable packet
286        #[derive(Debug, thiserror::Error)]
287        pub enum VariablePacketError {
288            #[error(transparent)]
289            FixedHeaderError(#[from] FixedHeaderError),
290            #[error("reserved packet type ({0}), [u8, ..{}]", .1.len())]
291            ReservedPacket(u8, Vec<u8>),
292            #[error(transparent)]
293            IoError(#[from] io::Error),
294            $(
295                #[error(transparent)]
296                $errname(#[from] PacketError<$name>),
297            )+
298        }
299    }
300}
301
302impl_variable_packet! {
303    ConnectPacket       & ConnectPacketError        => Connect,
304    ConnackPacket       & ConnackPacketError        => ConnectAcknowledgement,
305
306    PublishPacket       & PublishPacketError        => Publish,
307    PubackPacket        & PubackPacketError         => PublishAcknowledgement,
308    PubrecPacket        & PubrecPacketError         => PublishReceived,
309    PubrelPacket        & PubrelPacketError         => PublishRelease,
310    PubcompPacket       & PubcompPacketError        => PublishComplete,
311
312    PingreqPacket       & PingreqPacketError        => PingRequest,
313    PingrespPacket      & PingrespPacketError       => PingResponse,
314
315    SubscribePacket     & SubscribePacketError      => Subscribe,
316    SubackPacket        & SubackPacketError         => SubscribeAcknowledgement,
317
318    UnsubscribePacket   & UnsubscribePacketError    => Unsubscribe,
319    UnsubackPacket      & UnsubackPacketError       => UnsubscribeAcknowledgement,
320
321    DisconnectPacket    & DisconnectPacketError     => Disconnect,
322}
323
324impl VariablePacket {
325    pub fn new<T>(t: T) -> VariablePacket
326    where
327        VariablePacket: From<T>,
328    {
329        From::from(t)
330    }
331}
332
333#[cfg(feature = "tokio-codec")]
334mod tokio_codec {
335    use super::*;
336    use crate::control::packet_type::{PacketType, PacketTypeError};
337    use bytes::{Buf, BufMut, BytesMut};
338    use tokio_util::codec;
339
340    pub struct MqttDecoder {
341        state: DecodeState,
342    }
343
344    enum DecodeState {
345        Start,
346        Packet { length: u32, typ: DecodePacketType },
347    }
348
349    #[derive(Copy, Clone)]
350    enum DecodePacketType {
351        Standard(PacketType),
352        Reserved(u8),
353    }
354
355    impl MqttDecoder {
356        pub const fn new() -> Self {
357            MqttDecoder {
358                state: DecodeState::Start,
359            }
360        }
361    }
362
363    /// Like FixedHeader::decode(), but on a buffer instead of a stream. Returns None if it reaches
364    /// the end of the buffer before it finishes decoding the header.
365    #[inline]
366    fn decode_header(mut data: &[u8]) -> Option<Result<(DecodePacketType, u32, usize), FixedHeaderError>> {
367        let mut header_size = 0;
368        macro_rules! read_u8 {
369            () => {{
370                let (&x, rest) = data.split_first()?;
371                data = rest;
372                header_size += 1;
373                x
374            }};
375        }
376
377        let type_val = read_u8!();
378        let remaining_len = {
379            let mut cur = 0u32;
380            for i in 0.. {
381                let byte = read_u8!();
382                cur |= ((byte as u32) & 0x7F) << (7 * i);
383
384                if i >= 4 {
385                    return Some(Err(FixedHeaderError::MalformedRemainingLength));
386                }
387
388                if byte & 0x80 == 0 {
389                    break;
390                }
391            }
392
393            cur
394        };
395
396        let packet_type = match PacketType::from_u8(type_val) {
397            Ok(ty) => DecodePacketType::Standard(ty),
398            Err(PacketTypeError::ReservedType(ty, _)) => DecodePacketType::Reserved(ty),
399            Err(err) => return Some(Err(err.into())),
400        };
401        Some(Ok((packet_type, remaining_len, header_size)))
402    }
403
404    impl codec::Decoder for MqttDecoder {
405        type Item = VariablePacket;
406        type Error = VariablePacketError;
407        fn decode(&mut self, src: &mut BytesMut) -> Result<Option<VariablePacket>, VariablePacketError> {
408            loop {
409                match &mut self.state {
410                    DecodeState::Start => match decode_header(&src[..]) {
411                        Some(Ok((typ, length, header_size))) => {
412                            src.advance(header_size);
413                            self.state = DecodeState::Packet { length, typ };
414                            continue;
415                        }
416                        Some(Err(e)) => return Err(e.into()),
417                        None => return Ok(None),
418                    },
419                    DecodeState::Packet { length, typ } => {
420                        let length = *length;
421                        if src.remaining() < length as usize {
422                            return Ok(None);
423                        }
424                        let typ = *typ;
425
426                        self.state = DecodeState::Start;
427
428                        match typ {
429                            DecodePacketType::Standard(typ) => {
430                                let header = FixedHeader {
431                                    packet_type: typ,
432                                    remaining_length: length,
433                                };
434                                return decode_with_header(&mut src.reader(), header).map(Some);
435                            }
436                            DecodePacketType::Reserved(code) => {
437                                let data = src[..length as usize].to_vec();
438                                src.advance(length as usize);
439                                return Err(VariablePacketError::ReservedPacket(code, data));
440                            }
441                        }
442                    }
443                }
444            }
445        }
446    }
447
448    pub struct MqttEncoder {
449        _priv: (),
450    }
451
452    impl MqttEncoder {
453        pub const fn new() -> Self {
454            MqttEncoder { _priv: () }
455        }
456    }
457
458    impl<T: EncodablePacket> codec::Encoder<T> for MqttEncoder {
459        type Error = io::Error;
460        fn encode(&mut self, packet: T, dst: &mut BytesMut) -> Result<(), io::Error> {
461            dst.reserve(packet.encoded_length() as usize);
462            packet.encode(&mut dst.writer())
463        }
464    }
465
466    pub struct MqttCodec {
467        decode: MqttDecoder,
468        encode: MqttEncoder,
469    }
470
471    impl MqttCodec {
472        pub const fn new() -> Self {
473            MqttCodec {
474                decode: MqttDecoder::new(),
475                encode: MqttEncoder::new(),
476            }
477        }
478    }
479
480    impl codec::Decoder for MqttCodec {
481        type Item = VariablePacket;
482        type Error = VariablePacketError;
483        #[inline]
484        fn decode(&mut self, src: &mut BytesMut) -> Result<Option<VariablePacket>, VariablePacketError> {
485            self.decode.decode(src)
486        }
487    }
488
489    impl<T: EncodablePacket> codec::Encoder<T> for MqttCodec {
490        type Error = io::Error;
491        #[inline]
492        fn encode(&mut self, packet: T, dst: &mut BytesMut) -> Result<(), io::Error> {
493            self.encode.encode(packet, dst)
494        }
495    }
496}
497
498#[cfg(feature = "tokio-codec")]
499pub use tokio_codec::{MqttCodec, MqttDecoder, MqttEncoder};
500
501#[cfg(test)]
502mod test {
503    use super::*;
504
505    use std::io::Cursor;
506
507    use crate::{Decodable, Encodable};
508
509    #[test]
510    fn test_variable_packet_basic() {
511        let packet = ConnectPacket::new("1234".to_owned());
512
513        // Wrap it
514        let var_packet = VariablePacket::new(packet);
515
516        // Encode
517        let mut buf = Vec::new();
518        var_packet.encode(&mut buf).unwrap();
519
520        // Decode
521        let mut decode_buf = Cursor::new(buf);
522        let decoded_packet = VariablePacket::decode(&mut decode_buf).unwrap();
523
524        assert_eq!(var_packet, decoded_packet);
525    }
526
527    #[cfg(feature = "tokio")]
528    #[tokio::test]
529    async fn test_variable_packet_async_parse() {
530        let packet = ConnectPacket::new("1234".to_owned());
531
532        // Wrap it
533        let var_packet = VariablePacket::new(packet);
534
535        // Encode
536        let mut buf = Vec::new();
537        var_packet.encode(&mut buf).unwrap();
538
539        // Parse
540        let mut async_buf = buf.as_slice();
541        let decoded_packet = VariablePacket::parse(&mut async_buf).await.unwrap();
542
543        assert_eq!(var_packet, decoded_packet);
544    }
545
546    #[cfg(feature = "tokio-codec")]
547    #[tokio::test]
548    async fn test_variable_packet_framed() {
549        use crate::{QualityOfService, TopicFilter};
550        use futures::{SinkExt, StreamExt};
551        use tokio_util::codec::{FramedRead, FramedWrite};
552
553        let conn_packet = ConnectPacket::new("1234".to_owned());
554        let sub_packet = SubscribePacket::new(1, vec![(TopicFilter::new("foo/#").unwrap(), QualityOfService::Level0)]);
555
556        // small, to make sure buffering and stuff works
557        let (reader, writer) = tokio::io::duplex(8);
558
559        let task = tokio::spawn({
560            let (conn_packet, sub_packet) = (conn_packet.clone(), sub_packet.clone());
561            async move {
562                let mut sink = FramedWrite::new(writer, MqttEncoder::new());
563                sink.send(conn_packet).await.unwrap();
564                sink.send(sub_packet).await.unwrap();
565                SinkExt::<VariablePacket>::flush(&mut sink).await.unwrap();
566            }
567        });
568
569        let mut stream = FramedRead::new(reader, MqttDecoder::new());
570        let decoded_conn = stream.next().await.unwrap().unwrap();
571        let decoded_sub = stream.next().await.unwrap().unwrap();
572
573        task.await.unwrap();
574
575        assert!(stream.next().await.is_none());
576
577        assert_eq!(decoded_conn, conn_packet.into());
578        assert_eq!(decoded_sub, sub_packet.into());
579    }
580}