Skip to main content

ferray_ma/
sorting.rs

1// ferray-ma: Masked sort and argsort (REQ-13, REQ-14)
2//
3// Sorting unmasked elements while pushing masked elements to the end.
4
5use ferray_core::Array;
6use ferray_core::dimension::{Dimension, Ix1, IxDyn};
7use ferray_core::dtype::Element;
8use ferray_core::error::{FerrayError, FerrayResult};
9
10use crate::MaskedArray;
11
12impl<T, D> MaskedArray<T, D>
13where
14    T: Element + PartialOrd + Copy,
15    D: Dimension,
16{
17    /// Sort the masked array (flattened), placing masked elements at the end.
18    ///
19    /// Returns a new 1-D `MaskedArray` where:
20    /// - Unmasked elements are sorted in ascending order
21    /// - Masked elements come after all unmasked elements
22    ///
23    /// # Errors
24    /// Returns an error only for internal failures.
25    pub fn sort(&self) -> FerrayResult<MaskedArray<T, Ix1>> {
26        let mut unmasked: Vec<T> = Vec::new();
27        let mut masked_vals: Vec<T> = Vec::new();
28
29        for (v, m) in self.data().iter().zip(self.mask().iter()) {
30            if *m {
31                masked_vals.push(*v);
32            } else {
33                unmasked.push(*v);
34            }
35        }
36
37        // Sort unmasked elements using partial_cmp to handle floats
38        unmasked.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
39
40        let unmasked_count = unmasked.len();
41        let total = unmasked_count + masked_vals.len();
42
43        // Build result: sorted unmasked followed by masked
44        let mut result_data = Vec::with_capacity(total);
45        result_data.extend_from_slice(&unmasked);
46        result_data.extend_from_slice(&masked_vals);
47
48        let mut result_mask = Vec::with_capacity(total);
49        result_mask.extend(std::iter::repeat_n(false, unmasked_count));
50        result_mask.extend(std::iter::repeat_n(true, masked_vals.len()));
51
52        let data_arr = Array::from_vec(Ix1::new([total]), result_data)?;
53        let mask_arr = Array::from_vec(Ix1::new([total]), result_mask)?;
54        MaskedArray::new(data_arr, mask_arr)
55    }
56
57    /// Return the indices that would sort the masked array, with masked
58    /// elements placed at the end.
59    ///
60    /// Returns a 1-D `Array<u64, Ix1>` of indices, matching the index
61    /// dtype used elsewhere in ferray (#269). Previous versions returned
62    /// `Vec<usize>` which was inconsistent with [`MaskedArray::sort`]
63    /// (returns a `MaskedArray`) and forced callers to wrap manually
64    /// for any downstream array op.
65    ///
66    /// # Errors
67    /// Returns an error only for internal failures.
68    pub fn argsort(&self) -> FerrayResult<Array<u64, Ix1>> {
69        let vals: Vec<T> = self.data().iter().copied().collect();
70        let masks: Vec<bool> = self.mask().iter().copied().collect();
71
72        // Separate indices into unmasked and masked
73        let mut unmasked_indices: Vec<usize> = Vec::new();
74        let mut masked_indices: Vec<usize> = Vec::new();
75
76        for (i, m) in masks.iter().enumerate() {
77            if *m {
78                masked_indices.push(i);
79            } else {
80                unmasked_indices.push(i);
81            }
82        }
83
84        // Sort unmasked indices by their data values
85        unmasked_indices.sort_by(|a, b| {
86            vals[*a]
87                .partial_cmp(&vals[*b])
88                .unwrap_or(std::cmp::Ordering::Equal)
89        });
90
91        // Concatenate: sorted unmasked indices, then masked indices.
92        let total = unmasked_indices.len() + masked_indices.len();
93        let mut result: Vec<u64> = Vec::with_capacity(total);
94        for &i in &unmasked_indices {
95            result.push(i as u64);
96        }
97        for &i in &masked_indices {
98            result.push(i as u64);
99        }
100
101        Array::from_vec(Ix1::new([total]), result)
102    }
103
104    /// Sort the masked array along `axis`, placing masked elements at
105    /// the end of each lane (#271).
106    ///
107    /// Each 1-D slice along `axis` is sorted independently — unmasked
108    /// values ascend, masked values trail. The output preserves the
109    /// input shape (no flattening) and produces an `IxDyn` mask
110    /// reflecting the new positions of masked entries.
111    ///
112    /// # Errors
113    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= self.ndim()`.
114    pub fn sort_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
115        let ndim = self.ndim();
116        if axis >= ndim {
117            return Err(FerrayError::axis_out_of_bounds(axis, ndim));
118        }
119        let shape = self.shape().to_vec();
120        let axis_len = shape[axis];
121        let total: usize = shape.iter().product();
122
123        // Materialize source data and mask in row-major flat order.
124        let src_data: Vec<T> = self.data().iter().copied().collect();
125        let src_mask: Vec<bool> = self.mask().iter().copied().collect();
126        let mut strides = vec![1usize; ndim];
127        for i in (0..ndim.saturating_sub(1)).rev() {
128            strides[i] = strides[i + 1] * shape[i + 1];
129        }
130
131        let mut out_data = vec![src_data[0]; total];
132        let mut out_mask = vec![false; total];
133
134        // Iterate each lane along `axis` by walking the multi-index
135        // over the "outer" axes (all but `axis`), then sweeping the
136        // axis from 0..axis_len for each.
137        let outer_shape: Vec<usize> = shape
138            .iter()
139            .enumerate()
140            .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
141            .collect();
142        let outer_size: usize = if outer_shape.is_empty() {
143            1
144        } else {
145            outer_shape.iter().product()
146        };
147
148        let mut outer_multi = vec![0usize; outer_shape.len()];
149        for _ in 0..outer_size {
150            // Gather the lane's (value, mask) pairs and their flat indices.
151            let mut lane: Vec<(T, bool, usize)> = Vec::with_capacity(axis_len);
152            for k in 0..axis_len {
153                let mut flat = 0usize;
154                let mut o = 0usize;
155                for (i, &stride) in strides.iter().enumerate() {
156                    if i == axis {
157                        flat += stride * k;
158                    } else {
159                        flat += stride * outer_multi[o];
160                        o += 1;
161                    }
162                }
163                lane.push((src_data[flat], src_mask[flat], flat));
164            }
165            // Partition: unmasked first, then masked. Sort unmasked
166            // ascending. Masked entries keep relative input order.
167            let mut unmasked: Vec<(T, usize)> = Vec::new();
168            let mut masked: Vec<(T, usize)> = Vec::new();
169            for (v, m, flat) in lane {
170                if m {
171                    masked.push((v, flat));
172                } else {
173                    unmasked.push((v, flat));
174                }
175            }
176            unmasked.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
177            // Write back: lane position k gets the k-th value, with
178            // mask=false for unmasked positions and mask=true after.
179            for (k, (v, _flat)) in unmasked.iter().chain(masked.iter()).enumerate() {
180                let mut flat = 0usize;
181                let mut o = 0usize;
182                for (i, &stride) in strides.iter().enumerate() {
183                    if i == axis {
184                        flat += stride * k;
185                    } else {
186                        flat += stride * outer_multi[o];
187                        o += 1;
188                    }
189                }
190                out_data[flat] = *v;
191                out_mask[flat] = k >= unmasked.len();
192            }
193
194            // Increment the outer multi-index.
195            for i in (0..outer_shape.len()).rev() {
196                outer_multi[i] += 1;
197                if outer_multi[i] < outer_shape[i] {
198                    break;
199                }
200                outer_multi[i] = 0;
201            }
202        }
203
204        let data_arr = Array::<T, IxDyn>::from_vec(IxDyn::new(&shape), out_data)?;
205        let mask_arr = Array::<bool, IxDyn>::from_vec(IxDyn::new(&shape), out_mask)?;
206        MaskedArray::new(data_arr, mask_arr)
207    }
208}