1use ferray_core::error::{FerrayError, FerrayResult};
29use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
30
31use crate::parallel;
32use crate::parallel::nan_last_cmp;
33use crate::reductions::{compute_strides, flat_index, increment_multi_index};
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum SortKind {
42 Quick,
44 Stable,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum Side {
55 Left,
57 Right,
59}
60
61pub fn sort<T, D>(
78 a: &Array<T, D>,
79 axis: Option<usize>,
80 kind: SortKind,
81) -> FerrayResult<Array<T, IxDyn>>
82where
83 T: Element + PartialOrd + Copy + Send + Sync,
84 D: Dimension,
85{
86 match axis {
87 None => {
88 let mut data: Vec<T> = a.iter().copied().collect();
90 let n = data.len();
91 sort_slice(&mut data, kind);
92 Array::from_vec(IxDyn::new(&[n]), data)
93 }
94 Some(ax) => {
95 if ax >= a.ndim() {
96 return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
97 }
98 let shape = a.shape().to_vec();
99 let ndim = shape.len();
100 let mut buf: Vec<T> = a.iter().copied().collect();
104 let axis_len = shape[ax];
105
106 if ax == ndim - 1 {
110 for chunk in buf.chunks_exact_mut(axis_len) {
111 sort_slice(chunk, kind);
112 }
113 return Array::from_vec(IxDyn::new(&shape), buf);
114 }
115
116 let strides = compute_strides(&shape);
119 let out_shape: Vec<usize> = shape
120 .iter()
121 .enumerate()
122 .filter(|&(i, _)| i != ax)
123 .map(|(_, &s)| s)
124 .collect();
125 let out_size: usize = if out_shape.is_empty() {
126 1
127 } else {
128 out_shape.iter().product()
129 };
130
131 let mut out_multi = vec![0usize; out_shape.len()];
132 let mut in_multi = vec![0usize; ndim];
135 let mut lane: Vec<T> = Vec::with_capacity(axis_len);
136 let mut lane_indices: Vec<usize> = Vec::with_capacity(axis_len);
137
138 for _ in 0..out_size {
139 let mut out_dim = 0;
141 for (d, slot) in in_multi.iter_mut().enumerate() {
142 if d == ax {
143 *slot = 0;
144 } else {
145 *slot = out_multi[out_dim];
146 out_dim += 1;
147 }
148 }
149
150 lane.clear();
151 lane_indices.clear();
152 for k in 0..axis_len {
153 in_multi[ax] = k;
154 let idx = flat_index(&in_multi, &strides);
155 lane.push(buf[idx]);
156 lane_indices.push(idx);
157 }
158
159 sort_slice(&mut lane, kind);
160
161 for (k, &idx) in lane_indices.iter().enumerate() {
163 buf[idx] = lane[k];
164 }
165
166 if !out_shape.is_empty() {
167 increment_multi_index(&mut out_multi, &out_shape);
168 }
169 }
170
171 Array::from_vec(IxDyn::new(&shape), buf)
172 }
173 }
174}
175
176fn sort_slice<T: PartialOrd + Copy + Send + Sync>(data: &mut [T], kind: SortKind) {
178 match kind {
179 SortKind::Quick => {
180 parallel::parallel_sort(data);
181 }
182 SortKind::Stable => {
183 parallel::parallel_sort_stable(data);
184 }
185 }
186}
187
188pub fn argsort<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
205where
206 T: Element + PartialOrd + Copy,
207 D: Dimension,
208{
209 match axis {
210 None => {
211 let data: Vec<T> = a.iter().copied().collect();
212 let n = data.len();
213 let mut indices: Vec<usize> = (0..n).collect();
214 indices.sort_by(|&i, &j| nan_last_cmp(&data[i], &data[j]));
215 let result: Vec<u64> = indices.into_iter().map(|i| i as u64).collect();
216 Array::from_vec(IxDyn::new(&[n]), result)
217 }
218 Some(ax) => {
219 if ax >= a.ndim() {
220 return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
221 }
222 let shape = a.shape().to_vec();
223 let data: Vec<T> = a.iter().copied().collect();
224 let strides = compute_strides(&shape);
225 let ndim = shape.len();
226 let axis_len = shape[ax];
227
228 let out_shape: Vec<usize> = shape
229 .iter()
230 .enumerate()
231 .filter(|&(i, _)| i != ax)
232 .map(|(_, &s)| s)
233 .collect();
234 let out_size: usize = if out_shape.is_empty() {
235 1
236 } else {
237 out_shape.iter().product()
238 };
239
240 let mut result = vec![0u64; data.len()];
241 let mut out_multi = vec![0usize; out_shape.len()];
242
243 for _ in 0..out_size {
244 let mut in_multi = Vec::with_capacity(ndim);
245 let mut out_dim = 0;
246 for d in 0..ndim {
247 if d == ax {
248 in_multi.push(0);
249 } else {
250 in_multi.push(out_multi[out_dim]);
251 out_dim += 1;
252 }
253 }
254
255 let mut lane: Vec<(usize, T)> = Vec::with_capacity(axis_len);
257 let mut lane_flat_indices: Vec<usize> = Vec::with_capacity(axis_len);
258 for k in 0..axis_len {
259 in_multi[ax] = k;
260 let idx = flat_index(&in_multi, &strides);
261 lane.push((k, data[idx]));
262 lane_flat_indices.push(idx);
263 }
264
265 lane.sort_by(|a, b| nan_last_cmp(&a.1, &b.1));
267
268 for (k, &flat_idx) in lane_flat_indices.iter().enumerate() {
270 result[flat_idx] = lane[k].0 as u64;
271 }
272
273 if !out_shape.is_empty() {
274 increment_multi_index(&mut out_multi, &out_shape);
275 }
276 }
277
278 Array::from_vec(IxDyn::new(&shape), result)
279 }
280 }
281}
282
283pub fn partition<T>(a: &Array<T, Ix1>, kth: usize) -> FerrayResult<Array<T, Ix1>>
298where
299 T: Element + PartialOrd + Copy,
300{
301 let n = a.size();
302 if kth >= n {
303 return Err(FerrayError::invalid_value(format!(
304 "partition: kth={kth} out of range for array of size {n}"
305 )));
306 }
307 let mut data: Vec<T> = a.iter().copied().collect();
308 data.select_nth_unstable_by(kth, nan_last_cmp);
309 Array::from_vec(Ix1::new([n]), data)
310}
311
312pub fn argpartition<T>(a: &Array<T, Ix1>, kth: usize) -> FerrayResult<Array<u64, Ix1>>
319where
320 T: Element + PartialOrd + Copy,
321{
322 let n = a.size();
323 if kth >= n {
324 return Err(FerrayError::invalid_value(format!(
325 "argpartition: kth={kth} out of range for array of size {n}"
326 )));
327 }
328 let data: Vec<T> = a.iter().copied().collect();
329 let mut idx: Vec<u64> = (0..n as u64).collect();
330 idx.select_nth_unstable_by(kth, |&a_i, &b_i| {
331 nan_last_cmp(&data[a_i as usize], &data[b_i as usize])
332 });
333 Array::from_vec(Ix1::new([n]), idx)
334}
335
336pub fn lexsort<T>(keys: &[&Array<T, Ix1>]) -> FerrayResult<Array<u64, Ix1>>
357where
358 T: Element + PartialOrd + Copy,
359{
360 if keys.is_empty() {
361 return Err(FerrayError::invalid_value(
362 "lexsort: keys must contain at least one array",
363 ));
364 }
365 let n = keys[0].size();
366 for (i, k) in keys.iter().enumerate().skip(1) {
367 if k.size() != n {
368 return Err(FerrayError::invalid_value(format!(
369 "lexsort: key {i} has length {}, expected {n}",
370 k.size()
371 )));
372 }
373 }
374
375 let key_data: Vec<Vec<T>> = keys.iter().map(|k| k.iter().copied().collect()).collect();
379
380 let mut idx: Vec<u64> = (0..n as u64).collect();
381 idx.sort_by(|&a, &b| {
382 let ai = a as usize;
383 let bi = b as usize;
384 for k in key_data.iter().rev() {
386 match nan_last_cmp(&k[ai], &k[bi]) {
387 std::cmp::Ordering::Equal => {}
388 ord => return ord,
389 }
390 }
391 std::cmp::Ordering::Equal
392 });
393
394 Array::from_vec(Ix1::new([n]), idx)
395}
396
397pub fn searchsorted<T>(
409 a: &Array<T, Ix1>,
410 v: &Array<T, Ix1>,
411 side: Side,
412) -> FerrayResult<Array<u64, Ix1>>
413where
414 T: Element + PartialOrd + Copy,
415{
416 let sorted: Vec<T> = a.iter().copied().collect();
417 searchsorted_inner(&sorted, v, side)
418}
419
420pub fn searchsorted_with_sorter<T>(
433 a: &Array<T, Ix1>,
434 v: &Array<T, Ix1>,
435 side: Side,
436 sorter: &Array<u64, Ix1>,
437) -> FerrayResult<Array<u64, Ix1>>
438where
439 T: Element + PartialOrd + Copy,
440{
441 let n = a.size();
442 if sorter.size() != n {
443 return Err(FerrayError::shape_mismatch(format!(
444 "searchsorted: sorter length {} does not match array length {}",
445 sorter.size(),
446 n
447 )));
448 }
449
450 let a_data: Vec<T> = a.iter().copied().collect();
452 let mut sorted: Vec<T> = Vec::with_capacity(n);
453 for &idx in sorter.iter() {
454 let i = idx as usize;
455 if i >= n {
456 return Err(FerrayError::invalid_value(format!(
457 "searchsorted: sorter index {i} out of range for array of length {n}"
458 )));
459 }
460 sorted.push(a_data[i]);
461 }
462
463 searchsorted_inner(&sorted, v, side)
464}
465
466fn searchsorted_inner<T>(
469 sorted: &[T],
470 v: &Array<T, Ix1>,
471 side: Side,
472) -> FerrayResult<Array<u64, Ix1>>
473where
474 T: Element + PartialOrd + Copy,
475{
476 let mut result = Vec::with_capacity(v.size());
477 for &val in v.iter() {
478 let idx = match side {
479 Side::Left => {
480 sorted.partition_point(|x| nan_last_cmp(x, &val) == std::cmp::Ordering::Less)
481 }
482 Side::Right => {
483 sorted.partition_point(|x| nan_last_cmp(x, &val) != std::cmp::Ordering::Greater)
484 }
485 };
486 result.push(idx as u64);
487 }
488 let n = result.len();
489 Array::from_vec(Ix1::new([n]), result)
490}
491
492pub fn sort_complex<T>(
501 a: &Array<num_complex::Complex<T>, Ix1>,
502) -> FerrayResult<Array<num_complex::Complex<T>, Ix1>>
503where
504 T: Element + num_traits::Float,
505 num_complex::Complex<T>: Element,
506{
507 let mut data: Vec<num_complex::Complex<T>> = a.iter().copied().collect();
508 data.sort_by(|x, y| {
509 let r = nan_last_cmp(&x.re, &y.re);
510 if r != std::cmp::Ordering::Equal {
511 r
512 } else {
513 nan_last_cmp(&x.im, &y.im)
514 }
515 });
516 let n = data.len();
517 Array::from_vec(Ix1::new([n]), data)
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use ferray_core::Ix2;
524
525 #[test]
526 fn test_sort_1d() {
527 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 4.0, 1.0, 5.0]).unwrap();
528 let s = sort(&a, None, SortKind::Quick).unwrap();
529 assert_eq!(s.shape(), &[5]);
530 let data: Vec<f64> = s.iter().copied().collect();
531 assert_eq!(data, vec![1.0, 1.0, 3.0, 4.0, 5.0]);
532 }
533
534 #[test]
535 fn test_sort_stable_preserves_order() {
536 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
537 let s = sort(&a, None, SortKind::Stable).unwrap();
538 assert_eq!(s.shape(), &[5]);
539 let data: Vec<i32> = s.iter().copied().collect();
540 assert_eq!(data, vec![1, 1, 3, 4, 5]);
541 }
542
543 #[test]
544 fn test_sort_2d_axis_none_returns_flat() {
545 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![6.0, 4.0, 5.0, 3.0, 1.0, 2.0])
547 .unwrap();
548 let s = sort(&a, None, SortKind::Quick).unwrap();
549 assert_eq!(s.shape(), &[6]);
551 let data: Vec<f64> = s.iter().copied().collect();
552 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
553 }
554
555 #[test]
556 fn test_sort_2d_axis1() {
557 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
558 .unwrap();
559 let s = sort(&a, Some(1), SortKind::Quick).unwrap();
560 assert_eq!(s.shape(), &[2, 3]);
561 let data: Vec<f64> = s.iter().copied().collect();
562 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
563 }
564
565 #[test]
566 fn test_sort_2d_axis0() {
567 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0])
568 .unwrap();
569 let s = sort(&a, Some(0), SortKind::Quick).unwrap();
570 assert_eq!(s.shape(), &[2, 3]);
571 let data: Vec<f64> = s.iter().copied().collect();
572 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
573 }
574
575 #[test]
576 fn test_argsort_1d() {
577 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
578 let idx = argsort(&a, None).unwrap();
579 assert_eq!(idx.shape(), &[4]);
580 let data: Vec<u64> = idx.iter().copied().collect();
581 assert_eq!(data, vec![1, 3, 0, 2]);
582 }
583
584 #[test]
585 fn test_argsort_2d_axis1() {
586 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
587 .unwrap();
588 let idx = argsort(&a, Some(1)).unwrap();
589 assert_eq!(idx.shape(), &[2, 3]);
590 let data: Vec<u64> = idx.iter().copied().collect();
591 assert_eq!(data, vec![1, 2, 0, 1, 2, 0]);
592 }
593
594 #[test]
595 fn test_searchsorted_left() {
596 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
597 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
598 let idx = searchsorted(&a, &v, Side::Left).unwrap();
599 let data: Vec<u64> = idx.iter().copied().collect();
600 assert_eq!(data, vec![2, 0, 5]);
601 }
602
603 #[test]
604 fn test_searchsorted_right() {
605 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
606 let v = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![2.0, 4.0]).unwrap();
607 let idx = searchsorted(&a, &v, Side::Right).unwrap();
608 let data: Vec<u64> = idx.iter().copied().collect();
609 assert_eq!(data, vec![2, 4]);
610 }
611
612 #[test]
615 fn test_searchsorted_with_sorter_matches_pre_sorted() {
616 let unsorted =
619 Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 5.0, 2.0, 4.0]).unwrap();
620 let sorter = Array::<u64, Ix1>::from_vec(Ix1::new([5]), vec![1, 3, 0, 4, 2]).unwrap();
622 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
623
624 let idx = searchsorted_with_sorter(&unsorted, &v, Side::Left, &sorter).unwrap();
625 assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![2, 0, 5]);
626 }
627
628 #[test]
629 fn test_searchsorted_with_sorter_length_mismatch_errors() {
630 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 5.0, 2.0]).unwrap();
631 let bad_sorter = Array::<u64, Ix1>::from_vec(Ix1::new([3]), vec![1, 3, 0]).unwrap();
632 let v = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![2.5]).unwrap();
633 assert!(searchsorted_with_sorter(&a, &v, Side::Left, &bad_sorter).is_err());
634 }
635
636 #[test]
637 fn test_searchsorted_with_sorter_out_of_range_errors() {
638 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![3.0, 1.0, 5.0]).unwrap();
639 let bad_sorter = Array::<u64, Ix1>::from_vec(Ix1::new([3]), vec![1, 99, 0]).unwrap();
640 let v = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![2.5]).unwrap();
641 assert!(searchsorted_with_sorter(&a, &v, Side::Left, &bad_sorter).is_err());
642 }
643
644 #[test]
647 fn test_lexsort_single_key_matches_argsort() {
648 let k = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
649 let idx = lexsort(&[&k]).unwrap();
650 assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![1, 3, 0, 2, 4]);
652 }
653
654 #[test]
655 fn test_lexsort_secondary_key_breaks_ties() {
656 let secondary = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![20, 10, 40, 30]).unwrap();
659 let primary = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 1, 2]).unwrap();
660 let idx = lexsort(&[&secondary, &primary]).unwrap();
661 assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![0, 2, 1, 3]);
666 }
667
668 #[test]
669 fn test_lexsort_length_mismatch_errors() {
670 let k1 = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
671 let k2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
672 assert!(lexsort(&[&k1, &k2]).is_err());
673 }
674
675 #[test]
676 fn test_lexsort_empty_keys_errors() {
677 let keys: &[&Array<i32, Ix1>] = &[];
678 assert!(lexsort(keys).is_err());
679 }
680
681 #[test]
682 fn test_sort_axis_out_of_bounds() {
683 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
684 assert!(sort(&a, Some(1), SortKind::Quick).is_err());
685 }
686
687 #[test]
690 fn test_sort_complex_basic() {
691 use num_complex::Complex64;
692 let a = Array::<Complex64, Ix1>::from_vec(
693 Ix1::new([4]),
694 vec![
695 Complex64::new(2.0, 1.0),
696 Complex64::new(1.0, 5.0),
697 Complex64::new(2.0, -3.0),
698 Complex64::new(1.0, 2.0),
699 ],
700 )
701 .unwrap();
702 let r = sort_complex(&a).unwrap();
703 let v: Vec<Complex64> = r.iter().copied().collect();
704 assert_eq!(v[0], Complex64::new(1.0, 2.0));
707 assert_eq!(v[1], Complex64::new(1.0, 5.0));
708 assert_eq!(v[2], Complex64::new(2.0, -3.0));
709 assert_eq!(v[3], Complex64::new(2.0, 1.0));
710 }
711}