1use std::{
4 collections::HashMap,
5 fmt::{self, Debug, Formatter},
6 mem,
7 slice::Iter,
8 sync::{Arc, Mutex},
9};
10
11use itertools::Itertools;
12
13use crate::{messages::Reason, MatchingPolicy, ID, URI};
14
15use super::super::{random_id, ConnectionInfo};
16
17pub struct SubscriptionPatternNode<P: PatternData> {
25 edges: HashMap<String, SubscriptionPatternNode<P>>,
26 connections: Vec<DataWrapper<P>>,
27 prefix_connections: Vec<DataWrapper<P>>,
28 id: ID,
29 prefix_id: ID,
30}
31
32pub trait PatternData {
34 fn get_id(&self) -> ID;
35}
36
37struct DataWrapper<P: PatternData> {
38 subscriber: P,
39 policy: MatchingPolicy,
40}
41
42pub struct MatchIterator<'a, P>
44where
45 P: PatternData,
46{
47 uri: Vec<String>,
48 current: Box<StackFrame<'a, P>>,
49}
50
51struct StackFrame<'a, P>
52where
53 P: PatternData,
54{
55 node: &'a SubscriptionPatternNode<P>,
56 state: IterState<'a, P>,
57 depth: usize,
58 parent: Option<Box<StackFrame<'a, P>>>,
59}
60
61#[derive(Debug)]
63pub struct PatternError {
64 reason: Reason,
65}
66
67#[derive(Clone)]
68enum IterState<'a, P: PatternData>
69where
70 P: PatternData,
71{
72 None,
73 Wildcard,
74 Strict,
75 Prefix(Iter<'a, DataWrapper<P>>),
76 PrefixComplete,
77 Subs(Iter<'a, DataWrapper<P>>),
78 AllComplete,
79}
80
81impl PatternError {
82 #[inline]
83 pub fn new(reason: Reason) -> PatternError {
84 PatternError { reason }
85 }
86
87 pub fn reason(self) -> Reason {
88 self.reason
89 }
90}
91
92impl PatternData for Arc<Mutex<ConnectionInfo>> {
93 fn get_id(&self) -> ID {
94 self.lock().unwrap().id
95 }
96}
97
98impl<'a, P: PatternData> Debug for IterState<'a, P> {
99 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
100 write!(
101 f,
102 "{}",
103 match *self {
104 IterState::None => "None",
105 IterState::Wildcard => "Wildcard",
106 IterState::Strict => "Strict",
107 IterState::Prefix(_) => "Prefix",
108 IterState::PrefixComplete => "PrefixComplete",
109 IterState::Subs(_) => "Subs",
110 IterState::AllComplete => "AllComplete",
111 }
112 )
113 }
114}
115
116impl<P: PatternData> Debug for SubscriptionPatternNode<P> {
117 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
118 self.fmt_with_indent(f, 0)
119 }
120}
121
122impl<P: PatternData> SubscriptionPatternNode<P> {
123 fn fmt_with_indent(&self, f: &mut Formatter<'_>, indent: usize) -> fmt::Result {
124 writeln!(
125 f,
126 "{} pre: {:?} subs: {:?}",
127 self.id,
128 self.prefix_connections
129 .iter()
130 .map(|sub| sub.subscriber.get_id())
131 .join(","),
132 self.connections
133 .iter()
134 .map(|sub| sub.subscriber.get_id())
135 .join(","),
136 )?;
137 for (chunk, node) in &self.edges {
138 for _ in 0..indent * 2 {
139 write!(f, " ")?;
140 }
141 write!(f, "{} - ", chunk)?;
142 node.fmt_with_indent(f, indent + 1)?;
143 }
144 Ok(())
145 }
146
147 pub fn subscribe_with(
149 &mut self,
150 topic: &URI,
151 subscriber: P,
152 matching_policy: MatchingPolicy,
153 ) -> Result<ID, PatternError> {
154 let mut uri_bits = topic.uri.split('.');
155 let initial = match uri_bits.next() {
156 Some(initial) => initial,
157 None => return Err(PatternError::new(Reason::InvalidURI)),
158 };
159 let edge = self
160 .edges
161 .entry(initial.to_string())
162 .or_insert_with(SubscriptionPatternNode::new);
163 edge.add_subscription(uri_bits, subscriber, matching_policy)
164 }
165
166 pub fn unsubscribe_with(
168 &mut self,
169 topic: &str,
170 subscriber: &P,
171 is_prefix: bool,
172 ) -> Result<ID, PatternError> {
173 let uri_bits = topic.split('.');
174 self.remove_subscription(uri_bits, subscriber.get_id(), is_prefix)
175 }
176
177 #[inline]
179 pub fn new() -> SubscriptionPatternNode<P> {
180 SubscriptionPatternNode {
181 edges: HashMap::new(),
182 connections: Vec::new(),
183 prefix_connections: Vec::new(),
184 id: random_id(),
185 prefix_id: random_id(),
186 }
187 }
188
189 fn add_subscription<'a, I>(
190 &mut self,
191 mut uri_bits: I,
192 subscriber: P,
193 matching_policy: MatchingPolicy,
194 ) -> Result<ID, PatternError>
195 where
196 I: Iterator<Item = &'a str>,
197 {
198 match uri_bits.next() {
199 Some(uri_bit) => {
200 if uri_bit.is_empty() && matching_policy != MatchingPolicy::Wildcard {
201 return Err(PatternError::new(Reason::InvalidURI));
202 }
203 let edge = self
204 .edges
205 .entry(uri_bit.to_string())
206 .or_insert_with(SubscriptionPatternNode::new);
207 edge.add_subscription(uri_bits, subscriber, matching_policy)
208 }
209 None => {
210 if matching_policy == MatchingPolicy::Prefix {
211 self.prefix_connections.push(DataWrapper {
212 subscriber,
213 policy: matching_policy,
214 });
215 Ok(self.prefix_id)
216 } else {
217 self.connections.push(DataWrapper {
218 subscriber,
219 policy: matching_policy,
220 });
221 Ok(self.id)
222 }
223 }
224 }
225 }
226
227 fn remove_subscription<'a, I>(
228 &mut self,
229 mut uri_bits: I,
230 subscriber_id: u64,
231 is_prefix: bool,
232 ) -> Result<ID, PatternError>
233 where
234 I: Iterator<Item = &'a str>,
235 {
236 match uri_bits.next() {
238 Some(uri_bit) => {
239 if let Some(edge) = self.edges.get_mut(uri_bit) {
240 edge.remove_subscription(uri_bits, subscriber_id, is_prefix)
241 } else {
242 Err(PatternError::new(Reason::InvalidURI))
243 }
244 }
245 None => {
246 if is_prefix {
247 self.prefix_connections
248 .retain(|sub| sub.subscriber.get_id() != subscriber_id);
249 Ok(self.prefix_id)
250 } else {
251 self.connections
252 .retain(|sub| sub.subscriber.get_id() != subscriber_id);
253 Ok(self.id)
254 }
255 }
256 }
257 }
258
259 pub fn filter(&self, topic: URI) -> MatchIterator<'_, P> {
265 MatchIterator {
266 current: Box::new(StackFrame {
267 node: self,
268 depth: 0,
269 state: IterState::None,
270 parent: None,
271 }),
272 uri: topic.uri.split('.').map(|s| s.to_string()).collect(),
273 }
274 }
275}
276
277impl<'a, P: PatternData> MatchIterator<'a, P> {
278 fn push(&mut self, child: &'a SubscriptionPatternNode<P>) {
279 let new_node = Box::new(StackFrame {
280 parent: None,
281 depth: self.current.depth + 1,
282 node: child,
283 state: IterState::None,
284 });
285 let parent = mem::replace(&mut self.current, new_node);
286 self.current.parent = Some(parent);
287 }
288
289 fn traverse(&mut self) -> Option<(&'a P, ID, MatchingPolicy)> {
292 match self.current.state {
302 IterState::None => {
303 self.current.state = IterState::Prefix(self.current.node.prefix_connections.iter())
304 }
305 IterState::Prefix(_) => {
306 self.current.state = IterState::PrefixComplete;
307 }
308 IterState::PrefixComplete => {
309 if self.current.depth == self.uri.len() {
310 self.current.state = IterState::Subs(self.current.node.connections.iter());
311 } else if let Some(child) = self.current.node.edges.get("") {
312 self.current.state = IterState::Wildcard;
313 self.push(child);
314 } else if let Some(child) =
315 self.current.node.edges.get(&self.uri[self.current.depth])
316 {
317 self.current.state = IterState::Strict;
318 self.push(child);
319 } else {
320 self.current.state = IterState::AllComplete;
321 }
322 }
323 IterState::Wildcard => {
324 if self.current.depth == self.uri.len() {
325 self.current.state = IterState::AllComplete;
326 } else if let Some(child) =
327 self.current.node.edges.get(&self.uri[self.current.depth])
328 {
329 self.current.state = IterState::Strict;
330 self.push(child);
331 } else {
332 self.current.state = IterState::AllComplete;
333 }
334 }
335 IterState::Strict => {
336 self.current.state = IterState::AllComplete;
337 }
338 IterState::Subs(_) => {
339 self.current.state = IterState::AllComplete;
340 }
341 IterState::AllComplete => {
342 if self.current.depth == 0 {
343 return None;
344 } else {
345 let parent = self.current.parent.take();
346 let _ = mem::replace(&mut self.current, parent.unwrap());
347 }
348 }
349 };
350 self.next()
351 }
352}
353
354impl<'a, P: PatternData> Iterator for MatchIterator<'a, P> {
355 type Item = (&'a P, ID, MatchingPolicy);
356
357 fn next(&mut self) -> Option<(&'a P, ID, MatchingPolicy)> {
358 let prefix_id = self.current.node.prefix_id;
359 let node_id = self.current.node.id;
360 match self.current.state {
362 IterState::Prefix(ref mut prefix_iter) => {
363 let next = prefix_iter.next();
364 if let Some(next) = next {
365 return Some((&next.subscriber, prefix_id, next.policy));
366 }
367 }
368 IterState::Subs(ref mut sub_iter) => {
369 let next = sub_iter.next();
370 if let Some(next) = next {
371 return Some((&next.subscriber, node_id, next.policy));
372 }
373 }
374 _ => {}
375 };
376
377 self.traverse()
379 }
380}
381
382#[cfg(test)]
383mod test {
384 use super::{PatternData, SubscriptionPatternNode};
385 use crate::{MatchingPolicy, ID, URI};
386
387 #[derive(Clone)]
388 struct MockData {
389 id: ID,
390 }
391
392 impl PatternData for MockData {
393 fn get_id(&self) -> ID {
394 self.id
395 }
396 }
397 impl MockData {
398 pub fn new(id: ID) -> MockData {
399 MockData { id }
400 }
401 }
402
403 #[test]
404 fn adding_patterns() {
405 let connection1 = MockData::new(1);
406 let connection2 = MockData::new(2);
407 let connection3 = MockData::new(3);
408 let connection4 = MockData::new(4);
409 let mut root = SubscriptionPatternNode::new();
410
411 let ids = [
412 root.subscribe_with(
413 &URI::new("com.example.test..topic"),
414 connection1,
415 MatchingPolicy::Wildcard,
416 )
417 .unwrap(),
418 root.subscribe_with(
419 &URI::new("com.example.test.specific.topic"),
420 connection2,
421 MatchingPolicy::Strict,
422 )
423 .unwrap(),
424 root.subscribe_with(
425 &URI::new("com.example"),
426 connection3,
427 MatchingPolicy::Prefix,
428 )
429 .unwrap(),
430 root.subscribe_with(
431 &URI::new("com.example.test"),
432 connection4,
433 MatchingPolicy::Prefix,
434 )
435 .unwrap(),
436 ];
437
438 assert_eq!(
439 root.filter(URI::new("com.example.test.specific.topic"))
440 .map(|(_connection, id, _policy)| id)
441 .collect::<Vec<_>>(),
442 vec![ids[2], ids[3], ids[0], ids[1]]
443 );
444 }
445
446 #[test]
447 fn removing_patterns() {
448 let connection1 = MockData::new(1);
449 let connection2 = MockData::new(2);
450 let connection3 = MockData::new(3);
451 let connection4 = MockData::new(4);
452 let mut root = SubscriptionPatternNode::new();
453
454 let ids = [
455 root.subscribe_with(
456 &URI::new("com.example.test..topic"),
457 connection1.clone(),
458 MatchingPolicy::Wildcard,
459 )
460 .unwrap(),
461 root.subscribe_with(
462 &URI::new("com.example.test.specific.topic"),
463 connection2,
464 MatchingPolicy::Strict,
465 )
466 .unwrap(),
467 root.subscribe_with(
468 &URI::new("com.example"),
469 connection3,
470 MatchingPolicy::Prefix,
471 )
472 .unwrap(),
473 root.subscribe_with(
474 &URI::new("com.example.test"),
475 connection4.clone(),
476 MatchingPolicy::Prefix,
477 )
478 .unwrap(),
479 ];
480
481 root.unsubscribe_with("com.example.test..topic", &connection1, false)
482 .unwrap();
483 root.unsubscribe_with("com.example.test", &connection4, true)
484 .unwrap();
485
486 assert_eq!(
487 root.filter(URI::new("com.example.test.specific.topic"))
488 .map(|(_connection, id, _policy)| id)
489 .collect::<Vec<_>>(),
490 vec![ids[2], ids[1]]
491 )
492 }
493}