Skip to main content

ferray_stats/
sorting.rs

1// ferray-stats: Sorting and searching — sort, argsort, searchsorted (REQ-11, REQ-12, REQ-13)
2
3use ferray_core::error::{FerrayError, FerrayResult};
4use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
5
6use crate::parallel;
7use crate::reductions::{compute_strides, flat_index, increment_multi_index};
8
9// ---------------------------------------------------------------------------
10// SortKind
11// ---------------------------------------------------------------------------
12
13/// Sorting algorithm selection.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SortKind {
16    /// Unstable quicksort (faster but does not preserve order of equal elements).
17    Quick,
18    /// Stable merge sort (preserves relative order of equal elements).
19    Stable,
20}
21
22// ---------------------------------------------------------------------------
23// Side (for searchsorted)
24// ---------------------------------------------------------------------------
25
26/// Side parameter for `searchsorted`.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Side {
29    /// Find the leftmost insertion point (first position where the value could be inserted).
30    Left,
31    /// Find the rightmost insertion point (last position where the value could be inserted).
32    Right,
33}
34
35// ---------------------------------------------------------------------------
36// sort
37// ---------------------------------------------------------------------------
38
39/// Sort an array along the given axis (or flattened if axis is None).
40///
41/// When `axis` is `None`, the array is flattened before sorting and a 1-D
42/// array is returned (matching NumPy behaviour). When an axis is given, the
43/// returned array has the same shape as the input.
44///
45/// Equivalent to `numpy.sort`.
46pub fn sort<T, D>(
47    a: &Array<T, D>,
48    axis: Option<usize>,
49    kind: SortKind,
50) -> FerrayResult<Array<T, IxDyn>>
51where
52    T: Element + PartialOrd + Copy + Send + Sync,
53    D: Dimension,
54{
55    match axis {
56        None => {
57            // Flatten and sort — return a 1-D array (NumPy behaviour)
58            let mut data: Vec<T> = a.iter().copied().collect();
59            let n = data.len();
60            sort_slice(&mut data, kind);
61            Array::from_vec(IxDyn::new(&[n]), data)
62        }
63        Some(ax) => {
64            if ax >= a.ndim() {
65                return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
66            }
67            let shape = a.shape().to_vec();
68            let data: Vec<T> = a.iter().copied().collect();
69            let mut result = data.clone();
70            let strides = compute_strides(&shape);
71
72            let axis_len = shape[ax];
73            let out_shape: Vec<usize> = shape
74                .iter()
75                .enumerate()
76                .filter(|&(i, _)| i != ax)
77                .map(|(_, &s)| s)
78                .collect();
79            let out_size: usize = if out_shape.is_empty() {
80                1
81            } else {
82                out_shape.iter().product()
83            };
84
85            let mut out_multi = vec![0usize; out_shape.len()];
86            let ndim = shape.len();
87
88            for _ in 0..out_size {
89                // Build input multi-index template
90                let mut in_multi = Vec::with_capacity(ndim);
91                let mut out_dim = 0;
92                for d in 0..ndim {
93                    if d == ax {
94                        in_multi.push(0);
95                    } else {
96                        in_multi.push(out_multi[out_dim]);
97                        out_dim += 1;
98                    }
99                }
100
101                // Gather lane
102                let mut lane: Vec<T> = Vec::with_capacity(axis_len);
103                let mut lane_indices: Vec<usize> = Vec::with_capacity(axis_len);
104                for k in 0..axis_len {
105                    in_multi[ax] = k;
106                    let idx = flat_index(&in_multi, &strides);
107                    lane.push(data[idx]);
108                    lane_indices.push(idx);
109                }
110
111                sort_slice(&mut lane, kind);
112
113                // Scatter sorted values back
114                for (k, &idx) in lane_indices.iter().enumerate() {
115                    result[idx] = lane[k];
116                }
117
118                if !out_shape.is_empty() {
119                    increment_multi_index(&mut out_multi, &out_shape);
120                }
121            }
122
123            Array::from_vec(IxDyn::new(&shape), result)
124        }
125    }
126}
127
128/// Sort a slice in place using the given algorithm.
129fn sort_slice<T: PartialOrd + Copy + Send + Sync>(data: &mut [T], kind: SortKind) {
130    match kind {
131        SortKind::Quick => {
132            parallel::parallel_sort(data);
133        }
134        SortKind::Stable => {
135            parallel::parallel_sort_stable(data);
136        }
137    }
138}
139
140// ---------------------------------------------------------------------------
141// argsort
142// ---------------------------------------------------------------------------
143
144/// Return the indices that would sort an array along the given axis.
145///
146/// When `axis` is `None`, the array is flattened before computing
147/// indices and a 1-D array is returned (matching NumPy behaviour).
148///
149/// Returns u64 indices.
150///
151/// Equivalent to `numpy.argsort`.
152pub fn argsort<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
153where
154    T: Element + PartialOrd + Copy,
155    D: Dimension,
156{
157    match axis {
158        None => {
159            let data: Vec<T> = a.iter().copied().collect();
160            let n = data.len();
161            let mut indices: Vec<usize> = (0..n).collect();
162            indices.sort_by(|&i, &j| {
163                data[i]
164                    .partial_cmp(&data[j])
165                    .unwrap_or(std::cmp::Ordering::Equal)
166            });
167            let result: Vec<u64> = indices.into_iter().map(|i| i as u64).collect();
168            Array::from_vec(IxDyn::new(&[n]), result)
169        }
170        Some(ax) => {
171            if ax >= a.ndim() {
172                return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
173            }
174            let shape = a.shape().to_vec();
175            let data: Vec<T> = a.iter().copied().collect();
176            let strides = compute_strides(&shape);
177            let ndim = shape.len();
178            let axis_len = shape[ax];
179
180            let out_shape: Vec<usize> = shape
181                .iter()
182                .enumerate()
183                .filter(|&(i, _)| i != ax)
184                .map(|(_, &s)| s)
185                .collect();
186            let out_size: usize = if out_shape.is_empty() {
187                1
188            } else {
189                out_shape.iter().product()
190            };
191
192            let mut result = vec![0u64; data.len()];
193            let mut out_multi = vec![0usize; out_shape.len()];
194
195            for _ in 0..out_size {
196                let mut in_multi = Vec::with_capacity(ndim);
197                let mut out_dim = 0;
198                for d in 0..ndim {
199                    if d == ax {
200                        in_multi.push(0);
201                    } else {
202                        in_multi.push(out_multi[out_dim]);
203                        out_dim += 1;
204                    }
205                }
206
207                // Gather lane values and their axis-local indices
208                let mut lane: Vec<(usize, T)> = Vec::with_capacity(axis_len);
209                let mut lane_flat_indices: Vec<usize> = Vec::with_capacity(axis_len);
210                for k in 0..axis_len {
211                    in_multi[ax] = k;
212                    let idx = flat_index(&in_multi, &strides);
213                    lane.push((k, data[idx]));
214                    lane_flat_indices.push(idx);
215                }
216
217                // Sort by value, tracking original axis-local index
218                lane.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
219
220                // Scatter the original axis-local indices into the result
221                for (k, &flat_idx) in lane_flat_indices.iter().enumerate() {
222                    result[flat_idx] = lane[k].0 as u64;
223                }
224
225                if !out_shape.is_empty() {
226                    increment_multi_index(&mut out_multi, &out_shape);
227                }
228            }
229
230            Array::from_vec(IxDyn::new(&shape), result)
231        }
232    }
233}
234
235// ---------------------------------------------------------------------------
236// searchsorted
237// ---------------------------------------------------------------------------
238
239/// Find indices where elements should be inserted to maintain order.
240///
241/// `a` must be a sorted 1-D array. For each value in `v`, find the index
242/// in `a` where it should be inserted. Returns u64 indices.
243///
244/// Equivalent to `numpy.searchsorted`.
245pub fn searchsorted<T>(
246    a: &Array<T, Ix1>,
247    v: &Array<T, Ix1>,
248    side: Side,
249) -> FerrayResult<Array<u64, Ix1>>
250where
251    T: Element + PartialOrd + Copy,
252{
253    let sorted: Vec<T> = a.iter().copied().collect();
254    let values: Vec<T> = v.iter().copied().collect();
255
256    let mut result = Vec::with_capacity(values.len());
257    for &val in &values {
258        let idx = match side {
259            Side::Left => sorted.partition_point(|x| {
260                x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less) == std::cmp::Ordering::Less
261            }),
262            Side::Right => sorted.partition_point(|x| {
263                x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less)
264                    != std::cmp::Ordering::Greater
265            }),
266        };
267        result.push(idx as u64);
268    }
269
270    let n = result.len();
271    Array::from_vec(Ix1::new([n]), result)
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use ferray_core::Ix2;
278
279    #[test]
280    fn test_sort_1d() {
281        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 4.0, 1.0, 5.0]).unwrap();
282        let s = sort(&a, None, SortKind::Quick).unwrap();
283        assert_eq!(s.shape(), &[5]);
284        let data: Vec<f64> = s.iter().copied().collect();
285        assert_eq!(data, vec![1.0, 1.0, 3.0, 4.0, 5.0]);
286    }
287
288    #[test]
289    fn test_sort_stable_preserves_order() {
290        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
291        let s = sort(&a, None, SortKind::Stable).unwrap();
292        assert_eq!(s.shape(), &[5]);
293        let data: Vec<i32> = s.iter().copied().collect();
294        assert_eq!(data, vec![1, 1, 3, 4, 5]);
295    }
296
297    #[test]
298    fn test_sort_2d_axis_none_returns_flat() {
299        // Issue #91: sort(axis=None) should return a flat 1-D array
300        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![6.0, 4.0, 5.0, 3.0, 1.0, 2.0])
301            .unwrap();
302        let s = sort(&a, None, SortKind::Quick).unwrap();
303        // Must be 1-D with 6 elements, not [2, 3]
304        assert_eq!(s.shape(), &[6]);
305        let data: Vec<f64> = s.iter().copied().collect();
306        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
307    }
308
309    #[test]
310    fn test_sort_2d_axis1() {
311        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
312            .unwrap();
313        let s = sort(&a, Some(1), SortKind::Quick).unwrap();
314        assert_eq!(s.shape(), &[2, 3]);
315        let data: Vec<f64> = s.iter().copied().collect();
316        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
317    }
318
319    #[test]
320    fn test_sort_2d_axis0() {
321        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0])
322            .unwrap();
323        let s = sort(&a, Some(0), SortKind::Quick).unwrap();
324        assert_eq!(s.shape(), &[2, 3]);
325        let data: Vec<f64> = s.iter().copied().collect();
326        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
327    }
328
329    #[test]
330    fn test_argsort_1d() {
331        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
332        let idx = argsort(&a, None).unwrap();
333        assert_eq!(idx.shape(), &[4]);
334        let data: Vec<u64> = idx.iter().copied().collect();
335        assert_eq!(data, vec![1, 3, 0, 2]);
336    }
337
338    #[test]
339    fn test_argsort_2d_axis1() {
340        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
341            .unwrap();
342        let idx = argsort(&a, Some(1)).unwrap();
343        assert_eq!(idx.shape(), &[2, 3]);
344        let data: Vec<u64> = idx.iter().copied().collect();
345        assert_eq!(data, vec![1, 2, 0, 1, 2, 0]);
346    }
347
348    #[test]
349    fn test_searchsorted_left() {
350        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
351        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
352        let idx = searchsorted(&a, &v, Side::Left).unwrap();
353        let data: Vec<u64> = idx.iter().copied().collect();
354        assert_eq!(data, vec![2, 0, 5]);
355    }
356
357    #[test]
358    fn test_searchsorted_right() {
359        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
360        let v = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![2.0, 4.0]).unwrap();
361        let idx = searchsorted(&a, &v, Side::Right).unwrap();
362        let data: Vec<u64> = idx.iter().copied().collect();
363        assert_eq!(data, vec![2, 4]);
364    }
365
366    #[test]
367    fn test_sort_axis_out_of_bounds() {
368        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
369        assert!(sort(&a, Some(1), SortKind::Quick).is_err());
370    }
371}