Skip to main content

mqtt_codec/codec/
decode.rs

1use std::convert::TryFrom;
2use std::io::{Cursor, Read};
3
4use bytes::{buf::Buf, Bytes};
5use bytestring::ByteString;
6
7use crate::error::ParseError;
8use crate::packet::*;
9use crate::proto::*;
10
11use super::{ConnectAckFlags, ConnectFlags, FixedHeader, WILL_QOS_SHIFT};
12
13pub(crate) fn read_packet(
14    src: &mut Cursor<Bytes>,
15    header: FixedHeader,
16) -> Result<Packet, ParseError> {
17    match header.packet_type {
18        CONNECT => decode_connect_packet(src),
19        CONNACK => decode_connect_ack_packet(src),
20        PUBLISH => decode_publish_packet(src, header),
21        PUBACK => Ok(Packet::PublishAck {
22            packet_id: read_u16(src)?,
23        }),
24        PUBREC => Ok(Packet::PublishReceived {
25            packet_id: read_u16(src)?,
26        }),
27        PUBREL => Ok(Packet::PublishRelease {
28            packet_id: read_u16(src)?,
29        }),
30        PUBCOMP => Ok(Packet::PublishComplete {
31            packet_id: read_u16(src)?,
32        }),
33        SUBSCRIBE => decode_subscribe_packet(src),
34        SUBACK => decode_subscribe_ack_packet(src),
35        UNSUBSCRIBE => decode_unsubscribe_packet(src),
36        UNSUBACK => Ok(Packet::UnsubscribeAck {
37            packet_id: read_u16(src)?,
38        }),
39        PINGREQ => Ok(Packet::PingRequest),
40        PINGRESP => Ok(Packet::PingResponse),
41        DISCONNECT => Ok(Packet::Disconnect),
42        _ => Err(ParseError::UnsupportedPacketType),
43    }
44}
45
46macro_rules! check_flag {
47    ($flags:expr, $flag:expr) => {
48        ($flags & $flag.bits()) == $flag.bits()
49    };
50}
51
52macro_rules! ensure {
53    ($cond:expr, $e:expr) => {
54        if !($cond) {
55            return Err($e);
56        }
57    };
58    ($cond:expr, $fmt:expr, $($arg:tt)+) => {
59        if !($cond) {
60            return Err($fmt, $($arg)+);
61        }
62    };
63}
64
65pub fn decode_variable_length(src: &[u8]) -> Result<Option<(usize, usize)>, ParseError> {
66    if let Some((len, consumed, more)) = src
67        .iter()
68        .enumerate()
69        .scan((0, true), |state, (idx, x)| {
70            if !state.1 || idx > 3 {
71                return None;
72            }
73            state.0 += ((x & 0x7F) as usize) << (idx * 7);
74            state.1 = x & 0x80 != 0;
75            Some((state.0, idx + 1, state.1))
76        })
77        .last()
78    {
79        ensure!(!more || consumed < 4, ParseError::InvalidLength);
80        return Ok(Some((len, consumed)));
81    }
82
83    Ok(None)
84}
85
86fn decode_connect_packet(src: &mut Cursor<Bytes>) -> Result<Packet, ParseError> {
87    ensure!(src.remaining() >= 10, ParseError::InvalidLength);
88    let len = src.get_u16();
89    if len >= 4 {
90        let mut ver = [0u8; 4];
91        src.read_exact(&mut ver).unwrap();
92        if &ver[..] != b"MQTT" {
93            return Err(ParseError::InvalidProtocol);
94        }
95    } else {
96        return Err(ParseError::InvalidProtocol);
97    }
98
99    let level = src.get_u8();
100    ensure!(
101        level == DEFAULT_MQTT_LEVEL,
102        ParseError::UnsupportedProtocolLevel
103    );
104
105    let flags = src.get_u8();
106    ensure!((flags & 0x01) == 0, ParseError::ConnectReservedFlagSet);
107
108    let keep_alive = src.get_u16();
109    let client_id = decode_utf8_str(src)?;
110
111    ensure!(
112        !client_id.is_empty() || check_flag!(flags, ConnectFlags::CLEAN_SESSION),
113        ParseError::InvalidClientId
114    );
115
116    let topic = if check_flag!(flags, ConnectFlags::WILL) {
117        Some(decode_utf8_str(src)?)
118    } else {
119        None
120    };
121    let message = if check_flag!(flags, ConnectFlags::WILL) {
122        Some(decode_length_bytes(src)?)
123    } else {
124        None
125    };
126    let username = if check_flag!(flags, ConnectFlags::USERNAME) {
127        Some(decode_utf8_str(src)?)
128    } else {
129        None
130    };
131    let password = if check_flag!(flags, ConnectFlags::PASSWORD) {
132        Some(decode_length_bytes(src)?)
133    } else {
134        None
135    };
136    let last_will = if topic.is_some() {
137        Some(LastWill {
138            qos: QoS::from((flags & ConnectFlags::WILL_QOS.bits()) >> WILL_QOS_SHIFT),
139            retain: check_flag!(flags, ConnectFlags::WILL_RETAIN),
140            topic: topic.unwrap(),
141            message: message.unwrap(),
142        })
143    } else {
144        None
145    };
146
147    Ok(Packet::Connect(Connect {
148        protocol: Protocol::MQTT(level),
149        clean_session: check_flag!(flags, ConnectFlags::CLEAN_SESSION),
150        keep_alive,
151        client_id,
152        last_will,
153        username,
154        password,
155    }))
156}
157
158fn decode_connect_ack_packet(src: &mut Cursor<Bytes>) -> Result<Packet, ParseError> {
159    ensure!(src.remaining() >= 2, ParseError::InvalidLength);
160    let flags = src.get_u8();
161    ensure!(
162        (flags & 0b1111_1110) == 0,
163        ParseError::ConnAckReservedFlagSet
164    );
165
166    let return_code = src.get_u8();
167    Ok(Packet::ConnectAck {
168        session_present: check_flag!(flags, ConnectAckFlags::SESSION_PRESENT),
169        return_code: ConnectCode::from(return_code),
170    })
171}
172
173fn decode_publish_packet(
174    src: &mut Cursor<Bytes>,
175    header: FixedHeader,
176) -> Result<Packet, ParseError> {
177    let topic = decode_utf8_str(src)?;
178    let qos = QoS::from((header.packet_flags & 0b0110) >> 1);
179    let packet_id = if qos == QoS::AtMostOnce {
180        None
181    } else {
182        Some(read_u16(src)?)
183    };
184
185    let len = src.remaining();
186    let payload = take(src, len);
187
188    Ok(Packet::Publish(Publish {
189        dup: (header.packet_flags & 0b1000) == 0b1000,
190        qos,
191        retain: (header.packet_flags & 0b0001) == 0b0001,
192        topic,
193        packet_id,
194        payload,
195    }))
196}
197
198fn decode_subscribe_packet(src: &mut Cursor<Bytes>) -> Result<Packet, ParseError> {
199    let packet_id = read_u16(src)?;
200    let mut topic_filters = Vec::new();
201    while src.remaining() > 0 {
202        let topic = decode_utf8_str(src)?;
203        ensure!(src.remaining() >= 1, ParseError::InvalidLength);
204        let qos = QoS::from(src.get_u8() & 0x03);
205        topic_filters.push((topic, qos));
206    }
207
208    Ok(Packet::Subscribe {
209        packet_id,
210        topic_filters,
211    })
212}
213
214fn decode_subscribe_ack_packet(src: &mut Cursor<Bytes>) -> Result<Packet, ParseError> {
215    let packet_id = read_u16(src)?;
216    let status = src
217        .bytes()
218        //.iter()
219        .map(|code| {
220            let code = code.unwrap();
221            if code == 0x80 {
222                SubscribeReturnCode::Failure
223            } else {
224                SubscribeReturnCode::Success(QoS::from(code & 0x03))
225            }
226        })
227        .collect();
228    Ok(Packet::SubscribeAck { packet_id, status })
229}
230
231fn decode_unsubscribe_packet(src: &mut Cursor<Bytes>) -> Result<Packet, ParseError> {
232    let packet_id = read_u16(src)?;
233    let mut topic_filters = Vec::new();
234    while src.remaining() > 0 {
235        topic_filters.push(decode_utf8_str(src)?);
236    }
237    Ok(Packet::Unsubscribe {
238        packet_id,
239        topic_filters,
240    })
241}
242
243fn decode_length_bytes(src: &mut Cursor<Bytes>) -> Result<Bytes, ParseError> {
244    let len = read_u16(src)? as usize;
245    ensure!(src.remaining() >= len, ParseError::InvalidLength);
246    Ok(take(src, len))
247}
248
249fn decode_utf8_str(src: &mut Cursor<Bytes>) -> Result<ByteString, ParseError> {
250    Ok(ByteString::try_from(decode_length_bytes(src)?)?)
251}
252
253fn take(buf: &mut Cursor<Bytes>, n: usize) -> Bytes {
254    let pos = buf.position() as usize;
255    let ret = buf.get_ref().slice(pos..pos + n);
256    buf.set_position((pos + n) as u64);
257    ret
258}
259
260fn read_u16(src: &mut Cursor<Bytes>) -> Result<u16, ParseError> {
261    ensure!(src.remaining() >= 2, ParseError::InvalidLength);
262    Ok(src.get_u16())
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    macro_rules! assert_decode_packet (
270        ($bytes:expr, $res:expr) => {{
271            let fixed = $bytes.as_ref()[0];
272            let (_len, consumned) = decode_variable_length(&$bytes[1..]).unwrap().unwrap();
273            let hdr = FixedHeader {
274                packet_type: fixed >> 4,
275                packet_flags: fixed & 0xF,
276                remaining_length: $bytes.len() - consumned - 1,
277            };
278            let mut cur = Cursor::new(Bytes::from_static(&$bytes[consumned + 1..]));
279            assert_eq!(read_packet(&mut cur, hdr), Ok($res));
280        }};
281    );
282
283    #[test]
284    fn test_decode_variable_length() {
285        macro_rules! assert_variable_length (
286            ($bytes:expr, $res:expr) => {{
287                assert_eq!(decode_variable_length($bytes), Ok(Some($res)));
288            }};
289
290            ($bytes:expr, $res:expr, $rest:expr) => {{
291                assert_eq!(decode_variable_length($bytes), Ok(Some($res)));
292            }};
293        );
294
295        assert_variable_length!(b"\x7f\x7f", (127, 1), b"\x7f");
296
297        //assert_eq!(decode_variable_length(b"\xff\xff\xff"), Ok(None));
298        assert_eq!(
299            decode_variable_length(b"\xff\xff\xff\xff\xff\xff"),
300            Err(ParseError::InvalidLength)
301        );
302
303        assert_variable_length!(b"\x00", (0, 1));
304        assert_variable_length!(b"\x7f", (127, 1));
305        assert_variable_length!(b"\x80\x01", (128, 2));
306        assert_variable_length!(b"\xff\x7f", (16383, 2));
307        assert_variable_length!(b"\x80\x80\x01", (16384, 3));
308        assert_variable_length!(b"\xff\xff\x7f", (2097151, 3));
309        assert_variable_length!(b"\x80\x80\x80\x01", (2097152, 4));
310        assert_variable_length!(b"\xff\xff\xff\x7f", (268435455, 4));
311    }
312
313    // #[test]
314    // fn test_decode_header() {
315    //     assert_eq!(
316    //         decode_header(b"\x20\x7f"),
317    //         Done(
318    //             &b""[..],
319    //             FixedHeader {
320    //                 packet_type: CONNACK,
321    //                 packet_flags: 0,
322    //                 remaining_length: 127,
323    //             }
324    //         )
325    //     );
326
327    //     assert_eq!(
328    //         decode_header(b"\x3C\x82\x7f"),
329    //         Done(
330    //             &b""[..],
331    //             FixedHeader {
332    //                 packet_type: PUBLISH,
333    //                 packet_flags: 0x0C,
334    //                 remaining_length: 16258,
335    //             }
336    //         )
337    //     );
338
339    //     assert_eq!(decode_header(b"\x20"), Incomplete(Needed::Unknown));
340    // }
341
342    #[test]
343    fn test_decode_connect_packets() {
344        assert_eq!(
345            decode_connect_packet(&mut Cursor::new(Bytes::from_static(
346                b"\x00\x04MQTT\x04\xC0\x00\x3C\x00\x0512345\x00\x04user\x00\x04pass"
347            ))),
348            Ok(Packet::Connect(Connect {
349                protocol: Protocol::MQTT(4),
350                clean_session: false,
351                keep_alive: 60,
352                client_id: ByteString::try_from(Bytes::from_static(b"12345")).unwrap(),
353                last_will: None,
354                username: Some(ByteString::try_from(Bytes::from_static(b"user")).unwrap()),
355                password: Some(Bytes::from(&b"pass"[..])),
356            }))
357        );
358
359        assert_eq!(
360            decode_connect_packet(&mut Cursor::new(Bytes::from_static(
361                b"\x00\x04MQTT\x04\x14\x00\x3C\x00\x0512345\x00\x05topic\x00\x07message"
362            ))),
363            Ok(Packet::Connect(Connect {
364                protocol: Protocol::MQTT(4),
365                clean_session: false,
366                keep_alive: 60,
367                client_id: ByteString::try_from(Bytes::from_static(b"12345")).unwrap(),
368                last_will: Some(LastWill {
369                    qos: QoS::ExactlyOnce,
370                    retain: false,
371                    topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(),
372                    message: Bytes::from(&b"message"[..]),
373                }),
374                username: None,
375                password: None,
376            }))
377        );
378
379        assert_eq!(
380            decode_connect_packet(&mut Cursor::new(Bytes::from_static(
381                b"\x00\x02MQ00000000000000000000"
382            ))),
383            Err(ParseError::InvalidProtocol),
384        );
385        assert_eq!(
386            decode_connect_packet(&mut Cursor::new(Bytes::from_static(
387                b"\x00\x10MQ00000000000000000000"
388            ))),
389            Err(ParseError::InvalidProtocol),
390        );
391        assert_eq!(
392            decode_connect_packet(&mut Cursor::new(Bytes::from_static(
393                b"\x00\x04MQAA00000000000000000000"
394            ))),
395            Err(ParseError::InvalidProtocol),
396        );
397        assert_eq!(
398            decode_connect_packet(&mut Cursor::new(Bytes::from_static(
399                b"\x00\x04MQTT\x0300000000000000000000"
400            ))),
401            Err(ParseError::UnsupportedProtocolLevel),
402        );
403        assert_eq!(
404            decode_connect_packet(&mut Cursor::new(Bytes::from_static(
405                b"\x00\x04MQTT\x04\xff00000000000000000000"
406            ))),
407            Err(ParseError::ConnectReservedFlagSet)
408        );
409
410        assert_eq!(
411            decode_connect_ack_packet(&mut Cursor::new(Bytes::from_static(b"\x01\x04"))),
412            Ok(Packet::ConnectAck {
413                session_present: true,
414                return_code: ConnectCode::BadUserNameOrPassword
415            })
416        );
417
418        assert_eq!(
419            decode_connect_ack_packet(&mut Cursor::new(Bytes::from_static(b"\x03\x04"))),
420            Err(ParseError::ConnAckReservedFlagSet)
421        );
422
423        assert_decode_packet!(
424            b"\x20\x02\x01\x04",
425            Packet::ConnectAck {
426                session_present: true,
427                return_code: ConnectCode::BadUserNameOrPassword,
428            }
429        );
430
431        assert_decode_packet!(b"\xe0\x00", Packet::Disconnect);
432    }
433
434    #[test]
435    fn test_decode_publish_packets() {
436        //assert_eq!(
437        //    decode_publish_packet(b"\x00\x05topic\x12\x34"),
438        //    Done(&b""[..], ("topic".to_owned(), 0x1234))
439        //);
440
441        assert_decode_packet!(
442            b"\x3d\x0D\x00\x05topic\x43\x21data",
443            Packet::Publish(Publish {
444                dup: true,
445                retain: true,
446                qos: QoS::ExactlyOnce,
447                topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(),
448                packet_id: Some(0x4321),
449                payload: Bytes::from_static(b"data"),
450            })
451        );
452        assert_decode_packet!(
453            b"\x30\x0b\x00\x05topicdata",
454            Packet::Publish(Publish {
455                dup: false,
456                retain: false,
457                qos: QoS::AtMostOnce,
458                topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(),
459                packet_id: None,
460                payload: Bytes::from_static(b"data"),
461            })
462        );
463
464        assert_decode_packet!(
465            b"\x40\x02\x43\x21",
466            Packet::PublishAck { packet_id: 0x4321 }
467        );
468        assert_decode_packet!(
469            b"\x50\x02\x43\x21",
470            Packet::PublishReceived { packet_id: 0x4321 }
471        );
472        assert_decode_packet!(
473            b"\x60\x02\x43\x21",
474            Packet::PublishRelease { packet_id: 0x4321 }
475        );
476        assert_decode_packet!(
477            b"\x70\x02\x43\x21",
478            Packet::PublishComplete { packet_id: 0x4321 }
479        );
480    }
481
482    #[test]
483    fn test_decode_subscribe_packets() {
484        let p = Packet::Subscribe {
485            packet_id: 0x1234,
486            topic_filters: vec![
487                (
488                    ByteString::try_from(Bytes::from_static(b"test")).unwrap(),
489                    QoS::AtLeastOnce,
490                ),
491                (
492                    ByteString::try_from(Bytes::from_static(b"filter")).unwrap(),
493                    QoS::ExactlyOnce,
494                ),
495            ],
496        };
497
498        assert_eq!(
499            decode_subscribe_packet(&mut Cursor::new(Bytes::from_static(
500                b"\x12\x34\x00\x04test\x01\x00\x06filter\x02"
501            ))),
502            Ok(p.clone())
503        );
504        assert_decode_packet!(b"\x82\x12\x12\x34\x00\x04test\x01\x00\x06filter\x02", p);
505
506        let p = Packet::SubscribeAck {
507            packet_id: 0x1234,
508            status: vec![
509                SubscribeReturnCode::Success(QoS::AtLeastOnce),
510                SubscribeReturnCode::Failure,
511                SubscribeReturnCode::Success(QoS::ExactlyOnce),
512            ],
513        };
514
515        assert_eq!(
516            decode_subscribe_ack_packet(&mut Cursor::new(Bytes::from_static(
517                b"\x12\x34\x01\x80\x02"
518            ))),
519            Ok(p.clone())
520        );
521        assert_decode_packet!(b"\x90\x05\x12\x34\x01\x80\x02", p);
522
523        let p = Packet::Unsubscribe {
524            packet_id: 0x1234,
525            topic_filters: vec![
526                ByteString::try_from(Bytes::from_static(b"test")).unwrap(),
527                ByteString::try_from(Bytes::from_static(b"filter")).unwrap(),
528            ],
529        };
530
531        assert_eq!(
532            decode_unsubscribe_packet(&mut Cursor::new(Bytes::from_static(
533                b"\x12\x34\x00\x04test\x00\x06filter"
534            ))),
535            Ok(p.clone())
536        );
537        assert_decode_packet!(b"\xa2\x10\x12\x34\x00\x04test\x00\x06filter", p);
538
539        assert_decode_packet!(
540            b"\xb0\x02\x43\x21",
541            Packet::UnsubscribeAck { packet_id: 0x4321 }
542        );
543    }
544
545    #[test]
546    fn test_decode_ping_packets() {
547        assert_decode_packet!(b"\xc0\x00", Packet::PingRequest);
548        assert_decode_packet!(b"\xd0\x00", Packet::PingResponse);
549    }
550}