1use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
2use std::convert::TryFrom;
3use std::fmt;
4use std::hash::{Hash, Hasher};
5use std::io;
6use std::ops::Deref;
7use std::slice;
8use std::sync::Arc;
9
10use simdutf8::basic::from_utf8;
11use tokio::io::AsyncRead;
12
13use super::{read_bytes, read_u8};
14use crate::{Error, LEVEL_SEP, MATCH_ALL_CHAR, MATCH_ONE_CHAR, SHARED_PREFIX, SYS_PREFIX};
15
16pub const MQISDP: &[u8] = b"MQIsdp";
17pub const MQTT: &[u8] = b"MQTT";
18
19pub trait Encodable {
21 fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()>;
23 fn encode_len(&self) -> usize;
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
29#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
30pub enum Protocol {
31 V310 = 3,
35
36 V311 = 4,
40
41 V500 = 5,
45}
46
47impl Protocol {
48 pub fn new(name: &[u8], level: u8) -> Result<Protocol, Error> {
49 match (name, level) {
50 (MQISDP, 3) => Ok(Protocol::V310),
51 (MQTT, 4) => Ok(Protocol::V311),
52 (MQTT, 5) => Ok(Protocol::V500),
53 _ => {
54 let name = from_utf8(name).map_err(|_| Error::InvalidString)?;
55 Err(Error::InvalidProtocol(name.into(), level))
56 }
57 }
58 }
59
60 pub fn to_pair(self) -> (&'static [u8], u8) {
61 match self {
62 Self::V310 => (MQISDP, 3),
63 Self::V311 => (MQTT, 4),
64 Self::V500 => (MQTT, 5),
65 }
66 }
67
68 pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
69 let name_buf = read_bytes(reader).await?;
70 let level = read_u8(reader).await?;
71 Protocol::new(&name_buf, level)
72 }
73}
74
75impl fmt::Display for Protocol {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 let output = match self {
78 Self::V310 => "v3.1",
79 Self::V311 => "v3.1.1",
80 Self::V500 => "v5.0",
81 };
82 write!(f, "{output}")
83 }
84}
85
86impl Encodable for Protocol {
87 fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
88 let (name, level) = self.to_pair();
89 writer.write_all(&(name.len() as u16).to_be_bytes())?;
90 writer.write_all(name)?;
91 writer.write_all(slice::from_ref(&level))?;
92 Ok(())
93 }
94
95 fn encode_len(&self) -> usize {
96 match self {
97 Self::V310 => 2 + 6 + 1,
98 Self::V311 => 2 + 4 + 1,
99 Self::V500 => 2 + 4 + 1,
100 }
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
106#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
107pub struct Pid(u16);
108
109impl Pid {
110 pub fn value(self) -> u16 {
112 self.0
113 }
114}
115
116impl Default for Pid {
117 fn default() -> Pid {
118 Pid(1)
119 }
120}
121
122impl TryFrom<u16> for Pid {
123 type Error = Error;
124 fn try_from(value: u16) -> Result<Self, Error> {
125 if value == 0 {
126 Err(Error::ZeroPid)
127 } else {
128 Ok(Pid(value))
129 }
130 }
131}
132
133impl core::ops::Add<u16> for Pid {
134 type Output = Pid;
135
136 fn add(self, u: u16) -> Pid {
138 let n = match self.0.overflowing_add(u) {
139 (n, false) => n,
140 (n, true) => n + 1,
141 };
142 Pid(n)
143 }
144}
145
146impl core::ops::AddAssign<u16> for Pid {
147 fn add_assign(&mut self, other: u16) {
148 *self = *self + other;
149 }
150}
151
152impl core::ops::Sub<u16> for Pid {
153 type Output = Pid;
154
155 fn sub(self, u: u16) -> Pid {
157 let n = match self.0.overflowing_sub(u) {
158 (0, _) => core::u16::MAX,
159 (n, false) => n,
160 (n, true) => n - 1,
161 };
162 Pid(n)
163 }
164}
165
166impl core::ops::SubAssign<u16> for Pid {
167 fn sub_assign(&mut self, other: u16) {
168 *self = *self - other;
169 }
170}
171
172#[repr(u8)]
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
177#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
178pub enum QoS {
179 Level0 = 0,
181 Level1 = 1,
183 Level2 = 2,
185}
186
187impl QoS {
188 pub fn from_u8(byte: u8) -> Result<QoS, Error> {
189 match byte {
190 0 => Ok(QoS::Level0),
191 1 => Ok(QoS::Level1),
192 2 => Ok(QoS::Level2),
193 n => Err(Error::InvalidQos(n)),
194 }
195 }
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
206#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
207pub enum QosPid {
208 Level0,
209 Level1(Pid),
210 Level2(Pid),
211}
212
213impl QosPid {
214 pub fn pid(self) -> Option<Pid> {
218 match self {
219 QosPid::Level0 => None,
220 QosPid::Level1(p) => Some(p),
221 QosPid::Level2(p) => Some(p),
222 }
223 }
224
225 pub fn qos(self) -> QoS {
229 match self {
230 QosPid::Level0 => QoS::Level0,
231 QosPid::Level1(_) => QoS::Level1,
232 QosPid::Level2(_) => QoS::Level2,
233 }
234 }
235}
236
237#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
243#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
244pub struct TopicName(Arc<String>);
245
246impl TopicName {
247 pub fn is_invalid(value: &str) -> bool {
249 if value.len() > u16::max_value() as usize {
250 return true;
251 }
252 value.contains(|c| c == MATCH_ONE_CHAR || c == MATCH_ALL_CHAR || c == '\0')
253 }
254
255 pub fn is_shared(&self) -> bool {
256 self.0.starts_with(SHARED_PREFIX)
257 }
258 pub fn is_sys(&self) -> bool {
259 self.0.starts_with(SYS_PREFIX)
260 }
261}
262
263impl fmt::Display for TopicName {
264 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265 write!(f, "{}", self.0)
266 }
267}
268
269impl TryFrom<String> for TopicName {
270 type Error = Error;
271 fn try_from(value: String) -> Result<Self, Error> {
272 if TopicName::is_invalid(value.as_str()) {
273 Err(Error::InvalidTopicName(value))
274 } else {
275 Ok(TopicName(Arc::new(value)))
276 }
277 }
278}
279
280impl Deref for TopicName {
281 type Target = str;
282 fn deref(&self) -> &str {
283 self.0.as_str()
284 }
285}
286
287#[derive(Debug, Clone)]
296#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
297pub struct TopicFilter {
298 inner: Arc<String>,
299 shared_filter_sep: u16,
300}
301
302impl TopicFilter {
303 pub fn is_invalid(value: &str) -> (bool, u16) {
307 if value.len() > u16::max_value() as usize {
308 return (true, 0);
309 }
310
311 const SHARED_PREFIX_CHARS: [char; 7] = ['$', 's', 'h', 'a', 'r', 'e', '/'];
312
313 if value.is_empty() {
315 return (true, 0);
316 }
317
318 let mut last_sep: Option<usize> = None;
319 let mut has_all = false;
320 let mut has_one = false;
321 let mut byte_idx = 0;
322 let mut is_shared = true;
323 let mut shared_group_sep = 0;
324 let mut shared_filter_sep = 0;
325 for (char_idx, c) in value.chars().enumerate() {
326 if c == '\0' {
327 return (true, 0);
328 }
329 if has_all {
331 return (true, 0);
332 }
333
334 if is_shared && char_idx < 7 && c != SHARED_PREFIX_CHARS[char_idx] {
335 is_shared = false;
336 }
337
338 if c == LEVEL_SEP {
339 if is_shared {
340 if shared_group_sep == 0 {
341 shared_group_sep = byte_idx as u16;
342 } else if shared_filter_sep == 0 {
343 shared_filter_sep = byte_idx as u16;
344 }
345 }
346 if has_one && Some(char_idx) != last_sep.map(|v| v + 2) && char_idx != 1 {
348 return (true, 0);
349 }
350 last_sep = Some(char_idx);
351 has_one = false;
352 } else if c == MATCH_ALL_CHAR {
353 if shared_group_sep > 0 && shared_filter_sep == 0 {
355 return (true, 0);
356 }
357 if has_one {
358 return (true, 0);
360 } else if Some(char_idx) == last_sep.map(|v| v + 1) || char_idx == 0 {
361 has_all = true;
362 } else {
363 return (true, 0);
365 }
366 } else if c == MATCH_ONE_CHAR {
367 if shared_group_sep > 0 && shared_filter_sep == 0 {
369 return (true, 0);
370 }
371 if has_one {
372 return (true, 0);
374 } else if Some(char_idx) == last_sep.map(|v| v + 1) || char_idx == 0 {
375 has_one = true;
376 } else {
377 return (true, 0);
378 }
379 }
380
381 byte_idx += c.len_utf8();
382 }
383
384 if shared_filter_sep > 0 && shared_filter_sep as usize == value.len() - 1 {
386 return (true, 0);
387 }
388 if shared_group_sep > 0 && shared_filter_sep == 0 {
390 return (true, 0);
391 }
392 if shared_group_sep + 1 == shared_filter_sep {
394 return (true, 0);
395 }
396
397 debug_assert!(shared_group_sep == 0 || shared_group_sep == 6);
398
399 (false, shared_filter_sep)
400 }
401
402 pub fn is_shared(&self) -> bool {
403 self.shared_filter_sep > 0
404 }
405 pub fn is_sys(&self) -> bool {
406 self.inner.starts_with(SYS_PREFIX)
407 }
408
409 pub fn shared_group_name(&self) -> Option<&str> {
410 if self.is_shared() {
411 let group_end = self.shared_filter_sep as usize;
412 Some(&self.inner[7..group_end])
413 } else {
414 None
415 }
416 }
417
418 pub fn shared_filter(&self) -> Option<&str> {
419 if self.is_shared() {
420 let filter_begin = self.shared_filter_sep as usize + 1;
421 Some(&self.inner[filter_begin..])
422 } else {
423 None
424 }
425 }
426
427 pub fn shared_info(&self) -> Option<(&str, &str)> {
429 if self.is_shared() {
430 let group_end = self.shared_filter_sep as usize;
431 let filter_begin = self.shared_filter_sep as usize + 1;
432 Some((&self.inner[7..group_end], &self.inner[filter_begin..]))
433 } else {
434 None
435 }
436 }
437}
438
439impl Hash for TopicFilter {
440 fn hash<H: Hasher>(&self, state: &mut H) {
441 self.inner.hash(state);
442 }
443}
444
445impl Ord for TopicFilter {
446 fn cmp(&self, other: &Self) -> Ordering {
447 self.inner.cmp(&other.inner)
448 }
449}
450
451impl PartialOrd for TopicFilter {
452 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
453 Some(self.cmp(other))
454 }
455}
456
457impl PartialEq for TopicFilter {
458 fn eq(&self, other: &Self) -> bool {
459 self.inner.eq(&other.inner)
460 }
461}
462
463impl Eq for TopicFilter {}
464
465impl fmt::Display for TopicFilter {
466 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
467 write!(f, "{}", self.inner)
468 }
469}
470
471impl TryFrom<String> for TopicFilter {
472 type Error = Error;
473 fn try_from(value: String) -> Result<Self, Error> {
474 let (is_invalid, shared_filter_sep) = TopicFilter::is_invalid(value.as_str());
475 if is_invalid {
476 Err(Error::InvalidTopicFilter(value))
477 } else {
478 Ok(TopicFilter {
479 inner: Arc::new(value),
480 shared_filter_sep,
481 })
482 }
483 }
484}
485
486impl Deref for TopicFilter {
487 type Target = str;
488 fn deref(&self) -> &str {
489 self.inner.as_str()
490 }
491}
492
493#[derive(Debug, Clone, PartialEq, Eq, Hash)]
495pub enum VarBytes {
496 Dynamic(Vec<u8>),
497 Fixed2([u8; 2]),
498 Fixed4([u8; 4]),
499}
500
501impl AsRef<[u8]> for VarBytes {
502 fn as_ref(&self) -> &[u8] {
504 match self {
505 VarBytes::Dynamic(vec) => vec,
506 VarBytes::Fixed2(arr) => &arr[..],
507 VarBytes::Fixed4(arr) => &arr[..],
508 }
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn pid_add_sub() {
518 let t: Vec<(u16, u16, u16, u16)> = vec![
519 (2, 1, 1, 3),
520 (100, 1, 99, 101),
521 (1, 1, core::u16::MAX, 2),
522 (1, 2, core::u16::MAX - 1, 3),
523 (1, 3, core::u16::MAX - 2, 4),
524 (core::u16::MAX, 1, core::u16::MAX - 1, 1),
525 (core::u16::MAX, 2, core::u16::MAX - 2, 2),
526 (10, core::u16::MAX, 10, 10),
527 (10, 0, 10, 10),
528 (1, 0, 1, 1),
529 (core::u16::MAX, 0, core::u16::MAX, core::u16::MAX),
530 ];
531 for (cur, d, prev, next) in t {
532 let cur = Pid::try_from(cur).unwrap();
533 let sub = cur - d;
534 let add = cur + d;
535 assert_eq!(prev, sub.value(), "{:?} - {} should be {}", cur, d, prev);
536 assert_eq!(next, add.value(), "{:?} + {} should be {}", cur, d, next);
537 }
538 }
539
540 #[test]
541 fn test_valid_topic_name() {
542 assert!(!TopicName::is_invalid("/abc/def"));
544 assert!(!TopicName::is_invalid("abc/def"));
545 assert!(!TopicName::is_invalid("abc"));
546 assert!(!TopicName::is_invalid("/"));
547 assert!(!TopicName::is_invalid("//"));
548 assert!(!TopicName::is_invalid(""));
550 assert!(!TopicName::is_invalid(
551 "a".repeat(u16::max_value() as usize).as_str()
552 ));
553
554 assert!(TopicName::is_invalid("#"));
556 assert!(TopicName::is_invalid("+"));
557 assert!(TopicName::is_invalid("/+"));
558 assert!(TopicName::is_invalid("/#"));
559 assert!(TopicName::is_invalid("abc/\0"));
560 assert!(TopicName::is_invalid("abc\0def"));
561 assert!(TopicName::is_invalid("abc#def"));
562 assert!(TopicName::is_invalid("abc+def"));
563 assert!(TopicName::is_invalid(
564 "a".repeat(u16::max_value() as usize + 1).as_str()
565 ));
566 }
567
568 #[test]
569 fn test_valid_topic_filter() {
570 let string_65535 = "a".repeat(u16::max_value() as usize);
571 let string_65536 = "a".repeat(u16::max_value() as usize + 1);
572 for (is_invalid, topic) in [
573 (false, "abc/def"),
575 (false, "abc/+"),
576 (false, "abc/#"),
577 (false, "#"),
578 (false, "+"),
579 (false, "+/"),
580 (false, "+/+"),
581 (false, "///"),
582 (false, "//+/"),
583 (false, "//abc/"),
584 (false, "//+//#"),
585 (false, "/abc/+//#"),
586 (false, "+/abc/+"),
587 (false, string_65535.as_str()),
588 (true, ""),
590 (true, "abc\0def"),
591 (true, "abc/\0def"),
592 (true, "++"),
593 (true, "++/"),
594 (true, "/++"),
595 (true, "abc/++"),
596 (true, "abc/++/"),
597 (true, "#/abc"),
598 (true, "/ab#"),
599 (true, "##"),
600 (true, "/abc/ab#"),
601 (true, "/+#"),
602 (true, "//+#"),
603 (true, "/abc/+#"),
604 (true, "xxx/abc/+#"),
605 (true, "xxx/a+bc/"),
606 (true, "x+x/abc/"),
607 (true, "x+/abc/"),
608 (true, "+x/abc/"),
609 (true, "+/abc/++"),
610 (true, "+/a+c/+"),
611 (true, string_65536.as_str()),
612 ] {
613 assert_eq!((is_invalid, 0), TopicFilter::is_invalid(topic));
614 }
615 }
616
617 #[test]
618 fn test_valid_shared_topic_filter() {
619 for (is_invalid, topic) in [
620 (false, "abc/def"),
622 (false, "abc/+"),
623 (false, "abc/#"),
624 (false, "#"),
625 (false, "+"),
626 (false, "+/"),
627 (false, "+/+"),
628 (false, "///"),
629 (false, "//+/"),
630 (false, "//abc/"),
631 (false, "//+//#"),
632 (false, "/abc/+//#"),
633 (false, "+/abc/+"),
634 (true, "abc\0def"),
636 (true, "abc/\0def"),
637 (true, "++"),
638 (true, "++/"),
639 (true, "/++"),
640 (true, "abc/++"),
641 (true, "abc/++/"),
642 (true, "#/abc"),
643 (true, "/ab#"),
644 (true, "##"),
645 (true, "/abc/ab#"),
646 (true, "/+#"),
647 (true, "//+#"),
648 (true, "/abc/+#"),
649 (true, "xxx/abc/+#"),
650 (true, "xxx/a+bc/"),
651 (true, "x+x/abc/"),
652 (true, "x+/abc/"),
653 (true, "+x/abc/"),
654 (true, "+/abc/++"),
655 (true, "+/a+c/+"),
656 ] {
657 let result = if is_invalid { (true, 0) } else { (false, 10) };
658 assert_eq!(
659 result,
660 TopicFilter::is_invalid(format!("$share/xyz/{}", topic).as_str()),
661 );
662 }
663
664 for (result, raw_filter) in [
665 (Some((None, None)), "$abc/a/b"),
666 (Some((None, None)), "$abc/a/b/xyz/def"),
667 (Some((None, None)), "$sys/abc"),
668 (Some((Some("abc"), Some("xyz"))), "$share/abc/xyz"),
669 (Some((Some("abc"), Some("xyz/ijk"))), "$share/abc/xyz/ijk"),
670 (Some((Some("abc"), Some("/xyz"))), "$share/abc//xyz"),
671 (Some((Some("abc"), Some("/#"))), "$share/abc//#"),
672 (Some((Some("abc"), Some("/a/x/+"))), "$share/abc//a/x/+"),
673 (Some((Some("abc"), Some("+"))), "$share/abc/+"),
674 (Some((Some("你好"), Some("+"))), "$share/你好/+"),
675 (Some((Some("你好"), Some("你好"))), "$share/你好/你好"),
676 (Some((Some("abc"), Some("#"))), "$share/abc/#"),
677 (Some((Some("abc"), Some("#"))), "$share/abc/#"),
678 (None, "$share/abc/"),
679 (None, "$share/abc"),
680 (None, "$share/+/y"),
681 (None, "$share/+/+"),
682 (None, "$share//y"),
683 (None, "$share//+"),
684 ] {
685 if let Some((shared_group, shared_filter)) = result {
686 let filter = TopicFilter::try_from(raw_filter.to_owned()).unwrap();
687 assert_eq!(filter.shared_group_name(), shared_group);
688 assert_eq!(filter.shared_filter(), shared_filter);
689 if let Some(group_name) = shared_group {
690 assert_eq!(
691 filter.shared_info(),
692 Some((group_name, shared_filter.unwrap()))
693 );
694 }
695 } else {
696 assert_eq!((true, 0), TopicFilter::is_invalid(raw_filter));
697 }
698 }
699 }
700}