Skip to main content

asteroid_mq/protocol/
interest.rs

1//! # Interest
2//! ## Match Interest
3//! (/)?(<path>|<*>|<**>)/*
4use std::{
5    collections::{BTreeMap, HashMap, HashSet},
6    hash::Hash,
7};
8
9pub use asteroid_mq_model::{
10    Interest, InterestSegment, OwnedInterestSegment, Subject, SubjectSegments,
11};
12use serde::{Deserialize, Serialize};
13#[derive(Debug, Clone)]
14pub struct InterestMap<T> {
15    root: InterestRadixTreeNode<T>,
16    pub(crate) raw: HashMap<T, HashSet<Interest>>,
17}
18
19impl<T> Default for InterestMap<T> {
20    fn default() -> Self {
21        Self {
22            root: Default::default(),
23            raw: HashMap::default(),
24        }
25    }
26}
27
28#[derive(Clone)]
29pub struct InterestRadixTreeNode<T> {
30    value: HashSet<T>,
31    children: BTreeMap<Vec<u8>, InterestRadixTreeNode<T>>,
32    any_child: Option<Box<InterestRadixTreeNode<T>>>,
33    recursive_any_child: Option<Box<InterestRadixTreeNode<T>>>,
34}
35
36struct ChildrenDebugProxy<'a, T>(&'a InterestRadixTreeNode<T>);
37
38impl<T: std::fmt::Debug> std::fmt::Debug for ChildrenDebugProxy<'_, T> {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        let mut debug = f.debug_map();
41        debug.entries(
42            self.0
43                .children
44                .iter()
45                .map(|(k, v)| (std::str::from_utf8(k).unwrap_or("<invalid utf8 str>"), v)),
46        );
47        if let Some(any_child) = &self.0.any_child {
48            debug.entry(&"*", any_child);
49        }
50        if let Some(recursive_any_child) = &self.0.recursive_any_child {
51            debug.entry(&"**", recursive_any_child);
52        }
53        debug.finish()
54    }
55}
56
57impl<T: std::fmt::Debug> std::fmt::Debug for InterestRadixTreeNode<T> {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("InterestRadixTreeNode")
60            .field("value", &self.value)
61            .field("children", &ChildrenDebugProxy(self))
62            .finish()
63    }
64}
65
66impl<T> Default for InterestRadixTreeNode<T> {
67    fn default() -> Self {
68        Self {
69            value: HashSet::default(),
70            children: BTreeMap::new(),
71            any_child: None,
72            recursive_any_child: None,
73        }
74    }
75}
76
77impl<T> InterestRadixTreeNode<T>
78where
79    T: Hash + Eq + PartialEq,
80{
81    fn insert_recursive<'a>(
82        &mut self,
83        mut path: impl Iterator<Item = InterestSegment<'a>>,
84        value: T,
85    ) {
86        match path.next() {
87            Some(InterestSegment::Specific(seg)) => {
88                if let Some(child) = self.children.get_mut(seg) {
89                    child.insert_recursive(path, value)
90                } else {
91                    let mut child_tree = InterestRadixTreeNode::default();
92                    child_tree.insert_recursive(path, value);
93                    self.children.insert(seg.to_owned(), child_tree);
94                }
95            }
96            Some(InterestSegment::Any) => {
97                let child = self.any_child.get_or_insert_with(Default::default);
98                child.insert_recursive(path, value)
99            }
100            Some(InterestSegment::RecursiveAny) => {
101                let child = self
102                    .recursive_any_child
103                    .get_or_insert_with(Default::default);
104                child.insert_recursive(path, value)
105            }
106            None => {
107                self.value.insert(value);
108            }
109        }
110    }
111    fn delete_recursive<'a>(
112        &mut self,
113        mut path: impl Iterator<Item = InterestSegment<'a>>,
114        value: &T,
115    ) {
116        match path.next() {
117            Some(InterestSegment::Specific(seg)) => {
118                if let Some(child) = self.children.get_mut(seg) {
119                    child.delete_recursive(path, value)
120                }
121            }
122            Some(InterestSegment::Any) => {
123                if let Some(ref mut child) = self.any_child {
124                    child.delete_recursive(path, value)
125                }
126            }
127            Some(InterestSegment::RecursiveAny) => {
128                if let Some(ref mut child) = self.recursive_any_child {
129                    child.delete_recursive(path, value)
130                }
131            }
132            None => {
133                self.value.remove(value);
134            }
135        }
136    }
137    fn find_all_recursive<'a, 'i>(
138        &'a self,
139        mut path: impl Iterator<Item = &'i [u8]> + Clone,
140        collector: &mut HashSet<&'a T>,
141    ) {
142        if let Some(seg) = path.next() {
143            if let Some(ref rac) = self.recursive_any_child {
144                let mut rest_path = path.clone();
145                collector.extend(&rac.value);
146                while let Some(recursive_seg) = rest_path.next() {
147                    if let Some(matched) = rac.children.get(recursive_seg) {
148                        matched.find_all_recursive(rest_path.clone(), collector)
149                    }
150                }
151            }
152            if let Some(ref ac) = self.any_child {
153                ac.find_all_recursive(path.clone(), collector)
154            }
155            if let Some(child) = self.children.get(seg) {
156                child.find_all_recursive(path, collector)
157            }
158        } else {
159            collector.extend(&self.value)
160        }
161    }
162}
163impl<T> InterestMap<T>
164where
165    T: Hash + Eq + PartialEq + Clone,
166{
167    pub fn new() -> Self {
168        Self {
169            root: InterestRadixTreeNode::default(),
170            raw: HashMap::default(),
171        }
172    }
173    pub fn from_raw(raw: HashMap<T, HashSet<Interest>>) -> Self {
174        let mut map = Self::new();
175        for (value, interests) in raw {
176            for interest in &interests {
177                map.root
178                    .insert_recursive(interest.as_segments(), value.clone());
179            }
180            map.raw.insert(value, interests);
181        }
182        map
183    }
184
185    pub fn insert(&mut self, interest: Interest, value: T) {
186        self.root
187            .insert_recursive(interest.as_segments(), value.clone());
188        self.raw.entry(value).or_default().insert(interest);
189    }
190
191    pub fn find(&self, subject: &Subject) -> HashSet<&T> {
192        let mut collector = HashSet::new();
193        self.root
194            .find_all_recursive(subject.segments(), &mut collector);
195        collector
196    }
197
198    pub fn delete(&mut self, value: &T) {
199        if let Some(interests) = self.raw.remove(value) {
200            for interest in interests {
201                let mut path = interest.as_segments();
202                self.root.delete_recursive(&mut path, value);
203            }
204        }
205    }
206
207    pub fn interest_of(&self, value: &T) -> Option<&HashSet<Interest>> {
208        self.raw.get(value)
209    }
210}
211
212impl<T> Serialize for InterestMap<T>
213where
214    T: Serialize,
215{
216    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
217    where
218        S: serde::Serializer,
219    {
220        self.raw.serialize(serializer)
221    }
222}
223
224impl<'de, T> Deserialize<'de> for InterestMap<T>
225where
226    T: Deserialize<'de> + Hash + Eq + Clone,
227{
228    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
229    where
230        D: serde::Deserializer<'de>,
231    {
232        let raw = HashMap::<T, HashSet<Interest>>::deserialize(deserializer)?;
233        Ok(Self::from_raw(raw))
234    }
235}
236#[test]
237fn test_interest_map() {
238    let mut map = InterestMap::new();
239    let interest = Interest::new("event/**/user/a");
240    map.insert(interest, 1);
241    map.insert(Interest::new("event/**/user/*"), 2);
242
243    let values = map.find(&Subject::new("event/hello-world/user/a"));
244    assert!(values.contains(&1));
245    assert!(values.contains(&2));
246}