1use ferray_core::error::{FerrayError, FerrayResult};
4use ferray_core::{Array, Dimension, Element, Ix1};
5
6use crate::parallel;
7use crate::reductions::{compute_strides, flat_index, increment_multi_index};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SortKind {
16 Quick,
18 Stable,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Side {
29 Left,
31 Right,
33}
34
35pub fn sort<T, D>(a: &Array<T, D>, axis: Option<usize>, kind: SortKind) -> FerrayResult<Array<T, D>>
45where
46 T: Element + PartialOrd + Copy + Send + Sync,
47 D: Dimension,
48{
49 match axis {
50 None => {
51 let mut data: Vec<T> = a.iter().copied().collect();
53 sort_slice(&mut data, kind);
54 Array::from_vec(a.dim().clone(), data)
55 }
56 Some(ax) => {
57 if ax >= a.ndim() {
58 return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
59 }
60 let shape = a.shape().to_vec();
61 let data: Vec<T> = a.iter().copied().collect();
62 let mut result = data.clone();
63 let strides = compute_strides(&shape);
64
65 let axis_len = shape[ax];
66 let out_shape: Vec<usize> = shape
67 .iter()
68 .enumerate()
69 .filter(|&(i, _)| i != ax)
70 .map(|(_, &s)| s)
71 .collect();
72 let out_size: usize = if out_shape.is_empty() {
73 1
74 } else {
75 out_shape.iter().product()
76 };
77
78 let mut out_multi = vec![0usize; out_shape.len()];
79 let ndim = shape.len();
80
81 for _ in 0..out_size {
82 let mut in_multi = Vec::with_capacity(ndim);
84 let mut out_dim = 0;
85 for d in 0..ndim {
86 if d == ax {
87 in_multi.push(0);
88 } else {
89 in_multi.push(out_multi[out_dim]);
90 out_dim += 1;
91 }
92 }
93
94 let mut lane: Vec<T> = Vec::with_capacity(axis_len);
96 let mut lane_indices: Vec<usize> = Vec::with_capacity(axis_len);
97 for k in 0..axis_len {
98 in_multi[ax] = k;
99 let idx = flat_index(&in_multi, &strides);
100 lane.push(data[idx]);
101 lane_indices.push(idx);
102 }
103
104 sort_slice(&mut lane, kind);
105
106 for (k, &idx) in lane_indices.iter().enumerate() {
108 result[idx] = lane[k];
109 }
110
111 if !out_shape.is_empty() {
112 increment_multi_index(&mut out_multi, &out_shape);
113 }
114 }
115
116 Array::from_vec(a.dim().clone(), result)
117 }
118 }
119}
120
121fn sort_slice<T: PartialOrd + Copy + Send + Sync>(data: &mut [T], kind: SortKind) {
123 match kind {
124 SortKind::Quick => {
125 parallel::parallel_sort(data);
126 }
127 SortKind::Stable => {
128 parallel::parallel_sort_stable(data);
129 }
130 }
131}
132
133pub fn argsort<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, D>>
143where
144 T: Element + PartialOrd + Copy,
145 D: Dimension,
146{
147 match axis {
148 None => {
149 let data: Vec<T> = a.iter().copied().collect();
150 let mut indices: Vec<usize> = (0..data.len()).collect();
151 indices.sort_by(|&i, &j| {
152 data[i]
153 .partial_cmp(&data[j])
154 .unwrap_or(std::cmp::Ordering::Equal)
155 });
156 let result: Vec<u64> = indices.into_iter().map(|i| i as u64).collect();
157 Array::from_vec(a.dim().clone(), result)
158 }
159 Some(ax) => {
160 if ax >= a.ndim() {
161 return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
162 }
163 let shape = a.shape().to_vec();
164 let data: Vec<T> = a.iter().copied().collect();
165 let strides = compute_strides(&shape);
166 let ndim = shape.len();
167 let axis_len = shape[ax];
168
169 let out_shape: Vec<usize> = shape
170 .iter()
171 .enumerate()
172 .filter(|&(i, _)| i != ax)
173 .map(|(_, &s)| s)
174 .collect();
175 let out_size: usize = if out_shape.is_empty() {
176 1
177 } else {
178 out_shape.iter().product()
179 };
180
181 let mut result = vec![0u64; data.len()];
182 let mut out_multi = vec![0usize; out_shape.len()];
183
184 for _ in 0..out_size {
185 let mut in_multi = Vec::with_capacity(ndim);
186 let mut out_dim = 0;
187 for d in 0..ndim {
188 if d == ax {
189 in_multi.push(0);
190 } else {
191 in_multi.push(out_multi[out_dim]);
192 out_dim += 1;
193 }
194 }
195
196 let mut lane: Vec<(usize, T)> = Vec::with_capacity(axis_len);
198 let mut lane_flat_indices: Vec<usize> = Vec::with_capacity(axis_len);
199 for k in 0..axis_len {
200 in_multi[ax] = k;
201 let idx = flat_index(&in_multi, &strides);
202 lane.push((k, data[idx]));
203 lane_flat_indices.push(idx);
204 }
205
206 lane.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
208
209 for (k, &flat_idx) in lane_flat_indices.iter().enumerate() {
211 result[flat_idx] = lane[k].0 as u64;
212 }
213
214 if !out_shape.is_empty() {
215 increment_multi_index(&mut out_multi, &out_shape);
216 }
217 }
218
219 Array::from_vec(a.dim().clone(), result)
220 }
221 }
222}
223
224pub fn searchsorted<T>(
235 a: &Array<T, Ix1>,
236 v: &Array<T, Ix1>,
237 side: Side,
238) -> FerrayResult<Array<u64, Ix1>>
239where
240 T: Element + PartialOrd + Copy,
241{
242 let sorted: Vec<T> = a.iter().copied().collect();
243 let values: Vec<T> = v.iter().copied().collect();
244
245 let mut result = Vec::with_capacity(values.len());
246 for &val in &values {
247 let idx = match side {
248 Side::Left => sorted.partition_point(|x| {
249 x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less) == std::cmp::Ordering::Less
250 }),
251 Side::Right => sorted.partition_point(|x| {
252 x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less)
253 != std::cmp::Ordering::Greater
254 }),
255 };
256 result.push(idx as u64);
257 }
258
259 let n = result.len();
260 Array::from_vec(Ix1::new([n]), result)
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use ferray_core::Ix2;
267
268 #[test]
269 fn test_sort_1d() {
270 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 4.0, 1.0, 5.0]).unwrap();
271 let s = sort(&a, None, SortKind::Quick).unwrap();
272 assert_eq!(s.as_slice().unwrap(), &[1.0, 1.0, 3.0, 4.0, 5.0]);
273 }
274
275 #[test]
276 fn test_sort_stable_preserves_order() {
277 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
278 let s = sort(&a, None, SortKind::Stable).unwrap();
279 assert_eq!(s.as_slice().unwrap(), &[1, 1, 3, 4, 5]);
280 }
281
282 #[test]
283 fn test_sort_2d_axis1() {
284 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
285 .unwrap();
286 let s = sort(&a, Some(1), SortKind::Quick).unwrap();
287 let data: Vec<f64> = s.iter().copied().collect();
288 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
289 }
290
291 #[test]
292 fn test_sort_2d_axis0() {
293 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0])
294 .unwrap();
295 let s = sort(&a, Some(0), SortKind::Quick).unwrap();
296 let data: Vec<f64> = s.iter().copied().collect();
297 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
298 }
299
300 #[test]
301 fn test_argsort_1d() {
302 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
303 let idx = argsort(&a, None).unwrap();
304 let data: Vec<u64> = idx.iter().copied().collect();
305 assert_eq!(data, vec![1, 3, 0, 2]);
306 }
307
308 #[test]
309 fn test_argsort_2d_axis1() {
310 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
311 .unwrap();
312 let idx = argsort(&a, Some(1)).unwrap();
313 let data: Vec<u64> = idx.iter().copied().collect();
314 assert_eq!(data, vec![1, 2, 0, 1, 2, 0]);
315 }
316
317 #[test]
318 fn test_searchsorted_left() {
319 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
320 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
321 let idx = searchsorted(&a, &v, Side::Left).unwrap();
322 let data: Vec<u64> = idx.iter().copied().collect();
323 assert_eq!(data, vec![2, 0, 5]);
324 }
325
326 #[test]
327 fn test_searchsorted_right() {
328 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
329 let v = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![2.0, 4.0]).unwrap();
330 let idx = searchsorted(&a, &v, Side::Right).unwrap();
331 let data: Vec<u64> = idx.iter().copied().collect();
332 assert_eq!(data, vec![2, 4]);
333 }
334
335 #[test]
336 fn test_sort_axis_out_of_bounds() {
337 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
338 assert!(sort(&a, Some(1), SortKind::Quick).is_err());
339 }
340}