sudoku_solver/utils/
combination_generator.rs1use crate::solver::return_if_some;
2
3use arrayvec::ArrayVec;
4
5const MAX_SIZE: usize = 4;
6
7pub struct CombinationOptions<'a> {
8 pub on_element_selected: Option<&'a mut dyn FnMut(usize, usize) -> bool>,
9 pub on_element_unselected: Option<&'a mut dyn FnMut(usize, usize)>,
10}
11
12impl<'a> Default for CombinationOptions<'a> {
13 fn default() -> Self {
14 Self {
15 on_element_selected: None,
16 on_element_unselected: None,
17 }
18 }
19}
20
21pub struct CombinationIterator<'a, T: Copy> {
22 arr: &'a [T],
23 n: usize,
24 k: usize,
25 options: CombinationOptions<'a>,
26 stack: ArrayVec<usize, MAX_SIZE>,
27 result: ArrayVec<T, MAX_SIZE>,
28}
29
30impl<'a, T: Copy> CombinationIterator<'a, T> {
31 #[inline(always)]
32 pub fn new(arr: &'a [T], k: usize, options: CombinationOptions<'a>) -> Self {
33 debug_assert!(k <= MAX_SIZE);
34
35 let stack = ArrayVec::<usize, MAX_SIZE>::new();
36 let result = ArrayVec::<T, MAX_SIZE>::new();
37 let n = arr.len();
38 Self {
39 arr,
40 n,
41 k,
42 options,
43 stack,
44 result,
45 }
46 }
47
48 #[inline(always)]
49 fn try_update(&mut self, current: usize) -> Option<&'a [T]> {
50 if let Some(ref mut on_element_selected) = self.options.on_element_selected {
51 if !on_element_selected(current, self.stack[current]) {
52 return None;
53 }
54 }
55 self.result.push(self.arr[self.stack[current]]);
56 for i in current + 1..self.k {
57 self.stack.push(self.stack[i - 1] + 1);
58 if let Some(ref mut on_element_selected) = self.options.on_element_selected {
59 if !on_element_selected(i, self.stack[i]) {
60 return None;
61 }
62 }
63 self.result.push(self.arr[self.stack[i]]);
64 }
65
66 unsafe { Some(&*(self.result.as_slice() as *const [T])) }
68 }
69}
70
71impl<'a, T: Copy> Iterator for CombinationIterator<'a, T> {
72 type Item = &'a [T];
73
74 #[inline(always)]
75 fn next(&mut self) -> Option<Self::Item> {
76 let mut skip_unselected = false;
77 if self.stack.len() == 0 {
78 self.stack.push(0);
79 return_if_some!(self.try_update(0));
80 skip_unselected = true;
81 }
82
83 while let Some(¤t_element) = self.stack.last() {
84 let stack_index = self.stack.len() - 1;
85 if current_element + (self.k - stack_index) >= self.n {
86 self.stack.pop().unwrap();
87 if skip_unselected {
88 skip_unselected = false;
89 } else {
90 self.result.pop().unwrap();
91 if let Some(on_element_unselected) = self.options.on_element_unselected.as_mut()
92 {
93 on_element_unselected(stack_index, current_element);
94 }
95 }
96 continue;
97 }
98
99 if skip_unselected {
100 skip_unselected = false;
101 } else {
102 self.result.pop().unwrap();
103 if let Some(on_element_unselected) = self.options.on_element_unselected.as_mut() {
104 on_element_unselected(stack_index, current_element);
105 }
106 }
107
108 *self.stack.last_mut().unwrap() += 1;
109 return_if_some!(self.try_update(stack_index));
110 skip_unselected = true;
111 }
112
113 None
114 }
115}
116
117pub fn combinations<'a, T: Copy>(
118 arr: &'a [T],
119 k: usize,
120 options: CombinationOptions<'a>,
121) -> CombinationIterator<'a, T> {
122 CombinationIterator::new(arr, k, options)
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use itertools::Itertools;
129
130 #[test]
131 fn test_combination_iterator() {
132 let arr = [1, 2, 3, 4, 5];
133 let options = CombinationOptions {
134 on_element_selected: None,
135 on_element_unselected: None,
136 };
137 let iter = CombinationIterator::new(&arr, 2, options);
138 let result = iter
139 .map(|s| s.iter().cloned().collect::<Vec<_>>())
140 .collect::<Vec<_>>();
141 let expected = arr.iter().copied().combinations(2).collect::<Vec<_>>();
142 assert_eq!(result, expected);
143 }
144
145 #[test]
146 fn test_combination_iterator_options() {
147 let arr = [1, 2, 3, 4, 5];
148 let mut selected_order = vec![];
149 let mut unselected_order = vec![];
150 let ref mut on_element_selected = |pos, element| {
151 if element != 3 {
152 selected_order.push(element);
153 return true;
154 }
155 false
156 };
157 let ref mut on_element_unselected = |pos, element| {
158 unselected_order.push(element);
159 };
160 let options = CombinationOptions {
161 on_element_selected: Some(on_element_selected),
162 on_element_unselected: Some(on_element_unselected),
163 };
164 let iter = CombinationIterator::new(&arr, 2, options);
165 let result = iter
166 .map(|s| s.iter().cloned().collect::<Vec<_>>())
167 .collect::<Vec<_>>();
168 let expected = arr
169 .iter()
170 .copied()
171 .filter(|&x| x != 4)
172 .combinations(2)
173 .collect::<Vec<_>>();
174 assert_eq!(result, expected);
175 assert_eq!(selected_order, [0, 1, 2, 4, 1, 2, 4, 2, 4]);
176 assert_eq!(unselected_order, [1, 2, 4, 0, 2, 4, 1, 4, 2]);
177 }
178}