1use std::io::{self, Read, Write};
4use std::ops::Deref;
5
6use crate::topic_name::TopicNameRef;
7use crate::{Decodable, Encodable};
8
9#[inline]
10fn is_invalid_topic_filter(topic: &str) -> bool {
11 if topic.is_empty() || topic.as_bytes().len() > 65535 {
12 return true;
13 }
14
15 let mut found_hash = false;
16 for member in topic.split('/') {
17 if found_hash {
18 return true;
19 }
20
21 match member {
22 "#" => found_hash = true,
23 "+" => {}
24 _ => {
25 if member.contains(['#', '+']) {
26 return true;
27 }
28 }
29 }
30 }
31
32 false
33}
34
35#[derive(Debug, Eq, PartialEq, Clone, Hash, Ord, PartialOrd)]
47pub struct TopicFilter(String);
48
49impl TopicFilter {
50 pub fn new<S: Into<String>>(topic: S) -> Result<TopicFilter, TopicFilterError> {
53 let topic = topic.into();
54 if is_invalid_topic_filter(&topic) {
55 Err(TopicFilterError(topic))
56 } else {
57 Ok(TopicFilter(topic))
58 }
59 }
60
61 pub unsafe fn new_unchecked<S: Into<String>>(topic: S) -> TopicFilter {
68 TopicFilter(topic.into())
69 }
70}
71
72impl From<TopicFilter> for String {
73 fn from(topic: TopicFilter) -> String {
74 topic.0
75 }
76}
77
78impl Encodable for TopicFilter {
79 fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
80 (&self.0[..]).encode(writer)
81 }
82
83 fn encoded_length(&self) -> u32 {
84 (&self.0[..]).encoded_length()
85 }
86}
87
88impl Decodable for TopicFilter {
89 type Error = TopicFilterDecodeError;
90 type Cond = ();
91
92 fn decode_with<R: Read>(reader: &mut R, _rest: ()) -> Result<TopicFilter, TopicFilterDecodeError> {
93 let topic_filter = String::decode(reader)?;
94 Ok(TopicFilter::new(topic_filter)?)
95 }
96}
97
98impl Deref for TopicFilter {
99 type Target = TopicFilterRef;
100
101 fn deref(&self) -> &TopicFilterRef {
102 unsafe { TopicFilterRef::new_unchecked(&self.0) }
103 }
104}
105
106#[derive(Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
108#[repr(transparent)]
109pub struct TopicFilterRef(str);
110
111impl TopicFilterRef {
112 pub fn new<S: AsRef<str> + ?Sized>(topic: &S) -> Result<&TopicFilterRef, TopicFilterError> {
115 let topic = topic.as_ref();
116 if is_invalid_topic_filter(topic) {
117 Err(TopicFilterError(topic.to_owned()))
118 } else {
119 Ok(unsafe { &*(topic as *const str as *const TopicFilterRef) })
120 }
121 }
122
123 pub unsafe fn new_unchecked<S: AsRef<str> + ?Sized>(topic: &S) -> &TopicFilterRef {
130 let topic = topic.as_ref();
131 &*(topic as *const str as *const TopicFilterRef)
132 }
133
134 pub fn get_matcher(&self) -> TopicFilterMatcher<'_> {
136 TopicFilterMatcher::new(&self.0)
137 }
138}
139
140impl Deref for TopicFilterRef {
141 type Target = str;
142
143 fn deref(&self) -> &str {
144 &self.0
145 }
146}
147
148#[derive(Debug, thiserror::Error)]
149#[error("invalid topic filter ({0})")]
150pub struct TopicFilterError(pub String);
151
152#[derive(Debug, thiserror::Error)]
154#[error(transparent)]
155pub enum TopicFilterDecodeError {
156 IoError(#[from] io::Error),
157 InvalidTopicFilter(#[from] TopicFilterError),
158}
159
160#[derive(Debug, Copy, Clone)]
162pub struct TopicFilterMatcher<'a> {
163 topic_filter: &'a str,
164}
165
166impl<'a> TopicFilterMatcher<'a> {
167 fn new(filter: &'a str) -> TopicFilterMatcher<'a> {
168 TopicFilterMatcher { topic_filter: filter }
169 }
170
171 pub fn is_match(&self, topic_name: &TopicNameRef) -> bool {
173 let mut tn_itr = topic_name.split('/');
174 let mut ft_itr = self.topic_filter.split('/');
175
176 let first_ft = ft_itr.next().unwrap();
180 let first_tn = tn_itr.next().unwrap();
181
182 if first_tn.starts_with('$') {
183 if first_tn != first_ft {
184 return false;
185 }
186 } else {
187 match first_ft {
188 "#" => return true,
190 "+" => {}
191 _ => {
192 if first_tn != first_ft {
193 return false;
194 }
195 }
196 }
197 }
198
199 loop {
200 match (ft_itr.next(), tn_itr.next()) {
201 (Some(ft), Some(tn)) => match ft {
202 "#" => break,
203 "+" => {}
204 _ => {
205 if ft != tn {
206 return false;
207 }
208 }
209 },
210 (Some(ft), None) => {
211 if ft != "#" {
212 return false;
213 } else {
214 break;
215 }
216 }
217 (None, Some(..)) => return false,
218 (None, None) => break,
219 }
220 }
221
222 true
223 }
224}
225
226#[cfg(test)]
227mod test {
228 use super::*;
229
230 #[test]
231 fn topic_filter_validate() {
232 let topic = "#".to_owned();
233 TopicFilter::new(topic).unwrap();
234
235 let topic = "sport/tennis/player1".to_owned();
236 TopicFilter::new(topic).unwrap();
237
238 let topic = "sport/tennis/player1/ranking".to_owned();
239 TopicFilter::new(topic).unwrap();
240
241 let topic = "sport/tennis/player1/#".to_owned();
242 TopicFilter::new(topic).unwrap();
243
244 let topic = "#".to_owned();
245 TopicFilter::new(topic).unwrap();
246
247 let topic = "sport/tennis/#".to_owned();
248 TopicFilter::new(topic).unwrap();
249
250 let topic = "sport/tennis#".to_owned();
251 assert!(TopicFilter::new(topic).is_err());
252
253 let topic = "sport/tennis/#/ranking".to_owned();
254 assert!(TopicFilter::new(topic).is_err());
255
256 let topic = "+".to_owned();
257 TopicFilter::new(topic).unwrap();
258
259 let topic = "+/tennis/#".to_owned();
260 TopicFilter::new(topic).unwrap();
261
262 let topic = "sport+".to_owned();
263 assert!(TopicFilter::new(topic).is_err());
264
265 let topic = "sport/+/player1".to_owned();
266 TopicFilter::new(topic).unwrap();
267
268 let topic = "+/+".to_owned();
269 TopicFilter::new(topic).unwrap();
270
271 let topic = "$SYS/#".to_owned();
272 TopicFilter::new(topic).unwrap();
273
274 let topic = "$SYS".to_owned();
275 TopicFilter::new(topic).unwrap();
276 }
277
278 #[test]
279 fn topic_filter_matcher() {
280 let filter = TopicFilter::new("sport/#").unwrap();
281 let matcher = filter.get_matcher();
282 assert!(matcher.is_match(TopicNameRef::new("sport").unwrap()));
283
284 let filter = TopicFilter::new("#").unwrap();
285 let matcher = filter.get_matcher();
286 assert!(matcher.is_match(TopicNameRef::new("sport").unwrap()));
287 assert!(matcher.is_match(TopicNameRef::new("/").unwrap()));
288 assert!(matcher.is_match(TopicNameRef::new("abc/def").unwrap()));
289 assert!(!matcher.is_match(TopicNameRef::new("$SYS").unwrap()));
290 assert!(!matcher.is_match(TopicNameRef::new("$SYS/abc").unwrap()));
291
292 let filter = TopicFilter::new("+/monitor/Clients").unwrap();
293 let matcher = filter.get_matcher();
294 assert!(!matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap()));
295
296 let filter = TopicFilter::new("$SYS/#").unwrap();
297 let matcher = filter.get_matcher();
298 assert!(matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap()));
299 assert!(matcher.is_match(TopicNameRef::new("$SYS").unwrap()));
300
301 let filter = TopicFilter::new("$SYS/monitor/+").unwrap();
302 let matcher = filter.get_matcher();
303 assert!(matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap()));
304 }
305}