graphos_core/index/
trie.rs1use graphos_common::types::{EdgeId, NodeId};
8use graphos_common::utils::hash::FxHashMap;
9use smallvec::SmallVec;
10
11#[derive(Debug, Clone)]
13struct TrieNode {
14 children: FxHashMap<NodeId, TrieNode>,
16 edges: SmallVec<[EdgeId; 4]>,
18}
19
20impl TrieNode {
21 fn new() -> Self {
22 Self {
23 children: FxHashMap::default(),
24 edges: SmallVec::new(),
25 }
26 }
27
28 fn insert(&mut self, path: &[NodeId], edge_id: EdgeId) {
29 if path.is_empty() {
30 self.edges.push(edge_id);
31 return;
32 }
33
34 self.children
35 .entry(path[0])
36 .or_insert_with(TrieNode::new)
37 .insert(&path[1..], edge_id);
38 }
39
40 fn get_child(&self, key: NodeId) -> Option<&TrieNode> {
41 self.children.get(&key)
42 }
43
44 #[allow(dead_code)]
45 fn children_keys(&self) -> impl Iterator<Item = NodeId> + '_ {
46 self.children.keys().copied()
47 }
48
49 fn children_sorted(&self) -> Vec<NodeId> {
50 let mut keys: Vec<_> = self.children.keys().copied().collect();
51 keys.sort();
52 keys
53 }
54}
55
56pub struct TrieIndex {
61 root: TrieNode,
63 size: usize,
65}
66
67impl TrieIndex {
68 #[must_use]
70 pub fn new() -> Self {
71 Self {
72 root: TrieNode::new(),
73 size: 0,
74 }
75 }
76
77 pub fn insert(&mut self, path: &[NodeId], edge_id: EdgeId) {
81 self.root.insert(path, edge_id);
82 self.size += 1;
83 }
84
85 pub fn insert_edge(&mut self, src: NodeId, dst: NodeId, edge_id: EdgeId) {
87 self.insert(&[src, dst], edge_id);
88 }
89
90 pub fn len(&self) -> usize {
92 self.size
93 }
94
95 pub fn is_empty(&self) -> bool {
97 self.size == 0
98 }
99
100 pub fn iter(&self) -> TrieIterator<'_> {
102 TrieIterator::new(&self.root)
103 }
104
105 pub fn iter_at(&self, path: &[NodeId]) -> Option<TrieIterator<'_>> {
107 let mut node = &self.root;
108 for &key in path {
109 node = node.get_child(key)?;
110 }
111 Some(TrieIterator::new(node))
112 }
113
114 pub fn get(&self, path: &[NodeId]) -> Option<&[EdgeId]> {
116 let mut node = &self.root;
117 for &key in path {
118 node = node.get_child(key)?;
119 }
120 if node.edges.is_empty() {
121 None
122 } else {
123 Some(&node.edges)
124 }
125 }
126}
127
128impl Default for TrieIndex {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134pub struct TrieIterator<'a> {
136 node: &'a TrieNode,
137 keys: Vec<NodeId>,
138 pos: usize,
139}
140
141impl<'a> TrieIterator<'a> {
142 fn new(node: &'a TrieNode) -> Self {
143 let keys = node.children_sorted();
144 Self { node, keys, pos: 0 }
145 }
146
147 pub fn key(&self) -> Option<NodeId> {
149 self.keys.get(self.pos).copied()
150 }
151
152 pub fn next(&mut self) -> bool {
154 if self.pos < self.keys.len() {
155 self.pos += 1;
156 self.pos < self.keys.len()
157 } else {
158 false
159 }
160 }
161
162 pub fn seek(&mut self, target: NodeId) -> bool {
166 match self.keys[self.pos..].binary_search(&target) {
168 Ok(offset) => {
169 self.pos += offset;
170 true
171 }
172 Err(offset) => {
173 self.pos += offset;
174 self.pos < self.keys.len()
175 }
176 }
177 }
178
179 pub fn open(&self) -> Option<TrieIterator<'a>> {
181 let key = self.key()?;
182 let child = self.node.get_child(key)?;
183 Some(TrieIterator::new(child))
184 }
185
186 pub fn is_valid(&self) -> bool {
188 self.pos < self.keys.len()
189 }
190}
191
192pub struct LeapfrogJoin<'a> {
196 iters: Vec<TrieIterator<'a>>,
197 current_key: Option<NodeId>,
198}
199
200impl<'a> LeapfrogJoin<'a> {
201 pub fn new(iters: Vec<TrieIterator<'a>>) -> Self {
205 let mut join = Self {
206 iters,
207 current_key: None,
208 };
209 join.init();
210 join
211 }
212
213 fn init(&mut self) {
214 if self.iters.is_empty() {
215 return;
216 }
217
218 self.iters.sort_by_key(|it| it.key());
220
221 self.search();
223 }
224
225 fn search(&mut self) {
226 if self.iters.is_empty() || !self.iters[0].is_valid() {
227 self.current_key = None;
228 return;
229 }
230
231 loop {
232 let max_key = self.iters.last().and_then(|it| it.key());
233 let min_key = self.iters.first().and_then(|it| it.key());
234
235 match (min_key, max_key) {
236 (Some(min), Some(max)) if min == max => {
237 self.current_key = Some(min);
239 return;
240 }
241 (Some(_), Some(max)) => {
242 if !self.iters[0].seek(max) {
244 self.current_key = None;
245 return;
246 }
247 self.iters.sort_by_key(|it| it.key());
249 }
250 _ => {
251 self.current_key = None;
252 return;
253 }
254 }
255 }
256 }
257
258 pub fn key(&self) -> Option<NodeId> {
260 self.current_key
261 }
262
263 pub fn next(&mut self) -> bool {
265 if self.current_key.is_none() || self.iters.is_empty() {
266 return false;
267 }
268
269 self.iters[0].next();
271 self.iters.sort_by_key(|it| it.key());
272 self.search();
273
274 self.current_key.is_some()
275 }
276
277 pub fn open(&self) -> Option<Vec<TrieIterator<'a>>> {
279 if self.current_key.is_none() {
280 return None;
281 }
282
283 self.iters.iter().map(|it| it.open()).collect()
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_trie_basic() {
293 let mut trie = TrieIndex::new();
294
295 trie.insert_edge(NodeId::new(1), NodeId::new(2), EdgeId::new(0));
296 trie.insert_edge(NodeId::new(1), NodeId::new(3), EdgeId::new(1));
297 trie.insert_edge(NodeId::new(2), NodeId::new(3), EdgeId::new(2));
298
299 assert_eq!(trie.len(), 3);
300 }
301
302 #[test]
303 fn test_trie_iterator() {
304 let mut trie = TrieIndex::new();
305
306 trie.insert_edge(NodeId::new(1), NodeId::new(10), EdgeId::new(0));
307 trie.insert_edge(NodeId::new(2), NodeId::new(20), EdgeId::new(1));
308 trie.insert_edge(NodeId::new(3), NodeId::new(30), EdgeId::new(2));
309
310 let mut iter = trie.iter();
311
312 assert_eq!(iter.key(), Some(NodeId::new(1)));
314 assert!(iter.next());
315 assert_eq!(iter.key(), Some(NodeId::new(2)));
316 assert!(iter.next());
317 assert_eq!(iter.key(), Some(NodeId::new(3)));
318 assert!(!iter.next());
319 }
320
321 #[test]
322 fn test_trie_seek() {
323 let mut trie = TrieIndex::new();
324
325 for i in [1, 3, 5, 7, 9] {
326 trie.insert_edge(NodeId::new(i), NodeId::new(100), EdgeId::new(i));
327 }
328
329 let mut iter = trie.iter();
330
331 assert!(iter.seek(NodeId::new(4)));
333 assert_eq!(iter.key(), Some(NodeId::new(5)));
334
335 assert!(iter.seek(NodeId::new(7)));
337 assert_eq!(iter.key(), Some(NodeId::new(7)));
338
339 assert!(!iter.seek(NodeId::new(10)));
341 }
342
343 #[test]
344 fn test_leapfrog_join() {
345 let mut trie1 = TrieIndex::new();
347 let mut trie2 = TrieIndex::new();
348
349 for &i in &[1, 2, 3, 5] {
351 trie1.insert_edge(NodeId::new(i), NodeId::new(100), EdgeId::new(i));
352 }
353
354 for &i in &[2, 3, 4, 5] {
356 trie2.insert_edge(NodeId::new(i), NodeId::new(100), EdgeId::new(i + 10));
357 }
358
359 let iters = vec![trie1.iter(), trie2.iter()];
361 let mut join = LeapfrogJoin::new(iters);
362
363 let mut results = Vec::new();
364 loop {
365 if let Some(key) = join.key() {
366 results.push(key);
367 if !join.next() {
368 break;
369 }
370 } else {
371 break;
372 }
373 }
374
375 assert_eq!(results.len(), 3);
376 assert!(results.contains(&NodeId::new(2)));
377 assert!(results.contains(&NodeId::new(3)));
378 assert!(results.contains(&NodeId::new(5)));
379 }
380}