asteroid_mq/protocol/
interest.rs1use std::{
5 collections::{BTreeMap, HashMap, HashSet},
6 hash::Hash,
7};
8
9pub use asteroid_mq_model::{
10 Interest, InterestSegment, OwnedInterestSegment, Subject, SubjectSegments,
11};
12use serde::{Deserialize, Serialize};
13#[derive(Debug, Clone)]
14pub struct InterestMap<T> {
15 root: InterestRadixTreeNode<T>,
16 pub(crate) raw: HashMap<T, HashSet<Interest>>,
17}
18
19impl<T> Default for InterestMap<T> {
20 fn default() -> Self {
21 Self {
22 root: Default::default(),
23 raw: HashMap::default(),
24 }
25 }
26}
27
28#[derive(Clone)]
29pub struct InterestRadixTreeNode<T> {
30 value: HashSet<T>,
31 children: BTreeMap<Vec<u8>, InterestRadixTreeNode<T>>,
32 any_child: Option<Box<InterestRadixTreeNode<T>>>,
33 recursive_any_child: Option<Box<InterestRadixTreeNode<T>>>,
34}
35
36struct ChildrenDebugProxy<'a, T>(&'a InterestRadixTreeNode<T>);
37
38impl<T: std::fmt::Debug> std::fmt::Debug for ChildrenDebugProxy<'_, T> {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 let mut debug = f.debug_map();
41 debug.entries(
42 self.0
43 .children
44 .iter()
45 .map(|(k, v)| (std::str::from_utf8(k).unwrap_or("<invalid utf8 str>"), v)),
46 );
47 if let Some(any_child) = &self.0.any_child {
48 debug.entry(&"*", any_child);
49 }
50 if let Some(recursive_any_child) = &self.0.recursive_any_child {
51 debug.entry(&"**", recursive_any_child);
52 }
53 debug.finish()
54 }
55}
56
57impl<T: std::fmt::Debug> std::fmt::Debug for InterestRadixTreeNode<T> {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("InterestRadixTreeNode")
60 .field("value", &self.value)
61 .field("children", &ChildrenDebugProxy(self))
62 .finish()
63 }
64}
65
66impl<T> Default for InterestRadixTreeNode<T> {
67 fn default() -> Self {
68 Self {
69 value: HashSet::default(),
70 children: BTreeMap::new(),
71 any_child: None,
72 recursive_any_child: None,
73 }
74 }
75}
76
77impl<T> InterestRadixTreeNode<T>
78where
79 T: Hash + Eq + PartialEq,
80{
81 fn insert_recursive<'a>(
82 &mut self,
83 mut path: impl Iterator<Item = InterestSegment<'a>>,
84 value: T,
85 ) {
86 match path.next() {
87 Some(InterestSegment::Specific(seg)) => {
88 if let Some(child) = self.children.get_mut(seg) {
89 child.insert_recursive(path, value)
90 } else {
91 let mut child_tree = InterestRadixTreeNode::default();
92 child_tree.insert_recursive(path, value);
93 self.children.insert(seg.to_owned(), child_tree);
94 }
95 }
96 Some(InterestSegment::Any) => {
97 let child = self.any_child.get_or_insert_with(Default::default);
98 child.insert_recursive(path, value)
99 }
100 Some(InterestSegment::RecursiveAny) => {
101 let child = self
102 .recursive_any_child
103 .get_or_insert_with(Default::default);
104 child.insert_recursive(path, value)
105 }
106 None => {
107 self.value.insert(value);
108 }
109 }
110 }
111 fn delete_recursive<'a>(
112 &mut self,
113 mut path: impl Iterator<Item = InterestSegment<'a>>,
114 value: &T,
115 ) {
116 match path.next() {
117 Some(InterestSegment::Specific(seg)) => {
118 if let Some(child) = self.children.get_mut(seg) {
119 child.delete_recursive(path, value)
120 }
121 }
122 Some(InterestSegment::Any) => {
123 if let Some(ref mut child) = self.any_child {
124 child.delete_recursive(path, value)
125 }
126 }
127 Some(InterestSegment::RecursiveAny) => {
128 if let Some(ref mut child) = self.recursive_any_child {
129 child.delete_recursive(path, value)
130 }
131 }
132 None => {
133 self.value.remove(value);
134 }
135 }
136 }
137 fn find_all_recursive<'a, 'i>(
138 &'a self,
139 mut path: impl Iterator<Item = &'i [u8]> + Clone,
140 collector: &mut HashSet<&'a T>,
141 ) {
142 if let Some(seg) = path.next() {
143 if let Some(ref rac) = self.recursive_any_child {
144 let mut rest_path = path.clone();
145 collector.extend(&rac.value);
146 while let Some(recursive_seg) = rest_path.next() {
147 if let Some(matched) = rac.children.get(recursive_seg) {
148 matched.find_all_recursive(rest_path.clone(), collector)
149 }
150 }
151 }
152 if let Some(ref ac) = self.any_child {
153 ac.find_all_recursive(path.clone(), collector)
154 }
155 if let Some(child) = self.children.get(seg) {
156 child.find_all_recursive(path, collector)
157 }
158 } else {
159 collector.extend(&self.value)
160 }
161 }
162}
163impl<T> InterestMap<T>
164where
165 T: Hash + Eq + PartialEq + Clone,
166{
167 pub fn new() -> Self {
168 Self {
169 root: InterestRadixTreeNode::default(),
170 raw: HashMap::default(),
171 }
172 }
173 pub fn from_raw(raw: HashMap<T, HashSet<Interest>>) -> Self {
174 let mut map = Self::new();
175 for (value, interests) in raw {
176 for interest in &interests {
177 map.root
178 .insert_recursive(interest.as_segments(), value.clone());
179 }
180 map.raw.insert(value, interests);
181 }
182 map
183 }
184
185 pub fn insert(&mut self, interest: Interest, value: T) {
186 self.root
187 .insert_recursive(interest.as_segments(), value.clone());
188 self.raw.entry(value).or_default().insert(interest);
189 }
190
191 pub fn find(&self, subject: &Subject) -> HashSet<&T> {
192 let mut collector = HashSet::new();
193 self.root
194 .find_all_recursive(subject.segments(), &mut collector);
195 collector
196 }
197
198 pub fn delete(&mut self, value: &T) {
199 if let Some(interests) = self.raw.remove(value) {
200 for interest in interests {
201 let mut path = interest.as_segments();
202 self.root.delete_recursive(&mut path, value);
203 }
204 }
205 }
206
207 pub fn interest_of(&self, value: &T) -> Option<&HashSet<Interest>> {
208 self.raw.get(value)
209 }
210}
211
212impl<T> Serialize for InterestMap<T>
213where
214 T: Serialize,
215{
216 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
217 where
218 S: serde::Serializer,
219 {
220 self.raw.serialize(serializer)
221 }
222}
223
224impl<'de, T> Deserialize<'de> for InterestMap<T>
225where
226 T: Deserialize<'de> + Hash + Eq + Clone,
227{
228 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
229 where
230 D: serde::Deserializer<'de>,
231 {
232 let raw = HashMap::<T, HashSet<Interest>>::deserialize(deserializer)?;
233 Ok(Self::from_raw(raw))
234 }
235}
236#[test]
237fn test_interest_map() {
238 let mut map = InterestMap::new();
239 let interest = Interest::new("event/**/user/a");
240 map.insert(interest, 1);
241 map.insert(Interest::new("event/**/user/*"), 2);
242
243 let values = map.find(&Subject::new("event/hello-world/user/a"));
244 assert!(values.contains(&1));
245 assert!(values.contains(&2));
246}