use std::vec;
use num_traits::float::TotalOrder;
pub trait ArgSort<A = Self>: Sized {
type Output;
fn argsort<I: Iterator<Item = A>>(iter: I) -> ArgSortIter<Self::Output>;
}
macro_rules! impl_argsort {
($typ:ty) => {
impl ArgSort for $typ {
type Output = $typ;
fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
ArgSortIter::new(iter)
}
}
impl<'a> ArgSort for &'a $typ {
type Output = &'a $typ;
fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
ArgSortIter::new(iter)
}
}
};
}
impl_argsort!(usize);
impl_argsort!(u8);
impl_argsort!(u16);
impl_argsort!(u32);
impl_argsort!(u64);
impl_argsort!(u128);
impl_argsort!(isize);
impl_argsort!(i8);
impl_argsort!(i16);
impl_argsort!(i32);
impl_argsort!(i64);
impl_argsort!(i128);
impl_argsort!(char);
impl_argsort!(String);
impl<'a> ArgSort for &'a str {
type Output = &'a str;
fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self> {
ArgSortIter::new(iter)
}
}
impl ArgSort for f32 {
type Output = f32;
fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
ArgSortIter::new_total_order(iter)
}
}
impl<'a> ArgSort for &'a f32 {
type Output = f32;
fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
ArgSortIter::new_total_order(iter.copied())
}
}
impl ArgSort for f64 {
type Output = f64;
fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
ArgSortIter::new_total_order(iter)
}
}
impl<'a> ArgSort for &'a f64 {
type Output = f64;
fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
ArgSortIter::new_total_order(iter.copied())
}
}
pub struct ArgSortIter<T> {
sorted_enumerated: vec::IntoIter<(usize, T)>,
}
impl<T: Ord> ArgSortIter<T> {
fn new(iter: impl Iterator<Item = T>) -> Self {
let mut sorted = iter.enumerate().collect::<Vec<_>>();
sorted.sort_by(|(_, a), (_, b)| a.cmp(b));
Self {
sorted_enumerated: sorted.into_iter(),
}
}
}
impl<T: TotalOrder> ArgSortIter<T> {
fn new_total_order(iter: impl Iterator<Item = T>) -> Self {
let mut sorted = iter.enumerate().collect::<Vec<_>>();
sorted.sort_by(|(_, a), (_, b)| a.total_cmp(b));
Self {
sorted_enumerated: sorted.into_iter(),
}
}
}
impl<T> Iterator for ArgSortIter<T> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
self.sorted_enumerated.next().map(|(i, _)| i)
}
}
#[cfg(test)]
mod tests {
use super::*;
use paste::paste;
macro_rules! test_argsort {
($name:ident : $typ:ty as $input:expr => $expected:expr) => {
paste! {
#[test]
fn [<$name _iter>]() {
let output = <&$typ>::argsort($input.iter()).collect::<Vec<usize>>();
assert_eq!(output, $expected)
}
}
paste! {
#[test]
fn [<$name _into_iter>]() {
let output = $typ::argsort($input.into_iter()).collect::<Vec<usize>>();
assert_eq!(output, $expected)
}
}
};
}
test_argsort!(usize_ordered : usize as [0, 1, 2, 3, 4, 5] => vec![0, 1, 2, 3, 4, 5]);
test_argsort!(usize_reversed : usize as [5, 4, 3, 2, 1] => vec![4, 3, 2, 1, 0]);
test_argsort!(char_mixed : char as ['d', 'a', 'c', 'b', 'z', 'y'] => vec![1, 3, 2, 0, 5, 4]);
test_argsort!(i16_with_ties : i16 as [8, -1, 2, 0, 0, 2] => vec![1, 3, 4, 2, 5, 0]);
test_argsort!(f32 : f32 as [f32::NAN, 4.2, 0.0, f32::NEG_INFINITY, -0.0, f32::INFINITY, -f32::NAN] => vec![6, 3, 4, 2, 1, 5, 0]);
test_argsort!(f64 : f64 as [f64::NAN, 4.2, 0.0, f64::NEG_INFINITY, -0.0, f64::INFINITY, -f64::NAN] => vec![6, 3, 4, 2, 1, 5, 0]);
}