mqtt_format/v3/
packet.rs

1//
2//   This Source Code Form is subject to the terms of the Mozilla Public
3//   License, v. 2.0. If a copy of the MPL was not distributed with this
4//   file, You can obtain one at http://mozilla.org/MPL/2.0/.
5//
6#![allow(clippy::forget_copy)]
7
8use std::pin::Pin;
9
10use futures::AsyncWriteExt;
11use nom::{
12    bits, bytes::complete::take, error::FromExternalError, number::complete::be_u16,
13    sequence::tuple, IResult, Parser,
14};
15
16use super::{
17    connect_return::{mconnectreturn, MConnectReturnCode},
18    errors::{MPacketHeaderError, MPacketWriteError},
19    header::{mfixedheader, MPacketHeader, MPacketKind},
20    identifier::{mpacketidentifier, MPacketIdentifier},
21    qos::{mquality_of_service, MQualityOfService},
22    strings::{mstring, MString},
23    subscription_acks::{msubscriptionacks, MSubscriptionAcks},
24    subscription_request::{msubscriptionrequests, MSubscriptionRequests},
25    unsubscription_request::{munsubscriptionrequests, MUnsubscriptionRequests},
26    will::MLastWill,
27    MSResult,
28};
29
30#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct MConnect<'message> {
33    pub protocol_name: MString<'message>,
34    pub protocol_level: u8,
35    pub clean_session: bool,
36    pub will: Option<MLastWill<'message>>,
37    pub username: Option<MString<'message>>,
38    pub password: Option<&'message [u8]>,
39    pub keep_alive: u16,
40    pub client_id: MString<'message>,
41}
42
43#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub struct MConnack {
46    pub session_present: bool,
47    pub connect_return_code: MConnectReturnCode,
48}
49
50#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub struct MPublish<'message> {
53    pub dup: bool,
54    pub qos: MQualityOfService,
55    pub retain: bool,
56    pub topic_name: MString<'message>,
57    pub id: Option<MPacketIdentifier>,
58    pub payload: &'message [u8],
59}
60
61#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub struct MPuback {
64    pub id: MPacketIdentifier,
65}
66
67#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub struct MPubrec {
70    pub id: MPacketIdentifier,
71}
72
73#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub struct MPubrel {
76    pub id: MPacketIdentifier,
77}
78
79#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub struct MPubcomp {
82    pub id: MPacketIdentifier,
83}
84
85#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub struct MSubscribe<'message> {
88    pub id: MPacketIdentifier,
89    pub subscriptions: MSubscriptionRequests<'message>,
90}
91
92#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub struct MSuback<'message> {
95    pub id: MPacketIdentifier,
96    pub subscription_acks: MSubscriptionAcks<'message>,
97}
98
99#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub struct MUnsubscribe<'message> {
102    pub id: MPacketIdentifier,
103    pub unsubscriptions: MUnsubscriptionRequests<'message>,
104}
105
106#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub struct MUnsuback {
109    pub id: MPacketIdentifier,
110}
111
112#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub struct MPingreq;
115
116#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118pub struct MPingresp;
119
120#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub struct MDisconnect;
123
124#[cfg_attr(feature = "yoke", derive(yoke::Yokeable))]
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum MPacket<'message> {
127    Connect(MConnect<'message>),
128    Connack(MConnack),
129    Publish(MPublish<'message>),
130    Puback(MPuback),
131    Pubrec(MPubrec),
132    Pubrel(MPubrel),
133    Pubcomp(MPubcomp),
134    Subscribe(MSubscribe<'message>),
135    Suback(MSuback<'message>),
136    Unsubscribe(MUnsubscribe<'message>),
137    Unsuback(MUnsuback),
138    Pingreq(MPingreq),
139    Pingresp(MPingresp),
140    Disconnect(MDisconnect),
141}
142
143macro_rules! impl_conversion_packet {
144    ($var:ident => $kind:ty) => {
145        impl<'message> TryFrom<MPacket<'message>> for $kind {
146            type Error = ();
147
148            fn try_from(value: MPacket<'message>) -> Result<Self, Self::Error> {
149                if let MPacket::$var(var) = value {
150                    Ok(var)
151                } else {
152                    Err(())
153                }
154            }
155        }
156
157        impl<'other, 'message> TryFrom<&'other MPacket<'message>> for &'other $kind {
158            type Error = ();
159
160            fn try_from(value: &'other MPacket<'message>) -> Result<Self, Self::Error> {
161                if let MPacket::$var(var) = value {
162                    Ok(var)
163                } else {
164                    Err(())
165                }
166            }
167        }
168
169        impl<'other, 'message> TryFrom<&'other MPacket<'message>> for $kind {
170            type Error = ();
171
172            fn try_from(value: &'other MPacket<'message>) -> Result<Self, Self::Error> {
173                if let MPacket::$var(var) = value {
174                    Ok(*var)
175                } else {
176                    Err(())
177                }
178            }
179        }
180
181        impl<'message> From<$kind> for MPacket<'message> {
182            fn from(v: $kind) -> Self {
183                Self::$var(v)
184            }
185        }
186    };
187}
188
189impl_conversion_packet!(Connect => MConnect<'message>);
190impl_conversion_packet!(Connack => MConnack);
191impl_conversion_packet!(Publish => MPublish<'message>);
192impl_conversion_packet!(Puback => MPuback);
193impl_conversion_packet!(Pubrec => MPubrec);
194impl_conversion_packet!(Pubrel => MPubrel);
195impl_conversion_packet!(Pubcomp => MPubcomp);
196impl_conversion_packet!(Subscribe => MSubscribe<'message>);
197impl_conversion_packet!(Suback => MSuback<'message>);
198impl_conversion_packet!(Unsuback => MUnsuback);
199impl_conversion_packet!(Pingreq => MPingreq);
200impl_conversion_packet!(Pingresp => MPingresp);
201impl_conversion_packet!(Disconnect => MDisconnect);
202
203impl<'message> MPacket<'message> {
204    pub async fn write_to<W: futures::AsyncWrite>(
205        &self,
206        mut writer: Pin<&mut W>,
207    ) -> Result<(), MPacketWriteError> {
208        macro_rules! write_remaining_length {
209            ($writer:ident, $length:expr) => {
210                match $length {
211                    len @ 0..=127 => {
212                        $writer.write_all(&[len as u8]).await?;
213                    }
214                    len @ 128..=16383 => {
215                        let first = len % 128 | 0b1000_0000;
216                        let second = len / 128;
217                        $writer.write_all(&[first as u8, second as u8]).await?;
218                    }
219                    len @ 16384..=2_097_151 => {
220                        let first = len % 128 | 0b1000_0000;
221                        let second = (len / 128) % 128 | 0b1000_0000;
222                        let third = len / (128 * 128);
223                        $writer
224                            .write_all(&[first as u8, second as u8, third as u8])
225                            .await?;
226                    }
227                    len @ 2_097_152..=268_435_455 => {
228                        let first = len % 128 | 0b1000_0000;
229                        let second = (len / 128) % 128 | 0b1000_0000;
230                        let third = (len / (128 * 128)) % 128 | 0b1000_0000;
231                        let fourth = len / (128 * 128 * 128);
232                        $writer
233                            .write_all(&[first as u8, second as u8, third as u8, fourth as u8])
234                            .await?;
235                    }
236                    size => {
237                        return Err(MPacketWriteError::InvalidSize(size));
238                    }
239                }
240            };
241        }
242
243        match self {
244            MPacket::Connect(MConnect {
245                protocol_name,
246                protocol_level,
247                clean_session,
248                will,
249                username,
250                password,
251                keep_alive,
252                client_id,
253            }) => {
254                let packet_type = 0b0001_0000;
255
256                // Header 1
257                writer.write_all(&[packet_type]).await?;
258
259                let remaining_length = 10
260                    + MString::get_len(client_id)
261                    + will.as_ref().map(MLastWill::get_len).unwrap_or_default()
262                    + username.as_ref().map(MString::get_len).unwrap_or_default()
263                    + password.as_ref().map(|p| 2 + p.len()).unwrap_or_default();
264
265                // Header 2-5
266                write_remaining_length!(writer, remaining_length);
267
268                // Variable 1-6
269                MString::write_to(protocol_name, &mut writer).await?;
270                // Variable 7
271                writer.write_all(&[*protocol_level]).await?;
272                let connect_flags = bools_to_u8([
273                    username.is_some(),
274                    password.is_some(),
275                    will.as_ref().map(|w| w.retain).unwrap_or_default(),
276                    will.as_ref()
277                        .map(|w| w.qos == MQualityOfService::ExactlyOnce)
278                        .unwrap_or_default(),
279                    will.as_ref()
280                        .map(|w| w.qos != MQualityOfService::ExactlyOnce)
281                        .unwrap_or_default(),
282                    will.is_some(),
283                    *clean_session,
284                    false,
285                ]);
286                // Variable 8
287                writer.write_all(&[connect_flags]).await?;
288                // Variable 9-10
289                writer.write_all(&keep_alive.to_be_bytes()).await?;
290
291                // Payload Client
292                MString::write_to(client_id, &mut writer).await?;
293
294                // Payload Will
295                if let Some(will) = will {
296                    MString::write_to(&will.topic, &mut writer).await?;
297                    writer
298                        .write_all(&(will.payload.len() as u16).to_be_bytes())
299                        .await?;
300                    writer.write_all(will.payload).await?;
301                }
302
303                // Payload Username
304                if let Some(username) = username {
305                    MString::write_to(username, &mut writer).await?;
306                }
307
308                if let Some(password) = password {
309                    writer
310                        .write_all(&(password.len() as u16).to_be_bytes())
311                        .await?;
312                    writer.write_all(password).await?;
313                }
314            }
315            MPacket::Connack(MConnack {
316                session_present,
317                connect_return_code,
318            }) => {
319                let packet_type = 0b0010_0000;
320
321                // Header 1
322                writer.write_all(&[packet_type]).await?;
323
324                let remaining_length = 2;
325
326                // Header 2-5
327                write_remaining_length!(writer, remaining_length);
328
329                // Variable 1-6
330                writer
331                    .write_all(&[*session_present as u8, *connect_return_code as u8])
332                    .await?;
333            }
334            MPacket::Publish(MPublish {
335                dup,
336                qos,
337                retain,
338                topic_name,
339                id,
340                payload,
341            }) => {
342                let packet_type = 0b0011_0000;
343                let dup_mask = if *dup { 0b0000_1000 } else { 0 };
344                let qos_mask = qos.to_byte() << 1;
345                let retain_mask = *retain as u8;
346
347                // Header 1
348                writer
349                    .write_all(&[packet_type | dup_mask | qos_mask | retain_mask])
350                    .await?;
351
352                let remaining_length = MString::get_len(topic_name)
353                    + id.as_ref().map(MPacketIdentifier::get_len).unwrap_or(0)
354                    + payload.len();
355
356                // Header 2-5
357                write_remaining_length!(writer, remaining_length);
358
359                // Variable Header
360                MString::write_to(topic_name, &mut writer).await?;
361                if let Some(id) = id {
362                    MPacketIdentifier::write_to(id, &mut writer).await?;
363                }
364                writer.write_all(payload).await?;
365            }
366            MPacket::Puback(MPuback { id }) => {
367                let packet_type = 0b0100_0000;
368
369                // Header 1
370                writer.write_all(&[packet_type]).await?;
371
372                let remaining_length = 2;
373
374                // Header 2-5
375                write_remaining_length!(writer, remaining_length);
376
377                // Variable 1-6
378                id.write_to(&mut writer).await?;
379            }
380            MPacket::Pubrec(MPubrec { id }) => {
381                let packet_type = 0b0101_0000;
382
383                // Header 1
384                writer.write_all(&[packet_type]).await?;
385
386                let remaining_length = 2;
387
388                // Header 2-5
389                write_remaining_length!(writer, remaining_length);
390
391                // Variable 1-6
392                id.write_to(&mut writer).await?;
393            }
394            MPacket::Pubrel(MPubrel { id }) => {
395                let packet_type = 0b0110_0010;
396
397                // Header 1
398                writer.write_all(&[packet_type]).await?;
399
400                let remaining_length = 2;
401
402                // Header 2-5
403                write_remaining_length!(writer, remaining_length);
404
405                // Variable 1-6
406                id.write_to(&mut writer).await?;
407            }
408            MPacket::Pubcomp(MPubcomp { id }) => {
409                let packet_type = 0b0111_0000;
410
411                // Header 1
412                writer.write_all(&[packet_type]).await?;
413
414                let remaining_length = 2;
415
416                // Header 2-5
417                write_remaining_length!(writer, remaining_length);
418
419                // Variable 1-6
420                id.write_to(&mut writer).await?;
421            }
422            MPacket::Subscribe(MSubscribe { id, subscriptions }) => {
423                let packet_type = 0b1000_0010;
424
425                // Header 1
426                writer.write_all(&[packet_type]).await?;
427
428                let remaining_length = id.get_len() + subscriptions.get_len();
429
430                // Header 2-5
431                write_remaining_length!(writer, remaining_length);
432
433                // Variable header
434
435                id.write_to(&mut writer).await?;
436
437                subscriptions.write_to(&mut writer).await?;
438            }
439            MPacket::Suback(MSuback {
440                id,
441                subscription_acks,
442            }) => {
443                let packet_type = 0b1001_0000;
444
445                // Header 1
446                writer.write_all(&[packet_type]).await?;
447
448                let remaining_length = id.get_len() + subscription_acks.get_len();
449
450                // Header 2-5
451                write_remaining_length!(writer, remaining_length);
452
453                // Variable header
454
455                id.write_to(&mut writer).await?;
456
457                subscription_acks.write_to(&mut writer).await?;
458            }
459            MPacket::Unsubscribe(MUnsubscribe {
460                id: _,
461                unsubscriptions: _,
462            }) => todo!(),
463            MPacket::Unsuback(MUnsuback { id: _ }) => todo!(),
464            MPacket::Pingreq(MPingreq) => {
465                let packet_type = 0b1100_0000;
466                let variable_length = 0b0;
467
468                // Header
469                writer.write_all(&[packet_type, variable_length]).await?;
470            }
471            MPacket::Pingresp(MPingresp) => {
472                let packet_type = 0b1101_0000;
473                let variable_length = 0b0;
474
475                // Header
476                writer.write_all(&[packet_type, variable_length]).await?;
477            }
478            MPacket::Disconnect(MDisconnect) => todo!(),
479        }
480
481        Ok(())
482    }
483}
484
485fn bools_to_u8(bools: [bool; 8]) -> u8 {
486    (bools[0] as u8) << 7
487        | (bools[1] as u8) << 6
488        | (bools[2] as u8) << 5
489        | (bools[3] as u8) << 4
490        | (bools[4] as u8) << 3
491        | (bools[5] as u8) << 2
492        | (bools[6] as u8) << 1
493        | (bools[7] as u8)
494}
495
496fn mpayload(input: &[u8]) -> IResult<&[u8], &[u8]> {
497    let (input, len) = be_u16(input)?;
498    take(len)(input)
499}
500
501fn mpacketdata(fixed_header: MPacketHeader, input: &[u8]) -> IResult<&[u8], MPacket> {
502    let (input, info) = match fixed_header.kind {
503        MPacketKind::Connect => {
504            let (input, protocol_name) = mstring(input)?;
505
506            if &*protocol_name != "MQTT" {
507                return Err(nom::Err::Error(nom::error::Error::from_external_error(
508                    input,
509                    nom::error::ErrorKind::MapRes,
510                    MPacketHeaderError::InvalidProtocolName(protocol_name.to_string()),
511                )));
512            }
513
514            let (input, protocol_level) = nom::number::complete::u8(input)?;
515
516            if protocol_level != 4 {
517                return Err(nom::Err::Error(nom::error::Error::from_external_error(
518                    input,
519                    nom::error::ErrorKind::MapRes,
520                    MPacketHeaderError::InvalidProtocolLevel(protocol_level),
521                )));
522            }
523
524            let (
525                input,
526                (
527                    user_name_flag,
528                    password_flag,
529                    will_retain,
530                    will_qos,
531                    will_flag,
532                    clean_session,
533                    reserved,
534                ),
535            ): (_, (u8, u8, u8, _, u8, u8, u8)) =
536                bits::<_, _, nom::error::Error<(&[u8], usize)>, _, _>(tuple((
537                    nom::bits::complete::take(1usize),
538                    nom::bits::complete::take(1usize),
539                    nom::bits::complete::take(1usize),
540                    nom::bits::complete::take(2usize),
541                    nom::bits::complete::take(1usize),
542                    nom::bits::complete::take(1usize),
543                    nom::bits::complete::take(1usize),
544                )))(input)?;
545
546            if reserved != 0 {
547                return Err(nom::Err::Error(nom::error::Error::from_external_error(
548                    input,
549                    nom::error::ErrorKind::MapRes,
550                    MPacketHeaderError::ForbiddenReservedValue,
551                )));
552            }
553
554            if will_flag == 0 && will_qos != 0 {
555                return Err(nom::Err::Error(nom::error::Error::from_external_error(
556                    input,
557                    nom::error::ErrorKind::MapRes,
558                    MPacketHeaderError::InconsistentWillFlag,
559                )));
560            }
561
562            let (input, keep_alive) = be_u16(input)?;
563
564            // Payload
565
566            let (input, client_id) = mstring(input)?;
567
568            let (input, will) = if will_flag == 1 {
569                let (input, topic) = mstring(input)?;
570                let (input, payload) = mpayload(input)?;
571                let retain = will_retain != 0;
572
573                (
574                    input,
575                    Some(MLastWill {
576                        topic,
577                        payload,
578                        retain,
579                        qos: mquality_of_service(will_qos).map_err(|e| {
580                            nom::Err::Error(nom::error::Error::from_external_error(
581                                input,
582                                nom::error::ErrorKind::MapRes,
583                                e,
584                            ))
585                        })?,
586                    }),
587                )
588            } else {
589                (input, None)
590            };
591
592            let (input, username) = if user_name_flag == 1 {
593                mstring.map(Some).parse(input)?
594            } else {
595                (input, None)
596            };
597
598            let (input, password) = if password_flag == 1 {
599                mpayload.map(Some).parse(input)?
600            } else {
601                (input, None)
602            };
603
604            (
605                input,
606                MPacket::Connect(MConnect {
607                    protocol_name,
608                    protocol_level,
609                    clean_session: clean_session == 1,
610                    will,
611                    username,
612                    password,
613                    client_id,
614                    keep_alive,
615                }),
616            )
617        }
618        MPacketKind::Connack => {
619            let (input, (reserved, session_present)): (_, (u8, u8)) =
620                bits::<_, _, nom::error::Error<(&[u8], usize)>, _, _>(tuple((
621                    nom::bits::complete::take(7usize),
622                    nom::bits::complete::take(1usize),
623                )))(input)?;
624
625            if reserved != 0 {
626                return Err(nom::Err::Error(nom::error::Error::from_external_error(
627                    input,
628                    nom::error::ErrorKind::MapRes,
629                    MPacketHeaderError::ForbiddenReservedValue,
630                )));
631            }
632
633            let (input, connect_return_code) = mconnectreturn(input)?;
634
635            (
636                input,
637                MPacket::Connack(MConnack {
638                    session_present: session_present == 1,
639                    connect_return_code,
640                }),
641            )
642        }
643        MPacketKind::Publish { dup, qos, retain } => {
644            let variable_header_start = input;
645
646            let (input, topic_name) = mstring(input)?;
647
648            let (input, id) = if qos != MQualityOfService::AtMostOnce {
649                let (input, id) = mpacketidentifier(input)?;
650
651                (input, Some(id))
652            } else {
653                (input, None)
654            };
655
656            if dup && qos == MQualityOfService::AtMostOnce {
657                return Err(nom::Err::Error(nom::error::Error::from_external_error(
658                    input,
659                    nom::error::ErrorKind::MapRes,
660                    MPacketHeaderError::InvalidDupFlag,
661                )));
662            }
663
664            let variable_header_end = input;
665            let variable_header_len = variable_header_start.len() - variable_header_end.len();
666
667            // Payload
668
669            let payload_length = match fixed_header
670                .remaining_length
671                .checked_sub(variable_header_len as u32)
672            {
673                Some(len) => len,
674                None => {
675                    return Err(nom::Err::Error(nom::error::Error::from_external_error(
676                        input,
677                        nom::error::ErrorKind::MapRes,
678                        MPacketHeaderError::InvalidPacketLength,
679                    )))
680                }
681            };
682            let (input, payload) = take(payload_length)(input)?;
683
684            (
685                input,
686                MPacket::Publish(MPublish {
687                    qos,
688                    dup,
689                    retain,
690                    id,
691                    topic_name,
692                    payload,
693                }),
694            )
695        }
696        MPacketKind::Puback => {
697            let (input, id) = mpacketidentifier(input)?;
698
699            (input, MPacket::Puback(MPuback { id }))
700        }
701        MPacketKind::Pubrec => {
702            let (input, id) = mpacketidentifier(input)?;
703
704            (input, MPacket::Pubrec(MPubrec { id }))
705        }
706        MPacketKind::Pubrel => {
707            let (input, id) = mpacketidentifier(input)?;
708
709            (input, MPacket::Pubrel(MPubrel { id }))
710        }
711        MPacketKind::Pubcomp => {
712            let (input, id) = mpacketidentifier(input)?;
713
714            (input, MPacket::Pubcomp(MPubcomp { id }))
715        }
716        MPacketKind::Subscribe => {
717            let (input, id) = mpacketidentifier(input)?;
718
719            let (input, subscriptions) = msubscriptionrequests(input)?;
720
721            (input, MPacket::Subscribe(MSubscribe { id, subscriptions }))
722        }
723        MPacketKind::Suback => {
724            let (input, id) = mpacketidentifier(input)?;
725
726            let (input, subscription_acks) = msubscriptionacks(input)?;
727
728            (
729                input,
730                MPacket::Suback(MSuback {
731                    id,
732                    subscription_acks,
733                }),
734            )
735        }
736        MPacketKind::Unsubscribe => {
737            let (input, id) = mpacketidentifier(input)?;
738
739            let (input, unsubscriptions) = munsubscriptionrequests(input)?;
740
741            (
742                input,
743                MPacket::Unsubscribe(MUnsubscribe {
744                    id,
745                    unsubscriptions,
746                }),
747            )
748        }
749        MPacketKind::Unsuback => {
750            let (input, id) = mpacketidentifier(input)?;
751
752            (input, MPacket::Unsuback(MUnsuback { id }))
753        }
754        MPacketKind::Pingreq => (input, MPacket::Pingreq(MPingreq)),
755        MPacketKind::Pingresp => (input, MPacket::Pingresp(MPingresp)),
756        MPacketKind::Disconnect => (input, MPacket::Disconnect(MDisconnect)),
757    };
758
759    Ok((input, info))
760}
761
762pub fn mpacket(input: &[u8]) -> MSResult<'_, MPacket<'_>> {
763    let (input, header) = mfixedheader(input)?;
764
765    let data = nom::bytes::complete::take(header.remaining_length);
766
767    let (input, packet) = data
768        .and_then(|input| mpacketdata(header, input))
769        .parse(input)?;
770
771    Ok((input, packet))
772}
773
774#[cfg(test)]
775mod tests {
776    use crate::v3::{
777        packet::{MConnect, MDisconnect, MPacket},
778        strings::MString,
779        will::MLastWill,
780    };
781
782    use super::mpacket;
783    use std::pin::Pin;
784
785    use pretty_assertions::assert_eq;
786
787    #[test]
788    fn check_complete_length() {
789        let input = &[0b1110_0000, 0b0000_0000];
790
791        let (rest, disc) = mpacket(input).unwrap();
792
793        assert_eq!(rest, &[]);
794        assert_eq!(disc, MPacket::Disconnect(MDisconnect));
795    }
796
797    #[test]
798    fn check_will_consistency() {
799        let input = &[
800            0b0001_0000,
801            17,
802            0x0,
803            0x4, // String length
804            b'M',
805            b'Q',
806            b'T',
807            b'T',
808            0x4,         // Level
809            0b0000_1000, // Connect flags, with Will QoS = 1 and will flag = 0
810            0x0,
811            0x10, // Keel Alive in secs
812            0x0,  // Client Identifier
813            0x5,
814            b'H',
815            b'E',
816            b'L',
817            b'L',
818            b'O',
819        ];
820
821        mpacket(input).unwrap_err();
822    }
823
824    #[tokio::test]
825    async fn check_connect_roundtrip() {
826        let input = &[
827            0b0001_0000,
828            37,
829            0x0,
830            0x4, // String length
831            b'M',
832            b'Q',
833            b'T',
834            b'T',
835            0x4,         // Level
836            0b1111_0110, // Connect flags
837            0x0,
838            0x10, // Keel Alive in secs
839            0x0,  // Client Identifier
840            0x5,
841            b'H',
842            b'E',
843            b'L',
844            b'L',
845            b'O',
846            0x0, // Will Topic
847            0x5,
848            b'W',
849            b'O',
850            b'R',
851            b'L',
852            b'D',
853            0x0, // Will Payload
854            0x1,
855            0xFF,
856            0x0,
857            0x5, // Username
858            b'A',
859            b'D',
860            b'M',
861            b'I',
862            b'N',
863            0x0,
864            0x1, // Password
865            0xF0,
866        ];
867
868        let (_rest, conn) = mpacket(input).unwrap();
869
870        assert_eq!(
871            conn,
872            MPacket::Connect(MConnect {
873                protocol_name: MString { value: "MQTT" },
874                protocol_level: 4,
875                clean_session: true,
876                will: Some(MLastWill {
877                    topic: MString { value: "WORLD" },
878                    payload: &[0xFF],
879                    qos: crate::v3::qos::MQualityOfService::ExactlyOnce,
880                    retain: true
881                }),
882                username: Some(MString { value: "ADMIN" }),
883                password: Some(&[0xF0]),
884                keep_alive: 16,
885                client_id: MString { value: "HELLO" }
886            })
887        );
888
889        let mut buf = vec![];
890
891        conn.write_to(Pin::new(&mut buf)).await.unwrap();
892
893        assert_eq!(input, &buf[..]);
894    }
895}