Skip to main content

ferray_stats/
sorting.rs

1// ferray-stats: Sorting and searching — sort, argsort, searchsorted (REQ-11, REQ-12, REQ-13)
2
3use ferray_core::error::{FerrayError, FerrayResult};
4use ferray_core::{Array, Dimension, Element, Ix1};
5
6use crate::parallel;
7use crate::reductions::{compute_strides, flat_index, increment_multi_index};
8
9// ---------------------------------------------------------------------------
10// SortKind
11// ---------------------------------------------------------------------------
12
13/// Sorting algorithm selection.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SortKind {
16    /// Unstable quicksort (faster but does not preserve order of equal elements).
17    Quick,
18    /// Stable merge sort (preserves relative order of equal elements).
19    Stable,
20}
21
22// ---------------------------------------------------------------------------
23// Side (for searchsorted)
24// ---------------------------------------------------------------------------
25
26/// Side parameter for `searchsorted`.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Side {
29    /// Find the leftmost insertion point (first position where the value could be inserted).
30    Left,
31    /// Find the rightmost insertion point (last position where the value could be inserted).
32    Right,
33}
34
35// ---------------------------------------------------------------------------
36// sort
37// ---------------------------------------------------------------------------
38
39/// Sort an array along the given axis (or flattened if axis is None).
40///
41/// Returns a new sorted array with the same shape.
42///
43/// Equivalent to `numpy.sort`.
44pub fn sort<T, D>(a: &Array<T, D>, axis: Option<usize>, kind: SortKind) -> FerrayResult<Array<T, D>>
45where
46    T: Element + PartialOrd + Copy + Send + Sync,
47    D: Dimension,
48{
49    match axis {
50        None => {
51            // Flatten and sort
52            let mut data: Vec<T> = a.iter().copied().collect();
53            sort_slice(&mut data, kind);
54            Array::from_vec(a.dim().clone(), data)
55        }
56        Some(ax) => {
57            if ax >= a.ndim() {
58                return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
59            }
60            let shape = a.shape().to_vec();
61            let data: Vec<T> = a.iter().copied().collect();
62            let mut result = data.clone();
63            let strides = compute_strides(&shape);
64
65            let axis_len = shape[ax];
66            let out_shape: Vec<usize> = shape
67                .iter()
68                .enumerate()
69                .filter(|&(i, _)| i != ax)
70                .map(|(_, &s)| s)
71                .collect();
72            let out_size: usize = if out_shape.is_empty() {
73                1
74            } else {
75                out_shape.iter().product()
76            };
77
78            let mut out_multi = vec![0usize; out_shape.len()];
79            let ndim = shape.len();
80
81            for _ in 0..out_size {
82                // Build input multi-index template
83                let mut in_multi = Vec::with_capacity(ndim);
84                let mut out_dim = 0;
85                for d in 0..ndim {
86                    if d == ax {
87                        in_multi.push(0);
88                    } else {
89                        in_multi.push(out_multi[out_dim]);
90                        out_dim += 1;
91                    }
92                }
93
94                // Gather lane
95                let mut lane: Vec<T> = Vec::with_capacity(axis_len);
96                let mut lane_indices: Vec<usize> = Vec::with_capacity(axis_len);
97                for k in 0..axis_len {
98                    in_multi[ax] = k;
99                    let idx = flat_index(&in_multi, &strides);
100                    lane.push(data[idx]);
101                    lane_indices.push(idx);
102                }
103
104                sort_slice(&mut lane, kind);
105
106                // Scatter sorted values back
107                for (k, &idx) in lane_indices.iter().enumerate() {
108                    result[idx] = lane[k];
109                }
110
111                if !out_shape.is_empty() {
112                    increment_multi_index(&mut out_multi, &out_shape);
113                }
114            }
115
116            Array::from_vec(a.dim().clone(), result)
117        }
118    }
119}
120
121/// Sort a slice in place using the given algorithm.
122fn sort_slice<T: PartialOrd + Copy + Send + Sync>(data: &mut [T], kind: SortKind) {
123    match kind {
124        SortKind::Quick => {
125            parallel::parallel_sort(data);
126        }
127        SortKind::Stable => {
128            parallel::parallel_sort_stable(data);
129        }
130    }
131}
132
133// ---------------------------------------------------------------------------
134// argsort
135// ---------------------------------------------------------------------------
136
137/// Return the indices that would sort an array along the given axis.
138///
139/// Returns u64 indices.
140///
141/// Equivalent to `numpy.argsort`.
142pub fn argsort<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, D>>
143where
144    T: Element + PartialOrd + Copy,
145    D: Dimension,
146{
147    match axis {
148        None => {
149            let data: Vec<T> = a.iter().copied().collect();
150            let mut indices: Vec<usize> = (0..data.len()).collect();
151            indices.sort_by(|&i, &j| {
152                data[i]
153                    .partial_cmp(&data[j])
154                    .unwrap_or(std::cmp::Ordering::Equal)
155            });
156            let result: Vec<u64> = indices.into_iter().map(|i| i as u64).collect();
157            Array::from_vec(a.dim().clone(), result)
158        }
159        Some(ax) => {
160            if ax >= a.ndim() {
161                return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
162            }
163            let shape = a.shape().to_vec();
164            let data: Vec<T> = a.iter().copied().collect();
165            let strides = compute_strides(&shape);
166            let ndim = shape.len();
167            let axis_len = shape[ax];
168
169            let out_shape: Vec<usize> = shape
170                .iter()
171                .enumerate()
172                .filter(|&(i, _)| i != ax)
173                .map(|(_, &s)| s)
174                .collect();
175            let out_size: usize = if out_shape.is_empty() {
176                1
177            } else {
178                out_shape.iter().product()
179            };
180
181            let mut result = vec![0u64; data.len()];
182            let mut out_multi = vec![0usize; out_shape.len()];
183
184            for _ in 0..out_size {
185                let mut in_multi = Vec::with_capacity(ndim);
186                let mut out_dim = 0;
187                for d in 0..ndim {
188                    if d == ax {
189                        in_multi.push(0);
190                    } else {
191                        in_multi.push(out_multi[out_dim]);
192                        out_dim += 1;
193                    }
194                }
195
196                // Gather lane values and their axis-local indices
197                let mut lane: Vec<(usize, T)> = Vec::with_capacity(axis_len);
198                let mut lane_flat_indices: Vec<usize> = Vec::with_capacity(axis_len);
199                for k in 0..axis_len {
200                    in_multi[ax] = k;
201                    let idx = flat_index(&in_multi, &strides);
202                    lane.push((k, data[idx]));
203                    lane_flat_indices.push(idx);
204                }
205
206                // Sort by value, tracking original axis-local index
207                lane.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
208
209                // Scatter the original axis-local indices into the result
210                for (k, &flat_idx) in lane_flat_indices.iter().enumerate() {
211                    result[flat_idx] = lane[k].0 as u64;
212                }
213
214                if !out_shape.is_empty() {
215                    increment_multi_index(&mut out_multi, &out_shape);
216                }
217            }
218
219            Array::from_vec(a.dim().clone(), result)
220        }
221    }
222}
223
224// ---------------------------------------------------------------------------
225// searchsorted
226// ---------------------------------------------------------------------------
227
228/// Find indices where elements should be inserted to maintain order.
229///
230/// `a` must be a sorted 1-D array. For each value in `v`, find the index
231/// in `a` where it should be inserted. Returns u64 indices.
232///
233/// Equivalent to `numpy.searchsorted`.
234pub fn searchsorted<T>(
235    a: &Array<T, Ix1>,
236    v: &Array<T, Ix1>,
237    side: Side,
238) -> FerrayResult<Array<u64, Ix1>>
239where
240    T: Element + PartialOrd + Copy,
241{
242    let sorted: Vec<T> = a.iter().copied().collect();
243    let values: Vec<T> = v.iter().copied().collect();
244
245    let mut result = Vec::with_capacity(values.len());
246    for &val in &values {
247        let idx = match side {
248            Side::Left => sorted.partition_point(|x| {
249                x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less) == std::cmp::Ordering::Less
250            }),
251            Side::Right => sorted.partition_point(|x| {
252                x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less)
253                    != std::cmp::Ordering::Greater
254            }),
255        };
256        result.push(idx as u64);
257    }
258
259    let n = result.len();
260    Array::from_vec(Ix1::new([n]), result)
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use ferray_core::Ix2;
267
268    #[test]
269    fn test_sort_1d() {
270        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 4.0, 1.0, 5.0]).unwrap();
271        let s = sort(&a, None, SortKind::Quick).unwrap();
272        assert_eq!(s.as_slice().unwrap(), &[1.0, 1.0, 3.0, 4.0, 5.0]);
273    }
274
275    #[test]
276    fn test_sort_stable_preserves_order() {
277        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
278        let s = sort(&a, None, SortKind::Stable).unwrap();
279        assert_eq!(s.as_slice().unwrap(), &[1, 1, 3, 4, 5]);
280    }
281
282    #[test]
283    fn test_sort_2d_axis1() {
284        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
285            .unwrap();
286        let s = sort(&a, Some(1), SortKind::Quick).unwrap();
287        let data: Vec<f64> = s.iter().copied().collect();
288        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
289    }
290
291    #[test]
292    fn test_sort_2d_axis0() {
293        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0])
294            .unwrap();
295        let s = sort(&a, Some(0), SortKind::Quick).unwrap();
296        let data: Vec<f64> = s.iter().copied().collect();
297        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
298    }
299
300    #[test]
301    fn test_argsort_1d() {
302        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
303        let idx = argsort(&a, None).unwrap();
304        let data: Vec<u64> = idx.iter().copied().collect();
305        assert_eq!(data, vec![1, 3, 0, 2]);
306    }
307
308    #[test]
309    fn test_argsort_2d_axis1() {
310        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
311            .unwrap();
312        let idx = argsort(&a, Some(1)).unwrap();
313        let data: Vec<u64> = idx.iter().copied().collect();
314        assert_eq!(data, vec![1, 2, 0, 1, 2, 0]);
315    }
316
317    #[test]
318    fn test_searchsorted_left() {
319        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
320        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
321        let idx = searchsorted(&a, &v, Side::Left).unwrap();
322        let data: Vec<u64> = idx.iter().copied().collect();
323        assert_eq!(data, vec![2, 0, 5]);
324    }
325
326    #[test]
327    fn test_searchsorted_right() {
328        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
329        let v = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![2.0, 4.0]).unwrap();
330        let idx = searchsorted(&a, &v, Side::Right).unwrap();
331        let data: Vec<u64> = idx.iter().copied().collect();
332        assert_eq!(data, vec![2, 4]);
333    }
334
335    #[test]
336    fn test_sort_axis_out_of_bounds() {
337        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
338        assert!(sort(&a, Some(1), SortKind::Quick).is_err());
339    }
340}