Skip to main content

graphos_core/index/
trie.rs

1//! Trie index for Worst-Case Optimal Joins (WCOJ).
2//!
3//! The trie index is built lazily on-demand for complex pattern queries
4//! like triangles, cliques, and multi-way joins. It enables the
5//! Leapfrog Trie Join algorithm.
6
7use graphos_common::types::{EdgeId, NodeId};
8use graphos_common::utils::hash::FxHashMap;
9use smallvec::SmallVec;
10
11/// A trie node in the edge trie.
12#[derive(Debug, Clone)]
13struct TrieNode {
14    /// Children indexed by node ID.
15    children: FxHashMap<NodeId, TrieNode>,
16    /// Edge IDs at this level (for leaf nodes or intermediate data).
17    edges: SmallVec<[EdgeId; 4]>,
18}
19
20impl TrieNode {
21    fn new() -> Self {
22        Self {
23            children: FxHashMap::default(),
24            edges: SmallVec::new(),
25        }
26    }
27
28    fn insert(&mut self, path: &[NodeId], edge_id: EdgeId) {
29        if path.is_empty() {
30            self.edges.push(edge_id);
31            return;
32        }
33
34        self.children
35            .entry(path[0])
36            .or_insert_with(TrieNode::new)
37            .insert(&path[1..], edge_id);
38    }
39
40    fn get_child(&self, key: NodeId) -> Option<&TrieNode> {
41        self.children.get(&key)
42    }
43
44    #[allow(dead_code)]
45    fn children_keys(&self) -> impl Iterator<Item = NodeId> + '_ {
46        self.children.keys().copied()
47    }
48
49    fn children_sorted(&self) -> Vec<NodeId> {
50        let mut keys: Vec<_> = self.children.keys().copied().collect();
51        keys.sort();
52        keys
53    }
54}
55
56/// A trie index for edge patterns.
57///
58/// Used to support Worst-Case Optimal Joins (WCOJ) via the
59/// Leapfrog Trie Join algorithm.
60pub struct TrieIndex {
61    /// Root of the trie.
62    root: TrieNode,
63    /// Number of entries in the trie.
64    size: usize,
65}
66
67impl TrieIndex {
68    /// Creates a new empty trie index.
69    #[must_use]
70    pub fn new() -> Self {
71        Self {
72            root: TrieNode::new(),
73            size: 0,
74        }
75    }
76
77    /// Inserts an edge into the trie.
78    ///
79    /// The path typically represents [src, dst] for a directed edge.
80    pub fn insert(&mut self, path: &[NodeId], edge_id: EdgeId) {
81        self.root.insert(path, edge_id);
82        self.size += 1;
83    }
84
85    /// Inserts a directed edge (src -> dst).
86    pub fn insert_edge(&mut self, src: NodeId, dst: NodeId, edge_id: EdgeId) {
87        self.insert(&[src, dst], edge_id);
88    }
89
90    /// Returns the number of entries.
91    pub fn len(&self) -> usize {
92        self.size
93    }
94
95    /// Returns true if the trie is empty.
96    pub fn is_empty(&self) -> bool {
97        self.size == 0
98    }
99
100    /// Creates an iterator at the root level.
101    pub fn iter(&self) -> TrieIterator<'_> {
102        TrieIterator::new(&self.root)
103    }
104
105    /// Creates an iterator at a specific path.
106    pub fn iter_at(&self, path: &[NodeId]) -> Option<TrieIterator<'_>> {
107        let mut node = &self.root;
108        for &key in path {
109            node = node.get_child(key)?;
110        }
111        Some(TrieIterator::new(node))
112    }
113
114    /// Gets all values at the end of a path.
115    pub fn get(&self, path: &[NodeId]) -> Option<&[EdgeId]> {
116        let mut node = &self.root;
117        for &key in path {
118            node = node.get_child(key)?;
119        }
120        if node.edges.is_empty() {
121            None
122        } else {
123            Some(&node.edges)
124        }
125    }
126}
127
128impl Default for TrieIndex {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134/// An iterator over trie children at a single level.
135pub struct TrieIterator<'a> {
136    node: &'a TrieNode,
137    keys: Vec<NodeId>,
138    pos: usize,
139}
140
141impl<'a> TrieIterator<'a> {
142    fn new(node: &'a TrieNode) -> Self {
143        let keys = node.children_sorted();
144        Self { node, keys, pos: 0 }
145    }
146
147    /// Returns the current key, if any.
148    pub fn key(&self) -> Option<NodeId> {
149        self.keys.get(self.pos).copied()
150    }
151
152    /// Advances to the next key.
153    pub fn next(&mut self) -> bool {
154        if self.pos < self.keys.len() {
155            self.pos += 1;
156            self.pos < self.keys.len()
157        } else {
158            false
159        }
160    }
161
162    /// Seeks to the first key >= target.
163    ///
164    /// Returns true if a key was found.
165    pub fn seek(&mut self, target: NodeId) -> bool {
166        // Binary search for the target
167        match self.keys[self.pos..].binary_search(&target) {
168            Ok(offset) => {
169                self.pos += offset;
170                true
171            }
172            Err(offset) => {
173                self.pos += offset;
174                self.pos < self.keys.len()
175            }
176        }
177    }
178
179    /// Opens the current key's child node.
180    pub fn open(&self) -> Option<TrieIterator<'a>> {
181        let key = self.key()?;
182        let child = self.node.get_child(key)?;
183        Some(TrieIterator::new(child))
184    }
185
186    /// Returns whether the iterator is at a valid position.
187    pub fn is_valid(&self) -> bool {
188        self.pos < self.keys.len()
189    }
190}
191
192/// Leapfrog trie join implementation.
193///
194/// This performs worst-case optimal joins over multiple trie iterators.
195pub struct LeapfrogJoin<'a> {
196    iters: Vec<TrieIterator<'a>>,
197    current_key: Option<NodeId>,
198}
199
200impl<'a> LeapfrogJoin<'a> {
201    /// Creates a new leapfrog join over the given iterators.
202    ///
203    /// All iterators must be at the same level (representing the same variable).
204    pub fn new(iters: Vec<TrieIterator<'a>>) -> Self {
205        let mut join = Self {
206            iters,
207            current_key: None,
208        };
209        join.init();
210        join
211    }
212
213    fn init(&mut self) {
214        if self.iters.is_empty() {
215            return;
216        }
217
218        // Sort iterators by current key
219        self.iters.sort_by_key(|it| it.key());
220
221        // Check if all at same position (intersection found)
222        self.search();
223    }
224
225    fn search(&mut self) {
226        if self.iters.is_empty() || !self.iters[0].is_valid() {
227            self.current_key = None;
228            return;
229        }
230
231        loop {
232            let max_key = self.iters.last().and_then(|it| it.key());
233            let min_key = self.iters.first().and_then(|it| it.key());
234
235            match (min_key, max_key) {
236                (Some(min), Some(max)) if min == max => {
237                    // All iterators at the same key - found intersection
238                    self.current_key = Some(min);
239                    return;
240                }
241                (Some(_), Some(max)) => {
242                    // Seek minimum to max
243                    if !self.iters[0].seek(max) {
244                        self.current_key = None;
245                        return;
246                    }
247                    // Re-sort after seek
248                    self.iters.sort_by_key(|it| it.key());
249                }
250                _ => {
251                    self.current_key = None;
252                    return;
253                }
254            }
255        }
256    }
257
258    /// Returns the current intersection key.
259    pub fn key(&self) -> Option<NodeId> {
260        self.current_key
261    }
262
263    /// Advances to the next intersection.
264    pub fn next(&mut self) -> bool {
265        if self.current_key.is_none() || self.iters.is_empty() {
266            return false;
267        }
268
269        // Advance the first iterator
270        self.iters[0].next();
271        self.iters.sort_by_key(|it| it.key());
272        self.search();
273
274        self.current_key.is_some()
275    }
276
277    /// Opens the current level and returns iterators for the next level.
278    pub fn open(&self) -> Option<Vec<TrieIterator<'a>>> {
279        if self.current_key.is_none() {
280            return None;
281        }
282
283        self.iters.iter().map(|it| it.open()).collect()
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_trie_basic() {
293        let mut trie = TrieIndex::new();
294
295        trie.insert_edge(NodeId::new(1), NodeId::new(2), EdgeId::new(0));
296        trie.insert_edge(NodeId::new(1), NodeId::new(3), EdgeId::new(1));
297        trie.insert_edge(NodeId::new(2), NodeId::new(3), EdgeId::new(2));
298
299        assert_eq!(trie.len(), 3);
300    }
301
302    #[test]
303    fn test_trie_iterator() {
304        let mut trie = TrieIndex::new();
305
306        trie.insert_edge(NodeId::new(1), NodeId::new(10), EdgeId::new(0));
307        trie.insert_edge(NodeId::new(2), NodeId::new(20), EdgeId::new(1));
308        trie.insert_edge(NodeId::new(3), NodeId::new(30), EdgeId::new(2));
309
310        let mut iter = trie.iter();
311
312        // First level keys should be 1, 2, 3
313        assert_eq!(iter.key(), Some(NodeId::new(1)));
314        assert!(iter.next());
315        assert_eq!(iter.key(), Some(NodeId::new(2)));
316        assert!(iter.next());
317        assert_eq!(iter.key(), Some(NodeId::new(3)));
318        assert!(!iter.next());
319    }
320
321    #[test]
322    fn test_trie_seek() {
323        let mut trie = TrieIndex::new();
324
325        for i in [1, 3, 5, 7, 9] {
326            trie.insert_edge(NodeId::new(i), NodeId::new(100), EdgeId::new(i));
327        }
328
329        let mut iter = trie.iter();
330
331        // Seek to 4 should land on 5
332        assert!(iter.seek(NodeId::new(4)));
333        assert_eq!(iter.key(), Some(NodeId::new(5)));
334
335        // Seek to 7 should land on 7
336        assert!(iter.seek(NodeId::new(7)));
337        assert_eq!(iter.key(), Some(NodeId::new(7)));
338
339        // Seek to 10 should fail (past end)
340        assert!(!iter.seek(NodeId::new(10)));
341    }
342
343    #[test]
344    fn test_leapfrog_join() {
345        // Create two tries representing different edge sets
346        let mut trie1 = TrieIndex::new();
347        let mut trie2 = TrieIndex::new();
348
349        // Trie 1: edges from nodes 1, 2, 3, 5
350        for &i in &[1, 2, 3, 5] {
351            trie1.insert_edge(NodeId::new(i), NodeId::new(100), EdgeId::new(i));
352        }
353
354        // Trie 2: edges from nodes 2, 3, 4, 5
355        for &i in &[2, 3, 4, 5] {
356            trie2.insert_edge(NodeId::new(i), NodeId::new(100), EdgeId::new(i + 10));
357        }
358
359        // Intersection should be {2, 3, 5}
360        let iters = vec![trie1.iter(), trie2.iter()];
361        let mut join = LeapfrogJoin::new(iters);
362
363        let mut results = Vec::new();
364        loop {
365            if let Some(key) = join.key() {
366                results.push(key);
367                if !join.next() {
368                    break;
369                }
370            } else {
371                break;
372            }
373        }
374
375        assert_eq!(results.len(), 3);
376        assert!(results.contains(&NodeId::new(2)));
377        assert!(results.contains(&NodeId::new(3)));
378        assert!(results.contains(&NodeId::new(5)));
379    }
380}