1use ferray_core::Array;
6use ferray_core::dimension::{Dimension, Ix1, IxDyn};
7use ferray_core::dtype::Element;
8use ferray_core::error::{FerrayError, FerrayResult};
9
10use crate::MaskedArray;
11
12impl<T, D> MaskedArray<T, D>
13where
14 T: Element + PartialOrd + Copy,
15 D: Dimension,
16{
17 pub fn sort(&self) -> FerrayResult<MaskedArray<T, Ix1>> {
26 let mut unmasked: Vec<T> = Vec::new();
27 let mut masked_vals: Vec<T> = Vec::new();
28
29 for (v, m) in self.data().iter().zip(self.mask().iter()) {
30 if *m {
31 masked_vals.push(*v);
32 } else {
33 unmasked.push(*v);
34 }
35 }
36
37 unmasked.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
39
40 let unmasked_count = unmasked.len();
41 let total = unmasked_count + masked_vals.len();
42
43 let mut result_data = Vec::with_capacity(total);
45 result_data.extend_from_slice(&unmasked);
46 result_data.extend_from_slice(&masked_vals);
47
48 let mut result_mask = Vec::with_capacity(total);
49 result_mask.extend(std::iter::repeat_n(false, unmasked_count));
50 result_mask.extend(std::iter::repeat_n(true, masked_vals.len()));
51
52 let data_arr = Array::from_vec(Ix1::new([total]), result_data)?;
53 let mask_arr = Array::from_vec(Ix1::new([total]), result_mask)?;
54 MaskedArray::new(data_arr, mask_arr)
55 }
56
57 pub fn argsort(&self) -> FerrayResult<Array<u64, Ix1>> {
69 let vals: Vec<T> = self.data().iter().copied().collect();
70 let masks: Vec<bool> = self.mask().iter().copied().collect();
71
72 let mut unmasked_indices: Vec<usize> = Vec::new();
74 let mut masked_indices: Vec<usize> = Vec::new();
75
76 for (i, m) in masks.iter().enumerate() {
77 if *m {
78 masked_indices.push(i);
79 } else {
80 unmasked_indices.push(i);
81 }
82 }
83
84 unmasked_indices.sort_by(|a, b| {
86 vals[*a]
87 .partial_cmp(&vals[*b])
88 .unwrap_or(std::cmp::Ordering::Equal)
89 });
90
91 let total = unmasked_indices.len() + masked_indices.len();
93 let mut result: Vec<u64> = Vec::with_capacity(total);
94 for &i in &unmasked_indices {
95 result.push(i as u64);
96 }
97 for &i in &masked_indices {
98 result.push(i as u64);
99 }
100
101 Array::from_vec(Ix1::new([total]), result)
102 }
103
104 pub fn sort_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
115 let ndim = self.ndim();
116 if axis >= ndim {
117 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
118 }
119 let shape = self.shape().to_vec();
120 let axis_len = shape[axis];
121 let total: usize = shape.iter().product();
122
123 let src_data: Vec<T> = self.data().iter().copied().collect();
125 let src_mask: Vec<bool> = self.mask().iter().copied().collect();
126 let mut strides = vec![1usize; ndim];
127 for i in (0..ndim.saturating_sub(1)).rev() {
128 strides[i] = strides[i + 1] * shape[i + 1];
129 }
130
131 let mut out_data = vec![src_data[0]; total];
132 let mut out_mask = vec![false; total];
133
134 let outer_shape: Vec<usize> = shape
138 .iter()
139 .enumerate()
140 .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
141 .collect();
142 let outer_size: usize = if outer_shape.is_empty() {
143 1
144 } else {
145 outer_shape.iter().product()
146 };
147
148 let mut outer_multi = vec![0usize; outer_shape.len()];
149 for _ in 0..outer_size {
150 let mut lane: Vec<(T, bool, usize)> = Vec::with_capacity(axis_len);
152 for k in 0..axis_len {
153 let mut flat = 0usize;
154 let mut o = 0usize;
155 for (i, &stride) in strides.iter().enumerate() {
156 if i == axis {
157 flat += stride * k;
158 } else {
159 flat += stride * outer_multi[o];
160 o += 1;
161 }
162 }
163 lane.push((src_data[flat], src_mask[flat], flat));
164 }
165 let mut unmasked: Vec<(T, usize)> = Vec::new();
168 let mut masked: Vec<(T, usize)> = Vec::new();
169 for (v, m, flat) in lane {
170 if m {
171 masked.push((v, flat));
172 } else {
173 unmasked.push((v, flat));
174 }
175 }
176 unmasked.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
177 for (k, (v, _flat)) in unmasked.iter().chain(masked.iter()).enumerate() {
180 let mut flat = 0usize;
181 let mut o = 0usize;
182 for (i, &stride) in strides.iter().enumerate() {
183 if i == axis {
184 flat += stride * k;
185 } else {
186 flat += stride * outer_multi[o];
187 o += 1;
188 }
189 }
190 out_data[flat] = *v;
191 out_mask[flat] = k >= unmasked.len();
192 }
193
194 for i in (0..outer_shape.len()).rev() {
196 outer_multi[i] += 1;
197 if outer_multi[i] < outer_shape[i] {
198 break;
199 }
200 outer_multi[i] = 0;
201 }
202 }
203
204 let data_arr = Array::<T, IxDyn>::from_vec(IxDyn::new(&shape), out_data)?;
205 let mask_arr = Array::<bool, IxDyn>::from_vec(IxDyn::new(&shape), out_mask)?;
206 MaskedArray::new(data_arr, mask_arr)
207 }
208}