sudoku_solver/utils/
combination_generator.rs

1use crate::solver::return_if_some;
2
3use arrayvec::ArrayVec;
4
5const MAX_SIZE: usize = 4;
6
7pub struct CombinationOptions<'a> {
8    pub on_element_selected: Option<&'a mut dyn FnMut(usize, usize) -> bool>,
9    pub on_element_unselected: Option<&'a mut dyn FnMut(usize, usize)>,
10}
11
12impl<'a> Default for CombinationOptions<'a> {
13    fn default() -> Self {
14        Self {
15            on_element_selected: None,
16            on_element_unselected: None,
17        }
18    }
19}
20
21pub struct CombinationIterator<'a, T: Copy> {
22    arr: &'a [T],
23    n: usize,
24    k: usize,
25    options: CombinationOptions<'a>,
26    stack: ArrayVec<usize, MAX_SIZE>,
27    result: ArrayVec<T, MAX_SIZE>,
28}
29
30impl<'a, T: Copy> CombinationIterator<'a, T> {
31    #[inline(always)]
32    pub fn new(arr: &'a [T], k: usize, options: CombinationOptions<'a>) -> Self {
33        debug_assert!(k <= MAX_SIZE);
34
35        let stack = ArrayVec::<usize, MAX_SIZE>::new();
36        let result = ArrayVec::<T, MAX_SIZE>::new();
37        let n = arr.len();
38        Self {
39            arr,
40            n,
41            k,
42            options,
43            stack,
44            result,
45        }
46    }
47
48    #[inline(always)]
49    fn try_update(&mut self, current: usize) -> Option<&'a [T]> {
50        if let Some(ref mut on_element_selected) = self.options.on_element_selected {
51            if !on_element_selected(current, self.stack[current]) {
52                return None;
53            }
54        }
55        self.result.push(self.arr[self.stack[current]]);
56        for i in current + 1..self.k {
57            self.stack.push(self.stack[i - 1] + 1);
58            if let Some(ref mut on_element_selected) = self.options.on_element_selected {
59                if !on_element_selected(i, self.stack[i]) {
60                    return None;
61                }
62            }
63            self.result.push(self.arr[self.stack[i]]);
64        }
65
66        // Some(self.result.as_slice())
67        unsafe { Some(&*(self.result.as_slice() as *const [T])) }
68    }
69}
70
71impl<'a, T: Copy> Iterator for CombinationIterator<'a, T> {
72    type Item = &'a [T];
73
74    #[inline(always)]
75    fn next(&mut self) -> Option<Self::Item> {
76        let mut skip_unselected = false;
77        if self.stack.len() == 0 {
78            self.stack.push(0);
79            return_if_some!(self.try_update(0));
80            skip_unselected = true;
81        }
82
83        while let Some(&current_element) = self.stack.last() {
84            let stack_index = self.stack.len() - 1;
85            if current_element + (self.k - stack_index) >= self.n {
86                self.stack.pop().unwrap();
87                if skip_unselected {
88                    skip_unselected = false;
89                } else {
90                    self.result.pop().unwrap();
91                    if let Some(on_element_unselected) = self.options.on_element_unselected.as_mut()
92                    {
93                        on_element_unselected(stack_index, current_element);
94                    }
95                }
96                continue;
97            }
98
99            if skip_unselected {
100                skip_unselected = false;
101            } else {
102                self.result.pop().unwrap();
103                if let Some(on_element_unselected) = self.options.on_element_unselected.as_mut() {
104                    on_element_unselected(stack_index, current_element);
105                }
106            }
107
108            *self.stack.last_mut().unwrap() += 1;
109            return_if_some!(self.try_update(stack_index));
110            skip_unselected = true;
111        }
112
113        None
114    }
115}
116
117pub fn combinations<'a, T: Copy>(
118    arr: &'a [T],
119    k: usize,
120    options: CombinationOptions<'a>,
121) -> CombinationIterator<'a, T> {
122    CombinationIterator::new(arr, k, options)
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use itertools::Itertools;
129
130    #[test]
131    fn test_combination_iterator() {
132        let arr = [1, 2, 3, 4, 5];
133        let options = CombinationOptions {
134            on_element_selected: None,
135            on_element_unselected: None,
136        };
137        let iter = CombinationIterator::new(&arr, 2, options);
138        let result = iter
139            .map(|s| s.iter().cloned().collect::<Vec<_>>())
140            .collect::<Vec<_>>();
141        let expected = arr.iter().copied().combinations(2).collect::<Vec<_>>();
142        assert_eq!(result, expected);
143    }
144
145    #[test]
146    fn test_combination_iterator_options() {
147        let arr = [1, 2, 3, 4, 5];
148        let mut selected_order = vec![];
149        let mut unselected_order = vec![];
150        let ref mut on_element_selected = |pos, element| {
151            if element != 3 {
152                selected_order.push(element);
153                return true;
154            }
155            false
156        };
157        let ref mut on_element_unselected = |pos, element| {
158            unselected_order.push(element);
159        };
160        let options = CombinationOptions {
161            on_element_selected: Some(on_element_selected),
162            on_element_unselected: Some(on_element_unselected),
163        };
164        let iter = CombinationIterator::new(&arr, 2, options);
165        let result = iter
166            .map(|s| s.iter().cloned().collect::<Vec<_>>())
167            .collect::<Vec<_>>();
168        let expected = arr
169            .iter()
170            .copied()
171            .filter(|&x| x != 4)
172            .combinations(2)
173            .collect::<Vec<_>>();
174        assert_eq!(result, expected);
175        assert_eq!(selected_order, [0, 1, 2, 4, 1, 2, 4, 2, 4]);
176        assert_eq!(unselected_order, [1, 2, 4, 0, 2, 4, 1, 4, 2]);
177    }
178}