kermit_algos/
leapfrog_triejoin.rs

1use {
2    crate::{
3        join_algo::JoinAlgo,
4        leapfrog_join::{LeapfrogJoinIter, LeapfrogJoinIterator},
5        JoinQuery,
6    },
7    kermit_iters::{LinearIterator, TrieIterable, TrieIterator, TrieIteratorWrapper},
8    std::collections::HashMap,
9};
10
11/// A trait for iterators that implement the [Leapfrog Triejoin algorithm](https://arxiv.org/abs/1210.0481).
12pub trait LeapfrogTriejoinIterator: LeapfrogJoinIterator {
13    fn triejoin_open(&mut self) -> bool;
14
15    fn triejoin_up(&mut self) -> bool;
16}
17
18/// An iterator that performs the [Leapfrog Triejoin algorithm](https://arxiv.org/abs/1210.0481).
19pub struct LeapfrogTriejoinIter<IT>
20where
21    IT: TrieIterator,
22{
23    /// The key of the current position.
24    arity: usize,
25    iters: Vec<Option<IT>>,
26    current_iters_indexes: Vec<usize>,
27    iter_indexes_at_variable: Vec<Vec<usize>>,
28    depth: usize,
29    leapfrog: LeapfrogJoinIter<IT>,
30}
31
32impl<IT> LeapfrogJoinIterator for LeapfrogTriejoinIter<IT>
33where
34    IT: TrieIterator,
35{
36    fn leapfrog_next(&mut self) -> Option<usize> { self.leapfrog.leapfrog_next() }
37
38    fn key(&self) -> Option<usize> {
39        if self.depth == 0 {
40            None
41        } else {
42            self.leapfrog.key()
43        }
44    }
45
46    fn leapfrog_init(&mut self) -> bool { self.leapfrog.leapfrog_init() }
47
48    fn leapfrog_search(&mut self) -> bool { self.leapfrog.leapfrog_search() }
49
50    fn at_end(&self) -> bool {
51        if self.depth == 0 {
52            return true;
53        }
54        self.leapfrog.at_end()
55    }
56
57    fn leapfrog_seek(&mut self, seek_key: usize) -> bool { self.leapfrog.leapfrog_seek(seek_key) }
58}
59
60impl<IT> LeapfrogTriejoinIter<IT>
61where
62    IT: TrieIterator,
63{
64    /// Construct a new `LeapfrogTriejoinIter` with the given iterators.
65    ///
66    /// Q(a, b, c) = R(a, b) S(b, c), T(a, c)
67    /// variables = [a, b, c]
68    /// rel_variables = [[a, b], [b, c], [a, c]]
69    ///
70    /// # Arguments
71    /// * `variables` - The variables and their ordering.
72    /// * `rel_variables` - The variables in their relations.
73    /// * `iters` - Trie iterators.
74    pub fn new(variables: Vec<usize>, rel_variables: Vec<Vec<usize>>, iters: Vec<IT>) -> Self {
75        let mut iter_indexes_at_variable: Vec<Vec<usize>> = Vec::new();
76        for v in &variables {
77            let mut iters_at_level_v: Vec<usize> = Vec::new();
78            for (r_i, r) in rel_variables.iter().enumerate() {
79                if r.contains(v) {
80                    iters_at_level_v.push(r_i);
81                }
82            }
83            iter_indexes_at_variable.push(iters_at_level_v);
84        }
85
86        let iters = iters.into_iter().map(Some).collect();
87
88        LeapfrogTriejoinIter {
89            iters,
90            current_iters_indexes: Vec::new(),
91            iter_indexes_at_variable,
92            arity: variables.len(),
93            depth: 0,
94            leapfrog: LeapfrogJoinIter::new(vec![]),
95        }
96    }
97
98    fn update_iters(&mut self) {
99        while let Some(i) = self.current_iters_indexes.pop() {
100            let iter = self
101                .leapfrog
102                .iterators
103                .pop()
104                .expect("There should always be an iterator here");
105            self.iters[i] = Some(iter);
106        }
107
108        if self.depth == 0 {
109            return;
110        }
111
112        let mut next_iters =
113            Vec::<IT>::with_capacity(self.iter_indexes_at_variable[self.depth - 1].len());
114        for i in &self.iter_indexes_at_variable[self.depth - 1] {
115            let iter = self.iters[*i].take().expect("There is an iterator here");
116            next_iters.push(iter);
117            self.current_iters_indexes.push(*i);
118        }
119        self.leapfrog = LeapfrogJoinIter::new(next_iters);
120    }
121}
122
123impl<IT> LeapfrogTriejoinIterator for LeapfrogTriejoinIter<IT>
124where
125    IT: TrieIterator,
126{
127    fn triejoin_open(&mut self) -> bool {
128        if self.depth == self.arity {
129            return false;
130        }
131        self.depth += 1;
132        self.update_iters();
133        for iter in &mut self.leapfrog.iterators {
134            if !iter.open() {
135                return false;
136            }
137        }
138        self.leapfrog_init()
139    }
140
141    fn triejoin_up(&mut self) -> bool {
142        if self.depth == 0 {
143            return false;
144        }
145        for iter in &mut self.leapfrog.iterators {
146            assert!(iter.up());
147        }
148        self.depth -= 1;
149        self.update_iters();
150        true
151    }
152}
153
154impl<IT> TrieIterator for LeapfrogTriejoinIter<IT>
155where
156    IT: TrieIterator,
157{
158    fn open(&mut self) -> bool { self.triejoin_open() }
159
160    fn up(&mut self) -> bool { self.triejoin_up() }
161}
162
163impl<IT> LinearIterator for LeapfrogTriejoinIter<IT>
164where
165    IT: TrieIterator,
166{
167    fn key(&self) -> Option<usize> { LeapfrogJoinIterator::key(self) }
168
169    fn next(&mut self) -> Option<usize> { self.leapfrog_next() }
170
171    fn seek(&mut self, seek_key: usize) -> bool { self.leapfrog_seek(seek_key) }
172
173    fn at_end(&self) -> bool { LeapfrogJoinIterator::at_end(self) }
174}
175
176impl<IT> IntoIterator for LeapfrogTriejoinIter<IT>
177where
178    IT: TrieIterator,
179{
180    type IntoIter = TrieIteratorWrapper<Self>;
181    type Item = Vec<usize>;
182
183    fn into_iter(self) -> Self::IntoIter { TrieIteratorWrapper::new(self) }
184}
185
186pub struct LeapfrogTriejoin {}
187
188impl<DS> JoinAlgo<DS> for LeapfrogTriejoin
189where
190    DS: TrieIterable,
191{
192    fn join_iter(
193        query: JoinQuery, datastructures: HashMap<String, &DS>,
194    ) -> impl Iterator<Item = Vec<usize>> {
195        // Map variable names to unique indices, ordered by first appearance in head
196        // then body
197        let mut var_to_index: HashMap<String, usize> = HashMap::new();
198        let mut next_index: usize = 0;
199
200        // Helper to register a variable name and return its index
201        let register_var = |name: &str, map: &mut HashMap<String, usize>, next: &mut usize| {
202            if let std::collections::hash_map::Entry::Vacant(v) = map.entry(name.to_string()) {
203                let idx = *next;
204                v.insert(idx);
205                *next += 1;
206                idx
207            } else {
208                map[name]
209            }
210        };
211
212        // First pass: head variables (ignore placeholders and atoms)
213        for t in &query.head.terms {
214            if let crate::queries::join_query::Term::Var(ref vname) = t {
215                let _ = register_var(vname, &mut var_to_index, &mut next_index);
216            }
217        }
218
219        // Second pass: body variables (ignore placeholders and atoms)
220        for pred in &query.body {
221            for t in &pred.terms {
222                if let crate::queries::join_query::Term::Var(ref vname) = t {
223                    let _ = register_var(vname, &mut var_to_index, &mut next_index);
224                }
225            }
226        }
227
228        // Variables vector is 0..num_vars in the discovered order
229        let variables: Vec<usize> = (0..var_to_index.len()).collect();
230
231        // Build rel_variables following each predicate's order; ignore placeholders and
232        // atoms
233        let mut rel_variables: Vec<Vec<usize>> = Vec::with_capacity(query.body.len());
234        for pred in &query.body {
235            let mut rel_vars_for_pred: Vec<usize> = Vec::new();
236            for t in &pred.terms {
237                if let crate::queries::join_query::Term::Var(ref vname) = t {
238                    if let Some(idx) = var_to_index.get(vname) {
239                        rel_vars_for_pred.push(*idx);
240                    }
241                }
242            }
243            rel_variables.push(rel_vars_for_pred);
244        }
245
246        // Build trie iterators in the same order as query body using provided
247        // datastructures
248        let trie_iters: Vec<_> = query
249            .body
250            .iter()
251            .map(|pred| {
252                let ds = datastructures
253                    .get(&pred.name)
254                    .expect("Missing datastructure for predicate name");
255                ds.trie_iter()
256            })
257            .collect();
258
259        LeapfrogTriejoinIter::new(variables, rel_variables, trie_iters).into_iter()
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use {
266        crate::{
267            leapfrog_join::LeapfrogJoinIterator,
268            leapfrog_triejoin::{LeapfrogTriejoinIter, LeapfrogTriejoinIterator},
269        },
270        kermit_ds::{Relation, TreeTrie},
271        kermit_iters::TrieIterable,
272    };
273
274    #[test]
275    fn test_classic() {
276        let t1 = TreeTrie::from_tuples(1.into(), vec![vec![1], vec![2], vec![3]]);
277        let t2 = TreeTrie::from_tuples(1.into(), vec![vec![1], vec![2], vec![3]]);
278        let t1_iter = t1.trie_iter();
279        let t2_iter = t2.trie_iter();
280        let mut triejoin_iter =
281            LeapfrogTriejoinIter::new(vec![0], vec![vec![0], vec![0]], vec![t1_iter, t2_iter]);
282        triejoin_iter.triejoin_open();
283        assert_eq!(triejoin_iter.key(), Some(1));
284        assert_eq!(triejoin_iter.leapfrog_next(), Some(2));
285        assert_eq!(triejoin_iter.leapfrog_next(), Some(3));
286        triejoin_iter.leapfrog_next();
287        assert!(triejoin_iter.at_end());
288        triejoin_iter.triejoin_up();
289        assert!(triejoin_iter.at_end());
290        let res = triejoin_iter.into_iter().collect::<Vec<_>>();
291        assert_eq!(res, vec![vec![1], vec![2], vec![3]]);
292    }
293
294    #[test]
295    fn more_complicated() {
296        let r = TreeTrie::from_tuples(2.into(), vec![vec![7, 4]]);
297        let s = TreeTrie::from_tuples(2.into(), vec![vec![4, 1], vec![4, 4], vec![4, 5], vec![
298            4, 9,
299        ]]);
300        let t = TreeTrie::from_tuples(2.into(), vec![vec![7, 2], vec![7, 3], vec![7, 5]]);
301        let r_iter = r.trie_iter();
302        let s_iter = s.trie_iter();
303        let t_iter = t.trie_iter();
304        let mut triejoin_iter = LeapfrogTriejoinIter::new(
305            vec![0, 1, 2],
306            vec![vec![0, 1], vec![1, 2], vec![0, 2]],
307            vec![r_iter, s_iter, t_iter],
308        );
309        triejoin_iter.triejoin_open();
310        assert_eq!(triejoin_iter.key().unwrap().clone(), 7);
311        triejoin_iter.leapfrog_next();
312        assert!(triejoin_iter.at_end());
313        triejoin_iter.triejoin_open();
314        assert_eq!(triejoin_iter.key().unwrap().clone(), 4);
315        triejoin_iter.leapfrog_next();
316        assert!(triejoin_iter.at_end());
317        triejoin_iter.triejoin_open();
318        assert_eq!(triejoin_iter.key().unwrap().clone(), 5);
319    }
320
321    #[test]
322    fn chain() {
323        let r = TreeTrie::from_tuples(2.into(), vec![vec![1, 2], vec![2, 3]]);
324        let s = TreeTrie::from_tuples(2.into(), vec![vec![2, 4], vec![3, 5]]);
325        let t = TreeTrie::from_tuples(2.into(), vec![vec![4, 6], vec![5, 7]]);
326        let r_iter = r.trie_iter();
327        let s_iter = s.trie_iter();
328        let t_iter = t.trie_iter();
329        let mut triejoin_iter = LeapfrogTriejoinIter::new(
330            vec![0, 1, 2, 3],
331            vec![vec![0, 1], vec![1, 2], vec![2, 3]],
332            vec![r_iter, s_iter, t_iter],
333        );
334        assert!(triejoin_iter.triejoin_open());
335        assert_eq!(triejoin_iter.key(), Some(1));
336        assert!(triejoin_iter.triejoin_open());
337        assert_eq!(triejoin_iter.key(), Some(2));
338        assert!(triejoin_iter.triejoin_open());
339        assert_eq!(triejoin_iter.key(), Some(4));
340        assert!(triejoin_iter.triejoin_open());
341        assert_eq!(triejoin_iter.key(), Some(6));
342
343        assert!(triejoin_iter.triejoin_up());
344        assert!(triejoin_iter.triejoin_up());
345        assert!(triejoin_iter.triejoin_up());
346
347        assert_eq!(triejoin_iter.leapfrog_next(), Some(2));
348        assert!(triejoin_iter.triejoin_open());
349        assert_eq!(triejoin_iter.key(), Some(3));
350        assert!(triejoin_iter.triejoin_open());
351        assert_eq!(triejoin_iter.key(), Some(5));
352        assert!(triejoin_iter.triejoin_open());
353        assert_eq!(triejoin_iter.key(), Some(7));
354    }
355
356    // #[test_case(
357    // vec!["tests/data/a.csv", "tests/data/b.csv", "tests/data/c.csv"],
358    // vec![vec![8]];
359    // "a,b,c"
360    // )]
361    // #[test_case(
362    // vec!["tests/data/onetoten.csv", "tests/data/onetoten.csv",
363    // "tests/data/onetoten.csv"], vec![vec![1], vec![2], vec![3], vec![4],
364    // vec![5], vec![6], vec![7], vec![8], vec![9], vec![10]]; "onetoten x
365    // 3" )]
366    // #[test_case(
367    // vec!["tests/data/col_a.csv", "tests/data/col_b.csv",
368    // "tests/data/col_c.csv"], vec![vec![7], vec![10], vec![20]];
369    // "col_a, col_b, col_c"
370    // )]
371    // fn test_files(file_paths: Vec<&'static str>, expected: Vec<Vec<i32>>) {
372    // let tries: Vec<_> = file_paths
373    // .iter()
374    // .map(|file_path| {
375    // TrieBuilder::<i32>::new(1)
376    // .from_file(file_path)
377    // .unwrap()
378    // .build()
379    // })
380    // .collect();
381    // let res = leapfrog_triejoin(tries.iter().collect());
382    // assert_eq!(res, expected);
383    // }
384    //
385    // #[test_case(
386    // 1,
387    // vec![
388    // vec![vec![1], vec![2], vec![3]],
389    // vec![vec![1], vec![2], vec![3]]
390    // ],
391    // vec![vec![1], vec![2], vec![3]];
392    // "1-ary"
393    // )]
394    // fn test_inputs_outputs(arity: usize, inputs: Vec<Vec<Vec<i32>>>,
395    // expected: Vec<Vec<i32>>) { let tries: Vec<_> = inputs
396    // .into_iter()
397    // .map(|input|
398    // TrieBuilder::<i32>::new(arity).add_tuples(input).build())
399    // .collect();
400    // let res = leapfrog_triejoin(tries.iter().collect());
401    // assert_eq!(res, expected);
402    // }
403}