mqtt/
topic_filter.rs

1//! Topic filter
2
3use std::io::{self, Read, Write};
4use std::ops::Deref;
5
6use crate::topic_name::TopicNameRef;
7use crate::{Decodable, Encodable};
8
9#[inline]
10fn is_invalid_topic_filter(topic: &str) -> bool {
11    if topic.is_empty() || topic.as_bytes().len() > 65535 {
12        return true;
13    }
14
15    let mut found_hash = false;
16    for member in topic.split('/') {
17        if found_hash {
18            return true;
19        }
20
21        match member {
22            "#" => found_hash = true,
23            "+" => {}
24            _ => {
25                if member.contains(['#', '+']) {
26                    return true;
27                }
28            }
29        }
30    }
31
32    false
33}
34
35/// Topic filter
36///
37/// <http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106>
38///
39/// ```rust
40/// use mqtt::{TopicFilter, TopicNameRef};
41///
42/// let topic_filter = TopicFilter::new("sport/+/player1").unwrap();
43/// let matcher = topic_filter.get_matcher();
44/// assert!(matcher.is_match(TopicNameRef::new("sport/abc/player1").unwrap()));
45/// ```
46#[derive(Debug, Eq, PartialEq, Clone, Hash, Ord, PartialOrd)]
47pub struct TopicFilter(String);
48
49impl TopicFilter {
50    /// Creates a new topic filter from string
51    /// Return error if it is not a valid topic filter
52    pub fn new<S: Into<String>>(topic: S) -> Result<TopicFilter, TopicFilterError> {
53        let topic = topic.into();
54        if is_invalid_topic_filter(&topic) {
55            Err(TopicFilterError(topic))
56        } else {
57            Ok(TopicFilter(topic))
58        }
59    }
60
61    /// Creates a new topic filter from string without validation
62    ///
63    /// # Safety
64    ///
65    /// Topic filters' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106).
66    /// Creating a filter from raw string may cause errors
67    pub unsafe fn new_unchecked<S: Into<String>>(topic: S) -> TopicFilter {
68        TopicFilter(topic.into())
69    }
70}
71
72impl From<TopicFilter> for String {
73    fn from(topic: TopicFilter) -> String {
74        topic.0
75    }
76}
77
78impl Encodable for TopicFilter {
79    fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
80        (&self.0[..]).encode(writer)
81    }
82
83    fn encoded_length(&self) -> u32 {
84        (&self.0[..]).encoded_length()
85    }
86}
87
88impl Decodable for TopicFilter {
89    type Error = TopicFilterDecodeError;
90    type Cond = ();
91
92    fn decode_with<R: Read>(reader: &mut R, _rest: ()) -> Result<TopicFilter, TopicFilterDecodeError> {
93        let topic_filter = String::decode(reader)?;
94        Ok(TopicFilter::new(topic_filter)?)
95    }
96}
97
98impl Deref for TopicFilter {
99    type Target = TopicFilterRef;
100
101    fn deref(&self) -> &TopicFilterRef {
102        unsafe { TopicFilterRef::new_unchecked(&self.0) }
103    }
104}
105
106/// Reference to a `TopicFilter`
107#[derive(Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
108#[repr(transparent)]
109pub struct TopicFilterRef(str);
110
111impl TopicFilterRef {
112    /// Creates a new topic filter from string
113    /// Return error if it is not a valid topic filter
114    pub fn new<S: AsRef<str> + ?Sized>(topic: &S) -> Result<&TopicFilterRef, TopicFilterError> {
115        let topic = topic.as_ref();
116        if is_invalid_topic_filter(topic) {
117            Err(TopicFilterError(topic.to_owned()))
118        } else {
119            Ok(unsafe { &*(topic as *const str as *const TopicFilterRef) })
120        }
121    }
122
123    /// Creates a new topic filter from string without validation
124    ///
125    /// # Safety
126    ///
127    /// Topic filters' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106).
128    /// Creating a filter from raw string may cause errors
129    pub unsafe fn new_unchecked<S: AsRef<str> + ?Sized>(topic: &S) -> &TopicFilterRef {
130        let topic = topic.as_ref();
131        &*(topic as *const str as *const TopicFilterRef)
132    }
133
134    /// Get a matcher
135    pub fn get_matcher(&self) -> TopicFilterMatcher<'_> {
136        TopicFilterMatcher::new(&self.0)
137    }
138}
139
140impl Deref for TopicFilterRef {
141    type Target = str;
142
143    fn deref(&self) -> &str {
144        &self.0
145    }
146}
147
148#[derive(Debug, thiserror::Error)]
149#[error("invalid topic filter ({0})")]
150pub struct TopicFilterError(pub String);
151
152/// Errors while parsing topic filters
153#[derive(Debug, thiserror::Error)]
154#[error(transparent)]
155pub enum TopicFilterDecodeError {
156    IoError(#[from] io::Error),
157    InvalidTopicFilter(#[from] TopicFilterError),
158}
159
160/// Matcher for matching topic names with this filter
161#[derive(Debug, Copy, Clone)]
162pub struct TopicFilterMatcher<'a> {
163    topic_filter: &'a str,
164}
165
166impl<'a> TopicFilterMatcher<'a> {
167    fn new(filter: &'a str) -> TopicFilterMatcher<'a> {
168        TopicFilterMatcher { topic_filter: filter }
169    }
170
171    /// Check if this filter can match the `topic_name`
172    pub fn is_match(&self, topic_name: &TopicNameRef) -> bool {
173        let mut tn_itr = topic_name.split('/');
174        let mut ft_itr = self.topic_filter.split('/');
175
176        // The Server MUST NOT match Topic Filters starting with a wildcard character (# or +)
177        // with Topic Names beginning with a $ character [MQTT-4.7.2-1].
178
179        let first_ft = ft_itr.next().unwrap();
180        let first_tn = tn_itr.next().unwrap();
181
182        if first_tn.starts_with('$') {
183            if first_tn != first_ft {
184                return false;
185            }
186        } else {
187            match first_ft {
188                // Matches the whole topic
189                "#" => return true,
190                "+" => {}
191                _ => {
192                    if first_tn != first_ft {
193                        return false;
194                    }
195                }
196            }
197        }
198
199        loop {
200            match (ft_itr.next(), tn_itr.next()) {
201                (Some(ft), Some(tn)) => match ft {
202                    "#" => break,
203                    "+" => {}
204                    _ => {
205                        if ft != tn {
206                            return false;
207                        }
208                    }
209                },
210                (Some(ft), None) => {
211                    if ft != "#" {
212                        return false;
213                    } else {
214                        break;
215                    }
216                }
217                (None, Some(..)) => return false,
218                (None, None) => break,
219            }
220        }
221
222        true
223    }
224}
225
226#[cfg(test)]
227mod test {
228    use super::*;
229
230    #[test]
231    fn topic_filter_validate() {
232        let topic = "#".to_owned();
233        TopicFilter::new(topic).unwrap();
234
235        let topic = "sport/tennis/player1".to_owned();
236        TopicFilter::new(topic).unwrap();
237
238        let topic = "sport/tennis/player1/ranking".to_owned();
239        TopicFilter::new(topic).unwrap();
240
241        let topic = "sport/tennis/player1/#".to_owned();
242        TopicFilter::new(topic).unwrap();
243
244        let topic = "#".to_owned();
245        TopicFilter::new(topic).unwrap();
246
247        let topic = "sport/tennis/#".to_owned();
248        TopicFilter::new(topic).unwrap();
249
250        let topic = "sport/tennis#".to_owned();
251        assert!(TopicFilter::new(topic).is_err());
252
253        let topic = "sport/tennis/#/ranking".to_owned();
254        assert!(TopicFilter::new(topic).is_err());
255
256        let topic = "+".to_owned();
257        TopicFilter::new(topic).unwrap();
258
259        let topic = "+/tennis/#".to_owned();
260        TopicFilter::new(topic).unwrap();
261
262        let topic = "sport+".to_owned();
263        assert!(TopicFilter::new(topic).is_err());
264
265        let topic = "sport/+/player1".to_owned();
266        TopicFilter::new(topic).unwrap();
267
268        let topic = "+/+".to_owned();
269        TopicFilter::new(topic).unwrap();
270
271        let topic = "$SYS/#".to_owned();
272        TopicFilter::new(topic).unwrap();
273
274        let topic = "$SYS".to_owned();
275        TopicFilter::new(topic).unwrap();
276    }
277
278    #[test]
279    fn topic_filter_matcher() {
280        let filter = TopicFilter::new("sport/#").unwrap();
281        let matcher = filter.get_matcher();
282        assert!(matcher.is_match(TopicNameRef::new("sport").unwrap()));
283
284        let filter = TopicFilter::new("#").unwrap();
285        let matcher = filter.get_matcher();
286        assert!(matcher.is_match(TopicNameRef::new("sport").unwrap()));
287        assert!(matcher.is_match(TopicNameRef::new("/").unwrap()));
288        assert!(matcher.is_match(TopicNameRef::new("abc/def").unwrap()));
289        assert!(!matcher.is_match(TopicNameRef::new("$SYS").unwrap()));
290        assert!(!matcher.is_match(TopicNameRef::new("$SYS/abc").unwrap()));
291
292        let filter = TopicFilter::new("+/monitor/Clients").unwrap();
293        let matcher = filter.get_matcher();
294        assert!(!matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap()));
295
296        let filter = TopicFilter::new("$SYS/#").unwrap();
297        let matcher = filter.get_matcher();
298        assert!(matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap()));
299        assert!(matcher.is_match(TopicNameRef::new("$SYS").unwrap()));
300
301        let filter = TopicFilter::new("$SYS/monitor/+").unwrap();
302        let matcher = filter.get_matcher();
303        assert!(matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap()));
304    }
305}