1use std::cmp::Ordering;
10
11use crate::dimension::{Dimension, Ix1};
12use crate::dtype::Element;
13
14use super::owned::Array;
15
16fn nan_last_cmp<T: PartialOrd>(a: &T, b: &T) -> Ordering {
22 if let Some(ord) = a.partial_cmp(b) {
23 ord
24 } else {
25 let a_nan = a.partial_cmp(a).is_none();
26 let b_nan = b.partial_cmp(b).is_none();
27 match (a_nan, b_nan) {
28 (true, true) => Ordering::Equal,
29 (true, false) => Ordering::Greater,
30 (false, true) => Ordering::Less,
31 (false, false) => Ordering::Equal,
35 }
36 }
37}
38
39impl<T, D> Array<T, D>
40where
41 T: Element + PartialOrd,
42 D: Dimension,
43{
44 pub fn sorted(&self) -> Array<T, Ix1> {
52 let mut data: Vec<T> = self.iter().cloned().collect();
53 data.sort_by(nan_last_cmp);
54 let n = data.len();
55 Array::<T, Ix1>::from_vec(Ix1::new([n]), data)
56 .expect("from_vec with exact length cannot fail")
57 }
58
59 pub fn argsort(&self) -> Array<u64, Ix1> {
68 let data: Vec<T> = self.iter().cloned().collect();
69 let mut idx: Vec<u64> = (0..data.len() as u64).collect();
70 idx.sort_by(|&i, &j| nan_last_cmp(&data[i as usize], &data[j as usize]));
71 let n = idx.len();
72 Array::<u64, Ix1>::from_vec(Ix1::new([n]), idx)
73 .expect("from_vec with exact length cannot fail")
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use crate::dimension::{Ix1, Ix2};
81
82 #[test]
83 fn sorted_ascending_1d() {
84 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
85 let s = a.sorted();
86 assert_eq!(s.shape(), &[5]);
87 assert_eq!(s.as_slice().unwrap(), &[1, 1, 3, 4, 5]);
88 assert_eq!(a.as_slice().unwrap(), &[3, 1, 4, 1, 5]);
90 }
91
92 #[test]
93 fn sorted_flattens_2d() {
94 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![5, 2, 8, 1, 9, 4]).unwrap();
96 let s = a.sorted();
97 assert_eq!(s.shape(), &[6]);
98 assert_eq!(s.as_slice().unwrap(), &[1, 2, 4, 5, 8, 9]);
99 }
100
101 #[test]
102 fn sorted_f64_nans_go_last() {
103 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, f64::NAN, 1.0, f64::NAN, 2.0])
104 .unwrap();
105 let s = a.sorted();
106 let data = s.as_slice().unwrap();
107 assert_eq!(data[0], 1.0);
108 assert_eq!(data[1], 2.0);
109 assert_eq!(data[2], 3.0);
110 assert!(data[3].is_nan());
111 assert!(data[4].is_nan());
112 }
113
114 #[test]
115 fn sorted_empty() {
116 let a = Array::<i32, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
117 let s = a.sorted();
118 assert_eq!(s.shape(), &[0]);
119 assert_eq!(s.size(), 0);
120 }
121
122 #[test]
123 fn argsort_matches_sorted_1d() {
124 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
125 let idx = a.argsort();
126 assert_eq!(idx.shape(), &[5]);
127 let data = a.as_slice().unwrap();
129 let picked: Vec<i32> = idx
130 .as_slice()
131 .unwrap()
132 .iter()
133 .map(|&i| data[i as usize])
134 .collect();
135 assert_eq!(picked, vec![1, 1, 3, 4, 5]);
136 }
137
138 #[test]
139 fn argsort_f64_sends_nans_last() {
140 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![2.0, f64::NAN, 0.5, 5.0]).unwrap();
141 let idx = a.argsort();
142 let data = a.as_slice().unwrap();
143 let picked: Vec<f64> = idx
144 .as_slice()
145 .unwrap()
146 .iter()
147 .map(|&i| data[i as usize])
148 .collect();
149 assert_eq!(picked[0], 0.5);
150 assert_eq!(picked[1], 2.0);
151 assert_eq!(picked[2], 5.0);
152 assert!(picked[3].is_nan());
153 }
154
155 #[test]
156 fn argsort_flattens_2d() {
157 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![5, 2, 8, 1, 9, 4]).unwrap();
160 let idx = a.argsort();
161 assert_eq!(idx.shape(), &[6]);
162 assert_eq!(idx.as_slice().unwrap(), &[3, 1, 5, 0, 2, 4]);
163 }
164}