use std::cmp::Ordering;
use crate::dimension::{Dimension, Ix1};
use crate::dtype::Element;
use super::owned::Array;
fn nan_last_cmp<T: PartialOrd>(a: &T, b: &T) -> Ordering {
if let Some(ord) = a.partial_cmp(b) {
ord
} else {
let a_nan = a.partial_cmp(a).is_none();
let b_nan = b.partial_cmp(b).is_none();
match (a_nan, b_nan) {
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
(false, false) => Ordering::Equal,
}
}
}
impl<T, D> Array<T, D>
where
T: Element + PartialOrd,
D: Dimension,
{
pub fn sorted(&self) -> Array<T, Ix1> {
let mut data: Vec<T> = self.iter().cloned().collect();
data.sort_by(nan_last_cmp);
let n = data.len();
Array::<T, Ix1>::from_vec(Ix1::new([n]), data)
.expect("from_vec with exact length cannot fail")
}
pub fn argsort(&self) -> Array<u64, Ix1> {
let data: Vec<T> = self.iter().cloned().collect();
let mut idx: Vec<u64> = (0..data.len() as u64).collect();
idx.sort_by(|&i, &j| nan_last_cmp(&data[i as usize], &data[j as usize]));
let n = idx.len();
Array::<u64, Ix1>::from_vec(Ix1::new([n]), idx)
.expect("from_vec with exact length cannot fail")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dimension::{Ix1, Ix2};
#[test]
fn sorted_ascending_1d() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
let s = a.sorted();
assert_eq!(s.shape(), &[5]);
assert_eq!(s.as_slice().unwrap(), &[1, 1, 3, 4, 5]);
assert_eq!(a.as_slice().unwrap(), &[3, 1, 4, 1, 5]);
}
#[test]
fn sorted_flattens_2d() {
let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![5, 2, 8, 1, 9, 4]).unwrap();
let s = a.sorted();
assert_eq!(s.shape(), &[6]);
assert_eq!(s.as_slice().unwrap(), &[1, 2, 4, 5, 8, 9]);
}
#[test]
fn sorted_f64_nans_go_last() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, f64::NAN, 1.0, f64::NAN, 2.0])
.unwrap();
let s = a.sorted();
let data = s.as_slice().unwrap();
assert_eq!(data[0], 1.0);
assert_eq!(data[1], 2.0);
assert_eq!(data[2], 3.0);
assert!(data[3].is_nan());
assert!(data[4].is_nan());
}
#[test]
fn sorted_empty() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
let s = a.sorted();
assert_eq!(s.shape(), &[0]);
assert_eq!(s.size(), 0);
}
#[test]
fn argsort_matches_sorted_1d() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
let idx = a.argsort();
assert_eq!(idx.shape(), &[5]);
let data = a.as_slice().unwrap();
let picked: Vec<i32> = idx
.as_slice()
.unwrap()
.iter()
.map(|&i| data[i as usize])
.collect();
assert_eq!(picked, vec![1, 1, 3, 4, 5]);
}
#[test]
fn argsort_f64_sends_nans_last() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![2.0, f64::NAN, 0.5, 5.0]).unwrap();
let idx = a.argsort();
let data = a.as_slice().unwrap();
let picked: Vec<f64> = idx
.as_slice()
.unwrap()
.iter()
.map(|&i| data[i as usize])
.collect();
assert_eq!(picked[0], 0.5);
assert_eq!(picked[1], 2.0);
assert_eq!(picked[2], 5.0);
assert!(picked[3].is_nan());
}
#[test]
fn argsort_flattens_2d() {
let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![5, 2, 8, 1, 9, 4]).unwrap();
let idx = a.argsort();
assert_eq!(idx.shape(), &[6]);
assert_eq!(idx.as_slice().unwrap(), &[3, 1, 5, 0, 2, 4]);
}
}