1use crate::{
5 utils::{get_max_covering_paths, pre_pad_vec},
6 LookupResult, Node, OracleNumericInfo,
7};
8use combination_iterator::CombinationIterator;
9use ddk_dlc::Error;
10use digit_trie::{DigitTrie, DigitTrieDump, DigitTrieIter};
11use multi_oracle::compute_outcome_combinations;
12
13#[derive(Clone, Debug)]
14pub struct TrieNodeInfo {
16 pub trie_index: usize,
18 pub store_index: usize,
20}
21
22type MultiTrieNode<T> = Node<DigitTrie<T>, DigitTrie<Vec<TrieNodeInfo>>>;
23type NodeStackElement<'a> = Vec<(IndexedPath, DigitTrieIter<'a, Vec<TrieNodeInfo>>)>;
24type IndexedPath = (usize, Vec<usize>);
25
26impl<T> MultiTrieNode<T> {
27 fn new_node(base: usize) -> MultiTrieNode<T> {
28 let m_trie = DigitTrie::<Vec<TrieNodeInfo>>::new(base);
29 MultiTrieNode::Node(m_trie)
30 }
31 fn new_leaf(base: usize) -> MultiTrieNode<T> {
32 let d_trie = DigitTrie::<T>::new(base);
33 MultiTrieNode::Leaf(d_trie)
34 }
35}
36
37pub(crate) struct MultiTrieIterator<'a, T> {
39 trie: &'a MultiTrie<T>,
40 node_stack: NodeStackElement<'a>,
41 trie_info_iter: Vec<(
42 Vec<usize>,
43 std::iter::Enumerate<std::slice::Iter<'a, TrieNodeInfo>>,
44 )>,
45 leaf_iter: Vec<(usize, DigitTrieIter<'a, T>)>,
46 cur_path: Vec<(usize, Vec<usize>)>,
47}
48
49fn create_node_iterator<T>(node: &'_ MultiTrieNode<T>) -> DigitTrieIter<'_, Vec<TrieNodeInfo>> {
50 match node {
51 Node::Node(d_trie) => DigitTrieIter::new(d_trie),
52 _ => unreachable!(),
53 }
54}
55
56fn create_leaf_iterator<T>(node: &'_ MultiTrieNode<T>) -> DigitTrieIter<'_, T> {
57 match node {
58 Node::Leaf(d_trie) => DigitTrieIter::new(d_trie),
59 _ => unreachable!(),
60 }
61}
62
63impl<'a, T> MultiTrieIterator<'a, T> {
64 pub fn new(trie: &'a MultiTrie<T>) -> MultiTrieIterator<'a, T> {
66 let mut node_stack = Vec::with_capacity(trie.nb_required);
67 let nb_roots = trie.nb_tries - trie.nb_required + 1;
68 let mut leaf_iter = Vec::new();
69 for i in (0..nb_roots).rev() {
70 if trie.nb_required > 1 {
71 node_stack.push((
72 (i, Vec::<usize>::new()),
73 create_node_iterator(&trie.store[i]),
74 ));
75 } else {
76 leaf_iter.push((i, create_leaf_iterator(&trie.store[i])));
77 }
78 }
79 MultiTrieIterator {
80 trie,
81 node_stack,
82 trie_info_iter: Vec::new(),
83 leaf_iter,
84 cur_path: Vec::new(),
85 }
86 }
87}
88
89impl<'a, T> Iterator for MultiTrieIterator<'a, T> {
91 type Item = LookupResult<'a, T, (usize, Vec<usize>)>;
92
93 fn next(&mut self) -> Option<Self::Item> {
94 let mut leaf_iter = self.leaf_iter.last_mut();
95 if let Some(ref mut iter) = &mut leaf_iter {
96 match iter.1.next() {
97 Some(res) => {
98 let mut path = self.cur_path.clone();
99 path.push((iter.0, res.path));
100 return Some(LookupResult {
101 value: res.value,
102 path,
103 });
104 }
105 None => {
106 self.leaf_iter.pop();
107 return self.next();
108 }
109 }
110 };
111
112 let mut trie_info_iter = self.trie_info_iter.last_mut();
113
114 if let Some(ref mut iter) = &mut trie_info_iter {
115 match iter.1.next() {
116 None => {
117 self.trie_info_iter.pop();
118 self.cur_path.pop();
119 }
120 Some((i, info)) => {
121 if i == 0 {
122 self.cur_path
123 .push((self.node_stack.last().unwrap().0 .0, iter.0.clone()));
124 }
125 match &self.trie.store[info.store_index] {
126 Node::None => unreachable!(),
127 Node::Node(d_trie) => {
128 self.node_stack.push((
129 (info.trie_index, iter.0.clone()),
130 DigitTrieIter::new(d_trie),
131 ));
132 }
133 Node::Leaf(d_trie) => {
134 self.leaf_iter
135 .push((info.trie_index, DigitTrieIter::new(d_trie)));
136 return self.next();
137 }
138 }
139 }
140 }
141 }
142
143 let ((cur_trie_index, parent_path), mut cur_iter) = self.node_stack.pop()?;
144
145 match cur_iter.next() {
146 None => self.next(),
147 Some(res) => {
148 self.node_stack
150 .push(((cur_trie_index, parent_path), cur_iter));
151
152 self.trie_info_iter
154 .push((res.path, res.value.iter().enumerate()));
155
156 self.next()
157 }
158 }
159 }
160}
161
162#[derive(Clone)]
164pub struct MultiTrie<T> {
165 store: Vec<MultiTrieNode<T>>,
166 nb_tries: usize,
167 nb_required: usize,
168 min_support_exp: usize,
169 max_error_exp: usize,
170 maximize_coverage: bool,
171 oracle_numeric_infos: OracleNumericInfo,
172}
173
174impl<T> MultiTrie<T> {
175 pub fn new(
178 oracle_numeric_infos: &OracleNumericInfo,
179 nb_required: usize,
180 min_support_exp: usize,
181 max_error_exp: usize,
182 maximize_coverage: bool,
183 ) -> MultiTrie<T> {
184 let nb_tries = oracle_numeric_infos.nb_digits.len();
185 assert!(
186 nb_required > 0
187 && nb_tries >= nb_required
188 && !oracle_numeric_infos.nb_digits.is_empty()
189 );
190 let nb_roots = nb_tries - nb_required + 1;
191
192 let store: Vec<_> = if nb_required > 1 {
193 (0..nb_tries)
194 .take(nb_roots)
195 .map(|_| MultiTrieNode::new_node(oracle_numeric_infos.base))
196 .collect()
197 } else {
198 (0..nb_tries)
199 .take(nb_roots)
200 .map(|_| MultiTrieNode::new_leaf(oracle_numeric_infos.base))
201 .collect()
202 };
203
204 MultiTrie {
205 store,
206 nb_tries,
207 nb_required,
208 min_support_exp,
209 max_error_exp,
210 maximize_coverage,
211 oracle_numeric_infos: oracle_numeric_infos.clone(),
212 }
213 }
214
215 fn swap_remove(&mut self, index: usize) -> MultiTrieNode<T> {
216 self.store.push(MultiTrieNode::None);
217 self.store.swap_remove(index)
218 }
219
220 pub fn insert_max_paths<F>(&mut self, get_value: &mut F) -> Result<(), Error>
224 where
225 F: FnMut(&[Vec<usize>], &[usize]) -> Result<T, Error>,
226 {
227 let indexed_paths = get_max_covering_paths(&self.oracle_numeric_infos, self.nb_required);
228 for indexed_path in indexed_paths {
229 let (indexes, paths): (Vec<usize>, Vec<Vec<usize>>) = indexed_path.into_iter().unzip();
230 self.insert_internal(indexes[0], &paths, 0, &indexes, get_value)?;
231 }
232 Ok(())
233 }
234
235 pub fn insert<F>(&mut self, path: &[usize], get_value: &mut F) -> Result<(), Error>
237 where
238 F: FnMut(&[Vec<usize>], &[usize]) -> Result<T, Error>,
239 {
240 let combination_iter = CombinationIterator::new(self.nb_tries, self.nb_required);
241 let min_nb_digits = self.oracle_numeric_infos.get_min_nb_digits();
242
243 for selector in combination_iter {
244 let combinations = if self.nb_required > 1 {
245 let mut digit_infos = self
246 .oracle_numeric_infos
247 .nb_digits
248 .iter()
249 .enumerate()
250 .filter_map(|(i, x)| {
251 if selector.contains(&i) {
252 Some(*x)
253 } else {
254 None
255 }
256 })
257 .collect::<Vec<_>>();
258 let min_index = reorder_to_min_first(&mut digit_infos);
259 let to_pad = digit_infos[0] - min_nb_digits;
260 let padded_path = pre_pad_vec(path.to_vec(), path.len() + to_pad);
261 let mut combinations = compute_outcome_combinations(
262 &digit_infos,
263 &padded_path,
264 self.max_error_exp,
265 self.min_support_exp,
266 self.maximize_coverage,
267 );
268 if min_index != 0 {
269 for combination in &mut combinations {
270 let to_reorder = combination.remove(0);
271 combination.insert(min_index, to_reorder);
272 }
273 }
274 combinations
275 } else {
276 vec![vec![path.to_vec()]]
277 };
278
279 for combination in combinations {
280 self.insert_internal(selector[0], &combination, 0, &selector, get_value)?;
281 }
282 }
283
284 Ok(())
285 }
286
287 fn insert_new(&mut self, is_leaf: bool) {
288 let m_trie = if is_leaf {
289 let d_trie = DigitTrie::<T>::new(self.oracle_numeric_infos.base);
290 MultiTrieNode::Leaf(d_trie)
291 } else {
292 let d_trie = DigitTrie::<Vec<TrieNodeInfo>>::new(self.oracle_numeric_infos.base);
293 MultiTrieNode::Node(d_trie)
294 };
295 self.store.push(m_trie);
296 }
297
298 fn insert_internal<F>(
299 &mut self,
300 cur_node_index: usize,
301 paths: &[Vec<usize>],
302 path_index: usize,
303 trie_indexes: &[usize],
304 get_value: &mut F,
305 ) -> Result<(), Error>
306 where
307 F: FnMut(&[Vec<usize>], &[usize]) -> Result<T, Error>,
308 {
309 assert!(path_index < paths.len());
310 let cur_node = self.swap_remove(cur_node_index);
311 match cur_node {
312 MultiTrieNode::None => unreachable!(),
313 MultiTrieNode::Leaf(mut digit_trie) => {
314 assert_eq!(path_index, paths.len() - 1);
315 let mut get_data = |_| get_value(paths, trie_indexes);
316 digit_trie.insert(&paths[path_index], &mut get_data)?;
317 self.store[cur_node_index] = MultiTrieNode::Leaf(digit_trie);
318 }
319 MultiTrieNode::Node(mut node) => {
320 assert!(path_index < paths.len() - 1);
321 let mut store_index = 0;
322 let mut callback =
323 |cur_data_res: Option<Vec<TrieNodeInfo>>| -> Result<Vec<TrieNodeInfo>, Error> {
324 let mut cur_data = match cur_data_res {
325 Some(cur_data) => {
326 if let Some(cur_store_index) =
327 find_store_index(&cur_data, trie_indexes[path_index + 1])
328 {
329 store_index = cur_store_index;
330 return Ok(cur_data);
331 }
332 cur_data
333 }
334 _ => vec![],
335 };
336 self.insert_new(paths.len() - 1 == path_index + 1);
337 store_index = self.store.len() - 1;
338 let trie_index = trie_indexes[path_index + 1];
339 let trie_node_info = TrieNodeInfo {
340 trie_index,
341 store_index,
342 };
343 cur_data.push(trie_node_info);
344 Ok(cur_data)
345 };
346 node.insert(&paths[path_index], &mut callback)?;
347 self.store[cur_node_index] = MultiTrieNode::Node(node);
348 self.insert_internal(store_index, paths, path_index + 1, trie_indexes, get_value)?;
349 }
350 }
351 Ok(())
352 }
353
354 pub fn look_up<'a>(
356 &'a self,
357 paths: &[(usize, Vec<usize>)],
358 ) -> Option<(&'a T, Vec<IndexedPath>)> {
359 if paths.len() < self.nb_required {
360 return None;
361 }
362
363 let store = &self.store;
364
365 let combination_iter = CombinationIterator::new(paths.len(), self.nb_required);
366
367 let nb_roots = self.nb_tries - self.nb_required + 1;
368
369 for selector in combination_iter {
370 let first_index = paths[selector[0]].0;
371 if first_index >= nb_roots {
372 continue;
373 }
374
375 let res = self.look_up_internal(
376 &store[first_index],
377 &paths
378 .iter()
379 .enumerate()
380 .filter_map(|(i, x)| {
381 if selector.contains(&i) {
382 return Some(x);
383 }
384 None
385 })
386 .collect::<Vec<_>>(),
387 0,
388 );
389 if let Some(mut l_res) = res {
390 l_res.path.reverse();
391 return Some((l_res.value, l_res.path.clone()));
392 }
393 }
394
395 None
396 }
397
398 fn look_up_internal<'a>(
399 &'a self,
400 cur_node: &'a MultiTrieNode<T>,
401 paths: &[&(usize, Vec<usize>)],
402 path_index: usize,
403 ) -> Option<LookupResult<'a, T, (usize, Vec<usize>)>> {
404 assert!(path_index < paths.len());
405 let trie_index = paths[path_index].0;
406
407 match cur_node {
408 MultiTrieNode::None => unreachable!(),
409 MultiTrieNode::Leaf(d_trie) => {
410 let res = d_trie.look_up(&paths[path_index].1)?;
411 Some(LookupResult {
412 value: res[0].value,
413 path: vec![(trie_index, res[0].path.clone())],
414 })
415 }
416 MultiTrieNode::Node(d_trie) => {
417 assert!(path_index < paths.len() - 1);
418 let results = d_trie.look_up(&paths[path_index].1)?;
419
420 for l_res in results {
421 if let Some(index) = find_store_index(l_res.value, paths[path_index + 1].0) {
422 let next_node = &self.store[index];
423 if let Some(mut child_l_res) =
424 self.look_up_internal(next_node, paths, path_index + 1)
425 {
426 child_l_res.path.push((trie_index, l_res.path));
427 return Some(child_l_res);
428 }
429 }
430 }
431
432 None
433 }
434 }
435 }
436}
437
438fn find_store_index(children: &[TrieNodeInfo], trie_index: usize) -> Option<usize> {
439 for info in children {
440 if trie_index == info.trie_index {
441 return Some(info.store_index);
442 }
443 }
444
445 None
446}
447
448fn reorder_to_min_first(oracle_digit_infos: &mut Vec<usize>) -> usize {
449 let min_index = oracle_digit_infos
450 .iter()
451 .enumerate()
452 .min_by_key(|(_, x)| *x)
453 .unwrap()
454 .0;
455 if min_index != 0 {
456 let min_val = oracle_digit_infos.remove(min_index);
457 oracle_digit_infos.insert(0, min_val);
458 }
459 min_index
460}
461
462pub struct MultiTrieDump<T>
464where
465 T: Clone,
466{
467 pub node_data: Vec<MultiTrieNodeData<T>>,
469 pub nb_tries: usize,
471 pub nb_required: usize,
473 pub min_support_exp: usize,
475 pub max_error_exp: usize,
477 pub maximize_coverage: bool,
479 pub oracle_numeric_infos: OracleNumericInfo,
481}
482
483impl<T> MultiTrie<T>
484where
485 T: Clone,
486{
487 pub fn dump(&self) -> MultiTrieDump<T> {
489 let node_data = self.store.iter().map(|x| x.get_data()).collect();
490 MultiTrieDump {
491 node_data,
492 nb_tries: self.nb_tries,
493 nb_required: self.nb_required,
494 min_support_exp: self.min_support_exp,
495 max_error_exp: self.max_error_exp,
496 maximize_coverage: self.maximize_coverage,
497 oracle_numeric_infos: self.oracle_numeric_infos.clone(),
498 }
499 }
500
501 pub fn from_dump(dump: MultiTrieDump<T>) -> MultiTrie<T> {
503 let MultiTrieDump {
504 node_data,
505 nb_tries,
506 nb_required,
507 min_support_exp,
508 max_error_exp,
509 maximize_coverage,
510 oracle_numeric_infos,
511 } = dump;
512
513 let store = node_data
514 .into_iter()
515 .map(|x| MultiTrieNode::from_data(x))
516 .collect();
517
518 MultiTrie {
519 store,
520 nb_tries,
521 nb_required,
522 min_support_exp,
523 max_error_exp,
524 maximize_coverage,
525 oracle_numeric_infos,
526 }
527 }
528}
529
530pub enum MultiTrieNodeData<T>
532where
533 T: Clone,
534{
535 Leaf(DigitTrieDump<T>),
537 Node(DigitTrieDump<Vec<TrieNodeInfo>>),
539}
540
541impl<T> MultiTrieNode<T>
542where
543 T: Clone,
544{
545 fn get_data(&self) -> MultiTrieNodeData<T> {
546 match self {
547 Node::Leaf(l) => MultiTrieNodeData::Leaf(l.dump()),
548 Node::Node(n) => MultiTrieNodeData::Node(n.dump()),
549 Node::None => unreachable!(),
550 }
551 }
552
553 fn from_data(data: MultiTrieNodeData<T>) -> MultiTrieNode<T> {
554 match data {
555 MultiTrieNodeData::Leaf(l) => Node::Leaf(DigitTrie::from_dump(l)),
556 MultiTrieNodeData::Node(n) => Node::Node(DigitTrie::from_dump(n)),
557 }
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use crate::test_utils::{
565 get_variable_oracle_numeric_infos, same_num_digits_oracle_numeric_infos,
566 };
567
568 type ExpectedIter = Vec<Vec<(usize, Vec<usize>)>>;
569
570 fn tests_common(
571 m_trie: &mut MultiTrie<usize>,
572 path: Vec<usize>,
573 good_paths: Vec<Vec<(usize, Vec<usize>)>>,
574 bad_paths: Vec<Vec<(usize, Vec<usize>)>>,
575 expected_iter: Option<ExpectedIter>,
576 ) {
577 let mut get_value = |_: &[Vec<usize>], _: &[usize]| -> Result<usize, Error> { Ok(2) };
578
579 m_trie.insert(&path, &mut get_value).unwrap();
580
581 for good_path in good_paths {
582 assert!(
583 m_trie.look_up(&good_path).is_some(),
584 "Path {:?} not found",
585 good_path
586 );
587 }
588
589 for bad_path in bad_paths {
590 assert!(
591 m_trie.look_up(&bad_path).is_none(),
592 "Path {:?} was found",
593 bad_path
594 );
595 }
596
597 if let Some(expected) = expected_iter {
598 let iter = MultiTrieIterator::new(m_trie);
599
600 for (i, res) in iter.enumerate() {
601 assert_eq!(expected[i], res.path);
602 }
603 }
604 }
605
606 #[test]
607 fn multi_trie_1_of_1_test() {
608 let mut m_trie = MultiTrie::<usize>::new(
609 &same_num_digits_oracle_numeric_infos(1, 5, 2),
610 1,
611 2,
612 3,
613 true,
614 );
615
616 let path = vec![0, 1, 1, 1];
617
618 let good_paths = vec![
619 vec![(0, vec![0, 1, 1, 1, 1])],
620 vec![(0, vec![0, 1, 1, 1, 0])],
621 ];
622
623 let bad_paths = vec![
624 vec![(0, vec![1, 1, 1, 1, 1])],
625 vec![(0, vec![0, 1, 1, 0, 1])],
626 vec![(0, vec![0, 1, 0, 1, 0])],
627 ];
628
629 let expected_iter: Vec<Vec<(usize, Vec<usize>)>> = vec![vec![(0, vec![0, 1, 1, 1])]];
630
631 tests_common(
632 &mut m_trie,
633 path,
634 good_paths,
635 bad_paths,
636 Some(expected_iter),
637 );
638 }
639
640 #[test]
641 fn multi_trie_1_of_2_test() {
642 let mut m_trie = MultiTrie::<usize>::new(
643 &same_num_digits_oracle_numeric_infos(2, 5, 2),
644 1,
645 2,
646 3,
647 true,
648 );
649
650 let path = vec![0, 1, 1, 1];
651
652 let good_paths = vec![
653 vec![(0, vec![0, 1, 1, 1, 1])],
654 vec![(1, vec![0, 1, 1, 1, 1])],
655 vec![(0, vec![0, 1, 1, 1, 0])],
656 vec![(1, vec![0, 1, 1, 1, 0])],
657 ];
658
659 let bad_paths = vec![
660 vec![(0, vec![1, 1, 1, 1, 1])],
661 vec![(1, vec![0, 1, 1, 0, 1])],
662 vec![(0, vec![0, 1, 0, 1, 0])],
663 ];
664
665 let expected_iter: Vec<Vec<(usize, Vec<usize>)>> =
666 vec![vec![(0, vec![0, 1, 1, 1])], vec![(1, vec![0, 1, 1, 1])]];
667
668 tests_common(
669 &mut m_trie,
670 path,
671 good_paths,
672 bad_paths,
673 Some(expected_iter),
674 );
675 }
676
677 #[test]
678 fn multi_trie_2_of_2_test() {
679 let mut m_trie = MultiTrie::<usize>::new(
680 &same_num_digits_oracle_numeric_infos(2, 5, 2),
681 2,
682 2,
683 3,
684 true,
685 );
686
687 let path = vec![0, 1, 1, 1];
688
689 let good_paths = vec![
690 vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
691 vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![1, 0, 0, 1, 1])],
692 vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 0, 0])],
693 ];
694
695 let bad_paths = vec![
696 vec![(0, vec![1, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
697 vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![1, 1, 0, 1, 1])],
698 vec![(0, vec![0, 1, 0, 1, 1]), (1, vec![0, 1, 1, 0, 0])],
699 ];
700
701 let expected_iter: Vec<Vec<(usize, Vec<usize>)>> = vec![
702 vec![(0, vec![0, 1, 1, 1]), (1, vec![0, 1])],
703 vec![(0, vec![0, 1, 1, 1]), (1, vec![1, 0, 0])],
704 ];
705
706 tests_common(
707 &mut m_trie,
708 path,
709 good_paths,
710 bad_paths,
711 Some(expected_iter),
712 );
713 }
714
715 #[test]
716 fn multi_trie_2_of_3_test() {
717 let mut m_trie = MultiTrie::<usize>::new(
718 &same_num_digits_oracle_numeric_infos(3, 5, 2),
719 2,
720 2,
721 3,
722 true,
723 );
724
725 let path = vec![0, 1, 1, 1];
726
727 let good_paths = vec![
728 vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
729 vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![0, 1, 1, 1, 1])],
730 vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![0, 1, 1, 1, 1])],
731 vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![1, 0, 0, 1, 1])],
732 vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![1, 0, 0, 1, 1])],
733 ];
734
735 let bad_paths = vec![
736 vec![(0, vec![1, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
737 vec![(2, vec![0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
738 vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
739 vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
740 ];
741
742 tests_common(&mut m_trie, path, good_paths, bad_paths, None);
743 }
744
745 #[test]
746 fn multi_trie_5_of_5_test() {
747 let mut m_trie = MultiTrie::<usize>::new(
748 &same_num_digits_oracle_numeric_infos(5, 3, 2),
749 5,
750 1,
751 2,
752 true,
753 );
754
755 let path = vec![0, 0, 0];
756
757 let good_paths = vec![vec![
758 (0, vec![0, 0, 0]),
759 (1, vec![0]),
760 (2, vec![0]),
761 (3, vec![0]),
762 (4, vec![0]),
763 ]];
764
765 tests_common(
766 &mut m_trie,
767 path,
768 good_paths.clone(),
769 vec![],
770 Some(good_paths),
771 );
772 }
773
774 #[test]
775 fn multi_3_of_3_test_lexicographic_order() {
776 let mut m_trie = MultiTrie::<usize>::new(
777 &same_num_digits_oracle_numeric_infos(3, 3, 2),
778 3,
779 1,
780 2,
781 true,
782 );
783
784 let inputs = vec![
785 vec![0, 0],
786 vec![0, 0, 1],
787 vec![0, 1, 0],
788 vec![0, 1, 1],
789 vec![1, 0, 0],
790 vec![1, 0, 1],
791 ];
792
793 let mut counter = 0;
794
795 let mut get_value = |_: &[std::vec::Vec<usize>], _: &[usize]| -> Result<usize, Error> {
796 counter += 1;
797 Ok(counter - 1)
798 };
799
800 for input in inputs {
801 m_trie
802 .insert(&input, &mut get_value)
803 .expect("Error inserting in trie");
804 }
805
806 let iter = MultiTrieIterator::new(&m_trie);
807
808 for (i, res) in iter.enumerate() {
809 assert_eq!(i, *res.value);
810 }
811 }
812
813 fn multi_enumerate_equal_lookup_common(mut m_trie: MultiTrie<usize>) {
814 let inputs = vec![
815 vec![0, 1, 0],
817 ];
821
822 let mut counter = 0;
823
824 let mut get_value = |_: &[Vec<usize>], _: &[usize]| -> Result<usize, Error> {
825 counter += 1;
826 Ok(counter - 1)
827 };
828
829 for input in inputs {
830 m_trie
831 .insert(&input, &mut get_value)
832 .expect("Error inserting in trie");
833 }
834
835 let iter = MultiTrieIterator::new(&m_trie);
836
837 for res in iter {
838 assert_eq!(
839 m_trie.look_up(&res.path).expect("Path not found").0,
840 res.value
841 );
842 }
843 }
844
845 #[test]
846 fn multi_3_of_5_test_enumerate_equal_lookup() {
847 let m_trie = MultiTrie::<usize>::new(
848 &same_num_digits_oracle_numeric_infos(5, 3, 2),
849 3,
850 1,
851 2,
852 true,
853 );
854 multi_enumerate_equal_lookup_common(m_trie);
855 }
856
857 #[test]
858 fn multi_5_of_5_test_enumerate_equal_lookup() {
859 let m_trie = MultiTrie::<usize>::new(
860 &same_num_digits_oracle_numeric_infos(5, 3, 2),
861 5,
862 1,
863 2,
864 true,
865 );
866 multi_enumerate_equal_lookup_common(m_trie);
867 }
868
869 #[test]
870 fn multi_2_of_3_diff_nb_digits_enumerate_equal_lookup() {
871 let m_trie = MultiTrie::<usize>::new(
872 &get_variable_oracle_numeric_infos(&[3, 4, 5], 2),
873 2,
874 1,
875 2,
876 true,
877 );
878 multi_enumerate_equal_lookup_common(m_trie);
879 }
880
881 struct TestCase {
882 path: Vec<usize>,
883 good_paths: Vec<Vec<(usize, Vec<usize>)>>,
884 bad_paths: Vec<Vec<(usize, Vec<usize>)>>,
885 }
886
887 #[test]
888 fn multi_trie_2_of_3_diff_nb_digits_test() {
889 let mut m_trie = MultiTrie::<usize>::new(
890 &get_variable_oracle_numeric_infos(&[5, 6, 7], 2),
891 2,
892 2,
893 3,
894 true,
895 );
896
897 let test_cases = vec![
898 TestCase {
899 path: vec![0, 1, 1, 1],
900 good_paths: vec![
901 vec![(0, vec![0, 1, 1, 1, 1]), (1, vec![0, 0, 1, 1, 1, 1])],
902 vec![(1, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 0, 1, 1, 1, 1])],
903 vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![0, 0, 0, 1, 1, 1, 1])],
904 vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
905 vec![(1, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
906 vec![(1, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
907 ],
908 bad_paths: vec![
909 vec![(0, vec![1, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
910 vec![(2, vec![0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
911 vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
912 vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
913 ],
914 },
915 TestCase {
916 path: vec![1, 1, 1],
917 good_paths: vec![
918 vec![(0, vec![1, 1, 1, 1, 1]), (1, vec![1, 0, 0, 0, 0])],
919 vec![(0, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 0])],
920 vec![(0, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 1])],
921 vec![(1, vec![0, 1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 0])],
922 ],
923 bad_paths: vec![
924 vec![(0, vec![1, 1, 1, 1, 1]), (1, vec![1, 0, 0, 1, 1, 1])],
925 vec![(1, vec![0, 1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 1, 1, 1])],
926 vec![(0, vec![1, 1, 1, 0, 0]), (2, vec![0, 1, 0, 0, 1, 0, 1])],
927 vec![(0, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 1, 0, 0, 0])],
928 ],
929 },
930 ];
931
932 for case in test_cases {
933 tests_common(
934 &mut m_trie,
935 case.path,
936 case.good_paths,
937 case.bad_paths,
938 None,
939 );
940 }
941 }
942
943 #[test]
944 fn multi_trie_2_of_3_diff_nb_digits_unordered_test() {
945 let mut m_trie = MultiTrie::<usize>::new(
946 &get_variable_oracle_numeric_infos(&[6, 5, 7], 2),
947 2,
948 2,
949 3,
950 true,
951 );
952
953 let test_cases = vec![
954 TestCase {
955 path: vec![0, 1, 1, 1],
956 good_paths: vec![
957 vec![(0, vec![0, 0, 1, 1, 1, 1]), (1, vec![0, 1, 1, 1, 1])],
958 vec![(0, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 0, 1, 1, 1, 1])],
959 vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![0, 0, 0, 1, 1, 1, 1])],
960 vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
961 vec![(0, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
962 vec![(0, vec![0, 0, 1, 1, 1, 1]), (2, vec![0, 0, 1, 0, 0, 1, 1])],
963 ],
964 bad_paths: vec![
965 vec![(1, vec![1, 1, 1, 1, 1]), (0, vec![0, 1, 1, 1, 1])],
966 vec![(2, vec![0, 1, 1, 1, 1]), (0, vec![0, 1, 1, 1, 1])],
967 vec![(1, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
968 vec![(0, vec![0, 1, 1, 1, 1]), (2, vec![1, 1, 1, 1, 1])],
969 ],
970 },
971 TestCase {
972 path: vec![1, 1, 1],
973 good_paths: vec![
974 vec![(0, vec![1, 0, 0, 0, 0]), (1, vec![1, 1, 1, 1, 1])],
975 vec![(1, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 0])],
976 vec![(1, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 1])],
977 vec![(0, vec![0, 1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 0, 0, 0, 0])],
978 ],
979 bad_paths: vec![
980 vec![(1, vec![1, 1, 1, 1, 1]), (0, vec![1, 0, 0, 1, 1, 1])],
981 vec![(0, vec![0, 1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 1, 1, 1])],
982 vec![(1, vec![1, 1, 1, 0, 0]), (2, vec![0, 1, 0, 0, 1, 0, 1])],
983 vec![(1, vec![1, 1, 1, 1, 1]), (2, vec![0, 1, 0, 1, 0, 0, 0])],
984 ],
985 },
986 ];
987
988 for case in test_cases {
989 tests_common(
990 &mut m_trie,
991 case.path,
992 case.good_paths,
993 case.bad_paths,
994 None,
995 );
996 }
997 }
998
999 #[test]
1000 fn ttt() {
1001 let inputs = vec![
1002 vec![0, 0, 0],
1003 vec![0, 0, 1],
1004 vec![0, 1, 0],
1005 vec![0, 1, 1],
1006 vec![1],
1007 ];
1008 let mut m_trie = MultiTrie::<usize>::new(
1009 &get_variable_oracle_numeric_infos(&[4, 3], 2),
1010 2,
1011 1,
1012 2,
1013 true,
1014 );
1015
1016 let mut counter = 0;
1017 let mut get_value = |_: &[Vec<usize>], _: &[usize]| -> Result<usize, Error> {
1018 let res = counter;
1019 counter += 1;
1020 Ok(res)
1021 };
1022 for input in inputs {
1023 m_trie.insert(&input, &mut get_value).unwrap();
1024 }
1025
1026 let iterator = MultiTrieIterator::new(&m_trie);
1027 let mut unordered = iterator.map(|x| *x.value).collect::<Vec<_>>();
1028
1029 unordered.sort();
1030
1031 for (prev_index, i) in unordered.iter().skip(1).enumerate() {
1032 assert_eq!(*i, prev_index + 1);
1033 }
1034 }
1035}