ntex_mqtt/
topic.rs

1use std::{fmt, fmt::Write, io};
2
3use ntex_bytes::ByteString;
4
5pub(crate) fn is_valid(topic: &str) -> bool {
6    if topic.is_empty() {
7        false
8    } else {
9        enum PrevState {
10            None,
11            LevelSep,
12            SingleWildcard,
13            MultiWildcard,
14            Other,
15        }
16
17        let mut previous = PrevState::None;
18        for current in topic.bytes() {
19            previous = match (current, &previous) {
20                (_, PrevState::MultiWildcard) => return false, // `#` is not last char
21                (b'+', PrevState::None | PrevState::LevelSep) => PrevState::SingleWildcard,
22                (b'#', PrevState::None | PrevState::LevelSep) => PrevState::MultiWildcard,
23                (b'+' | b'#', _) => return false, // `+` or `#` after char other than `/`
24                (b'/', _) => PrevState::LevelSep,
25                (_, PrevState::SingleWildcard) => return false, // `+` is followed by char other than `/`
26                _ => PrevState::Other,
27            }
28        }
29        true
30    }
31}
32
33#[derive(Copy, Clone, Debug, PartialEq, Eq)]
34pub enum TopicFilterError {
35    InvalidTopic,
36    InvalidLevel,
37}
38
39#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
40pub enum TopicFilterLevel {
41    Normal(ByteString),
42    System(ByteString),
43    Blank,
44    SingleWildcard, // Single level wildcard +
45    MultiWildcard,  // Multi-level wildcard #
46}
47
48impl TopicFilterLevel {
49    fn is_valid(&self) -> bool {
50        match *self {
51            TopicFilterLevel::Normal(ref s) | TopicFilterLevel::System(ref s) => {
52                !s.contains(['+', '#'])
53            }
54            _ => true,
55        }
56    }
57}
58
59fn match_topic<T: MatchLevel, L: Iterator<Item = T>>(
60    superset: &TopicFilter,
61    subset: L,
62) -> bool {
63    let mut superset = superset.0.iter();
64
65    for (index, subset_level) in subset.enumerate() {
66        match superset.next() {
67            Some(TopicFilterLevel::SingleWildcard) => {
68                if !subset_level.match_level(&TopicFilterLevel::SingleWildcard, index) {
69                    return false;
70                }
71            }
72            Some(TopicFilterLevel::MultiWildcard) => {
73                return subset_level.match_level(&TopicFilterLevel::MultiWildcard, index);
74            }
75            Some(level) if subset_level.match_level(level, index) => continue,
76            _ => return false,
77        }
78    }
79
80    match superset.next() {
81        Some(&TopicFilterLevel::MultiWildcard) => true,
82        Some(_) => false,
83        None => true,
84    }
85}
86
87#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
88pub struct TopicFilter(Vec<TopicFilterLevel>);
89
90impl TopicFilter {
91    pub fn levels(&self) -> &[TopicFilterLevel] {
92        &self.0
93    }
94
95    fn is_valid(&self) -> bool {
96        self.0
97            .iter()
98            .position(|level| !level.is_valid())
99            .or_else(|| {
100                self.0.iter().enumerate().position(|(pos, level)| match *level {
101                    TopicFilterLevel::MultiWildcard => pos != self.0.len() - 1,
102                    TopicFilterLevel::System(_) => pos != 0,
103                    _ => false,
104                })
105            })
106            .is_none()
107    }
108
109    pub fn matches_filter(&self, topic: &TopicFilter) -> bool {
110        match_topic(self, topic.0.iter())
111    }
112
113    pub fn matches_topic<S: AsRef<str> + ?Sized>(&self, topic: &S) -> bool {
114        match_topic(self, topic.as_ref().split('/'))
115    }
116}
117
118impl TryFrom<&[TopicFilterLevel]> for TopicFilter {
119    type Error = TopicFilterError;
120
121    fn try_from(s: &[TopicFilterLevel]) -> Result<Self, Self::Error> {
122        let mut v = vec![];
123        v.extend_from_slice(s);
124
125        TopicFilter::try_from(v)
126    }
127}
128
129impl TryFrom<Vec<TopicFilterLevel>> for TopicFilter {
130    type Error = TopicFilterError;
131
132    fn try_from(v: Vec<TopicFilterLevel>) -> Result<Self, Self::Error> {
133        let tf = TopicFilter(v);
134        if tf.is_valid() { Ok(tf) } else { Err(TopicFilterError::InvalidTopic) }
135    }
136}
137
138impl From<TopicFilter> for Vec<TopicFilterLevel> {
139    fn from(t: TopicFilter) -> Self {
140        t.0
141    }
142}
143
144trait MatchLevel {
145    fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool;
146}
147
148impl MatchLevel for TopicFilterLevel {
149    fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
150        match_level_impl(self, level, index)
151    }
152}
153
154impl MatchLevel for &TopicFilterLevel {
155    fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
156        match_level_impl(self, level, index)
157    }
158}
159
160fn match_level_impl(
161    subset_level: &TopicFilterLevel,
162    superset_level: &TopicFilterLevel,
163    _index: usize,
164) -> bool {
165    match superset_level {
166        TopicFilterLevel::Normal(rhs) => {
167            matches!(subset_level, TopicFilterLevel::Normal(lhs) if lhs == rhs)
168        }
169        TopicFilterLevel::System(rhs) => {
170            matches!(subset_level, TopicFilterLevel::System(lhs) if lhs == rhs)
171        }
172        TopicFilterLevel::Blank => *subset_level == TopicFilterLevel::Blank,
173        TopicFilterLevel::SingleWildcard => *subset_level != TopicFilterLevel::MultiWildcard,
174        TopicFilterLevel::MultiWildcard => true,
175    }
176}
177
178impl<T: AsRef<str>> MatchLevel for T {
179    fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
180        match level {
181            TopicFilterLevel::Normal(lhs) => lhs == self.as_ref(),
182            TopicFilterLevel::System(lhs) => is_system(self) && lhs == self.as_ref(),
183            TopicFilterLevel::Blank => self.as_ref().is_empty(),
184            TopicFilterLevel::SingleWildcard | TopicFilterLevel::MultiWildcard => {
185                !(index == 0 && is_system(self))
186            }
187        }
188    }
189}
190
191impl TryFrom<ByteString> for TopicFilter {
192    type Error = TopicFilterError;
193
194    fn try_from(value: ByteString) -> Result<Self, Self::Error> {
195        if value.is_empty() {
196            return Err(TopicFilterError::InvalidTopic);
197        }
198
199        value
200            .split('/')
201            .enumerate()
202            .map(|(idx, level)| match level {
203                "+" => Ok(TopicFilterLevel::SingleWildcard),
204                "#" => Ok(TopicFilterLevel::MultiWildcard),
205                "" => Ok(TopicFilterLevel::Blank),
206                _ => {
207                    if level.contains(['+', '#']) {
208                        Err(TopicFilterError::InvalidLevel)
209                    } else if idx == 0 && is_system(level) {
210                        Ok(TopicFilterLevel::System(recover_bstr(&value, level)))
211                    } else {
212                        Ok(TopicFilterLevel::Normal(recover_bstr(&value, level)))
213                    }
214                }
215            })
216            .collect::<Result<Vec<_>, TopicFilterError>>()
217            .map(TopicFilter)
218            .and_then(|topic| {
219                if topic.is_valid() { Ok(topic) } else { Err(TopicFilterError::InvalidTopic) }
220            })
221    }
222}
223
224impl std::str::FromStr for TopicFilter {
225    type Err = TopicFilterError;
226
227    fn from_str(value: &str) -> Result<Self, Self::Err> {
228        let s: ByteString = value.into();
229        TopicFilter::try_from(s)
230    }
231}
232
233impl fmt::Display for TopicFilterLevel {
234    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235        match self {
236            TopicFilterLevel::Normal(s) | TopicFilterLevel::System(s) => {
237                f.write_str(s.as_str())
238            }
239            TopicFilterLevel::Blank => Ok(()),
240            TopicFilterLevel::SingleWildcard => f.write_char('+'),
241            TopicFilterLevel::MultiWildcard => f.write_char('#'),
242        }
243    }
244}
245
246impl fmt::Display for TopicFilter {
247    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248        let mut iter = self.0.iter();
249        let mut level = iter.next().unwrap();
250        loop {
251            level.fmt(f)?;
252            if let Some(l) = iter.next() {
253                level = l;
254                f.write_char('/')?;
255            } else {
256                break;
257            }
258        }
259        Ok(())
260    }
261}
262
263#[allow(dead_code)]
264pub(crate) trait WriteTopicExt: io::Write {
265    fn write_level(&mut self, level: &TopicFilterLevel) -> io::Result<usize> {
266        match *level {
267            TopicFilterLevel::Normal(ref s) | TopicFilterLevel::System(ref s) => {
268                self.write(s.as_str().as_bytes())
269            }
270            TopicFilterLevel::Blank => Ok(0),
271            TopicFilterLevel::SingleWildcard => self.write(b"+"),
272            TopicFilterLevel::MultiWildcard => self.write(b"#"),
273        }
274    }
275
276    fn write_topic(&mut self, topic: &TopicFilter) -> io::Result<usize> {
277        let mut n = 0;
278        let mut iter = topic.0.iter();
279        let mut level = iter.next().unwrap();
280        loop {
281            n += self.write_level(level)?;
282            if let Some(l) = iter.next() {
283                level = l;
284                n += self.write(b"/")?;
285            } else {
286                break;
287            }
288        }
289        Ok(n)
290    }
291}
292
293impl<W: io::Write + ?Sized> WriteTopicExt for W {}
294
295fn is_system<T: AsRef<str>>(s: T) -> bool {
296    s.as_ref().starts_with('$')
297}
298
299fn recover_bstr(superset: &ByteString, subset: &str) -> ByteString {
300    unsafe {
301        ByteString::from_bytes_unchecked(superset.as_bytes().slice_ref(subset.as_bytes()))
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use test_case::test_case;
309
310    #[test_case("abc" => true; "pass_norm1")]
311    #[test_case("a/b" => true; "pass_norm2")]
312    #[test_case("/" => true; "pass_norm3")]
313    #[test_case("//" => true; "pass_norm4")]
314    #[test_case("a/b/+" => true; "pass_plus1")]
315    #[test_case("+/a" => true; "pass_plus2")]
316    #[test_case("+" => true; "pass_plus3")]
317    #[test_case("+//+" => true; "pass_plus4")]
318    #[test_case("a/b/#" => true; "pass_hash1")]
319    #[test_case("#" => true; "pass_hash2")]
320    #[test_case("/#" => true; "pass_hash3")]
321    #[test_case("++" => false; "fail_plus1")]
322    #[test_case("b+/" => false; "fail_plus2")]
323    #[test_case("a/+b" => false; "fail_plus3")]
324    #[test_case("+#" => false; "fail_hash1")]
325    #[test_case("a#" => false; "fail_hash2")]
326    #[test_case("a/#/" => false; "fail_hash3")]
327    #[test_case("a/#b" => false; "fail_hash4")]
328    #[test_case("a/##" => false; "fail_hash5")]
329    #[test_case("a/#+" => false; "fail_hash6")]
330    fn check_is_valid(topic_filter: &'static str) -> bool {
331        is_valid(topic_filter)
332    }
333
334    fn lvl_normal<T: AsRef<str>>(s: T) -> TopicFilterLevel {
335        if s.as_ref().contains(['+', '#']) {
336            panic!("invalid normal level `{}` contains +|#", s.as_ref());
337        }
338
339        TopicFilterLevel::Normal(s.as_ref().into())
340    }
341
342    fn lvl_sys<T: AsRef<str>>(s: T) -> TopicFilterLevel {
343        if s.as_ref().contains(['+', '#']) {
344            panic!("invalid normal level `{}` contains +|#", s.as_ref());
345        }
346
347        if !s.as_ref().starts_with('$') {
348            panic!("invalid metadata level `{}` not starts with $", s.as_ref())
349        }
350
351        TopicFilterLevel::System(s.as_ref().into())
352    }
353
354    fn topic(topic: &'static str) -> TopicFilter {
355        TopicFilter::try_from(ByteString::from_static(topic)).unwrap()
356    }
357
358    #[test_case("level" => Ok(vec![lvl_normal("level")]) ; "1")]
359    #[test_case("level/+" => Ok(vec![lvl_normal("level"), TopicFilterLevel::SingleWildcard]) ; "2")]
360    #[test_case("a//#" => Ok(vec![lvl_normal("a"), TopicFilterLevel::Blank, TopicFilterLevel::MultiWildcard]) ; "3")]
361    #[test_case("$a///#" => Ok(vec![lvl_sys("$a"), TopicFilterLevel::Blank, TopicFilterLevel::Blank, TopicFilterLevel::MultiWildcard]) ; "4")]
362    #[test_case("$a/#/" => Err(TopicFilterError::InvalidTopic) ; "5")]
363    #[test_case("a+b" => Err(TopicFilterError::InvalidLevel) ; "6")]
364    #[test_case("a/+b" => Err(TopicFilterError::InvalidLevel) ; "7")]
365    #[test_case("$a/$b/" => Ok(vec![lvl_sys("$a"), lvl_normal("$b"), TopicFilterLevel::Blank]) ; "8")]
366    #[test_case("#/a" => Err(TopicFilterError::InvalidTopic) ; "10")]
367    #[test_case("" => Err(TopicFilterError::InvalidTopic) ; "11")]
368    #[test_case("/finance" => Ok(vec![TopicFilterLevel::Blank, lvl_normal("finance")]) ; "12")]
369    #[test_case("finance/" => Ok(vec![lvl_normal("finance"), TopicFilterLevel::Blank]) ; "13")]
370    fn parsing(input: &str) -> Result<Vec<TopicFilterLevel>, TopicFilterError> {
371        TopicFilter::try_from(ByteString::from(input)).map(|t| t.levels().to_vec())
372    }
373
374    #[test_case(vec![lvl_normal("sport"), lvl_normal("tennis"), lvl_normal("player1")] => true; "1")]
375    #[test_case(vec![lvl_normal("sport"), lvl_normal("tennis"), TopicFilterLevel::MultiWildcard] => true; "2")]
376    #[test_case(vec![lvl_sys("$SYS"), lvl_normal("tennis"), lvl_normal("player1")] => true; "3")]
377    #[test_case(vec![lvl_normal("sport"), TopicFilterLevel::SingleWildcard, lvl_normal("player1")] => true; "4")]
378    #[test_case(vec![lvl_normal("sport"), TopicFilterLevel::MultiWildcard, lvl_normal("player1")] => false; "5")]
379    #[test_case(vec![lvl_normal("sport"), lvl_sys("$SYS"), lvl_normal("player1")] => false; "6")]
380    fn topic_is_valid(levels: Vec<TopicFilterLevel>) -> bool {
381        TopicFilter::try_from(levels).is_ok()
382    }
383
384    #[test]
385    fn test_multi_wildcard_topic() {
386        assert!(topic("sport/tennis/#").matches_filter(&TopicFilter(vec![
387            lvl_normal("sport"),
388            lvl_normal("tennis"),
389            TopicFilterLevel::MultiWildcard
390        ])));
391
392        assert!(topic("sport/tennis/#").matches_topic("sport/tennis"));
393
394        assert!(topic("#").matches_filter(&TopicFilter(vec![TopicFilterLevel::MultiWildcard])));
395    }
396
397    #[test]
398    fn test_single_wildcard_topic() {
399        assert!(topic("+").matches_filter(
400            &TopicFilter::try_from(vec![TopicFilterLevel::SingleWildcard]).unwrap()
401        ));
402
403        assert!(topic("+/tennis/#").matches_filter(&TopicFilter(vec![
404            TopicFilterLevel::SingleWildcard,
405            lvl_normal("tennis"),
406            TopicFilterLevel::MultiWildcard
407        ])));
408
409        assert!(topic("sport/+/player1").matches_filter(&TopicFilter(vec![
410            lvl_normal("sport"),
411            TopicFilterLevel::SingleWildcard,
412            lvl_normal("player1")
413        ])));
414    }
415
416    #[test]
417    fn test_write_topic() {
418        let mut v = vec![];
419        let t = TopicFilter(vec![
420            TopicFilterLevel::SingleWildcard,
421            lvl_normal("tennis"),
422            TopicFilterLevel::MultiWildcard,
423        ]);
424
425        assert_eq!(v.write_topic(&t).unwrap(), 10);
426        assert_eq!(v, b"+/tennis/#");
427
428        assert_eq!(format!("{}", t), "+/tennis/#");
429        assert_eq!(t.to_string(), "+/tennis/#");
430    }
431
432    #[test_case("test", "test" => true)]
433    #[test_case("$SYS", "$SYS" => true)]
434    #[test_case("sport/tennis/player1/#", "sport/tennis/player1" => true)]
435    #[test_case("sport/tennis/player1/#", "sport/tennis/player1/score" => true)]
436    #[test_case("sport/tennis/player1/#", "sport/tennis/player1/score/wimbledon" => true)]
437    #[test_case("sport/#", "sport" => true)]
438    #[test_case("sport/tennis/+", "sport/tennis/player1" => true)]
439    #[test_case("sport/tennis/+", "sport/tennis/player2" => true)]
440    #[test_case("sport/tennis/+", "sport/tennis/player1/ranking" => false)]
441    #[test_case("sport/+", "sport" => false; "single1")]
442    #[test_case("sport/+", "sport/" => true; "single2")]
443    #[test_case("+/+", "/finance" => true; "single3")]
444    #[test_case("/+", "/finance" => true; "single4")]
445    #[test_case("+", "/finance" => false; "single5")]
446    #[test_case("#", "$SYS" => false; "sys1")]
447    #[test_case("+/monitor/Clients", "$SYS/monitor/Clients" => false; "sys2")]
448    #[test_case("$SYS/#", "$SYS/" => true; "sys3")]
449    #[test_case("$SYS/monitor/+", "$SYS/monitor/Clients" => true; "sys4")]
450    #[test_case("#", "/$SYS/monitor/Clients" => true; "sys5")]
451    #[test_case("+", "$SYS" => false; "sys6")]
452    #[test_case("+/#", "$SYS" => false; "sys7")]
453    fn matches_topic(filter: &'static str, topic_str: &'static str) -> bool {
454        topic(filter).matches_topic(topic_str)
455    }
456
457    #[test_case("a/b", "a/b" => true; "1")]
458    #[test_case("a/b", "a/+" => false; "2")]
459    #[test_case("a/b", "a/#" => false; "3")]
460    #[test_case("a/+", "a/#" => false; "4")]
461    #[test_case("a/+", "a/b" => true; "5")]
462    #[test_case("+/+", "/" => true; "6")]
463    #[test_case("+/+", "#" => false; "7")]
464    #[test_case("+", "#" => false; "8")]
465    #[test_case("#", "+" => true; "9")]
466    #[test_case("#", "#" => true; "10")]
467    #[test_case("a/#", "a/+/+" => true; "11")]
468    #[test_case("a/+/normal/+", "a/$not_sys/normal/+" => true; "12")]
469    #[test_case("a/+/#", "a/b" => true; "13")]
470    fn matches_filter(superset_filter: &'static str, subset_filter: &'static str) -> bool {
471        topic(superset_filter).matches_filter(&topic(subset_filter))
472    }
473}