compound_factor_iter/
ordered_permutation_iter.rs

1
2use std::cmp::Ordering;
3
4/// An expensive permutation iterator guaranteed to return results in order, from highest to lowest.
5/// 
6/// The algorithm works by exhaustively exploring a "frontier region" where results transition from
7/// greater than a previous result to less than a previous result.  Because of the multi-dimensional
8/// nature of this frontier, the algorithm must try every possible permutation within the frontier
9/// region to be sure the next-highest result is found.
10/// 
11/// The algorithm can be **insanely** expensive because it needs to invoke the `combination_fn`
12/// closure potentially `n*2^n` times at each step, where `n` is the number of factors.
13/// Due to the cost, the OrderedPermutationIter is only for situaitions where out-of-order
14/// results are unaccaptable.  **Otherwise, [ManhattanPermutationIter](crate::ManhattanPermutationIter) is recommended.**
15/// 
16/// ## Future Work
17///
18/// * High performance "Equal" path in OrderedPermutationIter.  Currently OrderedPermutationIter
19/// buffers up results with an equal value found during one exploration.  Then it returns
20/// results out of that buffer until they are exhausted after which it begins another search.
21/// This assumes equal-value results are an exception.  In use-cases where they are numerous, a
22/// better approach would be to have two traversal modes, one mode searching for the best
23/// next-result and the other mode scanning for the next equal-result.
24///
25pub struct OrderedPermutationIter<'a, T> {
26
27    /// The individual distributions we're iterating the permutations of
28    sorted_dists: Vec<Vec<(usize, T)>>,
29
30    /// A function capable of combining factors
31    combination_fn: &'a dyn Fn(&[T]) -> Option<T>,
32    
33    /// The current position of the result, as indices into the sorted_dists arrays
34    state: Vec<usize>,
35
36    /// The highest value that the state has achieved for a given factor
37    high_water_mark: Vec<usize>,
38
39    /// The threshold value, corresponding to the last returned result
40    current_val: T,
41
42    /// A place to stash future results with values that equal to the last-returned result
43    result_stash: Vec<(Vec<usize>, T)>,
44}
45
46impl<'a, T> OrderedPermutationIter<'a, T>
47    where
48    T: Copy + PartialOrd + num_traits::Bounded,
49{
50    pub fn new<E: AsRef<[T]>, F: Fn(&[T]) -> Option<T>>(factor_iter: impl Iterator<Item=E>, combination_fn: &'a F) -> Self {
51
52        let sorted_dists: Vec<Vec<(usize, T)>> = factor_iter
53            .map(|factor_dist| {
54                let mut sorted_elements: Vec<(usize, T)> = factor_dist.as_ref().iter().cloned().enumerate().collect();
55                sorted_elements.sort_by(|(_idx_a, element_a), (_idx_b, element_b)| element_b.partial_cmp(element_a).unwrap_or(Ordering::Equal));
56                sorted_elements
57            })
58            .collect();
59
60        let factor_count = sorted_dists.len();
61
62        Self {
63            sorted_dists,
64            combination_fn,
65            state: vec![0; factor_count],
66            high_water_mark: vec![0; factor_count],
67            current_val: T::max_value(),
68            result_stash: vec![],
69        }
70    }
71    pub fn factor_count(&self) -> usize {
72        self.sorted_dists.len()
73    }
74    fn factors_from_state(&self, state: &[usize]) -> Vec<T> {
75
76        let mut factors = Vec::with_capacity(state.len());
77        for (slot_idx, sorted_idx) in state.iter().enumerate() {
78            factors.push(self.sorted_dists[slot_idx][*sorted_idx].1);
79        }
80
81        factors
82    }
83    fn execute_combine_fn(&self, factors: &[T]) -> Option<T> {
84        (self.combination_fn)(&factors)
85    }
86    fn state_to_result(&self) -> Option<(Vec<usize>, T)> {
87        
88        let result: Vec<usize> = self.state.iter()
89            .enumerate()
90            .map(|(slot_idx, sorted_factor_idx)| self.sorted_dists[slot_idx][*sorted_factor_idx].0)
91            .collect();
92
93        let factors = self.factors_from_state(&self.state);
94
95        self.execute_combine_fn(&factors)
96            .map(|combined_val| (result, combined_val))
97    }
98    /// Searches the frontier around a state, looking for the next state that has the highest overall
99    /// combined value, that is lower than the current_val.  Returns None if it's impossible to advance
100    /// to a valid value.
101    /// 
102    fn find_smallest_next_increment(&self) -> Option<Vec<(Vec<usize>, T)>> {
103
104        //Explanation of overall algorithm:
105        //
106        //The algorithm maintains a "frontier region" between the positions of "tops" and "bottoms"
107        //First, "tops" are set, and then "bottoms" are discovered based on "tops".  Higher tops
108        // results in less constraining pressure and therfore lower bottoms as well, making the
109        // search space explode.  Therefore we need to be very judicious about advancing "tops"
110        //
111        //This algorithm consists of 3 nested loops.
112        //
113        //1. The outermost loop is the "factor_to_advance" loop, which iterates for each factor
114        // plus a final iteration that advances no factors.  The idea is that we advance each single
115        // factor to the furthest point it's ever been to for a new "top" on that factor alone.
116        // Then we search for permutations using that particular frontier, and repeat for each factor.
117        //
118        //2. The "while !finished" loop is the hairy one.  It systematically tries every permutation
119        // between tops and bottoms. This loop can iterate 2^n times for n factors, and sometimes more
120        //
121        //3. The "rollover loop", aka `while temp_state[cur_factor] > tops[cur_factor]` is effectively
122        // just an incrementor for a mixed-radix number.  It's carrying forward the increments until
123        // it finds a place to put them, or determines the iteration is finished
124        //
125
126        let factor_count = self.factor_count();
127
128        let mut highest_val = T::min_value();
129        let mut return_val = None;
130
131        //NOTE: when factor_to_advance == factor_count, that means we don't attempt to advance any factor
132        for factor_to_advance in 0..(factor_count+1) {
133
134            //The "tops" are the highest values each individual factor could possibly have and still reference
135            // the next permutation in the sequence
136            let mut skip_factor = false;
137            let mut tops = Vec::with_capacity(factor_count);
138            for (i , &val) in self.high_water_mark.iter().enumerate() {
139                if i == factor_to_advance {
140                    if val+1 < self.sorted_dists[i].len() {
141                        tops.push(val+1);
142                    } else {
143                        skip_factor = true;
144                    }
145                } else {
146                    tops.push(val);
147                }
148            }
149            if skip_factor {
150                continue;
151            }
152
153            //Find the "bottoms", i.e. the lowest value each factor could possibly have given
154            // the "tops", without exceeding the threshold established by self.current_val
155            let mut bottoms = Vec::with_capacity(factor_count);
156            for i in 0..factor_count {
157                let old_top = tops[i];
158                let mut new_bottom = self.state[i];
159                loop {
160                    if new_bottom == 0 {
161                        bottoms.push(0);
162                        break;
163                    }
164                    tops[i] = new_bottom; //Temporarily hijacking tops
165                    let factors = self.factors_from_state(&tops);
166                    let val = self.execute_combine_fn(&factors);
167                    if val.is_some() && val.unwrap() > self.current_val {
168                        bottoms.push(new_bottom+1);
169                        break;
170                    } else {
171                        new_bottom -= 1;
172                    }
173                }
174                tops[i] = old_top;
175            }
176
177            //We need to check every combination of adjustments between tops and bottoms
178            let mut temp_state = bottoms.clone();
179            let mut temp_factors = self.factors_from_state(&temp_state);
180            if factor_to_advance < factor_count {
181                temp_state[factor_to_advance] = tops[factor_to_advance];
182                temp_factors[factor_to_advance] = self.sorted_dists[factor_to_advance][temp_state[factor_to_advance]].1;
183            }
184            let mut finished = false;
185            while !finished {
186    
187                //Increment the adjustments to the next state we want to try
188                //NOTE: It is impossible for the initial starting case (all bottoms) to be the
189                // next sequence element, because it's going to be the current sequence element
190                // or something earlier
191                let mut cur_factor;
192                if factor_to_advance != 0 {
193                    temp_state[0] += 1;
194                    if temp_state[0] < self.sorted_dists[0].len() {
195                        temp_factors[0] = self.sorted_dists[0][temp_state[0]].1;
196                    }
197                    cur_factor = 0;
198                } else {
199                    temp_state[1] += 1;
200                    if temp_state[1] < self.sorted_dists[1].len() {
201                        temp_factors[1] = self.sorted_dists[1][temp_state[1]].1;
202                    }
203                    cur_factor = 1;
204                }
205
206                //Deal with any rollover caused by the increment above
207                while temp_state[cur_factor] > tops[cur_factor] {
208
209                    temp_state[cur_factor] = bottoms[cur_factor];
210                    temp_factors[cur_factor] = self.sorted_dists[cur_factor][temp_state[cur_factor]].1;
211                    cur_factor += 1;
212
213                    //Skip over the factor_to_advance, which we're going to leave pegged to tops
214                    if cur_factor == factor_to_advance {
215                        cur_factor += 1;
216                    }
217
218                    if cur_factor < factor_count {
219                        temp_state[cur_factor] += 1;
220                        if temp_state[cur_factor] < self.sorted_dists[cur_factor].len() {
221                            temp_factors[cur_factor] = self.sorted_dists[cur_factor][temp_state[cur_factor]].1;
222                        }
223                    } else {
224                        finished = true;
225                        break;
226                    }
227                }
228    
229                if let Some(temp_val) = self.execute_combine_fn(&temp_factors) {
230                    if temp_val < self.current_val && temp_val >= highest_val {
231
232                        if temp_val > highest_val {
233                            //Replace the results with a fresh array
234                            highest_val = temp_val;
235                            return_val = Some(vec![(temp_state.clone(), highest_val)]);
236                        } else {
237                            //We can infer temp_val == highest_val if we got here, so
238                            // append to the results array
239                            return_val.as_mut().unwrap().push((temp_state.clone(), highest_val));
240                        }
241                    }
242                }
243            }
244        }
245
246        //See if there are any additional results with the same combined value, adjacent to the
247        // results we found
248        if let Some(results) = &mut return_val.as_mut() {
249            let mut new_results = results.clone();
250            for (result, val) in results.iter() {
251                self.find_adjacent_equal_permutations(result, *val, &mut new_results);
252            }
253            **results = new_results;
254        }
255
256        return_val
257    }
258    //An Adjacent Permutation is defined as a permutation that can be created by adding 1 to one
259    // factor.  This function will find all adjacent permutations from the supplied state, with a
260    // value equal to the supplied "val" argument.  It will also find the equal permutations from
261    // all found permutations, recursively.
262    fn find_adjacent_equal_permutations(&self, state: &[usize], val: T, results: &mut Vec<(Vec<usize>, T)>) {
263
264        let factor_count = self.factor_count();
265        let mut new_state = state.to_owned();
266        
267        loop {
268
269            //Increment the state by 1 and get the new value
270            new_state[0] += 1;
271            let mut cur_digit = 0;
272            let mut temp_val = if new_state[cur_digit] < self.sorted_dists[cur_digit].len() {
273                let factors = self.factors_from_state(&new_state);
274                self.execute_combine_fn(&factors)
275            } else {
276                None
277            };
278
279            //Deal with the rollover caused by the previous increment
280            //NOTE: This loop has two continuing criteria. 1.) If we must roll over because
281            // we've incremented one factor to the end, and 2.) If the new combined value is too
282            // small, indicating the factor shouldn't be considered in an equal permutation
283            while new_state[cur_digit] == self.sorted_dists[cur_digit].len()
284                || (temp_val.is_some() && temp_val.unwrap() < val) {
285
286                new_state[cur_digit] = state[cur_digit];
287                cur_digit += 1;
288
289                if cur_digit == factor_count {
290                    break;
291                }
292
293                new_state[cur_digit] += 1;
294                if new_state[cur_digit] < self.sorted_dists[cur_digit].len() {
295                    let factors = self.factors_from_state(&new_state);
296                    temp_val = self.execute_combine_fn(&factors);
297                }
298            }
299            
300            if temp_val.is_some() && temp_val.unwrap() == val {
301                //Check for duplicates, and add this state if it's unique
302                if results.iter().position(|(element_state, _val)| *element_state == new_state).is_none() {
303                    results.push((new_state.clone(), val));
304                }
305            } else {
306                break;
307            }
308        }
309    }
310}
311
312impl<T> Iterator for OrderedPermutationIter<'_, T>
313    where
314    T: Copy + PartialOrd + num_traits::Bounded,
315{
316    type Item = (Vec<usize>, T);
317
318    fn next(&mut self) -> Option<Self::Item> {
319        
320        let factor_count = self.factor_count();
321
322        //If we have some results in the stash, return those first
323        if let Some((new_state, new_val)) = self.result_stash.pop() {
324            self.state = new_state;
325            self.current_val = new_val;
326
327            return self.state_to_result();
328        }
329
330        //Find the next configuration with the smallest incremental impact to the combined value
331        if let Some(new_states) = self.find_smallest_next_increment() {
332        
333            //Advance the high-water mark for all returned states
334            for (new_state, _new_val) in new_states.iter() {
335                for i in 0..factor_count {
336                    if new_state[i] > self.high_water_mark[i] {
337                        self.high_water_mark[i] = new_state[i];
338                    }
339                }
340            }
341
342            //Stash all the results we got
343            self.result_stash = new_states;
344
345            //Return one result from our stash
346            let (new_state, new_val) = self.result_stash.pop().unwrap();
347            self.state = new_state;
348            self.current_val = new_val;
349
350            return self.state_to_result();
351                
352        } else {
353            //If we couldn't find any factors to advancee, we've reached the end of the iteration
354            return None;
355        }
356    }
357}