mqtt_proto/common/
types.rs

1use core::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
2use core::convert::TryFrom;
3use core::fmt;
4use core::hash::{Hash, Hasher};
5use core::ops::Deref;
6
7use alloc::string::String;
8use alloc::sync::Arc;
9use alloc::vec::Vec;
10
11use simdutf8::basic::from_utf8;
12
13use crate::{
14    write_bytes, write_u8, AsyncRead, Error, SyncWrite, LEVEL_SEP, MATCH_ALL_CHAR, MATCH_ONE_CHAR,
15    SHARED_PREFIX, SYS_PREFIX,
16};
17
18use super::{read_bytes, read_u8};
19
20pub const MQISDP: &[u8] = b"MQIsdp";
21pub const MQTT: &[u8] = b"MQTT";
22
23/// The ability of encoding type into write trait, and calculating encoded size.
24pub trait Encodable {
25    /// Encode type into writer
26    fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error>;
27    /// Calculate the encoded size.
28    fn encode_len(&self) -> usize;
29}
30
31/// Protocol version.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
33#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
34pub enum Protocol {
35    /// [MQTT 3.1]
36    ///
37    /// [MQTT 3.1]: https://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html
38    V310 = 3,
39
40    /// [MQTT 3.1.1] is the most commonly implemented version.
41    ///
42    /// [MQTT 3.1.1]: https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html
43    V311 = 4,
44
45    /// [MQTT 5.0] is the latest version
46    ///
47    /// [MQTT 5.0]: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html
48    V500 = 5,
49}
50
51impl Protocol {
52    pub fn new(name: &[u8], level: u8) -> Result<Protocol, Error> {
53        match (name, level) {
54            (MQISDP, 3) => Ok(Protocol::V310),
55            (MQTT, 4) => Ok(Protocol::V311),
56            (MQTT, 5) => Ok(Protocol::V500),
57            _ => {
58                let name = from_utf8(name).map_err(|_| Error::InvalidString)?;
59                Err(Error::InvalidProtocol(name.into(), level))
60            }
61        }
62    }
63
64    pub fn to_pair(self) -> (&'static [u8], u8) {
65        match self {
66            Self::V310 => (MQISDP, 3),
67            Self::V311 => (MQTT, 4),
68            Self::V500 => (MQTT, 5),
69        }
70    }
71
72    pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
73        let name_buf = read_bytes(reader).await?;
74        let level = read_u8(reader).await?;
75        Protocol::new(&name_buf, level)
76    }
77}
78
79impl fmt::Display for Protocol {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        let output = match self {
82            Self::V310 => "v3.1",
83            Self::V311 => "v3.1.1",
84            Self::V500 => "v5.0",
85        };
86        write!(f, "{output}")
87    }
88}
89
90impl Encodable for Protocol {
91    fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error> {
92        let (name, level) = self.to_pair();
93        write_bytes(writer, name)?;
94        write_u8(writer, level)?;
95        Ok(())
96    }
97
98    fn encode_len(&self) -> usize {
99        match self {
100            Self::V310 => 2 + 6 + 1,
101            Self::V311 => 2 + 4 + 1,
102            Self::V500 => 2 + 4 + 1,
103        }
104    }
105}
106
107/// Packet identifier
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
109#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
110pub struct Pid(u16);
111
112impl Pid {
113    /// Get the `Pid` as a raw `u16`.
114    pub fn value(self) -> u16 {
115        self.0
116    }
117}
118
119impl Default for Pid {
120    fn default() -> Pid {
121        Pid(1)
122    }
123}
124
125impl TryFrom<u16> for Pid {
126    type Error = Error;
127    fn try_from(value: u16) -> Result<Self, Error> {
128        if value == 0 {
129            Err(Error::ZeroPid)
130        } else {
131            Ok(Pid(value))
132        }
133    }
134}
135
136impl core::ops::Add<u16> for Pid {
137    type Output = Pid;
138
139    /// Adding a `u16` to a `Pid` will wrap around and avoid 0.
140    fn add(self, u: u16) -> Pid {
141        let n = match self.0.overflowing_add(u) {
142            (n, false) => n,
143            (n, true) => n + 1,
144        };
145        Pid(n)
146    }
147}
148
149impl core::ops::AddAssign<u16> for Pid {
150    fn add_assign(&mut self, other: u16) {
151        *self = *self + other;
152    }
153}
154
155impl core::ops::Sub<u16> for Pid {
156    type Output = Pid;
157
158    /// Subing a `u16` to a `Pid` will wrap around and avoid 0.
159    fn sub(self, u: u16) -> Pid {
160        let n = match self.0.overflowing_sub(u) {
161            (0, _) => u16::MAX,
162            (n, false) => n,
163            (n, true) => n - 1,
164        };
165        Pid(n)
166    }
167}
168
169impl core::ops::SubAssign<u16> for Pid {
170    fn sub_assign(&mut self, other: u16) {
171        *self = *self - other;
172    }
173}
174
175/// Packet delivery [Quality of Service] level.
176///
177/// [Quality of Service]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718099
178#[repr(u8)]
179#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
180#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
181pub enum QoS {
182    /// `QoS 0`. At most once. No ack needed.
183    Level0 = 0,
184    /// `QoS 1`. At least once. One ack needed.
185    Level1 = 1,
186    /// `QoS 2`. Exactly once. Two acks needed.
187    Level2 = 2,
188}
189
190impl QoS {
191    pub fn from_u8(byte: u8) -> Result<QoS, Error> {
192        match byte {
193            0 => Ok(QoS::Level0),
194            1 => Ok(QoS::Level1),
195            2 => Ok(QoS::Level2),
196            n => Err(Error::InvalidQos(n)),
197        }
198    }
199}
200
201/// Combined [`QoS`] and [`Pid`].
202///
203/// Used only in [`Publish`] packets.
204///
205/// [`Publish`]: struct.Publish.html
206/// [`QoS`]: enum.QoS.html
207/// [`Pid`]: struct.Pid.html
208#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
209#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
210pub enum QosPid {
211    Level0,
212    Level1(Pid),
213    Level2(Pid),
214}
215
216impl QosPid {
217    /// Extract the [`Pid`] from a `QosPid`, if any.
218    ///
219    /// [`Pid`]: struct.Pid.html
220    pub fn pid(self) -> Option<Pid> {
221        match self {
222            QosPid::Level0 => None,
223            QosPid::Level1(p) => Some(p),
224            QosPid::Level2(p) => Some(p),
225        }
226    }
227
228    /// Extract the [`QoS`] from a `QosPid`.
229    ///
230    /// [`QoS`]: enum.QoS.html
231    pub fn qos(self) -> QoS {
232        match self {
233            QosPid::Level0 => QoS::Level0,
234            QosPid::Level1(_) => QoS::Level1,
235            QosPid::Level2(_) => QoS::Level2,
236        }
237    }
238}
239
240/// Topic name.
241///
242/// See [MQTT 4.7]. The internal value is `Arc<String>`.
243///
244/// [MQTT 4.7]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106
245#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
246#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
247pub struct TopicName(Arc<String>);
248
249impl TopicName {
250    /// Check if the topic name is invalid.
251    pub fn is_invalid(value: &str) -> bool {
252        if value.len() > u16::MAX as usize {
253            return true;
254        }
255        value.contains([MATCH_ONE_CHAR, MATCH_ALL_CHAR, '\0'])
256    }
257
258    pub fn is_shared(&self) -> bool {
259        self.0.starts_with(SHARED_PREFIX)
260    }
261    pub fn is_sys(&self) -> bool {
262        self.0.starts_with(SYS_PREFIX)
263    }
264}
265
266impl fmt::Display for TopicName {
267    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268        write!(f, "{}", self.0)
269    }
270}
271
272impl TryFrom<String> for TopicName {
273    type Error = Error;
274    fn try_from(value: String) -> Result<Self, Error> {
275        if TopicName::is_invalid(value.as_str()) {
276            Err(Error::InvalidTopicName(value))
277        } else {
278            Ok(TopicName(Arc::new(value)))
279        }
280    }
281}
282
283impl Deref for TopicName {
284    type Target = str;
285    fn deref(&self) -> &str {
286        self.0.as_str()
287    }
288}
289
290/// Topic filter.
291///
292/// See [MQTT 4.7]. The internal value is `Arc<String>` and a cache value for
293/// where shared filter byte index started. The traits:
294/// `Hash`/`Ord`/`PartialOrd`/`Eq`/`PartialEq` are all manually implemented for
295/// only contains the string value.
296///
297/// [MQTT 4.7]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106
298#[derive(Debug, Clone)]
299#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
300pub struct TopicFilter {
301    inner: Arc<String>,
302    shared_filter_sep: u16,
303}
304
305impl TopicFilter {
306    /// Check if the topic filter is invalid.
307    ///
308    ///   * The u16 returned is where the bytes index of '/' char before shared topic filter
309    pub fn is_invalid(value: &str) -> (bool, u16) {
310        if value.len() > u16::MAX as usize {
311            return (true, 0);
312        }
313
314        const SHARED_PREFIX_CHARS: [char; 7] = ['$', 's', 'h', 'a', 'r', 'e', '/'];
315
316        // v5.0 [MQTT-4.7.3-1]
317        if value.is_empty() {
318            return (true, 0);
319        }
320
321        let mut last_sep: Option<usize> = None;
322        let mut has_all = false;
323        let mut has_one = false;
324        let mut byte_idx = 0;
325        let mut is_shared = true;
326        let mut shared_group_sep = 0;
327        let mut shared_filter_sep = 0;
328        for (char_idx, c) in value.chars().enumerate() {
329            if c == '\0' {
330                return (true, 0);
331            }
332            // "#" must be last char
333            if has_all {
334                return (true, 0);
335            }
336
337            if is_shared && char_idx < 7 && c != SHARED_PREFIX_CHARS[char_idx] {
338                is_shared = false;
339            }
340
341            if c == LEVEL_SEP {
342                if is_shared {
343                    if shared_group_sep == 0 {
344                        shared_group_sep = byte_idx as u16;
345                    } else if shared_filter_sep == 0 {
346                        shared_filter_sep = byte_idx as u16;
347                    }
348                }
349                // "+" must occupy an entire level of the filter
350                if has_one && Some(char_idx) != last_sep.map(|v| v + 2) && char_idx != 1 {
351                    return (true, 0);
352                }
353                last_sep = Some(char_idx);
354                has_one = false;
355            } else if c == MATCH_ALL_CHAR {
356                // v5.0 [MQTT-4.8.2-2]
357                if shared_group_sep > 0 && shared_filter_sep == 0 {
358                    return (true, 0);
359                }
360                if has_one {
361                    // invalid topic filter: "/+#"
362                    return (true, 0);
363                } else if Some(char_idx) == last_sep.map(|v| v + 1) || char_idx == 0 {
364                    has_all = true;
365                } else {
366                    // invalid topic filter: "/ab#"
367                    return (true, 0);
368                }
369            } else if c == MATCH_ONE_CHAR {
370                // v5.0 [MQTT-4.8.2-2]
371                if shared_group_sep > 0 && shared_filter_sep == 0 {
372                    return (true, 0);
373                }
374                if has_one {
375                    // invalid topic filter: "/++"
376                    return (true, 0);
377                } else if Some(char_idx) == last_sep.map(|v| v + 1) || char_idx == 0 {
378                    has_one = true;
379                } else {
380                    return (true, 0);
381                }
382            }
383
384            byte_idx += c.len_utf8();
385        }
386
387        // v5.0 [MQTT-4.7.3-1]
388        if shared_filter_sep > 0 && shared_filter_sep as usize == value.len() - 1 {
389            return (true, 0);
390        }
391        // v5.0 [MQTT-4.8.2-2]
392        if shared_group_sep > 0 && shared_filter_sep == 0 {
393            return (true, 0);
394        }
395        // v5.0 [MQTT-4.8.2-1]
396        if shared_group_sep + 1 == shared_filter_sep {
397            return (true, 0);
398        }
399
400        debug_assert!(shared_group_sep == 0 || shared_group_sep == 6);
401
402        (false, shared_filter_sep)
403    }
404
405    pub fn is_shared(&self) -> bool {
406        self.shared_filter_sep > 0
407    }
408    pub fn is_sys(&self) -> bool {
409        self.inner.starts_with(SYS_PREFIX)
410    }
411
412    pub fn shared_group_name(&self) -> Option<&str> {
413        if self.is_shared() {
414            let group_end = self.shared_filter_sep as usize;
415            Some(&self.inner[7..group_end])
416        } else {
417            None
418        }
419    }
420
421    pub fn shared_filter(&self) -> Option<&str> {
422        if self.is_shared() {
423            let filter_begin = self.shared_filter_sep as usize + 1;
424            Some(&self.inner[filter_begin..])
425        } else {
426            None
427        }
428    }
429
430    /// return (shared group name, shared filter)
431    pub fn shared_info(&self) -> Option<(&str, &str)> {
432        if self.is_shared() {
433            let group_end = self.shared_filter_sep as usize;
434            let filter_begin = self.shared_filter_sep as usize + 1;
435            Some((&self.inner[7..group_end], &self.inner[filter_begin..]))
436        } else {
437            None
438        }
439    }
440}
441
442impl Hash for TopicFilter {
443    fn hash<H: Hasher>(&self, state: &mut H) {
444        self.inner.hash(state);
445    }
446}
447
448impl Ord for TopicFilter {
449    fn cmp(&self, other: &Self) -> Ordering {
450        self.inner.cmp(&other.inner)
451    }
452}
453
454impl PartialOrd for TopicFilter {
455    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
456        Some(self.cmp(other))
457    }
458}
459
460impl PartialEq for TopicFilter {
461    fn eq(&self, other: &Self) -> bool {
462        self.inner.eq(&other.inner)
463    }
464}
465
466impl Eq for TopicFilter {}
467
468impl fmt::Display for TopicFilter {
469    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
470        write!(f, "{}", self.inner)
471    }
472}
473
474impl TryFrom<String> for TopicFilter {
475    type Error = Error;
476    fn try_from(value: String) -> Result<Self, Error> {
477        let (is_invalid, shared_filter_sep) = TopicFilter::is_invalid(value.as_str());
478        if is_invalid {
479            Err(Error::InvalidTopicFilter(value))
480        } else {
481            Ok(TopicFilter {
482                inner: Arc::new(value),
483                shared_filter_sep,
484            })
485        }
486    }
487}
488
489impl Deref for TopicFilter {
490    type Target = str;
491    fn deref(&self) -> &str {
492        self.inner.as_str()
493    }
494}
495
496/// A bytes data structure represent a dynamic vector or fixed array.
497#[derive(Debug, Clone, PartialEq, Eq, Hash)]
498pub enum VarBytes {
499    Dynamic(Vec<u8>),
500    Fixed2([u8; 2]),
501    Fixed4([u8; 4]),
502}
503
504impl AsRef<[u8]> for VarBytes {
505    /// Return the slice of the internal bytes.
506    fn as_ref(&self) -> &[u8] {
507        match self {
508            VarBytes::Dynamic(vec) => vec,
509            VarBytes::Fixed2(arr) => &arr[..],
510            VarBytes::Fixed4(arr) => &arr[..],
511        }
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use alloc::borrow::ToOwned;
518
519    use super::*;
520
521    #[test]
522    fn pid_add_sub() {
523        let t: Vec<(u16, u16, u16, u16)> = alloc::vec![
524            (2, 1, 1, 3),
525            (100, 1, 99, 101),
526            (1, 1, core::u16::MAX, 2),
527            (1, 2, core::u16::MAX - 1, 3),
528            (1, 3, core::u16::MAX - 2, 4),
529            (core::u16::MAX, 1, core::u16::MAX - 1, 1),
530            (core::u16::MAX, 2, core::u16::MAX - 2, 2),
531            (10, core::u16::MAX, 10, 10),
532            (10, 0, 10, 10),
533            (1, 0, 1, 1),
534            (core::u16::MAX, 0, core::u16::MAX, core::u16::MAX),
535        ];
536        for (cur, d, prev, next) in t {
537            let cur = Pid::try_from(cur).unwrap();
538            let sub = cur - d;
539            let add = cur + d;
540            assert_eq!(prev, sub.value(), "{:?} - {} should be {}", cur, d, prev);
541            assert_eq!(next, add.value(), "{:?} + {} should be {}", cur, d, next);
542        }
543    }
544
545    #[test]
546    fn test_valid_topic_name() {
547        // valid topic name
548        assert!(!TopicName::is_invalid("/abc/def"));
549        assert!(!TopicName::is_invalid("abc/def"));
550        assert!(!TopicName::is_invalid("abc"));
551        assert!(!TopicName::is_invalid("/"));
552        assert!(!TopicName::is_invalid("//"));
553        // NOTE: Because v5.0 topic alias, we let up level to check empty topic name
554        assert!(!TopicName::is_invalid(""));
555        assert!(!TopicName::is_invalid(
556            "a".repeat(u16::max_value() as usize).as_str()
557        ));
558
559        // invalid topic name
560        assert!(TopicName::is_invalid("#"));
561        assert!(TopicName::is_invalid("+"));
562        assert!(TopicName::is_invalid("/+"));
563        assert!(TopicName::is_invalid("/#"));
564        assert!(TopicName::is_invalid("abc/\0"));
565        assert!(TopicName::is_invalid("abc\0def"));
566        assert!(TopicName::is_invalid("abc#def"));
567        assert!(TopicName::is_invalid("abc+def"));
568        assert!(TopicName::is_invalid(
569            "a".repeat(u16::max_value() as usize + 1).as_str()
570        ));
571    }
572
573    #[test]
574    fn test_valid_topic_filter() {
575        let string_65535 = "a".repeat(u16::max_value() as usize);
576        let string_65536 = "a".repeat(u16::max_value() as usize + 1);
577        for (is_invalid, topic) in [
578            // valid topic filter
579            (false, "abc/def"),
580            (false, "abc/+"),
581            (false, "abc/#"),
582            (false, "#"),
583            (false, "+"),
584            (false, "+/"),
585            (false, "+/+"),
586            (false, "///"),
587            (false, "//+/"),
588            (false, "//abc/"),
589            (false, "//+//#"),
590            (false, "/abc/+//#"),
591            (false, "+/abc/+"),
592            (false, string_65535.as_str()),
593            // invalid topic filter
594            (true, ""),
595            (true, "abc\0def"),
596            (true, "abc/\0def"),
597            (true, "++"),
598            (true, "++/"),
599            (true, "/++"),
600            (true, "abc/++"),
601            (true, "abc/++/"),
602            (true, "#/abc"),
603            (true, "/ab#"),
604            (true, "##"),
605            (true, "/abc/ab#"),
606            (true, "/+#"),
607            (true, "//+#"),
608            (true, "/abc/+#"),
609            (true, "xxx/abc/+#"),
610            (true, "xxx/a+bc/"),
611            (true, "x+x/abc/"),
612            (true, "x+/abc/"),
613            (true, "+x/abc/"),
614            (true, "+/abc/++"),
615            (true, "+/a+c/+"),
616            (true, string_65536.as_str()),
617        ] {
618            assert_eq!((is_invalid, 0), TopicFilter::is_invalid(topic));
619        }
620    }
621
622    #[test]
623    fn test_valid_shared_topic_filter() {
624        for (is_invalid, topic) in [
625            // valid topic filter
626            (false, "abc/def"),
627            (false, "abc/+"),
628            (false, "abc/#"),
629            (false, "#"),
630            (false, "+"),
631            (false, "+/"),
632            (false, "+/+"),
633            (false, "///"),
634            (false, "//+/"),
635            (false, "//abc/"),
636            (false, "//+//#"),
637            (false, "/abc/+//#"),
638            (false, "+/abc/+"),
639            // invalid topic filter
640            (true, "abc\0def"),
641            (true, "abc/\0def"),
642            (true, "++"),
643            (true, "++/"),
644            (true, "/++"),
645            (true, "abc/++"),
646            (true, "abc/++/"),
647            (true, "#/abc"),
648            (true, "/ab#"),
649            (true, "##"),
650            (true, "/abc/ab#"),
651            (true, "/+#"),
652            (true, "//+#"),
653            (true, "/abc/+#"),
654            (true, "xxx/abc/+#"),
655            (true, "xxx/a+bc/"),
656            (true, "x+x/abc/"),
657            (true, "x+/abc/"),
658            (true, "+x/abc/"),
659            (true, "+/abc/++"),
660            (true, "+/a+c/+"),
661        ] {
662            let result = if is_invalid { (true, 0) } else { (false, 10) };
663            assert_eq!(
664                result,
665                TopicFilter::is_invalid(alloc::format!("$share/xyz/{}", topic).as_str()),
666            );
667        }
668
669        for (result, raw_filter) in [
670            (Some((None, None)), "$abc/a/b"),
671            (Some((None, None)), "$abc/a/b/xyz/def"),
672            (Some((None, None)), "$sys/abc"),
673            (Some((Some("abc"), Some("xyz"))), "$share/abc/xyz"),
674            (Some((Some("abc"), Some("xyz/ijk"))), "$share/abc/xyz/ijk"),
675            (Some((Some("abc"), Some("/xyz"))), "$share/abc//xyz"),
676            (Some((Some("abc"), Some("/#"))), "$share/abc//#"),
677            (Some((Some("abc"), Some("/a/x/+"))), "$share/abc//a/x/+"),
678            (Some((Some("abc"), Some("+"))), "$share/abc/+"),
679            (Some((Some("你好"), Some("+"))), "$share/你好/+"),
680            (Some((Some("你好"), Some("你好"))), "$share/你好/你好"),
681            (Some((Some("abc"), Some("#"))), "$share/abc/#"),
682            (Some((Some("abc"), Some("#"))), "$share/abc/#"),
683            (None, "$share/abc/"),
684            (None, "$share/abc"),
685            (None, "$share/+/y"),
686            (None, "$share/+/+"),
687            (None, "$share//y"),
688            (None, "$share//+"),
689        ] {
690            if let Some((shared_group, shared_filter)) = result {
691                let filter = TopicFilter::try_from(raw_filter.to_owned()).unwrap();
692                assert_eq!(filter.shared_group_name(), shared_group);
693                assert_eq!(filter.shared_filter(), shared_filter);
694                if let Some(group_name) = shared_group {
695                    assert_eq!(
696                        filter.shared_info(),
697                        Some((group_name, shared_filter.unwrap()))
698                    );
699                }
700            } else {
701                assert_eq!((true, 0), TopicFilter::is_invalid(raw_filter));
702            }
703        }
704    }
705}