ddk_trie/
multi_trie.rs

1//! Data structure and functions to create, insert, lookup and iterate a trie
2//! of trie.
3
4use crate::{
5    utils::{get_max_covering_paths, pre_pad_vec},
6    LookupResult, Node, OracleNumericInfo,
7};
8use combination_iterator::CombinationIterator;
9use ddk_dlc::Error;
10use digit_trie::{DigitTrie, DigitTrieDump, DigitTrieIter};
11use multi_oracle::compute_outcome_combinations;
12
13#[derive(Clone, Debug)]
14/// Information stored in a node.
15pub struct TrieNodeInfo {
16    /// The index of the sub-trie.
17    pub trie_index: usize,
18    /// The index of the node in the trie store.
19    pub store_index: usize,
20}
21
22type MultiTrieNode<T> = Node<DigitTrie<T>, DigitTrie<Vec<TrieNodeInfo>>>;
23type NodeStackElement<'a> = Vec<(IndexedPath, DigitTrieIter<'a, Vec<TrieNodeInfo>>)>;
24type IndexedPath = (usize, Vec<usize>);
25
26impl<T> MultiTrieNode<T> {
27    fn new_node(base: usize) -> MultiTrieNode<T> {
28        let m_trie = DigitTrie::<Vec<TrieNodeInfo>>::new(base);
29        MultiTrieNode::Node(m_trie)
30    }
31    fn new_leaf(base: usize) -> MultiTrieNode<T> {
32        let d_trie = DigitTrie::<T>::new(base);
33        MultiTrieNode::Leaf(d_trie)
34    }
35}
36
37/// Struct for iterating over the values of a MultiTrie.
38pub(crate) struct MultiTrieIterator<'a, T> {
39    trie: &'a MultiTrie<T>,
40    node_stack: NodeStackElement<'a>,
41    trie_info_iter: Vec<(
42        Vec<usize>,
43        std::iter::Enumerate<std::slice::Iter<'a, TrieNodeInfo>>,
44    )>,
45    leaf_iter: Vec<(usize, DigitTrieIter<'a, T>)>,
46    cur_path: Vec<(usize, Vec<usize>)>,
47}
48
49fn create_node_iterator<T>(node: &'_ MultiTrieNode<T>) -> DigitTrieIter<'_, Vec<TrieNodeInfo>> {
50    match node {
51        Node::Node(d_trie) => DigitTrieIter::new(d_trie),
52        _ => unreachable!(),
53    }
54}
55
56fn create_leaf_iterator<T>(node: &'_ MultiTrieNode<T>) -> DigitTrieIter<'_, T> {
57    match node {
58        Node::Leaf(d_trie) => DigitTrieIter::new(d_trie),
59        _ => unreachable!(),
60    }
61}
62
63impl<'a, T> MultiTrieIterator<'a, T> {
64    /// Create a new MultiTrie iterator.
65    pub fn new(trie: &'a MultiTrie<T>) -> MultiTrieIterator<'a, T> {
66        let mut node_stack = Vec::with_capacity(trie.nb_required);
67        let nb_roots = trie.nb_tries - trie.nb_required + 1;
68        let mut leaf_iter = Vec::new();
69        for i in (0..nb_roots).rev() {
70            if trie.nb_required > 1 {
71                node_stack.push((
72                    (i, Vec::<usize>::new()),
73                    create_node_iterator(&trie.store[i]),
74                ));
75            } else {
76                leaf_iter.push((i, create_leaf_iterator(&trie.store[i])));
77            }
78        }
79        MultiTrieIterator {
80            trie,
81            node_stack,
82            trie_info_iter: Vec::new(),
83            leaf_iter,
84            cur_path: Vec::new(),
85        }
86    }
87}
88
89/// Implements the Iterator trait for MultiTrieIterator.
90impl<'a, T> Iterator for MultiTrieIterator<'a, T> {
91    type Item = LookupResult<'a, T, (usize, Vec<usize>)>;
92
93    fn next(&mut self) -> Option<Self::Item> {
94        let mut leaf_iter = self.leaf_iter.last_mut();
95        if let Some(ref mut iter) = &mut leaf_iter {
96            match iter.1.next() {
97                Some(res) => {
98                    let mut path = self.cur_path.clone();
99                    path.push((iter.0, res.path));
100                    return Some(LookupResult {
101                        value: res.value,
102                        path,
103                    });
104                }
105                None => {
106                    self.leaf_iter.pop();
107                    return self.next();
108                }
109            }
110        };
111
112        let mut trie_info_iter = self.trie_info_iter.last_mut();
113
114        if let Some(ref mut iter) = &mut trie_info_iter {
115            match iter.1.next() {
116                None => {
117                    self.trie_info_iter.pop();
118                    self.cur_path.pop();
119                }
120                Some((i, info)) => {
121                    if i == 0 {
122                        self.cur_path
123                            .push((self.node_stack.last().unwrap().0 .0, iter.0.clone()));
124                    }
125                    match &self.trie.store[info.store_index] {
126                        Node::None => unreachable!(),
127                        Node::Node(d_trie) => {
128                            self.node_stack.push((
129                                (info.trie_index, iter.0.clone()),
130                                DigitTrieIter::new(d_trie),
131                            ));
132                        }
133                        Node::Leaf(d_trie) => {
134                            self.leaf_iter
135                                .push((info.trie_index, DigitTrieIter::new(d_trie)));
136                            return self.next();
137                        }
138                    }
139                }
140            }
141        }
142
143        let ((cur_trie_index, parent_path), mut cur_iter) = self.node_stack.pop()?;
144
145        match cur_iter.next() {
146            None => self.next(),
147            Some(res) => {
148                // Put back the node on the stack
149                self.node_stack
150                    .push(((cur_trie_index, parent_path), cur_iter));
151
152                // Push an iterator to the child on the trie info stack
153                self.trie_info_iter
154                    .push((res.path, res.value.iter().enumerate()));
155
156                self.next()
157            }
158        }
159    }
160}
161
162/// Struct used to store DLC outcome information for multi oracle cases.  
163#[derive(Clone)]
164pub struct MultiTrie<T> {
165    store: Vec<MultiTrieNode<T>>,
166    nb_tries: usize,
167    nb_required: usize,
168    min_support_exp: usize,
169    max_error_exp: usize,
170    maximize_coverage: bool,
171    oracle_numeric_infos: OracleNumericInfo,
172}
173
174impl<T> MultiTrie<T> {
175    /// Create a new MultiTrie. Panics if `nb_required` is less or equal to
176    /// zero, or if `nb_tries` is less than `nb_required`.
177    pub fn new(
178        oracle_numeric_infos: &OracleNumericInfo,
179        nb_required: usize,
180        min_support_exp: usize,
181        max_error_exp: usize,
182        maximize_coverage: bool,
183    ) -> MultiTrie<T> {
184        let nb_tries = oracle_numeric_infos.nb_digits.len();
185        assert!(
186            nb_required > 0
187                && nb_tries >= nb_required
188                && !oracle_numeric_infos.nb_digits.is_empty()
189        );
190        let nb_roots = nb_tries - nb_required + 1;
191
192        let store: Vec<_> = if nb_required > 1 {
193            (0..nb_tries)
194                .take(nb_roots)
195                .map(|_| MultiTrieNode::new_node(oracle_numeric_infos.base))
196                .collect()
197        } else {
198            (0..nb_tries)
199                .take(nb_roots)
200                .map(|_| MultiTrieNode::new_leaf(oracle_numeric_infos.base))
201                .collect()
202        };
203
204        MultiTrie {
205            store,
206            nb_tries,
207            nb_required,
208            min_support_exp,
209            max_error_exp,
210            maximize_coverage,
211            oracle_numeric_infos: oracle_numeric_infos.clone(),
212        }
213    }
214
215    fn swap_remove(&mut self, index: usize) -> MultiTrieNode<T> {
216        self.store.push(MultiTrieNode::None);
217        self.store.swap_remove(index)
218    }
219
220    /// Insert the paths to cover outcomes outside of the range of the oracle with
221    /// minimum number of digits. Should only be called when oracles have varying
222    /// number of digits.
223    pub fn insert_max_paths<F>(&mut self, get_value: &mut F) -> Result<(), Error>
224    where
225        F: FnMut(&[Vec<usize>], &[usize]) -> Result<T, Error>,
226    {
227        let indexed_paths = get_max_covering_paths(&self.oracle_numeric_infos, self.nb_required);
228        for indexed_path in indexed_paths {
229            let (indexes, paths): (Vec<usize>, Vec<Vec<usize>>) = indexed_path.into_iter().unzip();
230            self.insert_internal(indexes[0], &paths, 0, &indexes, get_value)?;
231        }
232        Ok(())
233    }
234
235    /// Insert the value returned by `get_value` at the position specified by `path`.
236    pub fn insert<F>(&mut self, path: &[usize], get_value: &mut F) -> Result<(), Error>
237    where
238        F: FnMut(&[Vec<usize>], &[usize]) -> Result<T, Error>,
239    {
240        let combination_iter = CombinationIterator::new(self.nb_tries, self.nb_required);
241        let min_nb_digits = self.oracle_numeric_infos.get_min_nb_digits();
242
243        for selector in combination_iter {
244            let combinations = if self.nb_required > 1 {
245                let mut digit_infos = self
246                    .oracle_numeric_infos
247                    .nb_digits
248                    .iter()
249                    .enumerate()
250                    .filter_map(|(i, x)| {
251                        if selector.contains(&i) {
252                            Some(*x)
253                        } else {
254                            None
255                        }
256                    })
257                    .collect::<Vec<_>>();
258                let min_index = reorder_to_min_first(&mut digit_infos);
259                let to_pad = digit_infos[0] - min_nb_digits;
260                let padded_path = pre_pad_vec(path.to_vec(), path.len() + to_pad);
261                let mut combinations = compute_outcome_combinations(
262                    &digit_infos,
263                    &padded_path,
264                    self.max_error_exp,
265                    self.min_support_exp,
266                    self.maximize_coverage,
267                );
268                if min_index != 0 {
269                    for combination in &mut combinations {
270                        let to_reorder = combination.remove(0);
271                        combination.insert(min_index, to_reorder);
272                    }
273                }
274                combinations
275            } else {
276                vec![vec![path.to_vec()]]
277            };
278
279            for combination in combinations {
280                self.insert_internal(selector[0], &combination, 0, &selector, get_value)?;
281            }
282        }
283
284        Ok(())
285    }
286
287    fn insert_new(&mut self, is_leaf: bool) {
288        let m_trie = if is_leaf {
289            let d_trie = DigitTrie::<T>::new(self.oracle_numeric_infos.base);
290            MultiTrieNode::Leaf(d_trie)
291        } else {
292            let d_trie = DigitTrie::<Vec<TrieNodeInfo>>::new(self.oracle_numeric_infos.base);
293            MultiTrieNode::Node(d_trie)
294        };
295        self.store.push(m_trie);
296    }
297
298    fn insert_internal<F>(
299        &mut self,
300        cur_node_index: usize,
301        paths: &[Vec<usize>],
302        path_index: usize,
303        trie_indexes: &[usize],
304        get_value: &mut F,
305    ) -> Result<(), Error>
306    where
307        F: FnMut(&[Vec<usize>], &[usize]) -> Result<T, Error>,
308    {
309        assert!(path_index < paths.len());
310        let cur_node = self.swap_remove(cur_node_index);
311        match cur_node {
312            MultiTrieNode::None => unreachable!(),
313            MultiTrieNode::Leaf(mut digit_trie) => {
314                assert_eq!(path_index, paths.len() - 1);
315                let mut get_data = |_| get_value(paths, trie_indexes);
316                digit_trie.insert(&paths[path_index], &mut get_data)?;
317                self.store[cur_node_index] = MultiTrieNode::Leaf(digit_trie);
318            }
319            MultiTrieNode::Node(mut node) => {
320                assert!(path_index < paths.len() - 1);
321                let mut store_index = 0;
322                let mut callback =
323                    |cur_data_res: Option<Vec<TrieNodeInfo>>| -> Result<Vec<TrieNodeInfo>, Error> {
324                        let mut cur_data = match cur_data_res {
325                            Some(cur_data) => {
326                                if let Some(cur_store_index) =
327                                    find_store_index(&cur_data, trie_indexes[path_index + 1])
328                                {
329                                    store_index = cur_store_index;
330                                    return Ok(cur_data);
331                                }
332                                cur_data
333                            }
334                            _ => vec![],
335                        };
336                        self.insert_new(paths.len() - 1 == path_index + 1);
337                        store_index = self.store.len() - 1;
338                        let trie_index = trie_indexes[path_index + 1];
339                        let trie_node_info = TrieNodeInfo {
340                            trie_index,
341                            store_index,
342                        };
343                        cur_data.push(trie_node_info);
344                        Ok(cur_data)
345                    };
346                node.insert(&paths[path_index], &mut callback)?;
347                self.store[cur_node_index] = MultiTrieNode::Node(node);
348                self.insert_internal(store_index, paths, path_index + 1, trie_indexes, get_value)?;
349            }
350        }
351        Ok(())
352    }
353
354    /// Lookup in the trie for a value that matches with `paths`.
355    pub fn look_up<'a>(
356        &'a self,
357        paths: &[(usize, Vec<usize>)],
358    ) -> Option<(&'a T, Vec<IndexedPath>)> {
359        if paths.len() < self.nb_required {
360            return None;
361        }
362
363        let store = &self.store;
364
365        let combination_iter = CombinationIterator::new(paths.len(), self.nb_required);
366
367        let nb_roots = self.nb_tries - self.nb_required + 1;
368
369        for selector in combination_iter {
370            let first_index = paths[selector[0]].0;
371            if first_index >= nb_roots {
372                continue;
373            }
374
375            let res = self.look_up_internal(
376                &store[first_index],
377                &paths
378                    .iter()
379                    .enumerate()
380                    .filter_map(|(i, x)| {
381                        if selector.contains(&i) {
382                            return Some(x);
383                        }
384                        None
385                    })
386                    .collect::<Vec<_>>(),
387                0,
388            );
389            if let Some(mut l_res) = res {
390                l_res.path.reverse();
391                return Some((l_res.value, l_res.path.clone()));
392            }
393        }
394
395        None
396    }
397
398    fn look_up_internal<'a>(
399        &'a self,
400        cur_node: &'a MultiTrieNode<T>,
401        paths: &[&(usize, Vec<usize>)],
402        path_index: usize,
403    ) -> Option<LookupResult<'a, T, (usize, Vec<usize>)>> {
404        assert!(path_index < paths.len());
405        let trie_index = paths[path_index].0;
406
407        match cur_node {
408            MultiTrieNode::None => unreachable!(),
409            MultiTrieNode::Leaf(d_trie) => {
410                let res = d_trie.look_up(&paths[path_index].1)?;
411                Some(LookupResult {
412                    value: res[0].value,
413                    path: vec![(trie_index, res[0].path.clone())],
414                })
415            }
416            MultiTrieNode::Node(d_trie) => {
417                assert!(path_index < paths.len() - 1);
418                let results = d_trie.look_up(&paths[path_index].1)?;
419
420                for l_res in results {
421                    if let Some(index) = find_store_index(l_res.value, paths[path_index + 1].0) {
422                        let next_node = &self.store[index];
423                        if let Some(mut child_l_res) =
424                            self.look_up_internal(next_node, paths, path_index + 1)
425                        {
426                            child_l_res.path.push((trie_index, l_res.path));
427                            return Some(child_l_res);
428                        }
429                    }
430                }
431
432                None
433            }
434        }
435    }
436}
437
438fn find_store_index(children: &[TrieNodeInfo], trie_index: usize) -> Option<usize> {
439    for info in children {
440        if trie_index == info.trie_index {
441            return Some(info.store_index);
442        }
443    }
444
445    None
446}
447
448fn reorder_to_min_first(oracle_digit_infos: &mut Vec<usize>) -> usize {
449    let min_index = oracle_digit_infos
450        .iter()
451        .enumerate()
452        .min_by_key(|(_, x)| *x)
453        .unwrap()
454        .0;
455    if min_index != 0 {
456        let min_val = oracle_digit_infos.remove(min_index);
457        oracle_digit_infos.insert(0, min_val);
458    }
459    min_index
460}
461
462/// Container for a dump of a MultiTrie used for serialization purpose.
463pub struct MultiTrieDump<T>
464where
465    T: Clone,
466{
467    /// The node data.
468    pub node_data: Vec<MultiTrieNodeData<T>>,
469    /// The total number of tries.
470    pub nb_tries: usize,
471    /// The number of trie per path.
472    pub nb_required: usize,
473    /// The guaranteed support as a power of 2.
474    pub min_support_exp: usize,
475    /// The maximum support as a power of 2.
476    pub max_error_exp: usize,
477    /// Whether this trie maximizes outcome coverage.
478    pub maximize_coverage: bool,
479    /// Information about the numerical representation of oracles
480    pub oracle_numeric_infos: OracleNumericInfo,
481}
482
483impl<T> MultiTrie<T>
484where
485    T: Clone,
486{
487    /// Dump the content of the trie for the purpose of serialization.
488    pub fn dump(&self) -> MultiTrieDump<T> {
489        let node_data = self.store.iter().map(|x| x.get_data()).collect();
490        MultiTrieDump {
491            node_data,
492            nb_tries: self.nb_tries,
493            nb_required: self.nb_required,
494            min_support_exp: self.min_support_exp,
495            max_error_exp: self.max_error_exp,
496            maximize_coverage: self.maximize_coverage,
497            oracle_numeric_infos: self.oracle_numeric_infos.clone(),
498        }
499    }
500
501    /// Restore a trie from a dump.
502    pub fn from_dump(dump: MultiTrieDump<T>) -> MultiTrie<T> {
503        let MultiTrieDump {
504            node_data,
505            nb_tries,
506            nb_required,
507            min_support_exp,
508            max_error_exp,
509            maximize_coverage,
510            oracle_numeric_infos,
511        } = dump;
512
513        let store = node_data
514            .into_iter()
515            .map(|x| MultiTrieNode::from_data(x))
516            .collect();
517
518        MultiTrie {
519            store,
520            nb_tries,
521            nb_required,
522            min_support_exp,
523            max_error_exp,
524            maximize_coverage,
525            oracle_numeric_infos,
526        }
527    }
528}
529
530/// Holds the data of a multi trie node. Used for serialization purpose.
531pub enum MultiTrieNodeData<T>
532where
533    T: Clone,
534{
535    /// A leaf in the trie.
536    Leaf(DigitTrieDump<T>),
537    /// A node in the trie.
538    Node(DigitTrieDump<Vec<TrieNodeInfo>>),
539}
540
541impl<T> MultiTrieNode<T>
542where
543    T: Clone,
544{
545    fn get_data(&self) -> MultiTrieNodeData<T> {
546        match self {
547            Node::Leaf(l) => MultiTrieNodeData::Leaf(l.dump()),
548            Node::Node(n) => MultiTrieNodeData::Node(n.dump()),
549            Node::None => unreachable!(),
550        }
551    }
552
553    fn from_data(data: MultiTrieNodeData<T>) -> MultiTrieNode<T> {
554        match data {
555            MultiTrieNodeData::Leaf(l) => Node::Leaf(DigitTrie::from_dump(l)),
556            MultiTrieNodeData::Node(n) => Node::Node(DigitTrie::from_dump(n)),
557        }
558    }
559}
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564    use crate::test_utils::{
565        get_variable_oracle_numeric_infos, same_num_digits_oracle_numeric_infos,
566    };
567
568    type ExpectedIter = Vec<Vec<(usize, Vec<usize>)>>;
569
570    fn tests_common(
571        m_trie: &mut MultiTrie<usize>,
572        path: Vec<usize>,
573        good_paths: Vec<Vec<(usize, Vec<usize>)>>,
574        bad_paths: Vec<Vec<(usize, Vec<usize>)>>,
575        expected_iter: Option<ExpectedIter>,
576    ) {
577        let mut get_value = |_: &[Vec<usize>], _: &[usize]| -> Result<usize, Error> { Ok(2) };
578
579        m_trie.insert(&path, &mut get_value).unwrap();
580
581        for good_path in good_paths {
582            assert!(
583                m_trie.look_up(&good_path).is_some(),
584                "Path {:?} not found",
585                good_path
586            );
587        }
588
589        for bad_path in bad_paths {
590            assert!(
591                m_trie.look_up(&bad_path).is_none(),
592                "Path {:?} was found",
593                bad_path
594            );
595        }
596
597        if let Some(expected) = expected_iter {
598            let iter = MultiTrieIterator::new(m_trie);
599
600            for (i, res) in iter.enumerate() {
601                assert_eq!(expected[i], res.path);
602            }
603        }
604    }
605
606    #[test]
607    fn multi_trie_1_of_1_test() {
608        let mut m_trie = MultiTrie::<usize>::new(
609            &same_num_digits_oracle_numeric_infos(1, 5, 2),
610            1,
611            2,
612            3,
613            true,
614        );
615
616        let path = vec![0, 1, 1, 1];
617
618        let good_paths = vec![
619            vec![(0, vec![0, 1, 1, 1, 1])],
620            vec![(0, vec![0, 1, 1, 1, 0])],
621        ];
622
623        let bad_paths = vec![
624            vec![(0, vec![1, 1, 1, 1, 1])],
625            vec![(0, vec![0, 1, 1, 0, 1])],
626            vec![(0, vec![0, 1, 0, 1, 0])],
627        ];
628
629        let expected_iter: Vec<Vec<(usize, Vec<usize>)>> = vec![vec![(0, vec![0, 1, 1, 1])]];
630
631        tests_common(
632            &mut m_trie,
633            path,
634            good_paths,
635            bad_paths,
636            Some(expected_iter),
637        );
638    }
639
640    #[test]
641    fn multi_trie_1_of_2_test() {
642        let mut m_trie = MultiTrie::<usize>::new(
643            &same_num_digits_oracle_numeric_infos(2, 5, 2),
644            1,
645            2,
646            3,
647            true,
648        );
649
650        let path = vec![0, 1, 1, 1];
651
652        let good_paths = vec![
653            vec![(0, vec![0, 1, 1, 1, 1])],
654            vec![(1, vec![0, 1, 1, 1, 1])],
655            vec![(0, vec![0, 1, 1, 1, 0])],
656            vec![(1, vec![0, 1, 1, 1, 0])],
657        ];
658
659        let bad_paths = vec![
660            vec![(0, vec![1, 1, 1, 1, 1])],
661            vec![(1, vec![0, 1, 1, 0, 1])],
662            vec![(0, vec![0, 1, 0, 1, 0])],
663        ];
664
665        let expected_iter: Vec<Vec<(usize, Vec<usize>)>> =
666            vec![vec![(0, vec![0, 1, 1, 1])], vec![(1, vec![0, 1, 1, 1])]];
667
668        tests_common(
669            &mut m_trie,
670            path,
671            good_paths,
672            bad_paths,
673            Some(expected_iter),
674        );
675    }
676
677    #[test]
678    fn multi_trie_2_of_2_test() {
679        let mut m_trie = MultiTrie::<usize>::new(
680            &same_num_digits_oracle_numeric_infos(2, 5, 2),
681            2,
682            2,
683            3,
684            true,
685        );
686
687        let path = vec![0, 1, 1, 1];
688
689        let good_paths = vec![
690            vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
691            vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![1, 0, 0, 1, 1])],
692            vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 0, 0])],
693        ];
694
695        let bad_paths = vec![
696            vec![(0, vec![1, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
697            vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![1, 1, 0, 1, 1])],
698            vec![(0, vec![0, 1, 0, 1, 1]), (1, vec![0, 1, 1, 0, 0])],
699        ];
700
701        let expected_iter: Vec<Vec<(usize, Vec<usize>)>> = vec![
702            vec![(0, vec![0, 1, 1, 1]), (1, vec![0, 1])],
703            vec![(0, vec![0, 1, 1, 1]), (1, vec![1, 0, 0])],
704        ];
705
706        tests_common(
707            &mut m_trie,
708            path,
709            good_paths,
710            bad_paths,
711            Some(expected_iter),
712        );
713    }
714
715    #[test]
716    fn multi_trie_2_of_3_test() {
717        let mut m_trie = MultiTrie::<usize>::new(
718            &same_num_digits_oracle_numeric_infos(3, 5, 2),
719            2,
720            2,
721            3,
722            true,
723        );
724
725        let path = vec![0, 1, 1, 1];
726
727        let good_paths = vec![
728            vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
729            vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![0, 1, 1, 1, 1])],
730            vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![0, 1, 1, 1, 1])],
731            vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![1, 0, 0, 1, 1])],
732            vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![1, 0, 0, 1, 1])],
733        ];
734
735        let bad_paths = vec![
736            vec![(0, vec![1, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
737            vec![(2, vec![0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
738            vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
739            vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
740        ];
741
742        tests_common(&mut m_trie, path, good_paths, bad_paths, None);
743    }
744
745    #[test]
746    fn multi_trie_5_of_5_test() {
747        let mut m_trie = MultiTrie::<usize>::new(
748            &same_num_digits_oracle_numeric_infos(5, 3, 2),
749            5,
750            1,
751            2,
752            true,
753        );
754
755        let path = vec![0, 0, 0];
756
757        let good_paths = vec![vec![
758            (0, vec![0, 0, 0]),
759            (1, vec![0]),
760            (2, vec![0]),
761            (3, vec![0]),
762            (4, vec![0]),
763        ]];
764
765        tests_common(
766            &mut m_trie,
767            path,
768            good_paths.clone(),
769            vec![],
770            Some(good_paths),
771        );
772    }
773
774    #[test]
775    fn multi_3_of_3_test_lexicographic_order() {
776        let mut m_trie = MultiTrie::<usize>::new(
777            &same_num_digits_oracle_numeric_infos(3, 3, 2),
778            3,
779            1,
780            2,
781            true,
782        );
783
784        let inputs = vec![
785            vec![0, 0],
786            vec![0, 0, 1],
787            vec![0, 1, 0],
788            vec![0, 1, 1],
789            vec![1, 0, 0],
790            vec![1, 0, 1],
791        ];
792
793        let mut counter = 0;
794
795        let mut get_value = |_: &[std::vec::Vec<usize>], _: &[usize]| -> Result<usize, Error> {
796            counter += 1;
797            Ok(counter - 1)
798        };
799
800        for input in inputs {
801            m_trie
802                .insert(&input, &mut get_value)
803                .expect("Error inserting in trie");
804        }
805
806        let iter = MultiTrieIterator::new(&m_trie);
807
808        for (i, res) in iter.enumerate() {
809            assert_eq!(i, *res.value);
810        }
811    }
812
813    fn multi_enumerate_equal_lookup_common(mut m_trie: MultiTrie<usize>) {
814        let inputs = vec![
815            // vec![0, 0],
816            vec![0, 1, 0],
817            // vec![0, 1, 1],
818            // vec![1, 0, 0],
819            // vec![1, 0, 1],
820        ];
821
822        let mut counter = 0;
823
824        let mut get_value = |_: &[Vec<usize>], _: &[usize]| -> Result<usize, Error> {
825            counter += 1;
826            Ok(counter - 1)
827        };
828
829        for input in inputs {
830            m_trie
831                .insert(&input, &mut get_value)
832                .expect("Error inserting in trie");
833        }
834
835        let iter = MultiTrieIterator::new(&m_trie);
836
837        for res in iter {
838            assert_eq!(
839                m_trie.look_up(&res.path).expect("Path not found").0,
840                res.value
841            );
842        }
843    }
844
845    #[test]
846    fn multi_3_of_5_test_enumerate_equal_lookup() {
847        let m_trie = MultiTrie::<usize>::new(
848            &same_num_digits_oracle_numeric_infos(5, 3, 2),
849            3,
850            1,
851            2,
852            true,
853        );
854        multi_enumerate_equal_lookup_common(m_trie);
855    }
856
857    #[test]
858    fn multi_5_of_5_test_enumerate_equal_lookup() {
859        let m_trie = MultiTrie::<usize>::new(
860            &same_num_digits_oracle_numeric_infos(5, 3, 2),
861            5,
862            1,
863            2,
864            true,
865        );
866        multi_enumerate_equal_lookup_common(m_trie);
867    }
868
869    #[test]
870    fn multi_2_of_3_diff_nb_digits_enumerate_equal_lookup() {
871        let m_trie = MultiTrie::<usize>::new(
872            &get_variable_oracle_numeric_infos(&[3, 4, 5], 2),
873            2,
874            1,
875            2,
876            true,
877        );
878        multi_enumerate_equal_lookup_common(m_trie);
879    }
880
881    struct TestCase {
882        path: Vec<usize>,
883        good_paths: Vec<Vec<(usize, Vec<usize>)>>,
884        bad_paths: Vec<Vec<(usize, Vec<usize>)>>,
885    }
886
887    #[test]
888    fn multi_trie_2_of_3_diff_nb_digits_test() {
889        let mut m_trie = MultiTrie::<usize>::new(
890            &get_variable_oracle_numeric_infos(&[5, 6, 7], 2),
891            2,
892            2,
893            3,
894            true,
895        );
896
897        let test_cases = vec![
898            TestCase {
899                path: vec![0, 1, 1, 1],
900                good_paths: vec![
901                    vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![0, 0, 1, 1, 1, 1])],
902                    vec![(1, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 0, 1, 1, 1, 1])],
903                    vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![0, 0, 0, 1, 1, 1, 1])],
904                    vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
905                    vec![(1, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
906                    vec![(1, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
907                ],
908                bad_paths: vec![
909                    vec![(0, vec![1, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
910                    vec![(2, vec![0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
911                    vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
912                    vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
913                ],
914            },
915            TestCase {
916                path: vec![1, 1, 1],
917                good_paths: vec![
918                    vec![(0, vec![1, 1, 1, 1, 1]), (1, vec![1, 0, 0, 0, 0])],
919                    vec![(0, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 0])],
920                    vec![(0, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 1])],
921                    vec![(1, vec![0, 1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 0])],
922                ],
923                bad_paths: vec![
924                    vec![(0, vec![1, 1, 1, 1, 1]), (1, vec![1, 0, 0, 1, 1, 1])],
925                    vec![(1, vec![0, 1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 1, 1, 1])],
926                    vec![(0, vec![1, 1, 1, 0, 0]), (2, vec![0, 1, 0, 0, 1, 0, 1])],
927                    vec![(0, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 1, 0, 0, 0])],
928                ],
929            },
930        ];
931
932        for case in test_cases {
933            tests_common(
934                &mut m_trie,
935                case.path,
936                case.good_paths,
937                case.bad_paths,
938                None,
939            );
940        }
941    }
942
943    #[test]
944    fn multi_trie_2_of_3_diff_nb_digits_unordered_test() {
945        let mut m_trie = MultiTrie::<usize>::new(
946            &get_variable_oracle_numeric_infos(&[6, 5, 7], 2),
947            2,
948            2,
949            3,
950            true,
951        );
952
953        let test_cases = vec![
954            TestCase {
955                path: vec![0, 1, 1, 1],
956                good_paths: vec![
957                    vec![(0, vec![0, 0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
958                    vec![(0, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 0, 1, 1, 1, 1])],
959                    vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![0, 0, 0, 1, 1, 1, 1])],
960                    vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
961                    vec![(0, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
962                    vec![(0, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
963                ],
964                bad_paths: vec![
965                    vec![(1, vec![1, 1, 1, 1, 1]), (0, vec![0, 1, 1, 1, 1])],
966                    vec![(2, vec![0, 1, 1, 1, 1]), (0, vec![0, 1, 1, 1, 1])],
967                    vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
968                    vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
969                ],
970            },
971            TestCase {
972                path: vec![1, 1, 1],
973                good_paths: vec![
974                    vec![(0, vec![1, 0, 0, 0, 0]), (1, vec![1, 1, 1, 1, 1])],
975                    vec![(1, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 0])],
976                    vec![(1, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 1])],
977                    vec![(0, vec![0, 1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 0])],
978                ],
979                bad_paths: vec![
980                    vec![(1, vec![1, 1, 1, 1, 1]), (0, vec![1, 0, 0, 1, 1, 1])],
981                    vec![(0, vec![0, 1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 1, 1, 1])],
982                    vec![(1, vec![1, 1, 1, 0, 0]), (2, vec![0, 1, 0, 0, 1, 0, 1])],
983                    vec![(1, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 1, 0, 0, 0])],
984                ],
985            },
986        ];
987
988        for case in test_cases {
989            tests_common(
990                &mut m_trie,
991                case.path,
992                case.good_paths,
993                case.bad_paths,
994                None,
995            );
996        }
997    }
998
999    #[test]
1000    fn ttt() {
1001        let inputs = vec![
1002            vec![0, 0, 0],
1003            vec![0, 0, 1],
1004            vec![0, 1, 0],
1005            vec![0, 1, 1],
1006            vec![1],
1007        ];
1008        let mut m_trie = MultiTrie::<usize>::new(
1009            &get_variable_oracle_numeric_infos(&[4, 3], 2),
1010            2,
1011            1,
1012            2,
1013            true,
1014        );
1015
1016        let mut counter = 0;
1017        let mut get_value = |_: &[Vec<usize>], _: &[usize]| -> Result<usize, Error> {
1018            let res = counter;
1019            counter += 1;
1020            Ok(res)
1021        };
1022        for input in inputs {
1023            m_trie.insert(&input, &mut get_value).unwrap();
1024        }
1025
1026        let iterator = MultiTrieIterator::new(&m_trie);
1027        let mut unordered = iterator.map(|x| *x.value).collect::<Vec<_>>();
1028
1029        unordered.sort();
1030
1031        for (prev_index, i) in unordered.iter().skip(1).enumerate() {
1032            assert_eq!(*i, prev_index + 1);
1033        }
1034    }
1035}