mqtt_proto/common/
types.rs

1use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
2use std::convert::TryFrom;
3use std::fmt;
4use std::hash::{Hash, Hasher};
5use std::io;
6use std::ops::Deref;
7use std::slice;
8use std::sync::Arc;
9
10use simdutf8::basic::from_utf8;
11use tokio::io::AsyncRead;
12
13use super::{read_bytes, read_u8};
14use crate::{Error, LEVEL_SEP, MATCH_ALL_CHAR, MATCH_ONE_CHAR, SHARED_PREFIX, SYS_PREFIX};
15
16pub const MQISDP: &[u8] = b"MQIsdp";
17pub const MQTT: &[u8] = b"MQTT";
18
19/// The ability of encoding type into `io::Write`, and calculating encoded size.
20pub trait Encodable {
21    /// Encode type into `io::Write`
22    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()>;
23    /// Calculate the encoded size.
24    fn encode_len(&self) -> usize;
25}
26
27/// Protocol version.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
29#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
30pub enum Protocol {
31    /// [MQTT 3.1]
32    ///
33    /// [MQTT 3.1]: https://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html
34    V310 = 3,
35
36    /// [MQTT 3.1.1] is the most commonly implemented version.
37    ///
38    /// [MQTT 3.1.1]: https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html
39    V311 = 4,
40
41    /// [MQTT 5.0] is the latest version
42    ///
43    /// [MQTT 5.0]: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html
44    V500 = 5,
45}
46
47impl Protocol {
48    pub fn new(name: &[u8], level: u8) -> Result<Protocol, Error> {
49        match (name, level) {
50            (MQISDP, 3) => Ok(Protocol::V310),
51            (MQTT, 4) => Ok(Protocol::V311),
52            (MQTT, 5) => Ok(Protocol::V500),
53            _ => {
54                let name = from_utf8(name).map_err(|_| Error::InvalidString)?;
55                Err(Error::InvalidProtocol(name.into(), level))
56            }
57        }
58    }
59
60    pub fn to_pair(self) -> (&'static [u8], u8) {
61        match self {
62            Self::V310 => (MQISDP, 3),
63            Self::V311 => (MQTT, 4),
64            Self::V500 => (MQTT, 5),
65        }
66    }
67
68    pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
69        let name_buf = read_bytes(reader).await?;
70        let level = read_u8(reader).await?;
71        Protocol::new(&name_buf, level)
72    }
73}
74
75impl fmt::Display for Protocol {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        let output = match self {
78            Self::V310 => "v3.1",
79            Self::V311 => "v3.1.1",
80            Self::V500 => "v5.0",
81        };
82        write!(f, "{output}")
83    }
84}
85
86impl Encodable for Protocol {
87    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
88        let (name, level) = self.to_pair();
89        writer.write_all(&(name.len() as u16).to_be_bytes())?;
90        writer.write_all(name)?;
91        writer.write_all(slice::from_ref(&level))?;
92        Ok(())
93    }
94
95    fn encode_len(&self) -> usize {
96        match self {
97            Self::V310 => 2 + 6 + 1,
98            Self::V311 => 2 + 4 + 1,
99            Self::V500 => 2 + 4 + 1,
100        }
101    }
102}
103
104/// Packet identifier
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
106#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
107pub struct Pid(u16);
108
109impl Pid {
110    /// Get the `Pid` as a raw `u16`.
111    pub fn value(self) -> u16 {
112        self.0
113    }
114}
115
116impl Default for Pid {
117    fn default() -> Pid {
118        Pid(1)
119    }
120}
121
122impl TryFrom<u16> for Pid {
123    type Error = Error;
124    fn try_from(value: u16) -> Result<Self, Error> {
125        if value == 0 {
126            Err(Error::ZeroPid)
127        } else {
128            Ok(Pid(value))
129        }
130    }
131}
132
133impl core::ops::Add<u16> for Pid {
134    type Output = Pid;
135
136    /// Adding a `u16` to a `Pid` will wrap around and avoid 0.
137    fn add(self, u: u16) -> Pid {
138        let n = match self.0.overflowing_add(u) {
139            (n, false) => n,
140            (n, true) => n + 1,
141        };
142        Pid(n)
143    }
144}
145
146impl core::ops::AddAssign<u16> for Pid {
147    fn add_assign(&mut self, other: u16) {
148        *self = *self + other;
149    }
150}
151
152impl core::ops::Sub<u16> for Pid {
153    type Output = Pid;
154
155    /// Subing a `u16` to a `Pid` will wrap around and avoid 0.
156    fn sub(self, u: u16) -> Pid {
157        let n = match self.0.overflowing_sub(u) {
158            (0, _) => core::u16::MAX,
159            (n, false) => n,
160            (n, true) => n - 1,
161        };
162        Pid(n)
163    }
164}
165
166impl core::ops::SubAssign<u16> for Pid {
167    fn sub_assign(&mut self, other: u16) {
168        *self = *self - other;
169    }
170}
171
172/// Packet delivery [Quality of Service] level.
173///
174/// [Quality of Service]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718099
175#[repr(u8)]
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
177#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
178pub enum QoS {
179    /// `QoS 0`. At most once. No ack needed.
180    Level0 = 0,
181    /// `QoS 1`. At least once. One ack needed.
182    Level1 = 1,
183    /// `QoS 2`. Exactly once. Two acks needed.
184    Level2 = 2,
185}
186
187impl QoS {
188    pub fn from_u8(byte: u8) -> Result<QoS, Error> {
189        match byte {
190            0 => Ok(QoS::Level0),
191            1 => Ok(QoS::Level1),
192            2 => Ok(QoS::Level2),
193            n => Err(Error::InvalidQos(n)),
194        }
195    }
196}
197
198/// Combined [`QoS`] and [`Pid`].
199///
200/// Used only in [`Publish`] packets.
201///
202/// [`Publish`]: struct.Publish.html
203/// [`QoS`]: enum.QoS.html
204/// [`Pid`]: struct.Pid.html
205#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
206#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
207pub enum QosPid {
208    Level0,
209    Level1(Pid),
210    Level2(Pid),
211}
212
213impl QosPid {
214    /// Extract the [`Pid`] from a `QosPid`, if any.
215    ///
216    /// [`Pid`]: struct.Pid.html
217    pub fn pid(self) -> Option<Pid> {
218        match self {
219            QosPid::Level0 => None,
220            QosPid::Level1(p) => Some(p),
221            QosPid::Level2(p) => Some(p),
222        }
223    }
224
225    /// Extract the [`QoS`] from a `QosPid`.
226    ///
227    /// [`QoS`]: enum.QoS.html
228    pub fn qos(self) -> QoS {
229        match self {
230            QosPid::Level0 => QoS::Level0,
231            QosPid::Level1(_) => QoS::Level1,
232            QosPid::Level2(_) => QoS::Level2,
233        }
234    }
235}
236
237/// Topic name.
238///
239/// See [MQTT 4.7]. The internal value is `Arc<String>`.
240///
241/// [MQTT 4.7]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106
242#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
243#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
244pub struct TopicName(Arc<String>);
245
246impl TopicName {
247    /// Check if the topic name is invalid.
248    pub fn is_invalid(value: &str) -> bool {
249        if value.len() > u16::max_value() as usize {
250            return true;
251        }
252        value.contains(|c| c == MATCH_ONE_CHAR || c == MATCH_ALL_CHAR || c == '\0')
253    }
254
255    pub fn is_shared(&self) -> bool {
256        self.0.starts_with(SHARED_PREFIX)
257    }
258    pub fn is_sys(&self) -> bool {
259        self.0.starts_with(SYS_PREFIX)
260    }
261}
262
263impl fmt::Display for TopicName {
264    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265        write!(f, "{}", self.0)
266    }
267}
268
269impl TryFrom<String> for TopicName {
270    type Error = Error;
271    fn try_from(value: String) -> Result<Self, Error> {
272        if TopicName::is_invalid(value.as_str()) {
273            Err(Error::InvalidTopicName(value))
274        } else {
275            Ok(TopicName(Arc::new(value)))
276        }
277    }
278}
279
280impl Deref for TopicName {
281    type Target = str;
282    fn deref(&self) -> &str {
283        self.0.as_str()
284    }
285}
286
287/// Topic filter.
288///
289/// See [MQTT 4.7]. The internal value is `Arc<String>` and a cache value for
290/// where shared filter byte index started. The traits:
291/// `Hash`/`Ord`/`PartialOrd`/`Eq`/`PartialEq` are all manually implemented for
292/// only contains the string value.
293///
294/// [MQTT 4.7]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106
295#[derive(Debug, Clone)]
296#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
297pub struct TopicFilter {
298    inner: Arc<String>,
299    shared_filter_sep: u16,
300}
301
302impl TopicFilter {
303    /// Check if the topic filter is invalid.
304    ///
305    ///   * The u16 returned is where the bytes index of '/' char before shared topic filter
306    pub fn is_invalid(value: &str) -> (bool, u16) {
307        if value.len() > u16::max_value() as usize {
308            return (true, 0);
309        }
310
311        const SHARED_PREFIX_CHARS: [char; 7] = ['$', 's', 'h', 'a', 'r', 'e', '/'];
312
313        // v5.0 [MQTT-4.7.3-1]
314        if value.is_empty() {
315            return (true, 0);
316        }
317
318        let mut last_sep: Option<usize> = None;
319        let mut has_all = false;
320        let mut has_one = false;
321        let mut byte_idx = 0;
322        let mut is_shared = true;
323        let mut shared_group_sep = 0;
324        let mut shared_filter_sep = 0;
325        for (char_idx, c) in value.chars().enumerate() {
326            if c == '\0' {
327                return (true, 0);
328            }
329            // "#" must be last char
330            if has_all {
331                return (true, 0);
332            }
333
334            if is_shared && char_idx < 7 && c != SHARED_PREFIX_CHARS[char_idx] {
335                is_shared = false;
336            }
337
338            if c == LEVEL_SEP {
339                if is_shared {
340                    if shared_group_sep == 0 {
341                        shared_group_sep = byte_idx as u16;
342                    } else if shared_filter_sep == 0 {
343                        shared_filter_sep = byte_idx as u16;
344                    }
345                }
346                // "+" must occupy an entire level of the filter
347                if has_one && Some(char_idx) != last_sep.map(|v| v + 2) && char_idx != 1 {
348                    return (true, 0);
349                }
350                last_sep = Some(char_idx);
351                has_one = false;
352            } else if c == MATCH_ALL_CHAR {
353                // v5.0 [MQTT-4.8.2-2]
354                if shared_group_sep > 0 && shared_filter_sep == 0 {
355                    return (true, 0);
356                }
357                if has_one {
358                    // invalid topic filter: "/+#"
359                    return (true, 0);
360                } else if Some(char_idx) == last_sep.map(|v| v + 1) || char_idx == 0 {
361                    has_all = true;
362                } else {
363                    // invalid topic filter: "/ab#"
364                    return (true, 0);
365                }
366            } else if c == MATCH_ONE_CHAR {
367                // v5.0 [MQTT-4.8.2-2]
368                if shared_group_sep > 0 && shared_filter_sep == 0 {
369                    return (true, 0);
370                }
371                if has_one {
372                    // invalid topic filter: "/++"
373                    return (true, 0);
374                } else if Some(char_idx) == last_sep.map(|v| v + 1) || char_idx == 0 {
375                    has_one = true;
376                } else {
377                    return (true, 0);
378                }
379            }
380
381            byte_idx += c.len_utf8();
382        }
383
384        // v5.0 [MQTT-4.7.3-1]
385        if shared_filter_sep > 0 && shared_filter_sep as usize == value.len() - 1 {
386            return (true, 0);
387        }
388        // v5.0 [MQTT-4.8.2-2]
389        if shared_group_sep > 0 && shared_filter_sep == 0 {
390            return (true, 0);
391        }
392        // v5.0 [MQTT-4.8.2-1]
393        if shared_group_sep + 1 == shared_filter_sep {
394            return (true, 0);
395        }
396
397        debug_assert!(shared_group_sep == 0 || shared_group_sep == 6);
398
399        (false, shared_filter_sep)
400    }
401
402    pub fn is_shared(&self) -> bool {
403        self.shared_filter_sep > 0
404    }
405    pub fn is_sys(&self) -> bool {
406        self.inner.starts_with(SYS_PREFIX)
407    }
408
409    pub fn shared_group_name(&self) -> Option<&str> {
410        if self.is_shared() {
411            let group_end = self.shared_filter_sep as usize;
412            Some(&self.inner[7..group_end])
413        } else {
414            None
415        }
416    }
417
418    pub fn shared_filter(&self) -> Option<&str> {
419        if self.is_shared() {
420            let filter_begin = self.shared_filter_sep as usize + 1;
421            Some(&self.inner[filter_begin..])
422        } else {
423            None
424        }
425    }
426
427    /// return (shared group name, shared filter)
428    pub fn shared_info(&self) -> Option<(&str, &str)> {
429        if self.is_shared() {
430            let group_end = self.shared_filter_sep as usize;
431            let filter_begin = self.shared_filter_sep as usize + 1;
432            Some((&self.inner[7..group_end], &self.inner[filter_begin..]))
433        } else {
434            None
435        }
436    }
437}
438
439impl Hash for TopicFilter {
440    fn hash<H: Hasher>(&self, state: &mut H) {
441        self.inner.hash(state);
442    }
443}
444
445impl Ord for TopicFilter {
446    fn cmp(&self, other: &Self) -> Ordering {
447        self.inner.cmp(&other.inner)
448    }
449}
450
451impl PartialOrd for TopicFilter {
452    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
453        Some(self.cmp(other))
454    }
455}
456
457impl PartialEq for TopicFilter {
458    fn eq(&self, other: &Self) -> bool {
459        self.inner.eq(&other.inner)
460    }
461}
462
463impl Eq for TopicFilter {}
464
465impl fmt::Display for TopicFilter {
466    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
467        write!(f, "{}", self.inner)
468    }
469}
470
471impl TryFrom<String> for TopicFilter {
472    type Error = Error;
473    fn try_from(value: String) -> Result<Self, Error> {
474        let (is_invalid, shared_filter_sep) = TopicFilter::is_invalid(value.as_str());
475        if is_invalid {
476            Err(Error::InvalidTopicFilter(value))
477        } else {
478            Ok(TopicFilter {
479                inner: Arc::new(value),
480                shared_filter_sep,
481            })
482        }
483    }
484}
485
486impl Deref for TopicFilter {
487    type Target = str;
488    fn deref(&self) -> &str {
489        self.inner.as_str()
490    }
491}
492
493/// A bytes data structure represent a dynamic vector or fixed array.
494#[derive(Debug, Clone, PartialEq, Eq, Hash)]
495pub enum VarBytes {
496    Dynamic(Vec<u8>),
497    Fixed2([u8; 2]),
498    Fixed4([u8; 4]),
499}
500
501impl AsRef<[u8]> for VarBytes {
502    /// Return the slice of the internal bytes.
503    fn as_ref(&self) -> &[u8] {
504        match self {
505            VarBytes::Dynamic(vec) => vec,
506            VarBytes::Fixed2(arr) => &arr[..],
507            VarBytes::Fixed4(arr) => &arr[..],
508        }
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn pid_add_sub() {
518        let t: Vec<(u16, u16, u16, u16)> = vec![
519            (2, 1, 1, 3),
520            (100, 1, 99, 101),
521            (1, 1, core::u16::MAX, 2),
522            (1, 2, core::u16::MAX - 1, 3),
523            (1, 3, core::u16::MAX - 2, 4),
524            (core::u16::MAX, 1, core::u16::MAX - 1, 1),
525            (core::u16::MAX, 2, core::u16::MAX - 2, 2),
526            (10, core::u16::MAX, 10, 10),
527            (10, 0, 10, 10),
528            (1, 0, 1, 1),
529            (core::u16::MAX, 0, core::u16::MAX, core::u16::MAX),
530        ];
531        for (cur, d, prev, next) in t {
532            let cur = Pid::try_from(cur).unwrap();
533            let sub = cur - d;
534            let add = cur + d;
535            assert_eq!(prev, sub.value(), "{:?} - {} should be {}", cur, d, prev);
536            assert_eq!(next, add.value(), "{:?} + {} should be {}", cur, d, next);
537        }
538    }
539
540    #[test]
541    fn test_valid_topic_name() {
542        // valid topic name
543        assert!(!TopicName::is_invalid("/abc/def"));
544        assert!(!TopicName::is_invalid("abc/def"));
545        assert!(!TopicName::is_invalid("abc"));
546        assert!(!TopicName::is_invalid("/"));
547        assert!(!TopicName::is_invalid("//"));
548        // NOTE: Because v5.0 topic alias, we let up level to check empty topic name
549        assert!(!TopicName::is_invalid(""));
550        assert!(!TopicName::is_invalid(
551            "a".repeat(u16::max_value() as usize).as_str()
552        ));
553
554        // invalid topic name
555        assert!(TopicName::is_invalid("#"));
556        assert!(TopicName::is_invalid("+"));
557        assert!(TopicName::is_invalid("/+"));
558        assert!(TopicName::is_invalid("/#"));
559        assert!(TopicName::is_invalid("abc/\0"));
560        assert!(TopicName::is_invalid("abc\0def"));
561        assert!(TopicName::is_invalid("abc#def"));
562        assert!(TopicName::is_invalid("abc+def"));
563        assert!(TopicName::is_invalid(
564            "a".repeat(u16::max_value() as usize + 1).as_str()
565        ));
566    }
567
568    #[test]
569    fn test_valid_topic_filter() {
570        let string_65535 = "a".repeat(u16::max_value() as usize);
571        let string_65536 = "a".repeat(u16::max_value() as usize + 1);
572        for (is_invalid, topic) in [
573            // valid topic filter
574            (false, "abc/def"),
575            (false, "abc/+"),
576            (false, "abc/#"),
577            (false, "#"),
578            (false, "+"),
579            (false, "+/"),
580            (false, "+/+"),
581            (false, "///"),
582            (false, "//+/"),
583            (false, "//abc/"),
584            (false, "//+//#"),
585            (false, "/abc/+//#"),
586            (false, "+/abc/+"),
587            (false, string_65535.as_str()),
588            // invalid topic filter
589            (true, ""),
590            (true, "abc\0def"),
591            (true, "abc/\0def"),
592            (true, "++"),
593            (true, "++/"),
594            (true, "/++"),
595            (true, "abc/++"),
596            (true, "abc/++/"),
597            (true, "#/abc"),
598            (true, "/ab#"),
599            (true, "##"),
600            (true, "/abc/ab#"),
601            (true, "/+#"),
602            (true, "//+#"),
603            (true, "/abc/+#"),
604            (true, "xxx/abc/+#"),
605            (true, "xxx/a+bc/"),
606            (true, "x+x/abc/"),
607            (true, "x+/abc/"),
608            (true, "+x/abc/"),
609            (true, "+/abc/++"),
610            (true, "+/a+c/+"),
611            (true, string_65536.as_str()),
612        ] {
613            assert_eq!((is_invalid, 0), TopicFilter::is_invalid(topic));
614        }
615    }
616
617    #[test]
618    fn test_valid_shared_topic_filter() {
619        for (is_invalid, topic) in [
620            // valid topic filter
621            (false, "abc/def"),
622            (false, "abc/+"),
623            (false, "abc/#"),
624            (false, "#"),
625            (false, "+"),
626            (false, "+/"),
627            (false, "+/+"),
628            (false, "///"),
629            (false, "//+/"),
630            (false, "//abc/"),
631            (false, "//+//#"),
632            (false, "/abc/+//#"),
633            (false, "+/abc/+"),
634            // invalid topic filter
635            (true, "abc\0def"),
636            (true, "abc/\0def"),
637            (true, "++"),
638            (true, "++/"),
639            (true, "/++"),
640            (true, "abc/++"),
641            (true, "abc/++/"),
642            (true, "#/abc"),
643            (true, "/ab#"),
644            (true, "##"),
645            (true, "/abc/ab#"),
646            (true, "/+#"),
647            (true, "//+#"),
648            (true, "/abc/+#"),
649            (true, "xxx/abc/+#"),
650            (true, "xxx/a+bc/"),
651            (true, "x+x/abc/"),
652            (true, "x+/abc/"),
653            (true, "+x/abc/"),
654            (true, "+/abc/++"),
655            (true, "+/a+c/+"),
656        ] {
657            let result = if is_invalid { (true, 0) } else { (false, 10) };
658            assert_eq!(
659                result,
660                TopicFilter::is_invalid(format!("$share/xyz/{}", topic).as_str()),
661            );
662        }
663
664        for (result, raw_filter) in [
665            (Some((None, None)), "$abc/a/b"),
666            (Some((None, None)), "$abc/a/b/xyz/def"),
667            (Some((None, None)), "$sys/abc"),
668            (Some((Some("abc"), Some("xyz"))), "$share/abc/xyz"),
669            (Some((Some("abc"), Some("xyz/ijk"))), "$share/abc/xyz/ijk"),
670            (Some((Some("abc"), Some("/xyz"))), "$share/abc//xyz"),
671            (Some((Some("abc"), Some("/#"))), "$share/abc//#"),
672            (Some((Some("abc"), Some("/a/x/+"))), "$share/abc//a/x/+"),
673            (Some((Some("abc"), Some("+"))), "$share/abc/+"),
674            (Some((Some("你好"), Some("+"))), "$share/你好/+"),
675            (Some((Some("你好"), Some("你好"))), "$share/你好/你好"),
676            (Some((Some("abc"), Some("#"))), "$share/abc/#"),
677            (Some((Some("abc"), Some("#"))), "$share/abc/#"),
678            (None, "$share/abc/"),
679            (None, "$share/abc"),
680            (None, "$share/+/y"),
681            (None, "$share/+/+"),
682            (None, "$share//y"),
683            (None, "$share//+"),
684        ] {
685            if let Some((shared_group, shared_filter)) = result {
686                let filter = TopicFilter::try_from(raw_filter.to_owned()).unwrap();
687                assert_eq!(filter.shared_group_name(), shared_group);
688                assert_eq!(filter.shared_filter(), shared_filter);
689                if let Some(group_name) = shared_group {
690                    assert_eq!(
691                        filter.shared_info(),
692                        Some((group_name, shared_filter.unwrap()))
693                    );
694                }
695            } else {
696                assert_eq!((true, 0), TopicFilter::is_invalid(raw_filter));
697            }
698        }
699    }
700}