1use ferray_core::error::{FerrayError, FerrayResult};
4use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
5
6use crate::parallel;
7use crate::reductions::{compute_strides, flat_index, increment_multi_index};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SortKind {
16 Quick,
18 Stable,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Side {
29 Left,
31 Right,
33}
34
35pub fn sort<T, D>(
52 a: &Array<T, D>,
53 axis: Option<usize>,
54 kind: SortKind,
55) -> FerrayResult<Array<T, IxDyn>>
56where
57 T: Element + PartialOrd + Copy + Send + Sync,
58 D: Dimension,
59{
60 match axis {
61 None => {
62 let mut data: Vec<T> = a.iter().copied().collect();
64 let n = data.len();
65 sort_slice(&mut data, kind);
66 Array::from_vec(IxDyn::new(&[n]), data)
67 }
68 Some(ax) => {
69 if ax >= a.ndim() {
70 return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
71 }
72 let shape = a.shape().to_vec();
73 let ndim = shape.len();
74 let mut buf: Vec<T> = a.iter().copied().collect();
78 let axis_len = shape[ax];
79
80 if ax == ndim - 1 {
84 for chunk in buf.chunks_exact_mut(axis_len) {
85 sort_slice(chunk, kind);
86 }
87 return Array::from_vec(IxDyn::new(&shape), buf);
88 }
89
90 let strides = compute_strides(&shape);
93 let out_shape: Vec<usize> = shape
94 .iter()
95 .enumerate()
96 .filter(|&(i, _)| i != ax)
97 .map(|(_, &s)| s)
98 .collect();
99 let out_size: usize = if out_shape.is_empty() {
100 1
101 } else {
102 out_shape.iter().product()
103 };
104
105 let mut out_multi = vec![0usize; out_shape.len()];
106 let mut in_multi = vec![0usize; ndim];
109 let mut lane: Vec<T> = Vec::with_capacity(axis_len);
110 let mut lane_indices: Vec<usize> = Vec::with_capacity(axis_len);
111
112 for _ in 0..out_size {
113 let mut out_dim = 0;
115 for (d, slot) in in_multi.iter_mut().enumerate() {
116 if d == ax {
117 *slot = 0;
118 } else {
119 *slot = out_multi[out_dim];
120 out_dim += 1;
121 }
122 }
123
124 lane.clear();
125 lane_indices.clear();
126 for k in 0..axis_len {
127 in_multi[ax] = k;
128 let idx = flat_index(&in_multi, &strides);
129 lane.push(buf[idx]);
130 lane_indices.push(idx);
131 }
132
133 sort_slice(&mut lane, kind);
134
135 for (k, &idx) in lane_indices.iter().enumerate() {
137 buf[idx] = lane[k];
138 }
139
140 if !out_shape.is_empty() {
141 increment_multi_index(&mut out_multi, &out_shape);
142 }
143 }
144
145 Array::from_vec(IxDyn::new(&shape), buf)
146 }
147 }
148}
149
150fn sort_slice<T: PartialOrd + Copy + Send + Sync>(data: &mut [T], kind: SortKind) {
152 match kind {
153 SortKind::Quick => {
154 parallel::parallel_sort(data);
155 }
156 SortKind::Stable => {
157 parallel::parallel_sort_stable(data);
158 }
159 }
160}
161
162pub fn argsort<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
179where
180 T: Element + PartialOrd + Copy,
181 D: Dimension,
182{
183 match axis {
184 None => {
185 let data: Vec<T> = a.iter().copied().collect();
186 let n = data.len();
187 let mut indices: Vec<usize> = (0..n).collect();
188 indices.sort_by(|&i, &j| {
189 data[i]
190 .partial_cmp(&data[j])
191 .unwrap_or(std::cmp::Ordering::Equal)
192 });
193 let result: Vec<u64> = indices.into_iter().map(|i| i as u64).collect();
194 Array::from_vec(IxDyn::new(&[n]), result)
195 }
196 Some(ax) => {
197 if ax >= a.ndim() {
198 return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
199 }
200 let shape = a.shape().to_vec();
201 let data: Vec<T> = a.iter().copied().collect();
202 let strides = compute_strides(&shape);
203 let ndim = shape.len();
204 let axis_len = shape[ax];
205
206 let out_shape: Vec<usize> = shape
207 .iter()
208 .enumerate()
209 .filter(|&(i, _)| i != ax)
210 .map(|(_, &s)| s)
211 .collect();
212 let out_size: usize = if out_shape.is_empty() {
213 1
214 } else {
215 out_shape.iter().product()
216 };
217
218 let mut result = vec![0u64; data.len()];
219 let mut out_multi = vec![0usize; out_shape.len()];
220
221 for _ in 0..out_size {
222 let mut in_multi = Vec::with_capacity(ndim);
223 let mut out_dim = 0;
224 for d in 0..ndim {
225 if d == ax {
226 in_multi.push(0);
227 } else {
228 in_multi.push(out_multi[out_dim]);
229 out_dim += 1;
230 }
231 }
232
233 let mut lane: Vec<(usize, T)> = Vec::with_capacity(axis_len);
235 let mut lane_flat_indices: Vec<usize> = Vec::with_capacity(axis_len);
236 for k in 0..axis_len {
237 in_multi[ax] = k;
238 let idx = flat_index(&in_multi, &strides);
239 lane.push((k, data[idx]));
240 lane_flat_indices.push(idx);
241 }
242
243 lane.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
245
246 for (k, &flat_idx) in lane_flat_indices.iter().enumerate() {
248 result[flat_idx] = lane[k].0 as u64;
249 }
250
251 if !out_shape.is_empty() {
252 increment_multi_index(&mut out_multi, &out_shape);
253 }
254 }
255
256 Array::from_vec(IxDyn::new(&shape), result)
257 }
258 }
259}
260
261pub fn partition<T>(a: &Array<T, Ix1>, kth: usize) -> FerrayResult<Array<T, Ix1>>
276where
277 T: Element + PartialOrd + Copy,
278{
279 let n = a.size();
280 if kth >= n {
281 return Err(FerrayError::invalid_value(format!(
282 "partition: kth={kth} out of range for array of size {n}"
283 )));
284 }
285 let mut data: Vec<T> = a.iter().copied().collect();
286 data.select_nth_unstable_by(kth, |x, y| {
287 x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal)
288 });
289 Array::from_vec(Ix1::new([n]), data)
290}
291
292pub fn argpartition<T>(a: &Array<T, Ix1>, kth: usize) -> FerrayResult<Array<u64, Ix1>>
299where
300 T: Element + PartialOrd + Copy,
301{
302 let n = a.size();
303 if kth >= n {
304 return Err(FerrayError::invalid_value(format!(
305 "argpartition: kth={kth} out of range for array of size {n}"
306 )));
307 }
308 let data: Vec<T> = a.iter().copied().collect();
309 let mut idx: Vec<u64> = (0..n as u64).collect();
310 idx.select_nth_unstable_by(kth, |&a_i, &b_i| {
311 let va = data[a_i as usize];
312 let vb = data[b_i as usize];
313 va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
314 });
315 Array::from_vec(Ix1::new([n]), idx)
316}
317
318pub fn lexsort<T>(keys: &[&Array<T, Ix1>]) -> FerrayResult<Array<u64, Ix1>>
339where
340 T: Element + PartialOrd + Copy,
341{
342 if keys.is_empty() {
343 return Err(FerrayError::invalid_value(
344 "lexsort: keys must contain at least one array",
345 ));
346 }
347 let n = keys[0].size();
348 for (i, k) in keys.iter().enumerate().skip(1) {
349 if k.size() != n {
350 return Err(FerrayError::invalid_value(format!(
351 "lexsort: key {i} has length {}, expected {n}",
352 k.size()
353 )));
354 }
355 }
356
357 let key_data: Vec<Vec<T>> = keys.iter().map(|k| k.iter().copied().collect()).collect();
361
362 let mut idx: Vec<u64> = (0..n as u64).collect();
363 idx.sort_by(|&a, &b| {
364 let ai = a as usize;
365 let bi = b as usize;
366 for k in key_data.iter().rev() {
368 match k[ai]
369 .partial_cmp(&k[bi])
370 .unwrap_or(std::cmp::Ordering::Equal)
371 {
372 std::cmp::Ordering::Equal => {}
373 ord => return ord,
374 }
375 }
376 std::cmp::Ordering::Equal
377 });
378
379 Array::from_vec(Ix1::new([n]), idx)
380}
381
382pub fn searchsorted<T>(
394 a: &Array<T, Ix1>,
395 v: &Array<T, Ix1>,
396 side: Side,
397) -> FerrayResult<Array<u64, Ix1>>
398where
399 T: Element + PartialOrd + Copy,
400{
401 let sorted: Vec<T> = a.iter().copied().collect();
402 searchsorted_inner(&sorted, v, side)
403}
404
405pub fn searchsorted_with_sorter<T>(
418 a: &Array<T, Ix1>,
419 v: &Array<T, Ix1>,
420 side: Side,
421 sorter: &Array<u64, Ix1>,
422) -> FerrayResult<Array<u64, Ix1>>
423where
424 T: Element + PartialOrd + Copy,
425{
426 let n = a.size();
427 if sorter.size() != n {
428 return Err(FerrayError::shape_mismatch(format!(
429 "searchsorted: sorter length {} does not match array length {}",
430 sorter.size(),
431 n
432 )));
433 }
434
435 let a_data: Vec<T> = a.iter().copied().collect();
437 let mut sorted: Vec<T> = Vec::with_capacity(n);
438 for &idx in sorter.iter() {
439 let i = idx as usize;
440 if i >= n {
441 return Err(FerrayError::invalid_value(format!(
442 "searchsorted: sorter index {i} out of range for array of length {n}"
443 )));
444 }
445 sorted.push(a_data[i]);
446 }
447
448 searchsorted_inner(&sorted, v, side)
449}
450
451fn searchsorted_inner<T>(
454 sorted: &[T],
455 v: &Array<T, Ix1>,
456 side: Side,
457) -> FerrayResult<Array<u64, Ix1>>
458where
459 T: Element + PartialOrd + Copy,
460{
461 let mut result = Vec::with_capacity(v.size());
462 for &val in v.iter() {
463 let idx = match side {
464 Side::Left => sorted.partition_point(|x| {
465 x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less) == std::cmp::Ordering::Less
466 }),
467 Side::Right => sorted.partition_point(|x| {
468 x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less)
469 != std::cmp::Ordering::Greater
470 }),
471 };
472 result.push(idx as u64);
473 }
474 let n = result.len();
475 Array::from_vec(Ix1::new([n]), result)
476}
477
478pub fn sort_complex<T>(
487 a: &Array<num_complex::Complex<T>, Ix1>,
488) -> FerrayResult<Array<num_complex::Complex<T>, Ix1>>
489where
490 T: Element + num_traits::Float,
491 num_complex::Complex<T>: Element,
492{
493 let mut data: Vec<num_complex::Complex<T>> = a.iter().copied().collect();
494 data.sort_by(|x, y| {
495 let r = x.re.partial_cmp(&y.re).unwrap_or(std::cmp::Ordering::Equal);
496 if r != std::cmp::Ordering::Equal {
497 r
498 } else {
499 x.im.partial_cmp(&y.im).unwrap_or(std::cmp::Ordering::Equal)
500 }
501 });
502 let n = data.len();
503 Array::from_vec(Ix1::new([n]), data)
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use ferray_core::Ix2;
510
511 #[test]
512 fn test_sort_1d() {
513 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 4.0, 1.0, 5.0]).unwrap();
514 let s = sort(&a, None, SortKind::Quick).unwrap();
515 assert_eq!(s.shape(), &[5]);
516 let data: Vec<f64> = s.iter().copied().collect();
517 assert_eq!(data, vec![1.0, 1.0, 3.0, 4.0, 5.0]);
518 }
519
520 #[test]
521 fn test_sort_stable_preserves_order() {
522 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
523 let s = sort(&a, None, SortKind::Stable).unwrap();
524 assert_eq!(s.shape(), &[5]);
525 let data: Vec<i32> = s.iter().copied().collect();
526 assert_eq!(data, vec![1, 1, 3, 4, 5]);
527 }
528
529 #[test]
530 fn test_sort_2d_axis_none_returns_flat() {
531 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![6.0, 4.0, 5.0, 3.0, 1.0, 2.0])
533 .unwrap();
534 let s = sort(&a, None, SortKind::Quick).unwrap();
535 assert_eq!(s.shape(), &[6]);
537 let data: Vec<f64> = s.iter().copied().collect();
538 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
539 }
540
541 #[test]
542 fn test_sort_2d_axis1() {
543 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
544 .unwrap();
545 let s = sort(&a, Some(1), SortKind::Quick).unwrap();
546 assert_eq!(s.shape(), &[2, 3]);
547 let data: Vec<f64> = s.iter().copied().collect();
548 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
549 }
550
551 #[test]
552 fn test_sort_2d_axis0() {
553 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0])
554 .unwrap();
555 let s = sort(&a, Some(0), SortKind::Quick).unwrap();
556 assert_eq!(s.shape(), &[2, 3]);
557 let data: Vec<f64> = s.iter().copied().collect();
558 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
559 }
560
561 #[test]
562 fn test_argsort_1d() {
563 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
564 let idx = argsort(&a, None).unwrap();
565 assert_eq!(idx.shape(), &[4]);
566 let data: Vec<u64> = idx.iter().copied().collect();
567 assert_eq!(data, vec![1, 3, 0, 2]);
568 }
569
570 #[test]
571 fn test_argsort_2d_axis1() {
572 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
573 .unwrap();
574 let idx = argsort(&a, Some(1)).unwrap();
575 assert_eq!(idx.shape(), &[2, 3]);
576 let data: Vec<u64> = idx.iter().copied().collect();
577 assert_eq!(data, vec![1, 2, 0, 1, 2, 0]);
578 }
579
580 #[test]
581 fn test_searchsorted_left() {
582 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
583 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
584 let idx = searchsorted(&a, &v, Side::Left).unwrap();
585 let data: Vec<u64> = idx.iter().copied().collect();
586 assert_eq!(data, vec![2, 0, 5]);
587 }
588
589 #[test]
590 fn test_searchsorted_right() {
591 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
592 let v = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![2.0, 4.0]).unwrap();
593 let idx = searchsorted(&a, &v, Side::Right).unwrap();
594 let data: Vec<u64> = idx.iter().copied().collect();
595 assert_eq!(data, vec![2, 4]);
596 }
597
598 #[test]
601 fn test_searchsorted_with_sorter_matches_pre_sorted() {
602 let unsorted =
605 Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 5.0, 2.0, 4.0]).unwrap();
606 let sorter = Array::<u64, Ix1>::from_vec(Ix1::new([5]), vec![1, 3, 0, 4, 2]).unwrap();
608 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
609
610 let idx = searchsorted_with_sorter(&unsorted, &v, Side::Left, &sorter).unwrap();
611 assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![2, 0, 5]);
612 }
613
614 #[test]
615 fn test_searchsorted_with_sorter_length_mismatch_errors() {
616 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 5.0, 2.0]).unwrap();
617 let bad_sorter = Array::<u64, Ix1>::from_vec(Ix1::new([3]), vec![1, 3, 0]).unwrap();
618 let v = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![2.5]).unwrap();
619 assert!(searchsorted_with_sorter(&a, &v, Side::Left, &bad_sorter).is_err());
620 }
621
622 #[test]
623 fn test_searchsorted_with_sorter_out_of_range_errors() {
624 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![3.0, 1.0, 5.0]).unwrap();
625 let bad_sorter = Array::<u64, Ix1>::from_vec(Ix1::new([3]), vec![1, 99, 0]).unwrap();
626 let v = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![2.5]).unwrap();
627 assert!(searchsorted_with_sorter(&a, &v, Side::Left, &bad_sorter).is_err());
628 }
629
630 #[test]
633 fn test_lexsort_single_key_matches_argsort() {
634 let k = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
635 let idx = lexsort(&[&k]).unwrap();
636 assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![1, 3, 0, 2, 4]);
638 }
639
640 #[test]
641 fn test_lexsort_secondary_key_breaks_ties() {
642 let secondary = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![20, 10, 40, 30]).unwrap();
645 let primary = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 1, 2]).unwrap();
646 let idx = lexsort(&[&secondary, &primary]).unwrap();
647 assert_eq!(idx.iter().copied().collect::<Vec<_>>(), vec![0, 2, 1, 3]);
652 }
653
654 #[test]
655 fn test_lexsort_length_mismatch_errors() {
656 let k1 = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
657 let k2 = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
658 assert!(lexsort(&[&k1, &k2]).is_err());
659 }
660
661 #[test]
662 fn test_lexsort_empty_keys_errors() {
663 let keys: &[&Array<i32, Ix1>] = &[];
664 assert!(lexsort(keys).is_err());
665 }
666
667 #[test]
668 fn test_sort_axis_out_of_bounds() {
669 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
670 assert!(sort(&a, Some(1), SortKind::Quick).is_err());
671 }
672
673 #[test]
676 fn test_sort_complex_basic() {
677 use num_complex::Complex64;
678 let a = Array::<Complex64, Ix1>::from_vec(
679 Ix1::new([4]),
680 vec![
681 Complex64::new(2.0, 1.0),
682 Complex64::new(1.0, 5.0),
683 Complex64::new(2.0, -3.0),
684 Complex64::new(1.0, 2.0),
685 ],
686 )
687 .unwrap();
688 let r = sort_complex(&a).unwrap();
689 let v: Vec<Complex64> = r.iter().copied().collect();
690 assert_eq!(v[0], Complex64::new(1.0, 2.0));
693 assert_eq!(v[1], Complex64::new(1.0, 5.0));
694 assert_eq!(v[2], Complex64::new(2.0, -3.0));
695 assert_eq!(v[3], Complex64::new(2.0, 1.0));
696 }
697}