ddk_trie/
combination_iterator.rs

1//! # Combination Iterator
2//! Utility struct and functions too support iterating though all possible
3//! t combinations of n elements (where t >= n).
4
5/// Structure to use to support iteration through all possible t combinations
6/// of n elements.
7pub struct CombinationIterator {
8    selector: Vec<usize>,
9    pub(crate) nb_selected: usize,
10    pub(crate) nb_elements: usize,
11    is_init: bool,
12}
13
14impl CombinationIterator {
15    /// Creates a new combination iterator for a collection of `nb_elements`
16    /// where each combination includes `nb_selected` elements. Panics if
17    /// `nb_elements < nb_selected`.
18    pub fn new(nb_elements: usize, nb_selected: usize) -> CombinationIterator {
19        assert!(nb_elements >= nb_selected);
20
21        let selector = (0..nb_selected).collect();
22        CombinationIterator {
23            selector,
24            nb_elements,
25            nb_selected,
26            is_init: false,
27        }
28    }
29
30    /// Returns the index of the provided combination if part of the set of
31    /// combinations produced by the iterator, None otherwise.
32    pub fn get_index_for_combination(self, combination: &[usize]) -> Option<usize> {
33        for (i, cur) in self.enumerate() {
34            if combination == cur {
35                return Some(i);
36            }
37        }
38
39        None
40    }
41}
42
43impl Iterator for CombinationIterator {
44    type Item = Vec<usize>;
45    fn next(&mut self) -> Option<Self::Item> {
46        if !self.is_init {
47            self.is_init = true;
48            return Some(self.selector.clone());
49        }
50
51        let last_index = self.nb_selected - 1;
52        let mut cur_index = last_index;
53        while cur_index > 0
54            && self.selector[cur_index] == self.nb_elements - 1 - (last_index - cur_index)
55        {
56            cur_index -= 1;
57        }
58
59        self.selector[cur_index] += 1;
60        cur_index += 1;
61
62        while cur_index <= last_index {
63            self.selector[cur_index] = self.selector[cur_index - 1] + 1;
64            cur_index += 1;
65        }
66
67        if self.selector[0] == self.nb_elements - self.nb_selected + 1 {
68            return None;
69        }
70
71        Some(self.selector.clone())
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    #[test]
80    fn generate_combinations_test() {
81        let combination_iterator = CombinationIterator::new(4, 3);
82        let expected = vec![vec![0, 1, 2], vec![0, 1, 3], vec![0, 2, 3], vec![1, 2, 3]];
83
84        for (i, cur) in combination_iterator.enumerate() {
85            assert_eq!(cur, expected[i]);
86        }
87    }
88
89    #[test]
90    fn get_combination_index_test() {
91        let combination_iterator = CombinationIterator::new(4, 3);
92
93        assert_eq!(
94            2,
95            combination_iterator
96                .get_index_for_combination(&[0, 2, 3])
97                .expect("Could not find combination")
98        );
99    }
100}