1use std::cmp::Ordering;
10
11use crate::dimension::{Dimension, Ix1};
12use crate::dtype::Element;
13
14use super::owned::Array;
15
16fn nan_last_cmp<T: PartialOrd>(a: &T, b: &T) -> Ordering {
22 match a.partial_cmp(b) {
23 Some(ord) => ord,
24 None => {
25 let a_nan = a.partial_cmp(a).is_none();
26 let b_nan = b.partial_cmp(b).is_none();
27 match (a_nan, b_nan) {
28 (true, true) => Ordering::Equal,
29 (true, false) => Ordering::Greater,
30 (false, true) => Ordering::Less,
31 (false, false) => Ordering::Equal,
35 }
36 }
37 }
38}
39
40impl<T, D> Array<T, D>
41where
42 T: Element + PartialOrd,
43 D: Dimension,
44{
45 pub fn sorted(&self) -> Array<T, Ix1> {
53 let mut data: Vec<T> = self.iter().cloned().collect();
54 data.sort_by(nan_last_cmp);
55 let n = data.len();
56 Array::<T, Ix1>::from_vec(Ix1::new([n]), data)
57 .expect("from_vec with exact length cannot fail")
58 }
59
60 pub fn argsort(&self) -> Array<u64, Ix1> {
69 let data: Vec<T> = self.iter().cloned().collect();
70 let mut idx: Vec<u64> = (0..data.len() as u64).collect();
71 idx.sort_by(|&i, &j| nan_last_cmp(&data[i as usize], &data[j as usize]));
72 let n = idx.len();
73 Array::<u64, Ix1>::from_vec(Ix1::new([n]), idx)
74 .expect("from_vec with exact length cannot fail")
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::dimension::{Ix1, Ix2};
82
83 #[test]
84 fn sorted_ascending_1d() {
85 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
86 let s = a.sorted();
87 assert_eq!(s.shape(), &[5]);
88 assert_eq!(s.as_slice().unwrap(), &[1, 1, 3, 4, 5]);
89 assert_eq!(a.as_slice().unwrap(), &[3, 1, 4, 1, 5]);
91 }
92
93 #[test]
94 fn sorted_flattens_2d() {
95 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![5, 2, 8, 1, 9, 4]).unwrap();
97 let s = a.sorted();
98 assert_eq!(s.shape(), &[6]);
99 assert_eq!(s.as_slice().unwrap(), &[1, 2, 4, 5, 8, 9]);
100 }
101
102 #[test]
103 fn sorted_f64_nans_go_last() {
104 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, f64::NAN, 1.0, f64::NAN, 2.0])
105 .unwrap();
106 let s = a.sorted();
107 let data = s.as_slice().unwrap();
108 assert_eq!(data[0], 1.0);
109 assert_eq!(data[1], 2.0);
110 assert_eq!(data[2], 3.0);
111 assert!(data[3].is_nan());
112 assert!(data[4].is_nan());
113 }
114
115 #[test]
116 fn sorted_empty() {
117 let a = Array::<i32, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
118 let s = a.sorted();
119 assert_eq!(s.shape(), &[0]);
120 assert_eq!(s.size(), 0);
121 }
122
123 #[test]
124 fn argsort_matches_sorted_1d() {
125 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
126 let idx = a.argsort();
127 assert_eq!(idx.shape(), &[5]);
128 let data = a.as_slice().unwrap();
130 let picked: Vec<i32> = idx
131 .as_slice()
132 .unwrap()
133 .iter()
134 .map(|&i| data[i as usize])
135 .collect();
136 assert_eq!(picked, vec![1, 1, 3, 4, 5]);
137 }
138
139 #[test]
140 fn argsort_f64_sends_nans_last() {
141 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![2.0, f64::NAN, 0.5, 5.0]).unwrap();
142 let idx = a.argsort();
143 let data = a.as_slice().unwrap();
144 let picked: Vec<f64> = idx
145 .as_slice()
146 .unwrap()
147 .iter()
148 .map(|&i| data[i as usize])
149 .collect();
150 assert_eq!(picked[0], 0.5);
151 assert_eq!(picked[1], 2.0);
152 assert_eq!(picked[2], 5.0);
153 assert!(picked[3].is_nan());
154 }
155
156 #[test]
157 fn argsort_flattens_2d() {
158 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![5, 2, 8, 1, 9, 4]).unwrap();
161 let idx = a.argsort();
162 assert_eq!(idx.shape(), &[6]);
163 assert_eq!(idx.as_slice().unwrap(), &[3, 1, 5, 0, 2, 4]);
164 }
165}