wampire/router/pubsub/
patterns.rs

1//! Contains the `SubscriptionPatternNode` struct, which is used for constructing a trie corresponding
2//! to pattern based subscription
3use std::{
4    collections::HashMap,
5    fmt::{self, Debug, Formatter},
6    mem,
7    slice::Iter,
8    sync::{Arc, Mutex},
9};
10
11use itertools::Itertools;
12
13use crate::{messages::Reason, MatchingPolicy, ID, URI};
14
15use super::super::{random_id, ConnectionInfo};
16
17/// Contains a trie corresponding to the subscription patterns that connections have requested.
18///
19/// Each level of the trie corresponds to a fragment of a uri between the '.' character.
20/// Thus each subscription that starts with 'com' for example will be grouped together.
21/// Subscriptions can be added and removed, and the connections that match a particular URI
22/// can be found using the `get_registrant_for()` method.
23///
24pub struct SubscriptionPatternNode<P: PatternData> {
25    edges: HashMap<String, SubscriptionPatternNode<P>>,
26    connections: Vec<DataWrapper<P>>,
27    prefix_connections: Vec<DataWrapper<P>>,
28    id: ID,
29    prefix_id: ID,
30}
31
32/// Represents data that a pattern trie will hold
33pub trait PatternData {
34    fn get_id(&self) -> ID;
35}
36
37struct DataWrapper<P: PatternData> {
38    subscriber: P,
39    policy: MatchingPolicy,
40}
41
42/// A lazy iterator that traverses the pattern trie.  See `SubscriptionPatternNode` for more.
43pub struct MatchIterator<'a, P>
44where
45    P: PatternData,
46{
47    uri: Vec<String>,
48    current: Box<StackFrame<'a, P>>,
49}
50
51struct StackFrame<'a, P>
52where
53    P: PatternData,
54{
55    node: &'a SubscriptionPatternNode<P>,
56    state: IterState<'a, P>,
57    depth: usize,
58    parent: Option<Box<StackFrame<'a, P>>>,
59}
60
61/// Represents an error caused during adding or removing patterns
62#[derive(Debug)]
63pub struct PatternError {
64    reason: Reason,
65}
66
67#[derive(Clone)]
68enum IterState<'a, P: PatternData>
69where
70    P: PatternData,
71{
72    None,
73    Wildcard,
74    Strict,
75    Prefix(Iter<'a, DataWrapper<P>>),
76    PrefixComplete,
77    Subs(Iter<'a, DataWrapper<P>>),
78    AllComplete,
79}
80
81impl PatternError {
82    #[inline]
83    pub fn new(reason: Reason) -> PatternError {
84        PatternError { reason }
85    }
86
87    pub fn reason(self) -> Reason {
88        self.reason
89    }
90}
91
92impl PatternData for Arc<Mutex<ConnectionInfo>> {
93    fn get_id(&self) -> ID {
94        self.lock().unwrap().id
95    }
96}
97
98impl<'a, P: PatternData> Debug for IterState<'a, P> {
99    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
100        write!(
101            f,
102            "{}",
103            match *self {
104                IterState::None => "None",
105                IterState::Wildcard => "Wildcard",
106                IterState::Strict => "Strict",
107                IterState::Prefix(_) => "Prefix",
108                IterState::PrefixComplete => "PrefixComplete",
109                IterState::Subs(_) => "Subs",
110                IterState::AllComplete => "AllComplete",
111            }
112        )
113    }
114}
115
116impl<P: PatternData> Debug for SubscriptionPatternNode<P> {
117    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
118        self.fmt_with_indent(f, 0)
119    }
120}
121
122impl<P: PatternData> SubscriptionPatternNode<P> {
123    fn fmt_with_indent(&self, f: &mut Formatter<'_>, indent: usize) -> fmt::Result {
124        writeln!(
125            f,
126            "{} pre: {:?} subs: {:?}",
127            self.id,
128            self.prefix_connections
129                .iter()
130                .map(|sub| sub.subscriber.get_id())
131                .join(","),
132            self.connections
133                .iter()
134                .map(|sub| sub.subscriber.get_id())
135                .join(","),
136        )?;
137        for (chunk, node) in &self.edges {
138            for _ in 0..indent * 2 {
139                write!(f, "  ")?;
140            }
141            write!(f, "{} - ", chunk)?;
142            node.fmt_with_indent(f, indent + 1)?;
143        }
144        Ok(())
145    }
146
147    /// Add a new subscription to the pattern trie with the given pattern and matching policy.
148    pub fn subscribe_with(
149        &mut self,
150        topic: &URI,
151        subscriber: P,
152        matching_policy: MatchingPolicy,
153    ) -> Result<ID, PatternError> {
154        let mut uri_bits = topic.uri.split('.');
155        let initial = match uri_bits.next() {
156            Some(initial) => initial,
157            None => return Err(PatternError::new(Reason::InvalidURI)),
158        };
159        let edge = self
160            .edges
161            .entry(initial.to_string())
162            .or_insert_with(SubscriptionPatternNode::new);
163        edge.add_subscription(uri_bits, subscriber, matching_policy)
164    }
165
166    /// Removes a subscription from the pattern trie.
167    pub fn unsubscribe_with(
168        &mut self,
169        topic: &str,
170        subscriber: &P,
171        is_prefix: bool,
172    ) -> Result<ID, PatternError> {
173        let uri_bits = topic.split('.');
174        self.remove_subscription(uri_bits, subscriber.get_id(), is_prefix)
175    }
176
177    /// Constructs a new SubscriptionPatternNode to be used as the root of the trie
178    #[inline]
179    pub fn new() -> SubscriptionPatternNode<P> {
180        SubscriptionPatternNode {
181            edges: HashMap::new(),
182            connections: Vec::new(),
183            prefix_connections: Vec::new(),
184            id: random_id(),
185            prefix_id: random_id(),
186        }
187    }
188
189    fn add_subscription<'a, I>(
190        &mut self,
191        mut uri_bits: I,
192        subscriber: P,
193        matching_policy: MatchingPolicy,
194    ) -> Result<ID, PatternError>
195    where
196        I: Iterator<Item = &'a str>,
197    {
198        match uri_bits.next() {
199            Some(uri_bit) => {
200                if uri_bit.is_empty() && matching_policy != MatchingPolicy::Wildcard {
201                    return Err(PatternError::new(Reason::InvalidURI));
202                }
203                let edge = self
204                    .edges
205                    .entry(uri_bit.to_string())
206                    .or_insert_with(SubscriptionPatternNode::new);
207                edge.add_subscription(uri_bits, subscriber, matching_policy)
208            }
209            None => {
210                if matching_policy == MatchingPolicy::Prefix {
211                    self.prefix_connections.push(DataWrapper {
212                        subscriber,
213                        policy: matching_policy,
214                    });
215                    Ok(self.prefix_id)
216                } else {
217                    self.connections.push(DataWrapper {
218                        subscriber,
219                        policy: matching_policy,
220                    });
221                    Ok(self.id)
222                }
223            }
224        }
225    }
226
227    fn remove_subscription<'a, I>(
228        &mut self,
229        mut uri_bits: I,
230        subscriber_id: u64,
231        is_prefix: bool,
232    ) -> Result<ID, PatternError>
233    where
234        I: Iterator<Item = &'a str>,
235    {
236        // TODO consider deleting nodes in the tree if they are no longer in use.
237        match uri_bits.next() {
238            Some(uri_bit) => {
239                if let Some(edge) = self.edges.get_mut(uri_bit) {
240                    edge.remove_subscription(uri_bits, subscriber_id, is_prefix)
241                } else {
242                    Err(PatternError::new(Reason::InvalidURI))
243                }
244            }
245            None => {
246                if is_prefix {
247                    self.prefix_connections
248                        .retain(|sub| sub.subscriber.get_id() != subscriber_id);
249                    Ok(self.prefix_id)
250                } else {
251                    self.connections
252                        .retain(|sub| sub.subscriber.get_id() != subscriber_id);
253                    Ok(self.id)
254                }
255            }
256        }
257    }
258
259    /// Constructs a lazy iterator over all of the connections whose subscription patterns
260    /// match the given uri.
261    ///
262    /// This iterator returns a triple with the connection info, the id of the subscription and
263    /// the matching policy used when the subscription was created.
264    pub fn filter(&self, topic: URI) -> MatchIterator<'_, P> {
265        MatchIterator {
266            current: Box::new(StackFrame {
267                node: self,
268                depth: 0,
269                state: IterState::None,
270                parent: None,
271            }),
272            uri: topic.uri.split('.').map(|s| s.to_string()).collect(),
273        }
274    }
275}
276
277impl<'a, P: PatternData> MatchIterator<'a, P> {
278    fn push(&mut self, child: &'a SubscriptionPatternNode<P>) {
279        let new_node = Box::new(StackFrame {
280            parent: None,
281            depth: self.current.depth + 1,
282            node: child,
283            state: IterState::None,
284        });
285        let parent = mem::replace(&mut self.current, new_node);
286        self.current.parent = Some(parent);
287    }
288
289    /// Moves through the subscription tree, looking for the next set of connections that match the
290    /// given uri.
291    fn traverse(&mut self) -> Option<(&'a P, ID, MatchingPolicy)> {
292        // This method functions as a push down automata.  For each node, it starts by iterating
293        // through the data that match a prefix of the uri
294        // Then when that's done, it checks if the uri has been fully processed, and if so, iterates
295        // through the connections that require exact matching
296        // Otherwise, it pushes the current node on the stack, consumes another chunk of the uri
297        // and moves on to any children that use wildcard matching.
298        // Once it is finished traversing that part of the tree, it re-consumes the same chunk
299        // of the URI, and moves on to any children that match the chunk exactly.
300        // After all that is exhausted, it will pop the node of the stack and return to its parent
301        match self.current.state {
302            IterState::None => {
303                self.current.state = IterState::Prefix(self.current.node.prefix_connections.iter())
304            }
305            IterState::Prefix(_) => {
306                self.current.state = IterState::PrefixComplete;
307            }
308            IterState::PrefixComplete => {
309                if self.current.depth == self.uri.len() {
310                    self.current.state = IterState::Subs(self.current.node.connections.iter());
311                } else if let Some(child) = self.current.node.edges.get("") {
312                    self.current.state = IterState::Wildcard;
313                    self.push(child);
314                } else if let Some(child) =
315                    self.current.node.edges.get(&self.uri[self.current.depth])
316                {
317                    self.current.state = IterState::Strict;
318                    self.push(child);
319                } else {
320                    self.current.state = IterState::AllComplete;
321                }
322            }
323            IterState::Wildcard => {
324                if self.current.depth == self.uri.len() {
325                    self.current.state = IterState::AllComplete;
326                } else if let Some(child) =
327                    self.current.node.edges.get(&self.uri[self.current.depth])
328                {
329                    self.current.state = IterState::Strict;
330                    self.push(child);
331                } else {
332                    self.current.state = IterState::AllComplete;
333                }
334            }
335            IterState::Strict => {
336                self.current.state = IterState::AllComplete;
337            }
338            IterState::Subs(_) => {
339                self.current.state = IterState::AllComplete;
340            }
341            IterState::AllComplete => {
342                if self.current.depth == 0 {
343                    return None;
344                } else {
345                    let parent = self.current.parent.take();
346                    let _ = mem::replace(&mut self.current, parent.unwrap());
347                }
348            }
349        };
350        self.next()
351    }
352}
353
354impl<'a, P: PatternData> Iterator for MatchIterator<'a, P> {
355    type Item = (&'a P, ID, MatchingPolicy);
356
357    fn next(&mut self) -> Option<(&'a P, ID, MatchingPolicy)> {
358        let prefix_id = self.current.node.prefix_id;
359        let node_id = self.current.node.id;
360        // If we are currently iterating through connections, continue iterating
361        match self.current.state {
362            IterState::Prefix(ref mut prefix_iter) => {
363                let next = prefix_iter.next();
364                if let Some(next) = next {
365                    return Some((&next.subscriber, prefix_id, next.policy));
366                }
367            }
368            IterState::Subs(ref mut sub_iter) => {
369                let next = sub_iter.next();
370                if let Some(next) = next {
371                    return Some((&next.subscriber, node_id, next.policy));
372                }
373            }
374            _ => {}
375        };
376
377        // Otherwise, it is time to traverse through the tree.
378        self.traverse()
379    }
380}
381
382#[cfg(test)]
383mod test {
384    use super::{PatternData, SubscriptionPatternNode};
385    use crate::{MatchingPolicy, ID, URI};
386
387    #[derive(Clone)]
388    struct MockData {
389        id: ID,
390    }
391
392    impl PatternData for MockData {
393        fn get_id(&self) -> ID {
394            self.id
395        }
396    }
397    impl MockData {
398        pub fn new(id: ID) -> MockData {
399            MockData { id }
400        }
401    }
402
403    #[test]
404    fn adding_patterns() {
405        let connection1 = MockData::new(1);
406        let connection2 = MockData::new(2);
407        let connection3 = MockData::new(3);
408        let connection4 = MockData::new(4);
409        let mut root = SubscriptionPatternNode::new();
410
411        let ids = [
412            root.subscribe_with(
413                &URI::new("com.example.test..topic"),
414                connection1,
415                MatchingPolicy::Wildcard,
416            )
417            .unwrap(),
418            root.subscribe_with(
419                &URI::new("com.example.test.specific.topic"),
420                connection2,
421                MatchingPolicy::Strict,
422            )
423            .unwrap(),
424            root.subscribe_with(
425                &URI::new("com.example"),
426                connection3,
427                MatchingPolicy::Prefix,
428            )
429            .unwrap(),
430            root.subscribe_with(
431                &URI::new("com.example.test"),
432                connection4,
433                MatchingPolicy::Prefix,
434            )
435            .unwrap(),
436        ];
437
438        assert_eq!(
439            root.filter(URI::new("com.example.test.specific.topic"))
440                .map(|(_connection, id, _policy)| id)
441                .collect::<Vec<_>>(),
442            vec![ids[2], ids[3], ids[0], ids[1]]
443        );
444    }
445
446    #[test]
447    fn removing_patterns() {
448        let connection1 = MockData::new(1);
449        let connection2 = MockData::new(2);
450        let connection3 = MockData::new(3);
451        let connection4 = MockData::new(4);
452        let mut root = SubscriptionPatternNode::new();
453
454        let ids = [
455            root.subscribe_with(
456                &URI::new("com.example.test..topic"),
457                connection1.clone(),
458                MatchingPolicy::Wildcard,
459            )
460            .unwrap(),
461            root.subscribe_with(
462                &URI::new("com.example.test.specific.topic"),
463                connection2,
464                MatchingPolicy::Strict,
465            )
466            .unwrap(),
467            root.subscribe_with(
468                &URI::new("com.example"),
469                connection3,
470                MatchingPolicy::Prefix,
471            )
472            .unwrap(),
473            root.subscribe_with(
474                &URI::new("com.example.test"),
475                connection4.clone(),
476                MatchingPolicy::Prefix,
477            )
478            .unwrap(),
479        ];
480
481        root.unsubscribe_with("com.example.test..topic", &connection1, false)
482            .unwrap();
483        root.unsubscribe_with("com.example.test", &connection4, true)
484            .unwrap();
485
486        assert_eq!(
487            root.filter(URI::new("com.example.test.specific.topic"))
488                .map(|(_connection, id, _policy)| id)
489                .collect::<Vec<_>>(),
490            vec![ids[2], ids[1]]
491        )
492    }
493}