iterstats 0.7.0

Statistics for rust iterators.
Documentation
//! Calculate argsort.

use std::vec;

use num_traits::float::TotalOrder;

/// Calculate argsort on the iterator.
///
/// Argsort gives the indexes that would be required to sort the original iterator.
pub trait ArgSort<A = Self>: Sized {
    /// The iterator output.
    type Output;

    /// Calculate the rank.
    fn argsort<I: Iterator<Item = A>>(iter: I) -> ArgSortIter<Self::Output>;
}

macro_rules! impl_argsort {
    ($typ:ty) => {
        impl ArgSort for $typ {
            type Output = $typ;

            fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
                ArgSortIter::new(iter)
            }
        }

        impl<'a> ArgSort for &'a $typ {
            type Output = &'a $typ;

            fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
                ArgSortIter::new(iter)
            }
        }
    };
}

impl_argsort!(usize);
impl_argsort!(u8);
impl_argsort!(u16);
impl_argsort!(u32);
impl_argsort!(u64);
impl_argsort!(u128);
impl_argsort!(isize);
impl_argsort!(i8);
impl_argsort!(i16);
impl_argsort!(i32);
impl_argsort!(i64);
impl_argsort!(i128);
impl_argsort!(char);
impl_argsort!(String);

impl<'a> ArgSort for &'a str {
    type Output = &'a str;

    fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self> {
        ArgSortIter::new(iter)
    }
}

impl ArgSort for f32 {
    type Output = f32;

    fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
        ArgSortIter::new_total_order(iter)
    }
}

impl<'a> ArgSort for &'a f32 {
    type Output = f32;

    fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
        ArgSortIter::new_total_order(iter.copied())
    }
}

impl ArgSort for f64 {
    type Output = f64;

    fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
        ArgSortIter::new_total_order(iter)
    }
}

impl<'a> ArgSort for &'a f64 {
    type Output = f64;

    fn argsort<I: Iterator<Item = Self>>(iter: I) -> ArgSortIter<Self::Output> {
        ArgSortIter::new_total_order(iter.copied())
    }
}

/// An iterator that calculates the argsort of each element
pub struct ArgSortIter<T> {
    sorted_enumerated: vec::IntoIter<(usize, T)>,
}

impl<T: Ord> ArgSortIter<T> {
    fn new(iter: impl Iterator<Item = T>) -> Self {
        let mut sorted = iter.enumerate().collect::<Vec<_>>();
        sorted.sort_by(|(_, a), (_, b)| a.cmp(b));
        Self {
            sorted_enumerated: sorted.into_iter(),
        }
    }
}

impl<T: TotalOrder> ArgSortIter<T> {
    fn new_total_order(iter: impl Iterator<Item = T>) -> Self {
        let mut sorted = iter.enumerate().collect::<Vec<_>>();
        sorted.sort_by(|(_, a), (_, b)| a.total_cmp(b));
        Self {
            sorted_enumerated: sorted.into_iter(),
        }
    }
}

impl<T> Iterator for ArgSortIter<T> {
    type Item = usize;

    fn next(&mut self) -> Option<Self::Item> {
        self.sorted_enumerated.next().map(|(i, _)| i)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use paste::paste;

    macro_rules! test_argsort {
        ($name:ident : $typ:ty as $input:expr => $expected:expr) => {
            paste! {
                #[test]
                fn [<$name _iter>]() {
                    let output = <&$typ>::argsort($input.iter()).collect::<Vec<usize>>();
                    assert_eq!(output, $expected)
                }
            }
            paste! {
                #[test]
                fn [<$name _into_iter>]() {
                    let output = $typ::argsort($input.into_iter()).collect::<Vec<usize>>();
                    assert_eq!(output, $expected)
                }
            }
        };
    }

    test_argsort!(usize_ordered : usize as [0, 1, 2, 3, 4, 5] => vec![0, 1, 2, 3, 4, 5]);
    test_argsort!(usize_reversed : usize as [5, 4, 3, 2, 1] => vec![4, 3, 2, 1, 0]);
    test_argsort!(char_mixed : char as ['d', 'a', 'c', 'b', 'z', 'y'] => vec![1, 3, 2, 0, 5, 4]);
    test_argsort!(i16_with_ties : i16 as [8, -1, 2, 0, 0, 2] => vec![1, 3, 4, 2, 5, 0]);
    test_argsort!(f32 : f32 as [f32::NAN, 4.2, 0.0, f32::NEG_INFINITY, -0.0, f32::INFINITY, -f32::NAN] => vec![6, 3, 4, 2, 1, 5, 0]);
    test_argsort!(f64 : f64 as [f64::NAN, 4.2, 0.0, f64::NEG_INFINITY, -0.0, f64::INFINITY, -f64::NAN] => vec![6, 3, 4, 2, 1, 5, 0]);
}