Skip to main content

ferray_stats/
sorting.rs

1// ferray-stats: Sorting and searching — sort, argsort, searchsorted (REQ-11, REQ-12, REQ-13)
2
3use ferray_core::error::{FerrayError, FerrayResult};
4use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
5
6use crate::parallel;
7use crate::reductions::{compute_strides, flat_index, increment_multi_index};
8
9// ---------------------------------------------------------------------------
10// SortKind
11// ---------------------------------------------------------------------------
12
13/// Sorting algorithm selection.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SortKind {
16    /// Unstable quicksort (faster but does not preserve order of equal elements).
17    Quick,
18    /// Stable merge sort (preserves relative order of equal elements).
19    Stable,
20}
21
22// ---------------------------------------------------------------------------
23// Side (for searchsorted)
24// ---------------------------------------------------------------------------
25
26/// Side parameter for `searchsorted`.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Side {
29    /// Find the leftmost insertion point (first position where the value could be inserted).
30    Left,
31    /// Find the rightmost insertion point (last position where the value could be inserted).
32    Right,
33}
34
35// ---------------------------------------------------------------------------
36// sort
37// ---------------------------------------------------------------------------
38
39/// Sort an array along the given axis (or flattened if axis is None).
40///
41/// When `axis` is `None`, the array is flattened before sorting and a 1-D
42/// array is returned. When an axis is given, the returned array has the
43/// same shape as the input.
44///
45/// **Note:** `NumPy`'s `np.sort(a)` defaults to `axis=-1` (last axis).
46/// ferray's `sort(a, None, kind)` flattens instead. To match `NumPy`'s
47/// default, pass the last axis explicitly:
48/// `sort(a, Some(a.ndim() - 1), kind)`.
49///
50/// Equivalent to `numpy.sort`.
51pub fn sort<T, D>(
52    a: &Array<T, D>,
53    axis: Option<usize>,
54    kind: SortKind,
55) -> FerrayResult<Array<T, IxDyn>>
56where
57    T: Element + PartialOrd + Copy + Send + Sync,
58    D: Dimension,
59{
60    match axis {
61        None => {
62            // Flatten and sort — return a 1-D array (NumPy behaviour)
63            let mut data: Vec<T> = a.iter().copied().collect();
64            let n = data.len();
65            sort_slice(&mut data, kind);
66            Array::from_vec(IxDyn::new(&[n]), data)
67        }
68        Some(ax) => {
69            if ax >= a.ndim() {
70                return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
71            }
72            let shape = a.shape().to_vec();
73            let ndim = shape.len();
74            // Materialize once into a single buffer that we sort in
75            // place — the previous code allocated `data` plus a full
76            // `result = data.clone()` second copy (#171).
77            let mut buf: Vec<T> = a.iter().copied().collect();
78            let axis_len = shape[ax];
79
80            // Last-axis fast path: lanes are already contiguous in
81            // row-major order, so we can hand each `axis_len` window to
82            // `sort_slice` directly with no gather/scatter.
83            if ax == ndim - 1 {
84                for chunk in buf.chunks_exact_mut(axis_len) {
85                    sort_slice(chunk, kind);
86                }
87                return Array::from_vec(IxDyn::new(&shape), buf);
88            }
89
90            // General axis: gather a temporary lane, sort it, scatter
91            // values back into the same buffer.
92            let strides = compute_strides(&shape);
93            let out_shape: Vec<usize> = shape
94                .iter()
95                .enumerate()
96                .filter(|&(i, _)| i != ax)
97                .map(|(_, &s)| s)
98                .collect();
99            let out_size: usize = if out_shape.is_empty() {
100                1
101            } else {
102                out_shape.iter().product()
103            };
104
105            let mut out_multi = vec![0usize; out_shape.len()];
106            // Re-used per-lane scratch buffers to avoid `axis_len`
107            // re-allocations on every output position.
108            let mut in_multi = vec![0usize; ndim];
109            let mut lane: Vec<T> = Vec::with_capacity(axis_len);
110            let mut lane_indices: Vec<usize> = Vec::with_capacity(axis_len);
111
112            for _ in 0..out_size {
113                // Build input multi-index template
114                let mut out_dim = 0;
115                for (d, slot) in in_multi.iter_mut().enumerate() {
116                    if d == ax {
117                        *slot = 0;
118                    } else {
119                        *slot = out_multi[out_dim];
120                        out_dim += 1;
121                    }
122                }
123
124                lane.clear();
125                lane_indices.clear();
126                for k in 0..axis_len {
127                    in_multi[ax] = k;
128                    let idx = flat_index(&in_multi, &strides);
129                    lane.push(buf[idx]);
130                    lane_indices.push(idx);
131                }
132
133                sort_slice(&mut lane, kind);
134
135                // Scatter sorted values back into the in-place buffer.
136                for (k, &idx) in lane_indices.iter().enumerate() {
137                    buf[idx] = lane[k];
138                }
139
140                if !out_shape.is_empty() {
141                    increment_multi_index(&mut out_multi, &out_shape);
142                }
143            }
144
145            Array::from_vec(IxDyn::new(&shape), buf)
146        }
147    }
148}
149
150/// Sort a slice in place using the given algorithm.
151fn sort_slice<T: PartialOrd + Copy + Send + Sync>(data: &mut [T], kind: SortKind) {
152    match kind {
153        SortKind::Quick => {
154            parallel::parallel_sort(data);
155        }
156        SortKind::Stable => {
157            parallel::parallel_sort_stable(data);
158        }
159    }
160}
161
162// ---------------------------------------------------------------------------
163// argsort
164// ---------------------------------------------------------------------------
165
166/// Return the indices that would sort an array along the given axis.
167///
168/// When `axis` is `None`, the array is flattened before computing
169/// indices and a 1-D array is returned.
170///
171/// Returns u64 indices.
172///
173/// **Note:** `NumPy`'s `np.argsort(a)` defaults to `axis=-1` (last axis).
174/// ferray's `argsort(a, None)` flattens instead. To match `NumPy`'s
175/// default, pass the last axis explicitly: `argsort(a, Some(a.ndim() - 1))`.
176///
177/// Equivalent to `numpy.argsort`.
178pub fn argsort<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
179where
180    T: Element + PartialOrd + Copy,
181    D: Dimension,
182{
183    match axis {
184        None => {
185            let data: Vec<T> = a.iter().copied().collect();
186            let n = data.len();
187            let mut indices: Vec<usize> = (0..n).collect();
188            indices.sort_by(|&i, &j| {
189                data[i]
190                    .partial_cmp(&data[j])
191                    .unwrap_or(std::cmp::Ordering::Equal)
192            });
193            let result: Vec<u64> = indices.into_iter().map(|i| i as u64).collect();
194            Array::from_vec(IxDyn::new(&[n]), result)
195        }
196        Some(ax) => {
197            if ax >= a.ndim() {
198                return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
199            }
200            let shape = a.shape().to_vec();
201            let data: Vec<T> = a.iter().copied().collect();
202            let strides = compute_strides(&shape);
203            let ndim = shape.len();
204            let axis_len = shape[ax];
205
206            let out_shape: Vec<usize> = shape
207                .iter()
208                .enumerate()
209                .filter(|&(i, _)| i != ax)
210                .map(|(_, &s)| s)
211                .collect();
212            let out_size: usize = if out_shape.is_empty() {
213                1
214            } else {
215                out_shape.iter().product()
216            };
217
218            let mut result = vec![0u64; data.len()];
219            let mut out_multi = vec![0usize; out_shape.len()];
220
221            for _ in 0..out_size {
222                let mut in_multi = Vec::with_capacity(ndim);
223                let mut out_dim = 0;
224                for d in 0..ndim {
225                    if d == ax {
226                        in_multi.push(0);
227                    } else {
228                        in_multi.push(out_multi[out_dim]);
229                        out_dim += 1;
230                    }
231                }
232
233                // Gather lane values and their axis-local indices
234                let mut lane: Vec<(usize, T)> = Vec::with_capacity(axis_len);
235                let mut lane_flat_indices: Vec<usize> = Vec::with_capacity(axis_len);
236                for k in 0..axis_len {
237                    in_multi[ax] = k;
238                    let idx = flat_index(&in_multi, &strides);
239                    lane.push((k, data[idx]));
240                    lane_flat_indices.push(idx);
241                }
242
243                // Sort by value, tracking original axis-local index
244                lane.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
245
246                // Scatter the original axis-local indices into the result
247                for (k, &flat_idx) in lane_flat_indices.iter().enumerate() {
248                    result[flat_idx] = lane[k].0 as u64;
249                }
250
251                if !out_shape.is_empty() {
252                    increment_multi_index(&mut out_multi, &out_shape);
253                }
254            }
255
256            Array::from_vec(IxDyn::new(&shape), result)
257        }
258    }
259}
260
261// ---------------------------------------------------------------------------
262// partition / argpartition
263// ---------------------------------------------------------------------------
264
265/// Partial sort: rearrange elements so that `a[kth]` is the value that
266/// would be there in a sorted array, all elements before it are `<=`,
267/// and all elements after are `>=`. The relative order within the two
268/// halves is undefined.
269///
270/// Equivalent to `numpy.partition(a, kth)`. Uses `select_nth_unstable`
271/// for O(n) average-case performance (#466).
272///
273/// # Errors
274/// - `FerrayError::AxisOutOfBounds` if `kth >= a.size()`.
275pub fn partition<T>(a: &Array<T, Ix1>, kth: usize) -> FerrayResult<Array<T, Ix1>>
276where
277    T: Element + PartialOrd + Copy,
278{
279    let n = a.size();
280    if kth >= n {
281        return Err(FerrayError::invalid_value(format!(
282            "partition: kth={kth} out of range for array of size {n}"
283        )));
284    }
285    let mut data: Vec<T> = a.iter().copied().collect();
286    data.select_nth_unstable_by(kth, |x, y| {
287        x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal)
288    });
289    Array::from_vec(Ix1::new([n]), data)
290}
291
292/// Return indices that would partition the array. The k-th element of
293/// the result is the index of the k-th smallest element; elements
294/// before index kth are indices of smaller-or-equal elements, and
295/// elements after are indices of greater-or-equal.
296///
297/// Equivalent to `numpy.argpartition(a, kth)` (#466).
298pub fn argpartition<T>(a: &Array<T, Ix1>, kth: usize) -> FerrayResult<Array<u64, Ix1>>
299where
300    T: Element + PartialOrd + Copy,
301{
302    let n = a.size();
303    if kth >= n {
304        return Err(FerrayError::invalid_value(format!(
305            "argpartition: kth={kth} out of range for array of size {n}"
306        )));
307    }
308    let data: Vec<T> = a.iter().copied().collect();
309    let mut idx: Vec<u64> = (0..n as u64).collect();
310    idx.select_nth_unstable_by(kth, |&a_i, &b_i| {
311        let va = data[a_i as usize];
312        let vb = data[b_i as usize];
313        va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
314    });
315    Array::from_vec(Ix1::new([n]), idx)
316}
317
318// ---------------------------------------------------------------------------
319// lexsort
320// ---------------------------------------------------------------------------
321
322/// Indirect stable sort using a sequence of keys.
323///
324/// `keys` is a list of 1-D arrays of the same length. The **last** key
325/// in the list is the primary sort key (matching `NumPy`'s
326/// `numpy.lexsort` convention); ties are broken by the second-to-last
327/// key, then the third-to-last, and so on. Returns a permutation
328/// `idx` such that `keys[-1][idx]` is non-decreasing.
329///
330/// Implementation notes: the underlying sort is `sort_by` (stable),
331/// applied once with a comparator that walks the keys from primary
332/// (last) to secondary (earlier). This avoids the multi-pass stable
333/// sort that `NumPy` historically used.
334///
335/// # Errors
336/// - `FerrayError::InvalidValue` if `keys` is empty or the keys have
337///   different lengths.
338pub fn lexsort<T>(keys: &[&Array<T, Ix1>]) -> FerrayResult<Array<u64, Ix1>>
339where
340    T: Element + PartialOrd + Copy,
341{
342    if keys.is_empty() {
343        return Err(FerrayError::invalid_value(
344            "lexsort: keys must contain at least one array",
345        ));
346    }
347    let n = keys[0].size();
348    for (i, k) in keys.iter().enumerate().skip(1) {
349        if k.size() != n {
350            return Err(FerrayError::invalid_value(format!(
351                "lexsort: key {i} has length {}, expected {n}",
352                k.size()
353            )));
354        }
355    }
356
357    // Materialize each key into a contiguous Vec so the comparator
358    // can index into them directly without re-borrowing the array
359    // iterator on every comparison.
360    let key_data: Vec<Vec<T>> = keys.iter().map(|k| k.iter().copied().collect()).collect();
361
362    let mut idx: Vec<u64> = (0..n as u64).collect();
363    idx.sort_by(|&a, &b| {
364        let ai = a as usize;
365        let bi = b as usize;
366        // Iterate keys from primary (last) to secondary (earlier).
367        for k in key_data.iter().rev() {
368            match k[ai]
369                .partial_cmp(&k[bi])
370                .unwrap_or(std::cmp::Ordering::Equal)
371            {
372                std::cmp::Ordering::Equal => {}
373                ord => return ord,
374            }
375        }
376        std::cmp::Ordering::Equal
377    });
378
379    Array::from_vec(Ix1::new([n]), idx)
380}
381
382// ---------------------------------------------------------------------------
383// searchsorted
384// ---------------------------------------------------------------------------
385
386/// Find indices where elements should be inserted to maintain order.
387///
388/// `a` must be a sorted 1-D array. For each value in `v`, find the index
389/// in `a` where it should be inserted. Returns u64 indices.
390///
391/// Equivalent to `numpy.searchsorted` (without `sorter`). For an
392/// already-permuted view of an unsorted array, see [`searchsorted_with_sorter`].
393pub fn searchsorted<T>(
394    a: &Array<T, Ix1>,
395    v: &Array<T, Ix1>,
396    side: Side,
397) -> FerrayResult<Array<u64, Ix1>>
398where
399    T: Element + PartialOrd + Copy,
400{
401    let sorted: Vec<T> = a.iter().copied().collect();
402    searchsorted_inner(&sorted, v, side)
403}
404
405/// Find indices where elements should be inserted to maintain order,
406/// using `sorter` as a permutation that would sort `a`.
407///
408/// Mirrors `numpy.searchsorted(a, v, side, sorter)`. `a` may be in any
409/// order; `sorter[i]` gives the index in `a` of the i-th smallest
410/// element (i.e. `sorter` is the output of an `argsort` over `a`). The
411/// returned indices are positions into the **sorted** view, matching
412/// `NumPy`'s behaviour. See issue #473.
413///
414/// # Errors
415/// - `FerrayError::ShapeMismatch` if `sorter.len() != a.len()`.
416/// - `FerrayError::InvalidValue` if `sorter` contains an out-of-range index.
417pub fn searchsorted_with_sorter<T>(
418    a: &Array<T, Ix1>,
419    v: &Array<T, Ix1>,
420    side: Side,
421    sorter: &Array<u64, Ix1>,
422) -> FerrayResult<Array<u64, Ix1>>
423where
424    T: Element + PartialOrd + Copy,
425{
426    let n = a.size();
427    if sorter.size() != n {
428        return Err(FerrayError::shape_mismatch(format!(
429            "searchsorted: sorter length {} does not match array length {}",
430            sorter.size(),
431            n
432        )));
433    }
434
435    // Materialize `a` once and gather it in sorter order.
436    let a_data: Vec<T> = a.iter().copied().collect();
437    let mut sorted: Vec<T> = Vec::with_capacity(n);
438    for &idx in sorter.iter() {
439        let i = idx as usize;
440        if i >= n {
441            return Err(FerrayError::invalid_value(format!(
442                "searchsorted: sorter index {i} out of range for array of length {n}"
443            )));
444        }
445        sorted.push(a_data[i]);
446    }
447
448    searchsorted_inner(&sorted, v, side)
449}
450
451/// Shared binary-search core used by both [`searchsorted`] and
452/// [`searchsorted_with_sorter`].
453fn searchsorted_inner<T>(
454    sorted: &[T],
455    v: &Array<T, Ix1>,
456    side: Side,
457) -> FerrayResult<Array<u64, Ix1>>
458where
459    T: Element + PartialOrd + Copy,
460{
461    let mut result = Vec::with_capacity(v.size());
462    for &val in v.iter() {
463        let idx = match side {
464            Side::Left => sorted.partition_point(|x| {
465                x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less) == std::cmp::Ordering::Less
466            }),
467            Side::Right => sorted.partition_point(|x| {
468                x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less)
469                    != std::cmp::Ordering::Greater
470            }),
471        };
472        result.push(idx as u64);
473    }
474    let n = result.len();
475    Array::from_vec(Ix1::new([n]), result)
476}
477
478// ---------------------------------------------------------------------------
479// sort_complex
480// ---------------------------------------------------------------------------
481
482/// Sort a 1-D complex array, comparing first by real part, then by imaginary.
483///
484/// Equivalent to `numpy.sort_complex`. Always returns a stable sort —
485/// matching NumPy's behavior. Operates on the flattened input.
486pub fn sort_complex<T>(
487    a: &Array<num_complex::Complex<T>, Ix1>,
488) -> FerrayResult<Array<num_complex::Complex<T>, Ix1>>
489where
490    T: Element + num_traits::Float,
491    num_complex::Complex<T>: Element,
492{
493    let mut data: Vec<num_complex::Complex<T>> = a.iter().copied().collect();
494    data.sort_by(|x, y| {
495        let r = x.re.partial_cmp(&y.re).unwrap_or(std::cmp::Ordering::Equal);
496        if r != std::cmp::Ordering::Equal {
497            r
498        } else {
499            x.im.partial_cmp(&y.im).unwrap_or(std::cmp::Ordering::Equal)
500        }
501    });
502    let n = data.len();
503    Array::from_vec(Ix1::new([n]), data)
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use ferray_core::Ix2;
510
511    #[test]
512    fn test_sort_1d() {
513        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 4.0, 1.0, 5.0]).unwrap();
514        let s = sort(&a, None, SortKind::Quick).unwrap();
515        assert_eq!(s.shape(), &[5]);
516        let data: Vec<f64> = s.iter().copied().collect();
517        assert_eq!(data, vec![1.0, 1.0, 3.0, 4.0, 5.0]);
518    }
519
520    #[test]
521    fn test_sort_stable_preserves_order() {
522        let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
523        let s = sort(&a, None, SortKind::Stable).unwrap();
524        assert_eq!(s.shape(), &[5]);
525        let data: Vec<i32> = s.iter().copied().collect();
526        assert_eq!(data, vec![1, 1, 3, 4, 5]);
527    }
528
529    #[test]
530    fn test_sort_2d_axis_none_returns_flat() {
531        // Issue #91: sort(axis=None) should return a flat 1-D array
532        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![6.0, 4.0, 5.0, 3.0, 1.0, 2.0])
533            .unwrap();
534        let s = sort(&a, None, SortKind::Quick).unwrap();
535        // Must be 1-D with 6 elements, not [2, 3]
536        assert_eq!(s.shape(), &[6]);
537        let data: Vec<f64> = s.iter().copied().collect();
538        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
539    }
540
541    #[test]
542    fn test_sort_2d_axis1() {
543        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
544            .unwrap();
545        let s = sort(&a, Some(1), SortKind::Quick).unwrap();
546        assert_eq!(s.shape(), &[2, 3]);
547        let data: Vec<f64> = s.iter().copied().collect();
548        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
549    }
550
551    #[test]
552    fn test_sort_2d_axis0() {
553        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0])
554            .unwrap();
555        let s = sort(&a, Some(0), SortKind::Quick).unwrap();
556        assert_eq!(s.shape(), &[2, 3]);
557        let data: Vec<f64> = s.iter().copied().collect();
558        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
559    }
560
561    #[test]
562    fn test_argsort_1d() {
563        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
564        let idx = argsort(&a, None).unwrap();
565        assert_eq!(idx.shape(), &[4]);
566        let data: Vec<u64> = idx.iter().copied().collect();
567        assert_eq!(data, vec![1, 3, 0, 2]);
568    }
569
570    #[test]
571    fn test_argsort_2d_axis1() {
572        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
573            .unwrap();
574        let idx = argsort(&a, Some(1)).unwrap();
575        assert_eq!(idx.shape(), &[2, 3]);
576        let data: Vec<u64> = idx.iter().copied().collect();
577        assert_eq!(data, vec![1, 2, 0, 1, 2, 0]);
578    }
579
580    #[test]
581    fn test_searchsorted_left() {
582        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
583        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
584        let idx = searchsorted(&a, &v, Side::Left).unwrap();
585        let data: Vec<u64> = idx.iter().copied().collect();
586        assert_eq!(data, vec![2, 0, 5]);
587    }
588
589    #[test]
590    fn test_searchsorted_right() {
591        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
592        let v = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![2.0, 4.0]).unwrap();
593        let idx = searchsorted(&a, &v, Side::Right).unwrap();
594        let data: Vec<u64> = idx.iter().copied().collect();
595        assert_eq!(data, vec![2, 4]);
596    }
597
598    // ----- searchsorted_with_sorter (#473) -----
599
600    #[test]
601    fn test_searchsorted_with_sorter_matches_pre_sorted() {
602        // Unsorted `a` plus its argsort gives the same indices as
603        // calling searchsorted on the pre-sorted array.
604        let unsorted =
605            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 5.0, 2.0, 4.0]).unwrap();
606        // sorter so that unsorted[sorter] = [1.0, 2.0, 3.0, 4.0, 5.0]
607        let sorter = Array::<u64, Ix1>::from_vec(Ix1::new([5]), vec![1, 3, 0, 4, 2]).unwrap();
608        let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
609
610        let idx = searchsorted_with_sorter(&unsorted, &v, Side::Left, &sorter).unwrap();
611        assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![2, 0, 5]);
612    }
613
614    #[test]
615    fn test_searchsorted_with_sorter_length_mismatch_errors() {
616        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 5.0, 2.0]).unwrap();
617        let bad_sorter = Array::<u64, Ix1>::from_vec(Ix1::new([3]), vec![1, 3, 0]).unwrap();
618        let v = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![2.5]).unwrap();
619        assert!(searchsorted_with_sorter(&a, &v, Side::Left, &bad_sorter).is_err());
620    }
621
622    #[test]
623    fn test_searchsorted_with_sorter_out_of_range_errors() {
624        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![3.0, 1.0, 5.0]).unwrap();
625        let bad_sorter = Array::<u64, Ix1>::from_vec(Ix1::new([3]), vec![1, 99, 0]).unwrap();
626        let v = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![2.5]).unwrap();
627        assert!(searchsorted_with_sorter(&a, &v, Side::Left, &bad_sorter).is_err());
628    }
629
630    // ----- lexsort (#469) -----
631
632    #[test]
633    fn test_lexsort_single_key_matches_argsort() {
634        let k = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
635        let idx = lexsort(&[&k]).unwrap();
636        // Sorted order: 1@idx1, 1@idx3, 3@idx0, 4@idx2, 5@idx4
637        assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![1, 3, 0, 2, 4]);
638    }
639
640    #[test]
641    fn test_lexsort_secondary_key_breaks_ties() {
642        // Primary key (last in slice) sorts by ascending; ties resolved
643        // by the earlier key. Match NumPy's lexsort([secondary, primary]).
644        let secondary = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![20, 10, 40, 30]).unwrap();
645        let primary = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 1, 2]).unwrap();
646        let idx = lexsort(&[&secondary, &primary]).unwrap();
647        // primary buckets:
648        //   1 -> indices 0, 2 with secondary 20, 40 -> ordered as 0, 2
649        //   2 -> indices 1, 3 with secondary 10, 30 -> ordered as 1, 3
650        // result: [0, 2, 1, 3]
651        assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![0, 2, 1, 3]);
652    }
653
654    #[test]
655    fn test_lexsort_length_mismatch_errors() {
656        let k1 = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
657        let k2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
658        assert!(lexsort(&[&k1, &k2]).is_err());
659    }
660
661    #[test]
662    fn test_lexsort_empty_keys_errors() {
663        let keys: &[&Array<i32, Ix1>] = &[];
664        assert!(lexsort(keys).is_err());
665    }
666
667    #[test]
668    fn test_sort_axis_out_of_bounds() {
669        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
670        assert!(sort(&a, Some(1), SortKind::Quick).is_err());
671    }
672
673    // -- sort_complex --
674
675    #[test]
676    fn test_sort_complex_basic() {
677        use num_complex::Complex64;
678        let a = Array::<Complex64, Ix1>::from_vec(
679            Ix1::new([4]),
680            vec![
681                Complex64::new(2.0, 1.0),
682                Complex64::new(1.0, 5.0),
683                Complex64::new(2.0, -3.0),
684                Complex64::new(1.0, 2.0),
685            ],
686        )
687        .unwrap();
688        let r = sort_complex(&a).unwrap();
689        let v: Vec<Complex64> = r.iter().copied().collect();
690        // Sort by real first (1, 1, 2, 2), then by imag for ties
691        // 1+2i, 1+5i, 2-3i, 2+1i
692        assert_eq!(v[0], Complex64::new(1.0, 2.0));
693        assert_eq!(v[1], Complex64::new(1.0, 5.0));
694        assert_eq!(v[2], Complex64::new(2.0, -3.0));
695        assert_eq!(v[3], Complex64::new(2.0, 1.0));
696    }
697}