Skip to main content

ferray_core/array/
sort.rs

1// ferray-core: Array sort / argsort methods (#366)
2//
3// Minimal NumPy parity for ndarray.sort() / np.argsort() at the core
4// level. Both methods flatten first and return a 1-D result. Axis-aware
5// sorting with multiple algorithms lives in ferray-stats for callers
6// that need the richer API (and for property tests that want to pin
7// the stable-sort invariant against per-axis reductions).
8
9use std::cmp::Ordering;
10
11use crate::dimension::{Dimension, Ix1};
12use crate::dtype::Element;
13
14use super::owned::Array;
15
16/// Total order comparator that sends "unorderable" (NaN) values to the
17/// end, matching `NumPy`'s `np.sort` convention. Uses self-comparison to
18/// detect NaN without needing a `Float` bound on `T`: for any sane
19/// `PartialOrd`, `x.partial_cmp(&x)` is `Some(Equal)`; only NaN-like
20/// values break that invariant and return `None`.
21fn nan_last_cmp<T: PartialOrd>(a: &T, b: &T) -> Ordering {
22    if let Some(ord) = a.partial_cmp(b) {
23        ord
24    } else {
25        let a_nan = a.partial_cmp(a).is_none();
26        let b_nan = b.partial_cmp(b).is_none();
27        match (a_nan, b_nan) {
28            (true, true) => Ordering::Equal,
29            (true, false) => Ordering::Greater,
30            (false, true) => Ordering::Less,
31            // Genuinely incomparable non-NaN values shouldn't exist
32            // for any numeric Element type, but keep sort_by total
33            // by treating them as equal.
34            (false, false) => Ordering::Equal,
35        }
36    }
37}
38
39impl<T, D> Array<T, D>
40where
41    T: Element + PartialOrd,
42    D: Dimension,
43{
44    /// Return a sorted 1-D copy of the flattened array in ascending order.
45    ///
46    /// Equivalent to `np.sort(a.ravel())`. NaN values (for floating-point
47    /// element types) are ordered last, matching `NumPy`'s convention.
48    ///
49    /// For axis-aware sorting or algorithm selection (mergesort, heapsort),
50    /// use `ferray_stats::sorting::sort`.
51    pub fn sorted(&self) -> Array<T, Ix1> {
52        let mut data: Vec<T> = self.iter().cloned().collect();
53        data.sort_by(nan_last_cmp);
54        let n = data.len();
55        Array::<T, Ix1>::from_vec(Ix1::new([n]), data)
56            .expect("from_vec with exact length cannot fail")
57    }
58
59    /// Return the indices that would sort the flattened array in ascending
60    /// order.
61    ///
62    /// Equivalent to `np.argsort(a.ravel())`. The returned indices are
63    /// `u64` to match `NumPy`'s default index dtype. NaN values are sent
64    /// to the tail of the sort order, matching [`Array::sorted`].
65    ///
66    /// For axis-aware argsort, use `ferray_stats::sorting::argsort`.
67    pub fn argsort(&self) -> Array<u64, Ix1> {
68        let data: Vec<T> = self.iter().cloned().collect();
69        let mut idx: Vec<u64> = (0..data.len() as u64).collect();
70        idx.sort_by(|&i, &j| nan_last_cmp(&data[i as usize], &data[j as usize]));
71        let n = idx.len();
72        Array::<u64, Ix1>::from_vec(Ix1::new([n]), idx)
73            .expect("from_vec with exact length cannot fail")
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use crate::dimension::{Ix1, Ix2};
81
82    #[test]
83    fn sorted_ascending_1d() {
84        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
85        let s = a.sorted();
86        assert_eq!(s.shape(), &[5]);
87        assert_eq!(s.as_slice().unwrap(), &[1, 1, 3, 4, 5]);
88        // Source unchanged.
89        assert_eq!(a.as_slice().unwrap(), &[3, 1, 4, 1, 5]);
90    }
91
92    #[test]
93    fn sorted_flattens_2d() {
94        // Row-major flatten, then sort.
95        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![5, 2, 8, 1, 9, 4]).unwrap();
96        let s = a.sorted();
97        assert_eq!(s.shape(), &[6]);
98        assert_eq!(s.as_slice().unwrap(), &[1, 2, 4, 5, 8, 9]);
99    }
100
101    #[test]
102    fn sorted_f64_nans_go_last() {
103        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, f64::NAN, 1.0, f64::NAN, 2.0])
104            .unwrap();
105        let s = a.sorted();
106        let data = s.as_slice().unwrap();
107        assert_eq!(data[0], 1.0);
108        assert_eq!(data[1], 2.0);
109        assert_eq!(data[2], 3.0);
110        assert!(data[3].is_nan());
111        assert!(data[4].is_nan());
112    }
113
114    #[test]
115    fn sorted_empty() {
116        let a = Array::<i32, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
117        let s = a.sorted();
118        assert_eq!(s.shape(), &[0]);
119        assert_eq!(s.size(), 0);
120    }
121
122    #[test]
123    fn argsort_matches_sorted_1d() {
124        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
125        let idx = a.argsort();
126        assert_eq!(idx.shape(), &[5]);
127        // Applying the index permutation to `a` must yield the sorted array.
128        let data = a.as_slice().unwrap();
129        let picked: Vec<i32> = idx
130            .as_slice()
131            .unwrap()
132            .iter()
133            .map(|&i| data[i as usize])
134            .collect();
135        assert_eq!(picked, vec![1, 1, 3, 4, 5]);
136    }
137
138    #[test]
139    fn argsort_f64_sends_nans_last() {
140        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![2.0, f64::NAN, 0.5, 5.0]).unwrap();
141        let idx = a.argsort();
142        let data = a.as_slice().unwrap();
143        let picked: Vec<f64> = idx
144            .as_slice()
145            .unwrap()
146            .iter()
147            .map(|&i| data[i as usize])
148            .collect();
149        assert_eq!(picked[0], 0.5);
150        assert_eq!(picked[1], 2.0);
151        assert_eq!(picked[2], 5.0);
152        assert!(picked[3].is_nan());
153    }
154
155    #[test]
156    fn argsort_flattens_2d() {
157        // [[5,2,8],[1,9,4]] -> flat [5,2,8,1,9,4]
158        // sorted ascending: [1,2,4,5,8,9] at flat indices [3,1,5,0,2,4]
159        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![5, 2, 8, 1, 9, 4]).unwrap();
160        let idx = a.argsort();
161        assert_eq!(idx.shape(), &[6]);
162        assert_eq!(idx.as_slice().unwrap(), &[3, 1, 5, 0, 2, 4]);
163    }
164}