1use ferray_core::error::{FerrayError, FerrayResult};
4use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
5
6use crate::parallel;
7use crate::reductions::{compute_strides, flat_index, increment_multi_index};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SortKind {
16 Quick,
18 Stable,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Side {
29 Left,
31 Right,
33}
34
35pub fn sort<T, D>(
47 a: &Array<T, D>,
48 axis: Option<usize>,
49 kind: SortKind,
50) -> FerrayResult<Array<T, IxDyn>>
51where
52 T: Element + PartialOrd + Copy + Send + Sync,
53 D: Dimension,
54{
55 match axis {
56 None => {
57 let mut data: Vec<T> = a.iter().copied().collect();
59 let n = data.len();
60 sort_slice(&mut data, kind);
61 Array::from_vec(IxDyn::new(&[n]), data)
62 }
63 Some(ax) => {
64 if ax >= a.ndim() {
65 return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
66 }
67 let shape = a.shape().to_vec();
68 let data: Vec<T> = a.iter().copied().collect();
69 let mut result = data.clone();
70 let strides = compute_strides(&shape);
71
72 let axis_len = shape[ax];
73 let out_shape: Vec<usize> = shape
74 .iter()
75 .enumerate()
76 .filter(|&(i, _)| i != ax)
77 .map(|(_, &s)| s)
78 .collect();
79 let out_size: usize = if out_shape.is_empty() {
80 1
81 } else {
82 out_shape.iter().product()
83 };
84
85 let mut out_multi = vec![0usize; out_shape.len()];
86 let ndim = shape.len();
87
88 for _ in 0..out_size {
89 let mut in_multi = Vec::with_capacity(ndim);
91 let mut out_dim = 0;
92 for d in 0..ndim {
93 if d == ax {
94 in_multi.push(0);
95 } else {
96 in_multi.push(out_multi[out_dim]);
97 out_dim += 1;
98 }
99 }
100
101 let mut lane: Vec<T> = Vec::with_capacity(axis_len);
103 let mut lane_indices: Vec<usize> = Vec::with_capacity(axis_len);
104 for k in 0..axis_len {
105 in_multi[ax] = k;
106 let idx = flat_index(&in_multi, &strides);
107 lane.push(data[idx]);
108 lane_indices.push(idx);
109 }
110
111 sort_slice(&mut lane, kind);
112
113 for (k, &idx) in lane_indices.iter().enumerate() {
115 result[idx] = lane[k];
116 }
117
118 if !out_shape.is_empty() {
119 increment_multi_index(&mut out_multi, &out_shape);
120 }
121 }
122
123 Array::from_vec(IxDyn::new(&shape), result)
124 }
125 }
126}
127
128fn sort_slice<T: PartialOrd + Copy + Send + Sync>(data: &mut [T], kind: SortKind) {
130 match kind {
131 SortKind::Quick => {
132 parallel::parallel_sort(data);
133 }
134 SortKind::Stable => {
135 parallel::parallel_sort_stable(data);
136 }
137 }
138}
139
140pub fn argsort<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
153where
154 T: Element + PartialOrd + Copy,
155 D: Dimension,
156{
157 match axis {
158 None => {
159 let data: Vec<T> = a.iter().copied().collect();
160 let n = data.len();
161 let mut indices: Vec<usize> = (0..n).collect();
162 indices.sort_by(|&i, &j| {
163 data[i]
164 .partial_cmp(&data[j])
165 .unwrap_or(std::cmp::Ordering::Equal)
166 });
167 let result: Vec<u64> = indices.into_iter().map(|i| i as u64).collect();
168 Array::from_vec(IxDyn::new(&[n]), result)
169 }
170 Some(ax) => {
171 if ax >= a.ndim() {
172 return Err(FerrayError::axis_out_of_bounds(ax, a.ndim()));
173 }
174 let shape = a.shape().to_vec();
175 let data: Vec<T> = a.iter().copied().collect();
176 let strides = compute_strides(&shape);
177 let ndim = shape.len();
178 let axis_len = shape[ax];
179
180 let out_shape: Vec<usize> = shape
181 .iter()
182 .enumerate()
183 .filter(|&(i, _)| i != ax)
184 .map(|(_, &s)| s)
185 .collect();
186 let out_size: usize = if out_shape.is_empty() {
187 1
188 } else {
189 out_shape.iter().product()
190 };
191
192 let mut result = vec![0u64; data.len()];
193 let mut out_multi = vec![0usize; out_shape.len()];
194
195 for _ in 0..out_size {
196 let mut in_multi = Vec::with_capacity(ndim);
197 let mut out_dim = 0;
198 for d in 0..ndim {
199 if d == ax {
200 in_multi.push(0);
201 } else {
202 in_multi.push(out_multi[out_dim]);
203 out_dim += 1;
204 }
205 }
206
207 let mut lane: Vec<(usize, T)> = Vec::with_capacity(axis_len);
209 let mut lane_flat_indices: Vec<usize> = Vec::with_capacity(axis_len);
210 for k in 0..axis_len {
211 in_multi[ax] = k;
212 let idx = flat_index(&in_multi, &strides);
213 lane.push((k, data[idx]));
214 lane_flat_indices.push(idx);
215 }
216
217 lane.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
219
220 for (k, &flat_idx) in lane_flat_indices.iter().enumerate() {
222 result[flat_idx] = lane[k].0 as u64;
223 }
224
225 if !out_shape.is_empty() {
226 increment_multi_index(&mut out_multi, &out_shape);
227 }
228 }
229
230 Array::from_vec(IxDyn::new(&shape), result)
231 }
232 }
233}
234
235pub fn searchsorted<T>(
246 a: &Array<T, Ix1>,
247 v: &Array<T, Ix1>,
248 side: Side,
249) -> FerrayResult<Array<u64, Ix1>>
250where
251 T: Element + PartialOrd + Copy,
252{
253 let sorted: Vec<T> = a.iter().copied().collect();
254 let values: Vec<T> = v.iter().copied().collect();
255
256 let mut result = Vec::with_capacity(values.len());
257 for &val in &values {
258 let idx = match side {
259 Side::Left => sorted.partition_point(|x| {
260 x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less) == std::cmp::Ordering::Less
261 }),
262 Side::Right => sorted.partition_point(|x| {
263 x.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Less)
264 != std::cmp::Ordering::Greater
265 }),
266 };
267 result.push(idx as u64);
268 }
269
270 let n = result.len();
271 Array::from_vec(Ix1::new([n]), result)
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use ferray_core::Ix2;
278
279 #[test]
280 fn test_sort_1d() {
281 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 1.0, 4.0, 1.0, 5.0]).unwrap();
282 let s = sort(&a, None, SortKind::Quick).unwrap();
283 assert_eq!(s.shape(), &[5]);
284 let data: Vec<f64> = s.iter().copied().collect();
285 assert_eq!(data, vec![1.0, 1.0, 3.0, 4.0, 5.0]);
286 }
287
288 #[test]
289 fn test_sort_stable_preserves_order() {
290 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, 1, 4, 1, 5]).unwrap();
291 let s = sort(&a, None, SortKind::Stable).unwrap();
292 assert_eq!(s.shape(), &[5]);
293 let data: Vec<i32> = s.iter().copied().collect();
294 assert_eq!(data, vec![1, 1, 3, 4, 5]);
295 }
296
297 #[test]
298 fn test_sort_2d_axis_none_returns_flat() {
299 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![6.0, 4.0, 5.0, 3.0, 1.0, 2.0])
301 .unwrap();
302 let s = sort(&a, None, SortKind::Quick).unwrap();
303 assert_eq!(s.shape(), &[6]);
305 let data: Vec<f64> = s.iter().copied().collect();
306 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
307 }
308
309 #[test]
310 fn test_sort_2d_axis1() {
311 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
312 .unwrap();
313 let s = sort(&a, Some(1), SortKind::Quick).unwrap();
314 assert_eq!(s.shape(), &[2, 3]);
315 let data: Vec<f64> = s.iter().copied().collect();
316 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
317 }
318
319 #[test]
320 fn test_sort_2d_axis0() {
321 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0])
322 .unwrap();
323 let s = sort(&a, Some(0), SortKind::Quick).unwrap();
324 assert_eq!(s.shape(), &[2, 3]);
325 let data: Vec<f64> = s.iter().copied().collect();
326 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
327 }
328
329 #[test]
330 fn test_argsort_1d() {
331 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, 1.0, 4.0, 2.0]).unwrap();
332 let idx = argsort(&a, None).unwrap();
333 assert_eq!(idx.shape(), &[4]);
334 let data: Vec<u64> = idx.iter().copied().collect();
335 assert_eq!(data, vec![1, 3, 0, 2]);
336 }
337
338 #[test]
339 fn test_argsort_2d_axis1() {
340 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0])
341 .unwrap();
342 let idx = argsort(&a, Some(1)).unwrap();
343 assert_eq!(idx.shape(), &[2, 3]);
344 let data: Vec<u64> = idx.iter().copied().collect();
345 assert_eq!(data, vec![1, 2, 0, 1, 2, 0]);
346 }
347
348 #[test]
349 fn test_searchsorted_left() {
350 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
351 let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.5, 1.0, 5.5]).unwrap();
352 let idx = searchsorted(&a, &v, Side::Left).unwrap();
353 let data: Vec<u64> = idx.iter().copied().collect();
354 assert_eq!(data, vec![2, 0, 5]);
355 }
356
357 #[test]
358 fn test_searchsorted_right() {
359 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
360 let v = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![2.0, 4.0]).unwrap();
361 let idx = searchsorted(&a, &v, Side::Right).unwrap();
362 let data: Vec<u64> = idx.iter().copied().collect();
363 assert_eq!(data, vec![2, 4]);
364 }
365
366 #[test]
367 fn test_sort_axis_out_of_bounds() {
368 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
369 assert!(sort(&a, Some(1), SortKind::Quick).is_err());
370 }
371}