use std::cmp::Ordering;
pub struct RadixPermutationIter<'a, T> {
sorted_dists: Vec<Vec<(usize, T)>>,
combination_fn: &'a dyn Fn(&[T]) -> Option<T>,
state: Vec<usize>,
max_factors: Vec<usize>,
global_max_digit: usize,
pegged_factor: usize,
step_count: usize,
}
impl<'a, T> RadixPermutationIter<'a, T>
where
T: Copy + PartialOrd + num_traits::Bounded + num_traits::Zero + core::ops::Sub<Output=T>,
{
pub fn new<E: AsRef<[T]>, F: Fn(&[T]) -> Option<T>>(factor_iter: impl Iterator<Item=E>, combination_fn: &'a F) -> Self {
let mut global_max_digit = 0;
let sorted_dists: Vec<Vec<(usize, T)>> = factor_iter
.map(|factor_dist| {
let mut sorted_elements: Vec<(usize, T)> = factor_dist.as_ref().iter().cloned().enumerate().collect();
sorted_elements.sort_by(|(_idx_a, element_a), (_idx_b, element_b)| element_b.partial_cmp(element_a).unwrap_or(Ordering::Equal));
if sorted_elements.len()-1 > global_max_digit {
global_max_digit = sorted_elements.len()-1;
}
sorted_elements
})
.collect();
let factor_count = sorted_dists.len();
Self {
sorted_dists,
combination_fn,
state: vec![0; factor_count],
max_factors: vec![0; factor_count],
global_max_digit,
pegged_factor: 0,
step_count: 0,
}
}
pub fn factor_count(&self) -> usize {
self.sorted_dists.len()
}
fn execute_combine_fn(&self, factors: &[T]) -> Option<T> {
(self.combination_fn)(&factors)
}
fn state_to_result(&self) -> Option<(Vec<usize>, T)> {
let mut factors = Vec::with_capacity(self.factor_count());
for (slot_idx, sorted_idx) in self.state.iter().enumerate() {
factors.push(self.sorted_dists[slot_idx][*sorted_idx].1);
}
let result = self.state.iter()
.enumerate()
.map(|(slot_idx, sorted_factor_idx)| self.sorted_dists[slot_idx][*sorted_factor_idx].0)
.collect();
self.execute_combine_fn(&factors)
.map(|combined_val| (result, combined_val))
}
fn step(&mut self) -> (bool, Option<(Vec<usize>, T)>) {
let factor_count = self.factor_count();
let mut cur_digit = if self.pegged_factor != 0 {
0
} else {
1
};
self.state[cur_digit] += 1;
while self.state[cur_digit] > self.max_factors[cur_digit] ||
self.state[cur_digit] >= self.sorted_dists[cur_digit].len() {
self.state[cur_digit] = 0;
cur_digit += 1;
if cur_digit == self.pegged_factor {
cur_digit += 1;
}
if cur_digit < factor_count {
self.state[cur_digit] += 1;
} else {
if let Some(factor_to_advance) = self.find_factor_to_advance() {
self.max_factors[factor_to_advance] += 1;
self.state[self.pegged_factor] = 0;
self.pegged_factor = factor_to_advance;
self.state[factor_to_advance] = self.max_factors[factor_to_advance];
if factor_to_advance != 0 {
cur_digit = 0;
} else {
cur_digit = 1;
}
} else {
return (false, None);
}
}
}
(true, self.state_to_result())
}
fn find_factor_to_advance(&self) -> Option<usize> {
let mut best_factor: Option<(usize, T)> = None;
for (factor_idx, &factor_pos) in self.max_factors.iter().enumerate() {
if factor_pos < self.global_max_digit &&
factor_pos < self.sorted_dists[factor_idx].len()-1 {
let mut factors: Vec<T> = self.sorted_dists.iter().map(|inner_dist| inner_dist[0].1).collect();
factors[factor_idx] = self.sorted_dists[factor_idx][factor_pos+1].1;
let comb_val = (self.combination_fn)(&factors).unwrap_or(T::zero());
if let Some((best_idx, best_val)) = best_factor.as_mut() {
if comb_val > *best_val {
*best_idx = factor_idx;
*best_val = comb_val;
}
} else {
best_factor = Some((factor_idx, comb_val));
}
}
}
best_factor.map(|(best_idx, _best_val)| best_idx)
}
}
impl<T> Iterator for RadixPermutationIter<'_, T>
where
T: Copy + PartialOrd + num_traits::Bounded + num_traits::Zero + core::ops::Sub<Output=T>,
{
type Item = (Vec<usize>, T);
fn next(&mut self) -> Option<Self::Item> {
match self.step_count {
0 => {
self.step_count = 1;
return self.state_to_result();
},
1 => {
self.pegged_factor = self.find_factor_to_advance().unwrap();
self.max_factors[self.pegged_factor] = 1;
self.state[self.pegged_factor] = 1;
self.step_count = 2;
return self.state_to_result();
},
_ => {
loop {
let (keep_going, result_option) = self.step();
if !keep_going {
return None;
}
if result_option.is_some() {
return result_option;
}
}
}
}
}
}