1use core::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
2use core::convert::TryFrom;
3use core::fmt;
4use core::hash::{Hash, Hasher};
5use core::ops::Deref;
6
7use alloc::string::String;
8use alloc::sync::Arc;
9use alloc::vec::Vec;
10
11use simdutf8::basic::from_utf8;
12
13use crate::{
14 write_bytes, write_u8, AsyncRead, Error, SyncWrite, LEVEL_SEP, MATCH_ALL_CHAR, MATCH_ONE_CHAR,
15 SHARED_PREFIX, SYS_PREFIX,
16};
17
18use super::{read_bytes, read_u8};
19
20pub const MQISDP: &[u8] = b"MQIsdp";
21pub const MQTT: &[u8] = b"MQTT";
22
23pub trait Encodable {
25 fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error>;
27 fn encode_len(&self) -> usize;
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
33#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
34pub enum Protocol {
35 V310 = 3,
39
40 V311 = 4,
44
45 V500 = 5,
49}
50
51impl Protocol {
52 pub fn new(name: &[u8], level: u8) -> Result<Protocol, Error> {
53 match (name, level) {
54 (MQISDP, 3) => Ok(Protocol::V310),
55 (MQTT, 4) => Ok(Protocol::V311),
56 (MQTT, 5) => Ok(Protocol::V500),
57 _ => {
58 let name = from_utf8(name).map_err(|_| Error::InvalidString)?;
59 Err(Error::InvalidProtocol(name.into(), level))
60 }
61 }
62 }
63
64 pub fn to_pair(self) -> (&'static [u8], u8) {
65 match self {
66 Self::V310 => (MQISDP, 3),
67 Self::V311 => (MQTT, 4),
68 Self::V500 => (MQTT, 5),
69 }
70 }
71
72 pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
73 let name_buf = read_bytes(reader).await?;
74 let level = read_u8(reader).await?;
75 Protocol::new(&name_buf, level)
76 }
77}
78
79impl fmt::Display for Protocol {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 let output = match self {
82 Self::V310 => "v3.1",
83 Self::V311 => "v3.1.1",
84 Self::V500 => "v5.0",
85 };
86 write!(f, "{output}")
87 }
88}
89
90impl Encodable for Protocol {
91 fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error> {
92 let (name, level) = self.to_pair();
93 write_bytes(writer, name)?;
94 write_u8(writer, level)?;
95 Ok(())
96 }
97
98 fn encode_len(&self) -> usize {
99 match self {
100 Self::V310 => 2 + 6 + 1,
101 Self::V311 => 2 + 4 + 1,
102 Self::V500 => 2 + 4 + 1,
103 }
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
109#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
110pub struct Pid(u16);
111
112impl Pid {
113 pub fn value(self) -> u16 {
115 self.0
116 }
117}
118
119impl Default for Pid {
120 fn default() -> Pid {
121 Pid(1)
122 }
123}
124
125impl TryFrom<u16> for Pid {
126 type Error = Error;
127 fn try_from(value: u16) -> Result<Self, Error> {
128 if value == 0 {
129 Err(Error::ZeroPid)
130 } else {
131 Ok(Pid(value))
132 }
133 }
134}
135
136impl core::ops::Add<u16> for Pid {
137 type Output = Pid;
138
139 fn add(self, u: u16) -> Pid {
141 let n = match self.0.overflowing_add(u) {
142 (n, false) => n,
143 (n, true) => n + 1,
144 };
145 Pid(n)
146 }
147}
148
149impl core::ops::AddAssign<u16> for Pid {
150 fn add_assign(&mut self, other: u16) {
151 *self = *self + other;
152 }
153}
154
155impl core::ops::Sub<u16> for Pid {
156 type Output = Pid;
157
158 fn sub(self, u: u16) -> Pid {
160 let n = match self.0.overflowing_sub(u) {
161 (0, _) => u16::MAX,
162 (n, false) => n,
163 (n, true) => n - 1,
164 };
165 Pid(n)
166 }
167}
168
169impl core::ops::SubAssign<u16> for Pid {
170 fn sub_assign(&mut self, other: u16) {
171 *self = *self - other;
172 }
173}
174
175#[repr(u8)]
179#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
180#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
181pub enum QoS {
182 Level0 = 0,
184 Level1 = 1,
186 Level2 = 2,
188}
189
190impl QoS {
191 pub fn from_u8(byte: u8) -> Result<QoS, Error> {
192 match byte {
193 0 => Ok(QoS::Level0),
194 1 => Ok(QoS::Level1),
195 2 => Ok(QoS::Level2),
196 n => Err(Error::InvalidQos(n)),
197 }
198 }
199}
200
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
209#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
210pub enum QosPid {
211 Level0,
212 Level1(Pid),
213 Level2(Pid),
214}
215
216impl QosPid {
217 pub fn pid(self) -> Option<Pid> {
221 match self {
222 QosPid::Level0 => None,
223 QosPid::Level1(p) => Some(p),
224 QosPid::Level2(p) => Some(p),
225 }
226 }
227
228 pub fn qos(self) -> QoS {
232 match self {
233 QosPid::Level0 => QoS::Level0,
234 QosPid::Level1(_) => QoS::Level1,
235 QosPid::Level2(_) => QoS::Level2,
236 }
237 }
238}
239
240#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
246#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
247pub struct TopicName(Arc<String>);
248
249impl TopicName {
250 pub fn is_invalid(value: &str) -> bool {
252 if value.len() > u16::MAX as usize {
253 return true;
254 }
255 value.contains([MATCH_ONE_CHAR, MATCH_ALL_CHAR, '\0'])
256 }
257
258 pub fn is_shared(&self) -> bool {
259 self.0.starts_with(SHARED_PREFIX)
260 }
261 pub fn is_sys(&self) -> bool {
262 self.0.starts_with(SYS_PREFIX)
263 }
264}
265
266impl fmt::Display for TopicName {
267 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268 write!(f, "{}", self.0)
269 }
270}
271
272impl TryFrom<String> for TopicName {
273 type Error = Error;
274 fn try_from(value: String) -> Result<Self, Error> {
275 if TopicName::is_invalid(value.as_str()) {
276 Err(Error::InvalidTopicName(value))
277 } else {
278 Ok(TopicName(Arc::new(value)))
279 }
280 }
281}
282
283impl Deref for TopicName {
284 type Target = str;
285 fn deref(&self) -> &str {
286 self.0.as_str()
287 }
288}
289
290#[derive(Debug, Clone)]
299#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
300pub struct TopicFilter {
301 inner: Arc<String>,
302 shared_filter_sep: u16,
303}
304
305impl TopicFilter {
306 pub fn is_invalid(value: &str) -> (bool, u16) {
310 if value.len() > u16::MAX as usize {
311 return (true, 0);
312 }
313
314 const SHARED_PREFIX_CHARS: [char; 7] = ['$', 's', 'h', 'a', 'r', 'e', '/'];
315
316 if value.is_empty() {
318 return (true, 0);
319 }
320
321 let mut last_sep: Option<usize> = None;
322 let mut has_all = false;
323 let mut has_one = false;
324 let mut byte_idx = 0;
325 let mut is_shared = true;
326 let mut shared_group_sep = 0;
327 let mut shared_filter_sep = 0;
328 for (char_idx, c) in value.chars().enumerate() {
329 if c == '\0' {
330 return (true, 0);
331 }
332 if has_all {
334 return (true, 0);
335 }
336
337 if is_shared && char_idx < 7 && c != SHARED_PREFIX_CHARS[char_idx] {
338 is_shared = false;
339 }
340
341 if c == LEVEL_SEP {
342 if is_shared {
343 if shared_group_sep == 0 {
344 shared_group_sep = byte_idx as u16;
345 } else if shared_filter_sep == 0 {
346 shared_filter_sep = byte_idx as u16;
347 }
348 }
349 if has_one && Some(char_idx) != last_sep.map(|v| v + 2) && char_idx != 1 {
351 return (true, 0);
352 }
353 last_sep = Some(char_idx);
354 has_one = false;
355 } else if c == MATCH_ALL_CHAR {
356 if shared_group_sep > 0 && shared_filter_sep == 0 {
358 return (true, 0);
359 }
360 if has_one {
361 return (true, 0);
363 } else if Some(char_idx) == last_sep.map(|v| v + 1) || char_idx == 0 {
364 has_all = true;
365 } else {
366 return (true, 0);
368 }
369 } else if c == MATCH_ONE_CHAR {
370 if shared_group_sep > 0 && shared_filter_sep == 0 {
372 return (true, 0);
373 }
374 if has_one {
375 return (true, 0);
377 } else if Some(char_idx) == last_sep.map(|v| v + 1) || char_idx == 0 {
378 has_one = true;
379 } else {
380 return (true, 0);
381 }
382 }
383
384 byte_idx += c.len_utf8();
385 }
386
387 if shared_filter_sep > 0 && shared_filter_sep as usize == value.len() - 1 {
389 return (true, 0);
390 }
391 if shared_group_sep > 0 && shared_filter_sep == 0 {
393 return (true, 0);
394 }
395 if shared_group_sep + 1 == shared_filter_sep {
397 return (true, 0);
398 }
399
400 debug_assert!(shared_group_sep == 0 || shared_group_sep == 6);
401
402 (false, shared_filter_sep)
403 }
404
405 pub fn is_shared(&self) -> bool {
406 self.shared_filter_sep > 0
407 }
408 pub fn is_sys(&self) -> bool {
409 self.inner.starts_with(SYS_PREFIX)
410 }
411
412 pub fn shared_group_name(&self) -> Option<&str> {
413 if self.is_shared() {
414 let group_end = self.shared_filter_sep as usize;
415 Some(&self.inner[7..group_end])
416 } else {
417 None
418 }
419 }
420
421 pub fn shared_filter(&self) -> Option<&str> {
422 if self.is_shared() {
423 let filter_begin = self.shared_filter_sep as usize + 1;
424 Some(&self.inner[filter_begin..])
425 } else {
426 None
427 }
428 }
429
430 pub fn shared_info(&self) -> Option<(&str, &str)> {
432 if self.is_shared() {
433 let group_end = self.shared_filter_sep as usize;
434 let filter_begin = self.shared_filter_sep as usize + 1;
435 Some((&self.inner[7..group_end], &self.inner[filter_begin..]))
436 } else {
437 None
438 }
439 }
440}
441
442impl Hash for TopicFilter {
443 fn hash<H: Hasher>(&self, state: &mut H) {
444 self.inner.hash(state);
445 }
446}
447
448impl Ord for TopicFilter {
449 fn cmp(&self, other: &Self) -> Ordering {
450 self.inner.cmp(&other.inner)
451 }
452}
453
454impl PartialOrd for TopicFilter {
455 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
456 Some(self.cmp(other))
457 }
458}
459
460impl PartialEq for TopicFilter {
461 fn eq(&self, other: &Self) -> bool {
462 self.inner.eq(&other.inner)
463 }
464}
465
466impl Eq for TopicFilter {}
467
468impl fmt::Display for TopicFilter {
469 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
470 write!(f, "{}", self.inner)
471 }
472}
473
474impl TryFrom<String> for TopicFilter {
475 type Error = Error;
476 fn try_from(value: String) -> Result<Self, Error> {
477 let (is_invalid, shared_filter_sep) = TopicFilter::is_invalid(value.as_str());
478 if is_invalid {
479 Err(Error::InvalidTopicFilter(value))
480 } else {
481 Ok(TopicFilter {
482 inner: Arc::new(value),
483 shared_filter_sep,
484 })
485 }
486 }
487}
488
489impl Deref for TopicFilter {
490 type Target = str;
491 fn deref(&self) -> &str {
492 self.inner.as_str()
493 }
494}
495
496#[derive(Debug, Clone, PartialEq, Eq, Hash)]
498pub enum VarBytes {
499 Dynamic(Vec<u8>),
500 Fixed2([u8; 2]),
501 Fixed4([u8; 4]),
502}
503
504impl AsRef<[u8]> for VarBytes {
505 fn as_ref(&self) -> &[u8] {
507 match self {
508 VarBytes::Dynamic(vec) => vec,
509 VarBytes::Fixed2(arr) => &arr[..],
510 VarBytes::Fixed4(arr) => &arr[..],
511 }
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use alloc::borrow::ToOwned;
518
519 use super::*;
520
521 #[test]
522 fn pid_add_sub() {
523 let t: Vec<(u16, u16, u16, u16)> = alloc::vec![
524 (2, 1, 1, 3),
525 (100, 1, 99, 101),
526 (1, 1, core::u16::MAX, 2),
527 (1, 2, core::u16::MAX - 1, 3),
528 (1, 3, core::u16::MAX - 2, 4),
529 (core::u16::MAX, 1, core::u16::MAX - 1, 1),
530 (core::u16::MAX, 2, core::u16::MAX - 2, 2),
531 (10, core::u16::MAX, 10, 10),
532 (10, 0, 10, 10),
533 (1, 0, 1, 1),
534 (core::u16::MAX, 0, core::u16::MAX, core::u16::MAX),
535 ];
536 for (cur, d, prev, next) in t {
537 let cur = Pid::try_from(cur).unwrap();
538 let sub = cur - d;
539 let add = cur + d;
540 assert_eq!(prev, sub.value(), "{:?} - {} should be {}", cur, d, prev);
541 assert_eq!(next, add.value(), "{:?} + {} should be {}", cur, d, next);
542 }
543 }
544
545 #[test]
546 fn test_valid_topic_name() {
547 assert!(!TopicName::is_invalid("/abc/def"));
549 assert!(!TopicName::is_invalid("abc/def"));
550 assert!(!TopicName::is_invalid("abc"));
551 assert!(!TopicName::is_invalid("/"));
552 assert!(!TopicName::is_invalid("//"));
553 assert!(!TopicName::is_invalid(""));
555 assert!(!TopicName::is_invalid(
556 "a".repeat(u16::max_value() as usize).as_str()
557 ));
558
559 assert!(TopicName::is_invalid("#"));
561 assert!(TopicName::is_invalid("+"));
562 assert!(TopicName::is_invalid("/+"));
563 assert!(TopicName::is_invalid("/#"));
564 assert!(TopicName::is_invalid("abc/\0"));
565 assert!(TopicName::is_invalid("abc\0def"));
566 assert!(TopicName::is_invalid("abc#def"));
567 assert!(TopicName::is_invalid("abc+def"));
568 assert!(TopicName::is_invalid(
569 "a".repeat(u16::max_value() as usize + 1).as_str()
570 ));
571 }
572
573 #[test]
574 fn test_valid_topic_filter() {
575 let string_65535 = "a".repeat(u16::max_value() as usize);
576 let string_65536 = "a".repeat(u16::max_value() as usize + 1);
577 for (is_invalid, topic) in [
578 (false, "abc/def"),
580 (false, "abc/+"),
581 (false, "abc/#"),
582 (false, "#"),
583 (false, "+"),
584 (false, "+/"),
585 (false, "+/+"),
586 (false, "///"),
587 (false, "//+/"),
588 (false, "//abc/"),
589 (false, "//+//#"),
590 (false, "/abc/+//#"),
591 (false, "+/abc/+"),
592 (false, string_65535.as_str()),
593 (true, ""),
595 (true, "abc\0def"),
596 (true, "abc/\0def"),
597 (true, "++"),
598 (true, "++/"),
599 (true, "/++"),
600 (true, "abc/++"),
601 (true, "abc/++/"),
602 (true, "#/abc"),
603 (true, "/ab#"),
604 (true, "##"),
605 (true, "/abc/ab#"),
606 (true, "/+#"),
607 (true, "//+#"),
608 (true, "/abc/+#"),
609 (true, "xxx/abc/+#"),
610 (true, "xxx/a+bc/"),
611 (true, "x+x/abc/"),
612 (true, "x+/abc/"),
613 (true, "+x/abc/"),
614 (true, "+/abc/++"),
615 (true, "+/a+c/+"),
616 (true, string_65536.as_str()),
617 ] {
618 assert_eq!((is_invalid, 0), TopicFilter::is_invalid(topic));
619 }
620 }
621
622 #[test]
623 fn test_valid_shared_topic_filter() {
624 for (is_invalid, topic) in [
625 (false, "abc/def"),
627 (false, "abc/+"),
628 (false, "abc/#"),
629 (false, "#"),
630 (false, "+"),
631 (false, "+/"),
632 (false, "+/+"),
633 (false, "///"),
634 (false, "//+/"),
635 (false, "//abc/"),
636 (false, "//+//#"),
637 (false, "/abc/+//#"),
638 (false, "+/abc/+"),
639 (true, "abc\0def"),
641 (true, "abc/\0def"),
642 (true, "++"),
643 (true, "++/"),
644 (true, "/++"),
645 (true, "abc/++"),
646 (true, "abc/++/"),
647 (true, "#/abc"),
648 (true, "/ab#"),
649 (true, "##"),
650 (true, "/abc/ab#"),
651 (true, "/+#"),
652 (true, "//+#"),
653 (true, "/abc/+#"),
654 (true, "xxx/abc/+#"),
655 (true, "xxx/a+bc/"),
656 (true, "x+x/abc/"),
657 (true, "x+/abc/"),
658 (true, "+x/abc/"),
659 (true, "+/abc/++"),
660 (true, "+/a+c/+"),
661 ] {
662 let result = if is_invalid { (true, 0) } else { (false, 10) };
663 assert_eq!(
664 result,
665 TopicFilter::is_invalid(alloc::format!("$share/xyz/{}", topic).as_str()),
666 );
667 }
668
669 for (result, raw_filter) in [
670 (Some((None, None)), "$abc/a/b"),
671 (Some((None, None)), "$abc/a/b/xyz/def"),
672 (Some((None, None)), "$sys/abc"),
673 (Some((Some("abc"), Some("xyz"))), "$share/abc/xyz"),
674 (Some((Some("abc"), Some("xyz/ijk"))), "$share/abc/xyz/ijk"),
675 (Some((Some("abc"), Some("/xyz"))), "$share/abc//xyz"),
676 (Some((Some("abc"), Some("/#"))), "$share/abc//#"),
677 (Some((Some("abc"), Some("/a/x/+"))), "$share/abc//a/x/+"),
678 (Some((Some("abc"), Some("+"))), "$share/abc/+"),
679 (Some((Some("你好"), Some("+"))), "$share/你好/+"),
680 (Some((Some("你好"), Some("你好"))), "$share/你好/你好"),
681 (Some((Some("abc"), Some("#"))), "$share/abc/#"),
682 (Some((Some("abc"), Some("#"))), "$share/abc/#"),
683 (None, "$share/abc/"),
684 (None, "$share/abc"),
685 (None, "$share/+/y"),
686 (None, "$share/+/+"),
687 (None, "$share//y"),
688 (None, "$share//+"),
689 ] {
690 if let Some((shared_group, shared_filter)) = result {
691 let filter = TopicFilter::try_from(raw_filter.to_owned()).unwrap();
692 assert_eq!(filter.shared_group_name(), shared_group);
693 assert_eq!(filter.shared_filter(), shared_filter);
694 if let Some(group_name) = shared_group {
695 assert_eq!(
696 filter.shared_info(),
697 Some((group_name, shared_filter.unwrap()))
698 );
699 }
700 } else {
701 assert_eq!((true, 0), TopicFilter::is_invalid(raw_filter));
702 }
703 }
704 }
705}