use std::cmp::Ordering;
pub struct OrderedPermutationIter<'a, T> {
sorted_dists: Vec<Vec<(usize, T)>>,
combination_fn: &'a dyn Fn(&[T]) -> Option<T>,
state: Vec<usize>,
high_water_mark: Vec<usize>,
current_val: T,
result_stash: Vec<(Vec<usize>, T)>,
}
impl<'a, T> OrderedPermutationIter<'a, T>
where
T: Copy + PartialOrd + num_traits::Bounded,
{
pub fn new<E: AsRef<[T]>, F: Fn(&[T]) -> Option<T>>(factor_iter: impl Iterator<Item=E>, combination_fn: &'a F) -> Self {
let sorted_dists: Vec<Vec<(usize, T)>> = factor_iter
.map(|factor_dist| {
let mut sorted_elements: Vec<(usize, T)> = factor_dist.as_ref().iter().cloned().enumerate().collect();
sorted_elements.sort_by(|(_idx_a, element_a), (_idx_b, element_b)| element_b.partial_cmp(element_a).unwrap_or(Ordering::Equal));
sorted_elements
})
.collect();
let factor_count = sorted_dists.len();
Self {
sorted_dists,
combination_fn,
state: vec![0; factor_count],
high_water_mark: vec![0; factor_count],
current_val: T::max_value(),
result_stash: vec![],
}
}
pub fn factor_count(&self) -> usize {
self.sorted_dists.len()
}
fn factors_from_state(&self, state: &[usize]) -> Vec<T> {
let mut factors = Vec::with_capacity(state.len());
for (slot_idx, sorted_idx) in state.iter().enumerate() {
factors.push(self.sorted_dists[slot_idx][*sorted_idx].1);
}
factors
}
fn execute_combine_fn(&self, factors: &[T]) -> Option<T> {
(self.combination_fn)(&factors)
}
fn state_to_result(&self) -> Option<(Vec<usize>, T)> {
let result: Vec<usize> = self.state.iter()
.enumerate()
.map(|(slot_idx, sorted_factor_idx)| self.sorted_dists[slot_idx][*sorted_factor_idx].0)
.collect();
let factors = self.factors_from_state(&self.state);
self.execute_combine_fn(&factors)
.map(|combined_val| (result, combined_val))
}
fn find_smallest_next_increment(&self) -> Option<Vec<(Vec<usize>, T)>> {
let factor_count = self.factor_count();
let mut highest_val = T::min_value();
let mut return_val = None;
for factor_to_advance in 0..(factor_count+1) {
let mut skip_factor = false;
let mut tops = Vec::with_capacity(factor_count);
for (i , &val) in self.high_water_mark.iter().enumerate() {
if i == factor_to_advance {
if val+1 < self.sorted_dists[i].len() {
tops.push(val+1);
} else {
skip_factor = true;
}
} else {
tops.push(val);
}
}
if skip_factor {
continue;
}
let mut bottoms = Vec::with_capacity(factor_count);
for i in 0..factor_count {
let old_top = tops[i];
let mut new_bottom = self.state[i];
loop {
if new_bottom == 0 {
bottoms.push(0);
break;
}
tops[i] = new_bottom; let factors = self.factors_from_state(&tops);
let val = self.execute_combine_fn(&factors);
if val.is_some() && val.unwrap() > self.current_val {
bottoms.push(new_bottom+1);
break;
} else {
new_bottom -= 1;
}
}
tops[i] = old_top;
}
let mut temp_state = bottoms.clone();
let mut temp_factors = self.factors_from_state(&temp_state);
if factor_to_advance < factor_count {
temp_state[factor_to_advance] = tops[factor_to_advance];
temp_factors[factor_to_advance] = self.sorted_dists[factor_to_advance][temp_state[factor_to_advance]].1;
}
let mut finished = false;
while !finished {
let mut cur_factor;
if factor_to_advance != 0 {
temp_state[0] += 1;
if temp_state[0] < self.sorted_dists[0].len() {
temp_factors[0] = self.sorted_dists[0][temp_state[0]].1;
}
cur_factor = 0;
} else {
temp_state[1] += 1;
if temp_state[1] < self.sorted_dists[1].len() {
temp_factors[1] = self.sorted_dists[1][temp_state[1]].1;
}
cur_factor = 1;
}
while temp_state[cur_factor] > tops[cur_factor] {
temp_state[cur_factor] = bottoms[cur_factor];
temp_factors[cur_factor] = self.sorted_dists[cur_factor][temp_state[cur_factor]].1;
cur_factor += 1;
if cur_factor == factor_to_advance {
cur_factor += 1;
}
if cur_factor < factor_count {
temp_state[cur_factor] += 1;
if temp_state[cur_factor] < self.sorted_dists[cur_factor].len() {
temp_factors[cur_factor] = self.sorted_dists[cur_factor][temp_state[cur_factor]].1;
}
} else {
finished = true;
break;
}
}
if let Some(temp_val) = self.execute_combine_fn(&temp_factors) {
if temp_val < self.current_val && temp_val >= highest_val {
if temp_val > highest_val {
highest_val = temp_val;
return_val = Some(vec![(temp_state.clone(), highest_val)]);
} else {
return_val.as_mut().unwrap().push((temp_state.clone(), highest_val));
}
}
}
}
}
if let Some(results) = &mut return_val.as_mut() {
let mut new_results = results.clone();
for (result, val) in results.iter() {
self.find_adjacent_equal_permutations(result, *val, &mut new_results);
}
**results = new_results;
}
return_val
}
fn find_adjacent_equal_permutations(&self, state: &[usize], val: T, results: &mut Vec<(Vec<usize>, T)>) {
let factor_count = self.factor_count();
let mut new_state = state.to_owned();
loop {
new_state[0] += 1;
let mut cur_digit = 0;
let mut temp_val = if new_state[cur_digit] < self.sorted_dists[cur_digit].len() {
let factors = self.factors_from_state(&new_state);
self.execute_combine_fn(&factors)
} else {
None
};
while new_state[cur_digit] == self.sorted_dists[cur_digit].len()
|| (temp_val.is_some() && temp_val.unwrap() < val) {
new_state[cur_digit] = state[cur_digit];
cur_digit += 1;
if cur_digit == factor_count {
break;
}
new_state[cur_digit] += 1;
if new_state[cur_digit] < self.sorted_dists[cur_digit].len() {
let factors = self.factors_from_state(&new_state);
temp_val = self.execute_combine_fn(&factors);
}
}
if temp_val.is_some() && temp_val.unwrap() == val {
if results.iter().position(|(element_state, _val)| *element_state == new_state).is_none() {
results.push((new_state.clone(), val));
}
} else {
break;
}
}
}
}
impl<T> Iterator for OrderedPermutationIter<'_, T>
where
T: Copy + PartialOrd + num_traits::Bounded,
{
type Item = (Vec<usize>, T);
fn next(&mut self) -> Option<Self::Item> {
let factor_count = self.factor_count();
if let Some((new_state, new_val)) = self.result_stash.pop() {
self.state = new_state;
self.current_val = new_val;
return self.state_to_result();
}
if let Some(new_states) = self.find_smallest_next_increment() {
for (new_state, _new_val) in new_states.iter() {
for i in 0..factor_count {
if new_state[i] > self.high_water_mark[i] {
self.high_water_mark[i] = new_state[i];
}
}
}
self.result_stash = new_states;
let (new_state, new_val) = self.result_stash.pop().unwrap();
self.state = new_state;
self.current_val = new_val;
return self.state_to_result();
} else {
return None;
}
}
}