compound_factor_iter/ordered_permutation_iter.rs
1
2use std::cmp::Ordering;
3
4/// An expensive permutation iterator guaranteed to return results in order, from highest to lowest.
5///
6/// The algorithm works by exhaustively exploring a "frontier region" where results transition from
7/// greater than a previous result to less than a previous result. Because of the multi-dimensional
8/// nature of this frontier, the algorithm must try every possible permutation within the frontier
9/// region to be sure the next-highest result is found.
10///
11/// The algorithm can be **insanely** expensive because it needs to invoke the `combination_fn`
12/// closure potentially `n*2^n` times at each step, where `n` is the number of factors.
13/// Due to the cost, the OrderedPermutationIter is only for situaitions where out-of-order
14/// results are unaccaptable. **Otherwise, [ManhattanPermutationIter](crate::ManhattanPermutationIter) is recommended.**
15///
16/// ## Future Work
17///
18/// * High performance "Equal" path in OrderedPermutationIter. Currently OrderedPermutationIter
19/// buffers up results with an equal value found during one exploration. Then it returns
20/// results out of that buffer until they are exhausted after which it begins another search.
21/// This assumes equal-value results are an exception. In use-cases where they are numerous, a
22/// better approach would be to have two traversal modes, one mode searching for the best
23/// next-result and the other mode scanning for the next equal-result.
24///
25pub struct OrderedPermutationIter<'a, T> {
26
27 /// The individual distributions we're iterating the permutations of
28 sorted_dists: Vec<Vec<(usize, T)>>,
29
30 /// A function capable of combining factors
31 combination_fn: &'a dyn Fn(&[T]) -> Option<T>,
32
33 /// The current position of the result, as indices into the sorted_dists arrays
34 state: Vec<usize>,
35
36 /// The highest value that the state has achieved for a given factor
37 high_water_mark: Vec<usize>,
38
39 /// The threshold value, corresponding to the last returned result
40 current_val: T,
41
42 /// A place to stash future results with values that equal to the last-returned result
43 result_stash: Vec<(Vec<usize>, T)>,
44}
45
46impl<'a, T> OrderedPermutationIter<'a, T>
47 where
48 T: Copy + PartialOrd + num_traits::Bounded,
49{
50 pub fn new<E: AsRef<[T]>, F: Fn(&[T]) -> Option<T>>(factor_iter: impl Iterator<Item=E>, combination_fn: &'a F) -> Self {
51
52 let sorted_dists: Vec<Vec<(usize, T)>> = factor_iter
53 .map(|factor_dist| {
54 let mut sorted_elements: Vec<(usize, T)> = factor_dist.as_ref().iter().cloned().enumerate().collect();
55 sorted_elements.sort_by(|(_idx_a, element_a), (_idx_b, element_b)| element_b.partial_cmp(element_a).unwrap_or(Ordering::Equal));
56 sorted_elements
57 })
58 .collect();
59
60 let factor_count = sorted_dists.len();
61
62 Self {
63 sorted_dists,
64 combination_fn,
65 state: vec![0; factor_count],
66 high_water_mark: vec![0; factor_count],
67 current_val: T::max_value(),
68 result_stash: vec![],
69 }
70 }
71 pub fn factor_count(&self) -> usize {
72 self.sorted_dists.len()
73 }
74 fn factors_from_state(&self, state: &[usize]) -> Vec<T> {
75
76 let mut factors = Vec::with_capacity(state.len());
77 for (slot_idx, sorted_idx) in state.iter().enumerate() {
78 factors.push(self.sorted_dists[slot_idx][*sorted_idx].1);
79 }
80
81 factors
82 }
83 fn execute_combine_fn(&self, factors: &[T]) -> Option<T> {
84 (self.combination_fn)(&factors)
85 }
86 fn state_to_result(&self) -> Option<(Vec<usize>, T)> {
87
88 let result: Vec<usize> = self.state.iter()
89 .enumerate()
90 .map(|(slot_idx, sorted_factor_idx)| self.sorted_dists[slot_idx][*sorted_factor_idx].0)
91 .collect();
92
93 let factors = self.factors_from_state(&self.state);
94
95 self.execute_combine_fn(&factors)
96 .map(|combined_val| (result, combined_val))
97 }
98 /// Searches the frontier around a state, looking for the next state that has the highest overall
99 /// combined value, that is lower than the current_val. Returns None if it's impossible to advance
100 /// to a valid value.
101 ///
102 fn find_smallest_next_increment(&self) -> Option<Vec<(Vec<usize>, T)>> {
103
104 //Explanation of overall algorithm:
105 //
106 //The algorithm maintains a "frontier region" between the positions of "tops" and "bottoms"
107 //First, "tops" are set, and then "bottoms" are discovered based on "tops". Higher tops
108 // results in less constraining pressure and therfore lower bottoms as well, making the
109 // search space explode. Therefore we need to be very judicious about advancing "tops"
110 //
111 //This algorithm consists of 3 nested loops.
112 //
113 //1. The outermost loop is the "factor_to_advance" loop, which iterates for each factor
114 // plus a final iteration that advances no factors. The idea is that we advance each single
115 // factor to the furthest point it's ever been to for a new "top" on that factor alone.
116 // Then we search for permutations using that particular frontier, and repeat for each factor.
117 //
118 //2. The "while !finished" loop is the hairy one. It systematically tries every permutation
119 // between tops and bottoms. This loop can iterate 2^n times for n factors, and sometimes more
120 //
121 //3. The "rollover loop", aka `while temp_state[cur_factor] > tops[cur_factor]` is effectively
122 // just an incrementor for a mixed-radix number. It's carrying forward the increments until
123 // it finds a place to put them, or determines the iteration is finished
124 //
125
126 let factor_count = self.factor_count();
127
128 let mut highest_val = T::min_value();
129 let mut return_val = None;
130
131 //NOTE: when factor_to_advance == factor_count, that means we don't attempt to advance any factor
132 for factor_to_advance in 0..(factor_count+1) {
133
134 //The "tops" are the highest values each individual factor could possibly have and still reference
135 // the next permutation in the sequence
136 let mut skip_factor = false;
137 let mut tops = Vec::with_capacity(factor_count);
138 for (i , &val) in self.high_water_mark.iter().enumerate() {
139 if i == factor_to_advance {
140 if val+1 < self.sorted_dists[i].len() {
141 tops.push(val+1);
142 } else {
143 skip_factor = true;
144 }
145 } else {
146 tops.push(val);
147 }
148 }
149 if skip_factor {
150 continue;
151 }
152
153 //Find the "bottoms", i.e. the lowest value each factor could possibly have given
154 // the "tops", without exceeding the threshold established by self.current_val
155 let mut bottoms = Vec::with_capacity(factor_count);
156 for i in 0..factor_count {
157 let old_top = tops[i];
158 let mut new_bottom = self.state[i];
159 loop {
160 if new_bottom == 0 {
161 bottoms.push(0);
162 break;
163 }
164 tops[i] = new_bottom; //Temporarily hijacking tops
165 let factors = self.factors_from_state(&tops);
166 let val = self.execute_combine_fn(&factors);
167 if val.is_some() && val.unwrap() > self.current_val {
168 bottoms.push(new_bottom+1);
169 break;
170 } else {
171 new_bottom -= 1;
172 }
173 }
174 tops[i] = old_top;
175 }
176
177 //We need to check every combination of adjustments between tops and bottoms
178 let mut temp_state = bottoms.clone();
179 let mut temp_factors = self.factors_from_state(&temp_state);
180 if factor_to_advance < factor_count {
181 temp_state[factor_to_advance] = tops[factor_to_advance];
182 temp_factors[factor_to_advance] = self.sorted_dists[factor_to_advance][temp_state[factor_to_advance]].1;
183 }
184 let mut finished = false;
185 while !finished {
186
187 //Increment the adjustments to the next state we want to try
188 //NOTE: It is impossible for the initial starting case (all bottoms) to be the
189 // next sequence element, because it's going to be the current sequence element
190 // or something earlier
191 let mut cur_factor;
192 if factor_to_advance != 0 {
193 temp_state[0] += 1;
194 if temp_state[0] < self.sorted_dists[0].len() {
195 temp_factors[0] = self.sorted_dists[0][temp_state[0]].1;
196 }
197 cur_factor = 0;
198 } else {
199 temp_state[1] += 1;
200 if temp_state[1] < self.sorted_dists[1].len() {
201 temp_factors[1] = self.sorted_dists[1][temp_state[1]].1;
202 }
203 cur_factor = 1;
204 }
205
206 //Deal with any rollover caused by the increment above
207 while temp_state[cur_factor] > tops[cur_factor] {
208
209 temp_state[cur_factor] = bottoms[cur_factor];
210 temp_factors[cur_factor] = self.sorted_dists[cur_factor][temp_state[cur_factor]].1;
211 cur_factor += 1;
212
213 //Skip over the factor_to_advance, which we're going to leave pegged to tops
214 if cur_factor == factor_to_advance {
215 cur_factor += 1;
216 }
217
218 if cur_factor < factor_count {
219 temp_state[cur_factor] += 1;
220 if temp_state[cur_factor] < self.sorted_dists[cur_factor].len() {
221 temp_factors[cur_factor] = self.sorted_dists[cur_factor][temp_state[cur_factor]].1;
222 }
223 } else {
224 finished = true;
225 break;
226 }
227 }
228
229 if let Some(temp_val) = self.execute_combine_fn(&temp_factors) {
230 if temp_val < self.current_val && temp_val >= highest_val {
231
232 if temp_val > highest_val {
233 //Replace the results with a fresh array
234 highest_val = temp_val;
235 return_val = Some(vec![(temp_state.clone(), highest_val)]);
236 } else {
237 //We can infer temp_val == highest_val if we got here, so
238 // append to the results array
239 return_val.as_mut().unwrap().push((temp_state.clone(), highest_val));
240 }
241 }
242 }
243 }
244 }
245
246 //See if there are any additional results with the same combined value, adjacent to the
247 // results we found
248 if let Some(results) = &mut return_val.as_mut() {
249 let mut new_results = results.clone();
250 for (result, val) in results.iter() {
251 self.find_adjacent_equal_permutations(result, *val, &mut new_results);
252 }
253 **results = new_results;
254 }
255
256 return_val
257 }
258 //An Adjacent Permutation is defined as a permutation that can be created by adding 1 to one
259 // factor. This function will find all adjacent permutations from the supplied state, with a
260 // value equal to the supplied "val" argument. It will also find the equal permutations from
261 // all found permutations, recursively.
262 fn find_adjacent_equal_permutations(&self, state: &[usize], val: T, results: &mut Vec<(Vec<usize>, T)>) {
263
264 let factor_count = self.factor_count();
265 let mut new_state = state.to_owned();
266
267 loop {
268
269 //Increment the state by 1 and get the new value
270 new_state[0] += 1;
271 let mut cur_digit = 0;
272 let mut temp_val = if new_state[cur_digit] < self.sorted_dists[cur_digit].len() {
273 let factors = self.factors_from_state(&new_state);
274 self.execute_combine_fn(&factors)
275 } else {
276 None
277 };
278
279 //Deal with the rollover caused by the previous increment
280 //NOTE: This loop has two continuing criteria. 1.) If we must roll over because
281 // we've incremented one factor to the end, and 2.) If the new combined value is too
282 // small, indicating the factor shouldn't be considered in an equal permutation
283 while new_state[cur_digit] == self.sorted_dists[cur_digit].len()
284 || (temp_val.is_some() && temp_val.unwrap() < val) {
285
286 new_state[cur_digit] = state[cur_digit];
287 cur_digit += 1;
288
289 if cur_digit == factor_count {
290 break;
291 }
292
293 new_state[cur_digit] += 1;
294 if new_state[cur_digit] < self.sorted_dists[cur_digit].len() {
295 let factors = self.factors_from_state(&new_state);
296 temp_val = self.execute_combine_fn(&factors);
297 }
298 }
299
300 if temp_val.is_some() && temp_val.unwrap() == val {
301 //Check for duplicates, and add this state if it's unique
302 if results.iter().position(|(element_state, _val)| *element_state == new_state).is_none() {
303 results.push((new_state.clone(), val));
304 }
305 } else {
306 break;
307 }
308 }
309 }
310}
311
312impl<T> Iterator for OrderedPermutationIter<'_, T>
313 where
314 T: Copy + PartialOrd + num_traits::Bounded,
315{
316 type Item = (Vec<usize>, T);
317
318 fn next(&mut self) -> Option<Self::Item> {
319
320 let factor_count = self.factor_count();
321
322 //If we have some results in the stash, return those first
323 if let Some((new_state, new_val)) = self.result_stash.pop() {
324 self.state = new_state;
325 self.current_val = new_val;
326
327 return self.state_to_result();
328 }
329
330 //Find the next configuration with the smallest incremental impact to the combined value
331 if let Some(new_states) = self.find_smallest_next_increment() {
332
333 //Advance the high-water mark for all returned states
334 for (new_state, _new_val) in new_states.iter() {
335 for i in 0..factor_count {
336 if new_state[i] > self.high_water_mark[i] {
337 self.high_water_mark[i] = new_state[i];
338 }
339 }
340 }
341
342 //Stash all the results we got
343 self.result_stash = new_states;
344
345 //Return one result from our stash
346 let (new_state, new_val) = self.result_stash.pop().unwrap();
347 self.state = new_state;
348 self.current_val = new_val;
349
350 return self.state_to_result();
351
352 } else {
353 //If we couldn't find any factors to advancee, we've reached the end of the iteration
354 return None;
355 }
356 }
357}