Skip to main content

ferray_stats/
sorting.rs

1// ferray-stats: Sorting and searching — sort, argsort, searchsorted (REQ-11, REQ-12, REQ-13)
2//
3// ## REQ status (ferray-stats sorting, NumPy parity)
4//  - REQ-11 (sort with `SortKind::Stable` (merge) and `SortKind::Quick`) —
5//    SHIPPED: `pub fn sort` (this file) dispatches on `SortKind`, routing large
6//    inputs through `parallel::parallel_sort` / `parallel::parallel_sort_stable`
7//    and using `nan_last_cmp` so NaN sorts last, matching numpy
8//    (numpy/_core/fromnumeric.py:879 `sort`; NaN-last per numpy/_core/src/npysort).
9//    Non-test consumer: the `ferray_stats::sort` `#[pyfunction]` shim in
10//    `ferray-python/src/stats.rs`.
11//  - REQ-12 (argsort returning index array) — SHIPPED: `pub fn argsort`
12//    returns `Array<u64, IxDyn>` (ferray's `intp` analog), matching
13//    numpy/_core/fromnumeric.py:1118 `argsort`. Consumer:
14//    `ferray_stats::argsort` `#[pyfunction]` shim in `stats.rs`.
15//  - REQ-13 (searchsorted with `Side::Left`/`Side::Right`) — SHIPPED:
16//    `pub fn searchsorted` and `pub fn searchsorted_with_sorter` (binary search,
17//    left/right side), matching numpy/_core/fromnumeric.py:1387 `searchsorted`.
18//    Consumer: `ferray_stats::searchsorted` `#[pyfunction]` shim in `stats.rs`.
19//  - partition / argpartition / lexsort / sort_complex — SHIPPED:
20//    `pub fn partition` / `pub fn argpartition` (introselect-style kth
21//    placement, numpy `partition`/`argpartition`), `pub fn lexsort` (stable
22//    multi-key sort, numpy/_core/fromnumeric.py:1041 `lexsort`), and
23//    `pub fn sort_complex` (sort by real then imaginary,
24//    numpy/lib/_function_base_impl.py `sort_complex`). Consumers:
25//    `ferray_stats::lexsort` / `ferray_stats::sort_complex` `#[pyfunction]`
26//    shims in `stats.rs`.
27
28use ferray_core::error::{FerrayError, FerrayResult};
29use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
30
31use crate::parallel;
32use crate::parallel::nan_last_cmp;
33use crate::reductions::{compute_strides, flat_index, increment_multi_index};
34
35// ---------------------------------------------------------------------------
36// SortKind
37// ---------------------------------------------------------------------------
38
39/// Sorting algorithm selection.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum SortKind {
42    /// Unstable quicksort (faster but does not preserve order of equal elements).
43    Quick,
44    /// Stable merge sort (preserves relative order of equal elements).
45    Stable,
46}
47
48// ---------------------------------------------------------------------------
49// Side (for searchsorted)
50// ---------------------------------------------------------------------------
51
52/// Side parameter for `searchsorted`.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum Side {
55    /// Find the leftmost insertion point (first position where the value could be inserted).
56    Left,
57    /// Find the rightmost insertion point (last position where the value could be inserted).
58    Right,
59}
60
61// ---------------------------------------------------------------------------
62// sort
63// ---------------------------------------------------------------------------
64
65/// Sort an array along the given axis (or flattened if axis is None).
66///
67/// When `axis` is `None`, the array is flattened before sorting and a 1-D
68/// array is returned. When an axis is given, the returned array has the
69/// same shape as the input.
70///
71/// **Note:** `NumPy`'s `np.sort(a)` defaults to `axis=-1` (last axis).
72/// ferray's `sort(a, None, kind)` flattens instead. To match `NumPy`'s
73/// default, pass the last axis explicitly:
74/// `sort(a, Some(a.ndim() - 1), kind)`.
75///
76/// Equivalent to `numpy.sort`.
77pub fn sort<T, D>(
78    a: &Array<T, D>,
79    axis: Option<usize>,
80    kind: SortKind,
81) -> FerrayResult<Array<T, IxDyn>>
82where
83    T: Element + PartialOrd + Copy + Send + Sync,
84    D: Dimension,
85{
86    match axis {
87        None => {
88            // Flatten and sort — return a 1-D array (NumPy behaviour)
89            let mut data: Vec<T> = a.iter().copied().collect();
90            let n = data.len();
91            sort_slice(&mut data, kind);
92            Array::from_vec(IxDyn::new(&[n]), data)
93        }
94        Some(ax) => {
95            if ax >= a.ndim() {
96                return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
97            }
98            let shape = a.shape().to_vec();
99            let ndim = shape.len();
100            // Materialize once into a single buffer that we sort in
101            // place — the previous code allocated `data` plus a full
102            // `result = data.clone()` second copy (#171).
103            let mut buf: Vec<T> = a.iter().copied().collect();
104            let axis_len = shape[ax];
105
106            // Last-axis fast path: lanes are already contiguous in
107            // row-major order, so we can hand each `axis_len` window to
108            // `sort_slice` directly with no gather/scatter.
109            if ax == ndim - 1 {
110                for chunk in buf.chunks_exact_mut(axis_len) {
111                    sort_slice(chunk, kind);
112                }
113                return Array::from_vec(IxDyn::new(&shape), buf);
114            }
115
116            // General axis: gather a temporary lane, sort it, scatter
117            // values back into the same buffer.
118            let strides = compute_strides(&shape);
119            let out_shape: Vec<usize> = shape
120                .iter()
121                .enumerate()
122                .filter(|&(i, _)| i != ax)
123                .map(|(_, &s)| s)
124                .collect();
125            let out_size: usize = if out_shape.is_empty() {
126                1
127            } else {
128                out_shape.iter().product()
129            };
130
131            let mut out_multi = vec![0usize; out_shape.len()];
132            // Re-used per-lane scratch buffers to avoid `axis_len`
133            // re-allocations on every output position.
134            let mut in_multi = vec![0usize; ndim];
135            let mut lane: Vec<T> = Vec::with_capacity(axis_len);
136            let mut lane_indices: Vec<usize> = Vec::with_capacity(axis_len);
137
138            for _ in 0..out_size {
139                // Build input multi-index template
140                let mut out_dim = 0;
141                for (d, slot) in in_multi.iter_mut().enumerate() {
142                    if d == ax {
143                        *slot = 0;
144                    } else {
145                        *slot = out_multi[out_dim];
146                        out_dim += 1;
147                    }
148                }
149
150                lane.clear();
151                lane_indices.clear();
152                for k in 0..axis_len {
153                    in_multi[ax] = k;
154                    let idx = flat_index(&in_multi, &strides);
155                    lane.push(buf[idx]);
156                    lane_indices.push(idx);
157                }
158
159                sort_slice(&mut lane, kind);
160
161                // Scatter sorted values back into the in-place buffer.
162                for (k, &idx) in lane_indices.iter().enumerate() {
163                    buf[idx] = lane[k];
164                }
165
166                if !out_shape.is_empty() {
167                    increment_multi_index(&mut out_multi, &out_shape);
168                }
169            }
170
171            Array::from_vec(IxDyn::new(&shape), buf)
172        }
173    }
174}
175
176/// Sort a slice in place using the given algorithm.
177fn sort_slice<T: PartialOrd + Copy + Send + Sync>(data: &mut [T], kind: SortKind) {
178    match kind {
179        SortKind::Quick => {
180            parallel::parallel_sort(data);
181        }
182        SortKind::Stable => {
183            parallel::parallel_sort_stable(data);
184        }
185    }
186}
187
188// ---------------------------------------------------------------------------
189// argsort
190// ---------------------------------------------------------------------------
191
192/// Return the indices that would sort an array along the given axis.
193///
194/// When `axis` is `None`, the array is flattened before computing
195/// indices and a 1-D array is returned.
196///
197/// Returns u64 indices.
198///
199/// **Note:** `NumPy`'s `np.argsort(a)` defaults to `axis=-1` (last axis).
200/// ferray's `argsort(a, None)` flattens instead. To match `NumPy`'s
201/// default, pass the last axis explicitly: `argsort(a, Some(a.ndim() - 1))`.
202///
203/// Equivalent to `numpy.argsort`.
204pub fn argsort<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
205where
206    T: Element + PartialOrd + Copy,
207    D: Dimension,
208{
209    match axis {
210        None => {
211            let data: Vec<T> = a.iter().copied().collect();
212            let n = data.len();
213            let mut indices: Vec<usize> = (0..n).collect();
214            indices.sort_by(|&i, &j| nan_last_cmp(&data[i], &data[j]));
215            let result: Vec<u64> = indices.into_iter().map(|i| i as u64).collect();
216            Array::from_vec(IxDyn::new(&[n]), result)
217        }
218        Some(ax) => {
219            if ax >= a.ndim() {
220                return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
221            }
222            let shape = a.shape().to_vec();
223            let data: Vec<T> = a.iter().copied().collect();
224            let strides = compute_strides(&shape);
225            let ndim = shape.len();
226            let axis_len = shape[ax];
227
228            let out_shape: Vec<usize> = shape
229                .iter()
230                .enumerate()
231                .filter(|&(i, _)| i != ax)
232                .map(|(_, &s)| s)
233                .collect();
234            let out_size: usize = if out_shape.is_empty() {
235                1
236            } else {
237                out_shape.iter().product()
238            };
239
240            let mut result = vec![0u64; data.len()];
241            let mut out_multi = vec![0usize; out_shape.len()];
242
243            for _ in 0..out_size {
244                let mut in_multi = Vec::with_capacity(ndim);
245                let mut out_dim = 0;
246                for d in 0..ndim {
247                    if d == ax {
248                        in_multi.push(0);
249                    } else {
250                        in_multi.push(out_multi[out_dim]);
251                        out_dim += 1;
252                    }
253                }
254
255                // Gather lane values and their axis-local indices
256                let mut lane: Vec<(usize, T)> = Vec::with_capacity(axis_len);
257                let mut lane_flat_indices: Vec<usize> = Vec::with_capacity(axis_len);
258                for k in 0..axis_len {
259                    in_multi[ax] = k;
260                    let idx = flat_index(&in_multi, &strides);
261                    lane.push((k, data[idx]));
262                    lane_flat_indices.push(idx);
263                }
264
265                // Sort by value, tracking original axis-local index
266                lane.sort_by(|a, b| nan_last_cmp(&a.1, &b.1));
267
268                // Scatter the original axis-local indices into the result
269                for (k, &flat_idx) in lane_flat_indices.iter().enumerate() {
270                    result[flat_idx] = lane[k].0 as u64;
271                }
272
273                if !out_shape.is_empty() {
274                    increment_multi_index(&mut out_multi, &out_shape);
275                }
276            }
277
278            Array::from_vec(IxDyn::new(&shape), result)
279        }
280    }
281}
282
283// ---------------------------------------------------------------------------
284// partition / argpartition
285// ---------------------------------------------------------------------------
286
287/// Partial sort: rearrange elements so that `a[kth]` is the value that
288/// would be there in a sorted array, all elements before it are `<=`,
289/// and all elements after are `>=`. The relative order within the two
290/// halves is undefined.
291///
292/// Equivalent to `numpy.partition(a, kth)`. Uses `select_nth_unstable`
293/// for O(n) average-case performance (#466).
294///
295/// # Errors
296/// - `FerrayError::AxisOutOfBounds` if `kth >= a.size()`.
297pub fn partition<T>(a: &Array<T, Ix1>, kth: usize) -> FerrayResult<Array<T, Ix1>>
298where
299    T: Element + PartialOrd + Copy,
300{
301    let n = a.size();
302    if kth >= n {
303        return Err(FerrayError::invalid_value(format!(
304            "partition: kth={kth} out of range for array of size {n}"
305        )));
306    }
307    let mut data: Vec<T> = a.iter().copied().collect();
308    data.select_nth_unstable_by(kth, nan_last_cmp);
309    Array::from_vec(Ix1::new([n]), data)
310}
311
312/// Return indices that would partition the array. The k-th element of
313/// the result is the index of the k-th smallest element; elements
314/// before index kth are indices of smaller-or-equal elements, and
315/// elements after are indices of greater-or-equal.
316///
317/// Equivalent to `numpy.argpartition(a, kth)` (#466).
318pub fn argpartition<T>(a: &Array<T, Ix1>, kth: usize) -> FerrayResult<Array<u64, Ix1>>
319where
320    T: Element + PartialOrd + Copy,
321{
322    let n = a.size();
323    if kth >= n {
324        return Err(FerrayError::invalid_value(format!(
325            "argpartition: kth={kth} out of range for array of size {n}"
326        )));
327    }
328    let data: Vec<T> = a.iter().copied().collect();
329    let mut idx: Vec<u64> = (0..n as u64).collect();
330    idx.select_nth_unstable_by(kth, |&a_i, &b_i| {
331        nan_last_cmp(&data[a_i as usize], &data[b_i as usize])
332    });
333    Array::from_vec(Ix1::new([n]), idx)
334}
335
336// ---------------------------------------------------------------------------
337// lexsort
338// ---------------------------------------------------------------------------
339
340/// Indirect stable sort using a sequence of keys.
341///
342/// `keys` is a list of 1-D arrays of the same length. The **last** key
343/// in the list is the primary sort key (matching `NumPy`'s
344/// `numpy.lexsort` convention); ties are broken by the second-to-last
345/// key, then the third-to-last, and so on. Returns a permutation
346/// `idx` such that `keys[-1][idx]` is non-decreasing.
347///
348/// Implementation notes: the underlying sort is `sort_by` (stable),
349/// applied once with a comparator that walks the keys from primary
350/// (last) to secondary (earlier). This avoids the multi-pass stable
351/// sort that `NumPy` historically used.
352///
353/// # Errors
354/// - `FerrayError::InvalidValue` if `keys` is empty or the keys have
355///   different lengths.
356pub fn lexsort<T>(keys: &[&Array<T, Ix1>]) -> FerrayResult<Array<u64, Ix1>>
357where
358    T: Element + PartialOrd + Copy,
359{
360    if keys.is_empty() {
361        return Err(FerrayError::invalid_value(
362            "lexsort: keys must contain at least one array",
363        ));
364    }
365    let n = keys[0].size();
366    for (i, k) in keys.iter().enumerate().skip(1) {
367        if k.size() != n {
368            return Err(FerrayError::invalid_value(format!(
369                "lexsort: key {i} has length {}, expected {n}",
370                k.size()
371            )));
372        }
373    }
374
375    // Materialize each key into a contiguous Vec so the comparator
376    // can index into them directly without re-borrowing the array
377    // iterator on every comparison.
378    let key_data: Vec<Vec<T>> = keys.iter().map(|k| k.iter().copied().collect()).collect();
379
380    let mut idx: Vec<u64> = (0..n as u64).collect();
381    idx.sort_by(|&a, &b| {
382        let ai = a as usize;
383        let bi = b as usize;
384        // Iterate keys from primary (last) to secondary (earlier).
385        for k in key_data.iter().rev() {
386            match nan_last_cmp(&k[ai], &k[bi]) {
387                std::cmp::Ordering::Equal => {}
388                ord => return ord,
389            }
390        }
391        std::cmp::Ordering::Equal
392    });
393
394    Array::from_vec(Ix1::new([n]), idx)
395}
396
397// ---------------------------------------------------------------------------
398// searchsorted
399// ---------------------------------------------------------------------------
400
401/// Find indices where elements should be inserted to maintain order.
402///
403/// `a` must be a sorted 1-D array. For each value in `v`, find the index
404/// in `a` where it should be inserted. Returns u64 indices.
405///
406/// Equivalent to `numpy.searchsorted` (without `sorter`). For an
407/// already-permuted view of an unsorted array, see [`searchsorted_with_sorter`].
408pub fn searchsorted<T>(
409    a: &Array<T, Ix1>,
410    v: &Array<T, Ix1>,
411    side: Side,
412) -> FerrayResult<Array<u64, Ix1>>
413where
414    T: Element + PartialOrd + Copy,
415{
416    let sorted: Vec<T> = a.iter().copied().collect();
417    searchsorted_inner(&sorted, v, side)
418}
419
420/// Find indices where elements should be inserted to maintain order,
421/// using `sorter` as a permutation that would sort `a`.
422///
423/// Mirrors `numpy.searchsorted(a, v, side, sorter)`. `a` may be in any
424/// order; `sorter[i]` gives the index in `a` of the i-th smallest
425/// element (i.e. `sorter` is the output of an `argsort` over `a`). The
426/// returned indices are positions into the **sorted** view, matching
427/// `NumPy`'s behaviour. See issue #473.
428///
429/// # Errors
430/// - `FerrayError::ShapeMismatch` if `sorter.len() != a.len()`.
431/// - `FerrayError::InvalidValue` if `sorter` contains an out-of-range index.
432pub fn searchsorted_with_sorter<T>(
433    a: &Array<T, Ix1>,
434    v: &Array<T, Ix1>,
435    side: Side,
436    sorter: &Array<u64, Ix1>,
437) -> FerrayResult<Array<u64, Ix1>>
438where
439    T: Element + PartialOrd + Copy,
440{
441    let n = a.size();
442    if sorter.size() != n {
443        return Err(FerrayError::shape_mismatch(format!(
444            "searchsorted: sorter length {} does not match array length {}",
445            sorter.size(),
446            n
447        )));
448    }
449
450    // Materialize `a` once and gather it in sorter order.
451    let a_data: Vec<T> = a.iter().copied().collect();
452    let mut sorted: Vec<T> = Vec::with_capacity(n);
453    for &idx in sorter.iter() {
454        let i = idx as usize;
455        if i >= n {
456            return Err(FerrayError::invalid_value(format!(
457                "searchsorted: sorter index {i} out of range for array of length {n}"
458            )));
459        }
460        sorted.push(a_data[i]);
461    }
462
463    searchsorted_inner(&sorted, v, side)
464}
465
466/// Shared binary-search core used by both [`searchsorted`] and
467/// [`searchsorted_with_sorter`].
468fn searchsorted_inner<T>(
469    sorted: &[T],
470    v: &Array<T, Ix1>,
471    side: Side,
472) -> FerrayResult<Array<u64, Ix1>>
473where
474    T: Element + PartialOrd + Copy,
475{
476    let mut result = Vec::with_capacity(v.size());
477    for &val in v.iter() {
478        let idx = match side {
479            Side::Left => {
480                sorted.partition_point(|x| nan_last_cmp(x, &val) == std::cmp::Ordering::Less)
481            }
482            Side::Right => {
483                sorted.partition_point(|x| nan_last_cmp(x, &val) != std::cmp::Ordering::Greater)
484            }
485        };
486        result.push(idx as u64);
487    }
488    let n = result.len();
489    Array::from_vec(Ix1::new([n]), result)
490}
491
492// ---------------------------------------------------------------------------
493// sort_complex
494// ---------------------------------------------------------------------------
495
496/// Sort a 1-D complex array, comparing first by real part, then by imaginary.
497///
498/// Equivalent to `numpy.sort_complex`. Always returns a stable sort —
499/// matching NumPy's behavior. Operates on the flattened input.
500pub fn sort_complex<T>(
501    a: &Array<num_complex::Complex<T>, Ix1>,
502) -> FerrayResult<Array<num_complex::Complex<T>, Ix1>>
503where
504    T: Element + num_traits::Float,
505    num_complex::Complex<T>: Element,
506{
507    let mut data: Vec<num_complex::Complex<T>> = a.iter().copied().collect();
508    data.sort_by(|x, y| {
509        let r = nan_last_cmp(&x.re, &y.re);
510        if r != std::cmp::Ordering::Equal {
511            r
512        } else {
513            nan_last_cmp(&x.im, &y.im)
514        }
515    });
516    let n = data.len();
517    Array::from_vec(Ix1::new([n]), data)
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use ferray_core::Ix2;
524
525    #[test]
526    fn test_sort_1d() {
527        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 4.0, 1.0, 5.0]).unwrap();
528        let s = sort(&a, None, SortKind::Quick).unwrap();
529        assert_eq!(s.shape(), &[5]);
530        let data: Vec<f64> = s.iter().copied().collect();
531        assert_eq!(data, vec![1.0, 1.0, 3.0, 4.0, 5.0]);
532    }
533
534    #[test]
535    fn test_sort_stable_preserves_order() {
536        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
537        let s = sort(&a, None, SortKind::Stable).unwrap();
538        assert_eq!(s.shape(), &[5]);
539        let data: Vec<i32> = s.iter().copied().collect();
540        assert_eq!(data, vec![1, 1, 3, 4, 5]);
541    }
542
543    #[test]
544    fn test_sort_2d_axis_none_returns_flat() {
545        // Issue #91: sort(axis=None) should return a flat 1-D array
546        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![6.0, 4.0, 5.0, 3.0, 1.0, 2.0])
547            .unwrap();
548        let s = sort(&a, None, SortKind::Quick).unwrap();
549        // Must be 1-D with 6 elements, not [2, 3]
550        assert_eq!(s.shape(), &[6]);
551        let data: Vec<f64> = s.iter().copied().collect();
552        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
553    }
554
555    #[test]
556    fn test_sort_2d_axis1() {
557        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
558            .unwrap();
559        let s = sort(&a, Some(1), SortKind::Quick).unwrap();
560        assert_eq!(s.shape(), &[2, 3]);
561        let data: Vec<f64> = s.iter().copied().collect();
562        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
563    }
564
565    #[test]
566    fn test_sort_2d_axis0() {
567        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0])
568            .unwrap();
569        let s = sort(&a, Some(0), SortKind::Quick).unwrap();
570        assert_eq!(s.shape(), &[2, 3]);
571        let data: Vec<f64> = s.iter().copied().collect();
572        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
573    }
574
575    #[test]
576    fn test_argsort_1d() {
577        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
578        let idx = argsort(&a, None).unwrap();
579        assert_eq!(idx.shape(), &[4]);
580        let data: Vec<u64> = idx.iter().copied().collect();
581        assert_eq!(data, vec![1, 3, 0, 2]);
582    }
583
584    #[test]
585    fn test_argsort_2d_axis1() {
586        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
587            .unwrap();
588        let idx = argsort(&a, Some(1)).unwrap();
589        assert_eq!(idx.shape(), &[2, 3]);
590        let data: Vec<u64> = idx.iter().copied().collect();
591        assert_eq!(data, vec![1, 2, 0, 1, 2, 0]);
592    }
593
594    #[test]
595    fn test_searchsorted_left() {
596        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
597        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
598        let idx = searchsorted(&a, &v, Side::Left).unwrap();
599        let data: Vec<u64> = idx.iter().copied().collect();
600        assert_eq!(data, vec![2, 0, 5]);
601    }
602
603    #[test]
604    fn test_searchsorted_right() {
605        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
606        let v = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![2.0, 4.0]).unwrap();
607        let idx = searchsorted(&a, &v, Side::Right).unwrap();
608        let data: Vec<u64> = idx.iter().copied().collect();
609        assert_eq!(data, vec![2, 4]);
610    }
611
612    // ----- searchsorted_with_sorter (#473) -----
613
614    #[test]
615    fn test_searchsorted_with_sorter_matches_pre_sorted() {
616        // Unsorted `a` plus its argsort gives the same indices as
617        // calling searchsorted on the pre-sorted array.
618        let unsorted =
619            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 5.0, 2.0, 4.0]).unwrap();
620        // sorter so that unsorted[sorter] = [1.0, 2.0, 3.0, 4.0, 5.0]
621        let sorter = Array::<u64, Ix1>::from_vec(Ix1::new([5]), vec![1, 3, 0, 4, 2]).unwrap();
622        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
623
624        let idx = searchsorted_with_sorter(&unsorted, &v, Side::Left, &sorter).unwrap();
625        assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![2, 0, 5]);
626    }
627
628    #[test]
629    fn test_searchsorted_with_sorter_length_mismatch_errors() {
630        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 5.0, 2.0]).unwrap();
631        let bad_sorter = Array::<u64, Ix1>::from_vec(Ix1::new([3]), vec![1, 3, 0]).unwrap();
632        let v = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![2.5]).unwrap();
633        assert!(searchsorted_with_sorter(&a, &v, Side::Left, &bad_sorter).is_err());
634    }
635
636    #[test]
637    fn test_searchsorted_with_sorter_out_of_range_errors() {
638        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![3.0, 1.0, 5.0]).unwrap();
639        let bad_sorter = Array::<u64, Ix1>::from_vec(Ix1::new([3]), vec![1, 99, 0]).unwrap();
640        let v = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![2.5]).unwrap();
641        assert!(searchsorted_with_sorter(&a, &v, Side::Left, &bad_sorter).is_err());
642    }
643
644    // ----- lexsort (#469) -----
645
646    #[test]
647    fn test_lexsort_single_key_matches_argsort() {
648        let k = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
649        let idx = lexsort(&[&k]).unwrap();
650        // Sorted order: 1@idx1, 1@idx3, 3@idx0, 4@idx2, 5@idx4
651        assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![1, 3, 0, 2, 4]);
652    }
653
654    #[test]
655    fn test_lexsort_secondary_key_breaks_ties() {
656        // Primary key (last in slice) sorts by ascending; ties resolved
657        // by the earlier key. Match NumPy's lexsort([secondary, primary]).
658        let secondary = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![20, 10, 40, 30]).unwrap();
659        let primary = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 1, 2]).unwrap();
660        let idx = lexsort(&[&secondary, &primary]).unwrap();
661        // primary buckets:
662        //   1 -> indices 0, 2 with secondary 20, 40 -> ordered as 0, 2
663        //   2 -> indices 1, 3 with secondary 10, 30 -> ordered as 1, 3
664        // result: [0, 2, 1, 3]
665        assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![0, 2, 1, 3]);
666    }
667
668    #[test]
669    fn test_lexsort_length_mismatch_errors() {
670        let k1 = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
671        let k2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
672        assert!(lexsort(&[&k1, &k2]).is_err());
673    }
674
675    #[test]
676    fn test_lexsort_empty_keys_errors() {
677        let keys: &[&Array<i32, Ix1>] = &[];
678        assert!(lexsort(keys).is_err());
679    }
680
681    #[test]
682    fn test_sort_axis_out_of_bounds() {
683        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
684        assert!(sort(&a, Some(1), SortKind::Quick).is_err());
685    }
686
687    // -- sort_complex --
688
689    #[test]
690    fn test_sort_complex_basic() {
691        use num_complex::Complex64;
692        let a = Array::<Complex64, Ix1>::from_vec(
693            Ix1::new([4]),
694            vec![
695                Complex64::new(2.0, 1.0),
696                Complex64::new(1.0, 5.0),
697                Complex64::new(2.0, -3.0),
698                Complex64::new(1.0, 2.0),
699            ],
700        )
701        .unwrap();
702        let r = sort_complex(&a).unwrap();
703        let v: Vec<Complex64> = r.iter().copied().collect();
704        // Sort by real first (1, 1, 2, 2), then by imag for ties
705        // 1+2i, 1+5i, 2-3i, 2+1i
706        assert_eq!(v[0], Complex64::new(1.0, 2.0));
707        assert_eq!(v[1], Complex64::new(1.0, 5.0));
708        assert_eq!(v[2], Complex64::new(2.0, -3.0));
709        assert_eq!(v[3], Complex64::new(2.0, 1.0));
710    }
711}