1use std::{fmt, fmt::Write, io};
2
3use ntex_bytes::ByteString;
4
5pub(crate) fn is_valid(topic: &str) -> bool {
6 if topic.is_empty() {
7 false
8 } else {
9 enum PrevState {
10 None,
11 LevelSep,
12 SingleWildcard,
13 MultiWildcard,
14 Other,
15 }
16
17 let mut previous = PrevState::None;
18 for current in topic.bytes() {
19 previous = match (current, &previous) {
20 (_, PrevState::MultiWildcard) => return false, (b'+', PrevState::None | PrevState::LevelSep) => PrevState::SingleWildcard,
22 (b'#', PrevState::None | PrevState::LevelSep) => PrevState::MultiWildcard,
23 (b'+' | b'#', _) => return false, (b'/', _) => PrevState::LevelSep,
25 (_, PrevState::SingleWildcard) => return false, _ => PrevState::Other,
27 }
28 }
29 true
30 }
31}
32
33#[derive(Copy, Clone, Debug, PartialEq, Eq)]
34pub enum TopicFilterError {
35 InvalidTopic,
36 InvalidLevel,
37}
38
39#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
40pub enum TopicFilterLevel {
41 Normal(ByteString),
42 System(ByteString),
43 Blank,
44 SingleWildcard, MultiWildcard, }
47
48impl TopicFilterLevel {
49 fn is_valid(&self) -> bool {
50 match *self {
51 TopicFilterLevel::Normal(ref s) | TopicFilterLevel::System(ref s) => {
52 !s.contains(['+', '#'])
53 }
54 _ => true,
55 }
56 }
57}
58
59fn match_topic<T: MatchLevel, L: Iterator<Item = T>>(
60 superset: &TopicFilter,
61 subset: L,
62) -> bool {
63 let mut superset = superset.0.iter();
64
65 for (index, subset_level) in subset.enumerate() {
66 match superset.next() {
67 Some(TopicFilterLevel::SingleWildcard) => {
68 if !subset_level.match_level(&TopicFilterLevel::SingleWildcard, index) {
69 return false;
70 }
71 }
72 Some(TopicFilterLevel::MultiWildcard) => {
73 return subset_level.match_level(&TopicFilterLevel::MultiWildcard, index);
74 }
75 Some(level) if subset_level.match_level(level, index) => continue,
76 _ => return false,
77 }
78 }
79
80 match superset.next() {
81 Some(&TopicFilterLevel::MultiWildcard) => true,
82 Some(_) => false,
83 None => true,
84 }
85}
86
87#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
88pub struct TopicFilter(Vec<TopicFilterLevel>);
89
90impl TopicFilter {
91 pub fn levels(&self) -> &[TopicFilterLevel] {
92 &self.0
93 }
94
95 fn is_valid(&self) -> bool {
96 self.0
97 .iter()
98 .position(|level| !level.is_valid())
99 .or_else(|| {
100 self.0.iter().enumerate().position(|(pos, level)| match *level {
101 TopicFilterLevel::MultiWildcard => pos != self.0.len() - 1,
102 TopicFilterLevel::System(_) => pos != 0,
103 _ => false,
104 })
105 })
106 .is_none()
107 }
108
109 pub fn matches_filter(&self, topic: &TopicFilter) -> bool {
110 match_topic(self, topic.0.iter())
111 }
112
113 pub fn matches_topic<S: AsRef<str> + ?Sized>(&self, topic: &S) -> bool {
114 match_topic(self, topic.as_ref().split('/'))
115 }
116}
117
118impl TryFrom<&[TopicFilterLevel]> for TopicFilter {
119 type Error = TopicFilterError;
120
121 fn try_from(s: &[TopicFilterLevel]) -> Result<Self, Self::Error> {
122 let mut v = vec![];
123 v.extend_from_slice(s);
124
125 TopicFilter::try_from(v)
126 }
127}
128
129impl TryFrom<Vec<TopicFilterLevel>> for TopicFilter {
130 type Error = TopicFilterError;
131
132 fn try_from(v: Vec<TopicFilterLevel>) -> Result<Self, Self::Error> {
133 let tf = TopicFilter(v);
134 if tf.is_valid() { Ok(tf) } else { Err(TopicFilterError::InvalidTopic) }
135 }
136}
137
138impl From<TopicFilter> for Vec<TopicFilterLevel> {
139 fn from(t: TopicFilter) -> Self {
140 t.0
141 }
142}
143
144trait MatchLevel {
145 fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool;
146}
147
148impl MatchLevel for TopicFilterLevel {
149 fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
150 match_level_impl(self, level, index)
151 }
152}
153
154impl MatchLevel for &TopicFilterLevel {
155 fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
156 match_level_impl(self, level, index)
157 }
158}
159
160fn match_level_impl(
161 subset_level: &TopicFilterLevel,
162 superset_level: &TopicFilterLevel,
163 _index: usize,
164) -> bool {
165 match superset_level {
166 TopicFilterLevel::Normal(rhs) => {
167 matches!(subset_level, TopicFilterLevel::Normal(lhs) if lhs == rhs)
168 }
169 TopicFilterLevel::System(rhs) => {
170 matches!(subset_level, TopicFilterLevel::System(lhs) if lhs == rhs)
171 }
172 TopicFilterLevel::Blank => *subset_level == TopicFilterLevel::Blank,
173 TopicFilterLevel::SingleWildcard => *subset_level != TopicFilterLevel::MultiWildcard,
174 TopicFilterLevel::MultiWildcard => true,
175 }
176}
177
178impl<T: AsRef<str>> MatchLevel for T {
179 fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
180 match level {
181 TopicFilterLevel::Normal(lhs) => lhs == self.as_ref(),
182 TopicFilterLevel::System(lhs) => is_system(self) && lhs == self.as_ref(),
183 TopicFilterLevel::Blank => self.as_ref().is_empty(),
184 TopicFilterLevel::SingleWildcard | TopicFilterLevel::MultiWildcard => {
185 !(index == 0 && is_system(self))
186 }
187 }
188 }
189}
190
191impl TryFrom<ByteString> for TopicFilter {
192 type Error = TopicFilterError;
193
194 fn try_from(value: ByteString) -> Result<Self, Self::Error> {
195 if value.is_empty() {
196 return Err(TopicFilterError::InvalidTopic);
197 }
198
199 value
200 .split('/')
201 .enumerate()
202 .map(|(idx, level)| match level {
203 "+" => Ok(TopicFilterLevel::SingleWildcard),
204 "#" => Ok(TopicFilterLevel::MultiWildcard),
205 "" => Ok(TopicFilterLevel::Blank),
206 _ => {
207 if level.contains(['+', '#']) {
208 Err(TopicFilterError::InvalidLevel)
209 } else if idx == 0 && is_system(level) {
210 Ok(TopicFilterLevel::System(recover_bstr(&value, level)))
211 } else {
212 Ok(TopicFilterLevel::Normal(recover_bstr(&value, level)))
213 }
214 }
215 })
216 .collect::<Result<Vec<_>, TopicFilterError>>()
217 .map(TopicFilter)
218 .and_then(|topic| {
219 if topic.is_valid() { Ok(topic) } else { Err(TopicFilterError::InvalidTopic) }
220 })
221 }
222}
223
224impl std::str::FromStr for TopicFilter {
225 type Err = TopicFilterError;
226
227 fn from_str(value: &str) -> Result<Self, Self::Err> {
228 let s: ByteString = value.into();
229 TopicFilter::try_from(s)
230 }
231}
232
233impl fmt::Display for TopicFilterLevel {
234 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235 match self {
236 TopicFilterLevel::Normal(s) | TopicFilterLevel::System(s) => {
237 f.write_str(s.as_str())
238 }
239 TopicFilterLevel::Blank => Ok(()),
240 TopicFilterLevel::SingleWildcard => f.write_char('+'),
241 TopicFilterLevel::MultiWildcard => f.write_char('#'),
242 }
243 }
244}
245
246impl fmt::Display for TopicFilter {
247 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248 let mut iter = self.0.iter();
249 let mut level = iter.next().unwrap();
250 loop {
251 level.fmt(f)?;
252 if let Some(l) = iter.next() {
253 level = l;
254 f.write_char('/')?;
255 } else {
256 break;
257 }
258 }
259 Ok(())
260 }
261}
262
263#[allow(dead_code)]
264pub(crate) trait WriteTopicExt: io::Write {
265 fn write_level(&mut self, level: &TopicFilterLevel) -> io::Result<usize> {
266 match *level {
267 TopicFilterLevel::Normal(ref s) | TopicFilterLevel::System(ref s) => {
268 self.write(s.as_str().as_bytes())
269 }
270 TopicFilterLevel::Blank => Ok(0),
271 TopicFilterLevel::SingleWildcard => self.write(b"+"),
272 TopicFilterLevel::MultiWildcard => self.write(b"#"),
273 }
274 }
275
276 fn write_topic(&mut self, topic: &TopicFilter) -> io::Result<usize> {
277 let mut n = 0;
278 let mut iter = topic.0.iter();
279 let mut level = iter.next().unwrap();
280 loop {
281 n += self.write_level(level)?;
282 if let Some(l) = iter.next() {
283 level = l;
284 n += self.write(b"/")?;
285 } else {
286 break;
287 }
288 }
289 Ok(n)
290 }
291}
292
293impl<W: io::Write + ?Sized> WriteTopicExt for W {}
294
295fn is_system<T: AsRef<str>>(s: T) -> bool {
296 s.as_ref().starts_with('$')
297}
298
299fn recover_bstr(superset: &ByteString, subset: &str) -> ByteString {
300 unsafe {
301 ByteString::from_bytes_unchecked(superset.as_bytes().slice_ref(subset.as_bytes()))
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use test_case::test_case;
309
310 #[test_case("abc" => true; "pass_norm1")]
311 #[test_case("a/b" => true; "pass_norm2")]
312 #[test_case("/" => true; "pass_norm3")]
313 #[test_case("//" => true; "pass_norm4")]
314 #[test_case("a/b/+" => true; "pass_plus1")]
315 #[test_case("+/a" => true; "pass_plus2")]
316 #[test_case("+" => true; "pass_plus3")]
317 #[test_case("+//+" => true; "pass_plus4")]
318 #[test_case("a/b/#" => true; "pass_hash1")]
319 #[test_case("#" => true; "pass_hash2")]
320 #[test_case("/#" => true; "pass_hash3")]
321 #[test_case("++" => false; "fail_plus1")]
322 #[test_case("b+/" => false; "fail_plus2")]
323 #[test_case("a/+b" => false; "fail_plus3")]
324 #[test_case("+#" => false; "fail_hash1")]
325 #[test_case("a#" => false; "fail_hash2")]
326 #[test_case("a/#/" => false; "fail_hash3")]
327 #[test_case("a/#b" => false; "fail_hash4")]
328 #[test_case("a/##" => false; "fail_hash5")]
329 #[test_case("a/#+" => false; "fail_hash6")]
330 fn check_is_valid(topic_filter: &'static str) -> bool {
331 is_valid(topic_filter)
332 }
333
334 fn lvl_normal<T: AsRef<str>>(s: T) -> TopicFilterLevel {
335 if s.as_ref().contains(['+', '#']) {
336 panic!("invalid normal level `{}` contains +|#", s.as_ref());
337 }
338
339 TopicFilterLevel::Normal(s.as_ref().into())
340 }
341
342 fn lvl_sys<T: AsRef<str>>(s: T) -> TopicFilterLevel {
343 if s.as_ref().contains(['+', '#']) {
344 panic!("invalid normal level `{}` contains +|#", s.as_ref());
345 }
346
347 if !s.as_ref().starts_with('$') {
348 panic!("invalid metadata level `{}` not starts with $", s.as_ref())
349 }
350
351 TopicFilterLevel::System(s.as_ref().into())
352 }
353
354 fn topic(topic: &'static str) -> TopicFilter {
355 TopicFilter::try_from(ByteString::from_static(topic)).unwrap()
356 }
357
358 #[test_case("level" => Ok(vec![lvl_normal("level")]) ; "1")]
359 #[test_case("level/+" => Ok(vec![lvl_normal("level"), TopicFilterLevel::SingleWildcard]) ; "2")]
360 #[test_case("a//#" => Ok(vec![lvl_normal("a"), TopicFilterLevel::Blank, TopicFilterLevel::MultiWildcard]) ; "3")]
361 #[test_case("$a///#" => Ok(vec![lvl_sys("$a"), TopicFilterLevel::Blank, TopicFilterLevel::Blank, TopicFilterLevel::MultiWildcard]) ; "4")]
362 #[test_case("$a/#/" => Err(TopicFilterError::InvalidTopic) ; "5")]
363 #[test_case("a+b" => Err(TopicFilterError::InvalidLevel) ; "6")]
364 #[test_case("a/+b" => Err(TopicFilterError::InvalidLevel) ; "7")]
365 #[test_case("$a/$b/" => Ok(vec![lvl_sys("$a"), lvl_normal("$b"), TopicFilterLevel::Blank]) ; "8")]
366 #[test_case("#/a" => Err(TopicFilterError::InvalidTopic) ; "10")]
367 #[test_case("" => Err(TopicFilterError::InvalidTopic) ; "11")]
368 #[test_case("/finance" => Ok(vec![TopicFilterLevel::Blank, lvl_normal("finance")]) ; "12")]
369 #[test_case("finance/" => Ok(vec![lvl_normal("finance"), TopicFilterLevel::Blank]) ; "13")]
370 fn parsing(input: &str) -> Result<Vec<TopicFilterLevel>, TopicFilterError> {
371 TopicFilter::try_from(ByteString::from(input)).map(|t| t.levels().to_vec())
372 }
373
374 #[test_case(vec![lvl_normal("sport"), lvl_normal("tennis"), lvl_normal("player1")] => true; "1")]
375 #[test_case(vec![lvl_normal("sport"), lvl_normal("tennis"), TopicFilterLevel::MultiWildcard] => true; "2")]
376 #[test_case(vec![lvl_sys("$SYS"), lvl_normal("tennis"), lvl_normal("player1")] => true; "3")]
377 #[test_case(vec![lvl_normal("sport"), TopicFilterLevel::SingleWildcard, lvl_normal("player1")] => true; "4")]
378 #[test_case(vec![lvl_normal("sport"), TopicFilterLevel::MultiWildcard, lvl_normal("player1")] => false; "5")]
379 #[test_case(vec![lvl_normal("sport"), lvl_sys("$SYS"), lvl_normal("player1")] => false; "6")]
380 fn topic_is_valid(levels: Vec<TopicFilterLevel>) -> bool {
381 TopicFilter::try_from(levels).is_ok()
382 }
383
384 #[test]
385 fn test_multi_wildcard_topic() {
386 assert!(topic("sport/tennis/#").matches_filter(&TopicFilter(vec![
387 lvl_normal("sport"),
388 lvl_normal("tennis"),
389 TopicFilterLevel::MultiWildcard
390 ])));
391
392 assert!(topic("sport/tennis/#").matches_topic("sport/tennis"));
393
394 assert!(topic("#").matches_filter(&TopicFilter(vec![TopicFilterLevel::MultiWildcard])));
395 }
396
397 #[test]
398 fn test_single_wildcard_topic() {
399 assert!(topic("+").matches_filter(
400 &TopicFilter::try_from(vec![TopicFilterLevel::SingleWildcard]).unwrap()
401 ));
402
403 assert!(topic("+/tennis/#").matches_filter(&TopicFilter(vec![
404 TopicFilterLevel::SingleWildcard,
405 lvl_normal("tennis"),
406 TopicFilterLevel::MultiWildcard
407 ])));
408
409 assert!(topic("sport/+/player1").matches_filter(&TopicFilter(vec![
410 lvl_normal("sport"),
411 TopicFilterLevel::SingleWildcard,
412 lvl_normal("player1")
413 ])));
414 }
415
416 #[test]
417 fn test_write_topic() {
418 let mut v = vec![];
419 let t = TopicFilter(vec![
420 TopicFilterLevel::SingleWildcard,
421 lvl_normal("tennis"),
422 TopicFilterLevel::MultiWildcard,
423 ]);
424
425 assert_eq!(v.write_topic(&t).unwrap(), 10);
426 assert_eq!(v, b"+/tennis/#");
427
428 assert_eq!(format!("{}", t), "+/tennis/#");
429 assert_eq!(t.to_string(), "+/tennis/#");
430 }
431
432 #[test_case("test", "test" => true)]
433 #[test_case("$SYS", "$SYS" => true)]
434 #[test_case("sport/tennis/player1/#", "sport/tennis/player1" => true)]
435 #[test_case("sport/tennis/player1/#", "sport/tennis/player1/score" => true)]
436 #[test_case("sport/tennis/player1/#", "sport/tennis/player1/score/wimbledon" => true)]
437 #[test_case("sport/#", "sport" => true)]
438 #[test_case("sport/tennis/+", "sport/tennis/player1" => true)]
439 #[test_case("sport/tennis/+", "sport/tennis/player2" => true)]
440 #[test_case("sport/tennis/+", "sport/tennis/player1/ranking" => false)]
441 #[test_case("sport/+", "sport" => false; "single1")]
442 #[test_case("sport/+", "sport/" => true; "single2")]
443 #[test_case("+/+", "/finance" => true; "single3")]
444 #[test_case("/+", "/finance" => true; "single4")]
445 #[test_case("+", "/finance" => false; "single5")]
446 #[test_case("#", "$SYS" => false; "sys1")]
447 #[test_case("+/monitor/Clients", "$SYS/monitor/Clients" => false; "sys2")]
448 #[test_case("$SYS/#", "$SYS/" => true; "sys3")]
449 #[test_case("$SYS/monitor/+", "$SYS/monitor/Clients" => true; "sys4")]
450 #[test_case("#", "/$SYS/monitor/Clients" => true; "sys5")]
451 #[test_case("+", "$SYS" => false; "sys6")]
452 #[test_case("+/#", "$SYS" => false; "sys7")]
453 fn matches_topic(filter: &'static str, topic_str: &'static str) -> bool {
454 topic(filter).matches_topic(topic_str)
455 }
456
457 #[test_case("a/b", "a/b" => true; "1")]
458 #[test_case("a/b", "a/+" => false; "2")]
459 #[test_case("a/b", "a/#" => false; "3")]
460 #[test_case("a/+", "a/#" => false; "4")]
461 #[test_case("a/+", "a/b" => true; "5")]
462 #[test_case("+/+", "/" => true; "6")]
463 #[test_case("+/+", "#" => false; "7")]
464 #[test_case("+", "#" => false; "8")]
465 #[test_case("#", "+" => true; "9")]
466 #[test_case("#", "#" => true; "10")]
467 #[test_case("a/#", "a/+/+" => true; "11")]
468 #[test_case("a/+/normal/+", "a/$not_sys/normal/+" => true; "12")]
469 #[test_case("a/+/#", "a/b" => true; "13")]
470 fn matches_filter(superset_filter: &'static str, subset_filter: &'static str) -> bool {
471 topic(superset_filter).matches_filter(&topic(subset_filter))
472 }
473}