Skip to main content

ferray_core/array/
sort.rs

1// ferray-core: Array sort / argsort methods (#366)
2//
3// Minimal NumPy parity for ndarray.sort() / np.argsort() at the core
4// level. Both methods flatten first and return a 1-D result. Axis-aware
5// sorting with multiple algorithms lives in ferray-stats for callers
6// that need the richer API (and for property tests that want to pin
7// the stable-sort invariant against per-axis reductions).
8
9use std::cmp::Ordering;
10
11use crate::dimension::{Dimension, Ix1};
12use crate::dtype::Element;
13
14use super::owned::Array;
15
16/// Total order comparator that sends "unorderable" (NaN) values to the
17/// end, matching NumPy's `np.sort` convention. Uses self-comparison to
18/// detect NaN without needing a `Float` bound on `T`: for any sane
19/// `PartialOrd`, `x.partial_cmp(&x)` is `Some(Equal)`; only NaN-like
20/// values break that invariant and return `None`.
21fn nan_last_cmp<T: PartialOrd>(a: &T, b: &T) -> Ordering {
22    match a.partial_cmp(b) {
23        Some(ord) => ord,
24        None => {
25            let a_nan = a.partial_cmp(a).is_none();
26            let b_nan = b.partial_cmp(b).is_none();
27            match (a_nan, b_nan) {
28                (true, true) => Ordering::Equal,
29                (true, false) => Ordering::Greater,
30                (false, true) => Ordering::Less,
31                // Genuinely incomparable non-NaN values shouldn't exist
32                // for any numeric Element type, but keep sort_by total
33                // by treating them as equal.
34                (false, false) => Ordering::Equal,
35            }
36        }
37    }
38}
39
40impl<T, D> Array<T, D>
41where
42    T: Element + PartialOrd,
43    D: Dimension,
44{
45    /// Return a sorted 1-D copy of the flattened array in ascending order.
46    ///
47    /// Equivalent to `np.sort(a.ravel())`. NaN values (for floating-point
48    /// element types) are ordered last, matching NumPy's convention.
49    ///
50    /// For axis-aware sorting or algorithm selection (mergesort, heapsort),
51    /// use `ferray_stats::sorting::sort`.
52    pub fn sorted(&self) -> Array<T, Ix1> {
53        let mut data: Vec<T> = self.iter().cloned().collect();
54        data.sort_by(nan_last_cmp);
55        let n = data.len();
56        Array::<T, Ix1>::from_vec(Ix1::new([n]), data)
57            .expect("from_vec with exact length cannot fail")
58    }
59
60    /// Return the indices that would sort the flattened array in ascending
61    /// order.
62    ///
63    /// Equivalent to `np.argsort(a.ravel())`. The returned indices are
64    /// `u64` to match NumPy's default index dtype. NaN values are sent
65    /// to the tail of the sort order, matching [`Array::sorted`].
66    ///
67    /// For axis-aware argsort, use `ferray_stats::sorting::argsort`.
68    pub fn argsort(&self) -> Array<u64, Ix1> {
69        let data: Vec<T> = self.iter().cloned().collect();
70        let mut idx: Vec<u64> = (0..data.len() as u64).collect();
71        idx.sort_by(|&i, &j| nan_last_cmp(&data[i as usize], &data[j as usize]));
72        let n = idx.len();
73        Array::<u64, Ix1>::from_vec(Ix1::new([n]), idx)
74            .expect("from_vec with exact length cannot fail")
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use crate::dimension::{Ix1, Ix2};
82
83    #[test]
84    fn sorted_ascending_1d() {
85        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
86        let s = a.sorted();
87        assert_eq!(s.shape(), &[5]);
88        assert_eq!(s.as_slice().unwrap(), &[1, 1, 3, 4, 5]);
89        // Source unchanged.
90        assert_eq!(a.as_slice().unwrap(), &[3, 1, 4, 1, 5]);
91    }
92
93    #[test]
94    fn sorted_flattens_2d() {
95        // Row-major flatten, then sort.
96        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![5, 2, 8, 1, 9, 4]).unwrap();
97        let s = a.sorted();
98        assert_eq!(s.shape(), &[6]);
99        assert_eq!(s.as_slice().unwrap(), &[1, 2, 4, 5, 8, 9]);
100    }
101
102    #[test]
103    fn sorted_f64_nans_go_last() {
104        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, f64::NAN, 1.0, f64::NAN, 2.0])
105            .unwrap();
106        let s = a.sorted();
107        let data = s.as_slice().unwrap();
108        assert_eq!(data[0], 1.0);
109        assert_eq!(data[1], 2.0);
110        assert_eq!(data[2], 3.0);
111        assert!(data[3].is_nan());
112        assert!(data[4].is_nan());
113    }
114
115    #[test]
116    fn sorted_empty() {
117        let a = Array::<i32, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
118        let s = a.sorted();
119        assert_eq!(s.shape(), &[0]);
120        assert_eq!(s.size(), 0);
121    }
122
123    #[test]
124    fn argsort_matches_sorted_1d() {
125        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
126        let idx = a.argsort();
127        assert_eq!(idx.shape(), &[5]);
128        // Applying the index permutation to `a` must yield the sorted array.
129        let data = a.as_slice().unwrap();
130        let picked: Vec<i32> = idx
131            .as_slice()
132            .unwrap()
133            .iter()
134            .map(|&i| data[i as usize])
135            .collect();
136        assert_eq!(picked, vec![1, 1, 3, 4, 5]);
137    }
138
139    #[test]
140    fn argsort_f64_sends_nans_last() {
141        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![2.0, f64::NAN, 0.5, 5.0]).unwrap();
142        let idx = a.argsort();
143        let data = a.as_slice().unwrap();
144        let picked: Vec<f64> = idx
145            .as_slice()
146            .unwrap()
147            .iter()
148            .map(|&i| data[i as usize])
149            .collect();
150        assert_eq!(picked[0], 0.5);
151        assert_eq!(picked[1], 2.0);
152        assert_eq!(picked[2], 5.0);
153        assert!(picked[3].is_nan());
154    }
155
156    #[test]
157    fn argsort_flattens_2d() {
158        // [[5,2,8],[1,9,4]] -> flat [5,2,8,1,9,4]
159        // sorted ascending: [1,2,4,5,8,9] at flat indices [3,1,5,0,2,4]
160        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![5, 2, 8, 1, 9, 4]).unwrap();
161        let idx = a.argsort();
162        assert_eq!(idx.shape(), &[6]);
163        assert_eq!(idx.as_slice().unwrap(), &[3, 1, 5, 0, 2, 4]);
164    }
165}