use indexmap::IndexMap;
use ndarray::prelude::*;
use ndarray::Slice;
use rand::prelude::*;
use rand::thread_rng;
pub trait Sort1dExt<A> {
fn get_from_sorted_mut(&mut self, i: usize) -> A
where
A: Ord + Clone;
fn get_many_from_sorted_mut(&mut self, indexes: &ArrayRef1<usize>) -> IndexMap<usize, A>
where
A: Ord + Clone;
fn partition_mut(&mut self, pivot_index: usize) -> usize
where
A: Ord + Clone;
private_decl! {}
}
impl<A> Sort1dExt<A> for ArrayRef<A, Ix1> {
fn get_from_sorted_mut(&mut self, i: usize) -> A
where
A: Ord + Clone,
{
let n = self.len();
if n == 1 {
self[0].clone()
} else {
let mut rng = thread_rng();
let pivot_index = rng.gen_range(0..n);
let partition_index = self.partition_mut(pivot_index);
if i < partition_index {
self.slice_axis_mut(Axis(0), Slice::from(..partition_index))
.get_from_sorted_mut(i)
} else if i == partition_index {
self[i].clone()
} else {
self.slice_axis_mut(Axis(0), Slice::from(partition_index + 1..))
.get_from_sorted_mut(i - (partition_index + 1))
}
}
}
fn get_many_from_sorted_mut(&mut self, indexes: &ArrayRef1<usize>) -> IndexMap<usize, A>
where
A: Ord + Clone,
{
let mut deduped_indexes: Vec<usize> = indexes.to_vec();
deduped_indexes.sort_unstable();
deduped_indexes.dedup();
get_many_from_sorted_mut_unchecked(self, &deduped_indexes)
}
fn partition_mut(&mut self, pivot_index: usize) -> usize
where
A: Ord + Clone,
{
let pivot_value = self[pivot_index].clone();
self.swap(pivot_index, 0);
let n = self.len();
let mut i = 1;
let mut j = n - 1;
loop {
loop {
if i > j {
break;
}
if self[i] >= pivot_value {
break;
}
i += 1;
}
while pivot_value <= self[j] {
if j == 1 {
break;
}
j -= 1;
}
if i >= j {
break;
} else {
self.swap(i, j);
i += 1;
j -= 1;
}
}
self.swap(0, i - 1);
i - 1
}
private_impl! {}
}
pub(crate) fn get_many_from_sorted_mut_unchecked<A>(
array: &mut ArrayRef1<A>,
indexes: &[usize],
) -> IndexMap<usize, A>
where
A: Ord + Clone,
{
if indexes.is_empty() {
return IndexMap::new();
}
let mut values = vec![array[0].clone(); indexes.len()];
_get_many_from_sorted_mut_unchecked(array.view_mut(), &mut indexes.to_owned(), &mut values);
indexes.iter().cloned().zip(values.into_iter()).collect()
}
fn _get_many_from_sorted_mut_unchecked<A>(
mut array: ArrayViewMut1<'_, A>,
indexes: &mut [usize],
values: &mut [A],
) where
A: Ord + Clone,
{
let n = array.len();
debug_assert!(n >= indexes.len()); debug_assert_eq!(indexes.len(), values.len());
if indexes.is_empty() {
return;
}
if n == 1 {
debug_assert_eq!(indexes.len(), 1);
values[0] = array[0].clone();
return;
}
let mut rng = thread_rng();
let pivot_index = rng.gen_range(0..n);
let array_partition_index = array.partition_mut(pivot_index);
let (found_exact, index_split) = match indexes.binary_search(&array_partition_index) {
Ok(index) => (true, index),
Err(index) => (false, index),
};
let (smaller_indexes, other_indexes) = indexes.split_at_mut(index_split);
let (smaller_values, other_values) = values.split_at_mut(index_split);
let (bigger_indexes, bigger_values) = if found_exact {
other_values[0] = array[array_partition_index].clone(); (&mut other_indexes[1..], &mut other_values[1..])
} else {
(other_indexes, other_values)
};
_get_many_from_sorted_mut_unchecked(
array.slice_axis_mut(Axis(0), Slice::from(..array_partition_index)),
smaller_indexes,
smaller_values,
);
bigger_indexes
.iter_mut()
.for_each(|x| *x -= array_partition_index + 1);
_get_many_from_sorted_mut_unchecked(
array.slice_axis_mut(Axis(0), Slice::from(array_partition_index + 1..)),
bigger_indexes,
bigger_values,
);
}