Skip to main content

ferray_ma/
sorting.rs

1// ferray-ma: Masked sort and argsort (REQ-13, REQ-14)
2//
3// Sorting unmasked elements while pushing masked elements to the end.
4
5use ferray_core::Array;
6use ferray_core::dimension::{Dimension, Ix1};
7use ferray_core::dtype::Element;
8use ferray_core::error::FerrayResult;
9
10use crate::MaskedArray;
11
12impl<T, D> MaskedArray<T, D>
13where
14    T: Element + PartialOrd + Copy,
15    D: Dimension,
16{
17    /// Sort the masked array (flattened), placing masked elements at the end.
18    ///
19    /// Returns a new 1-D `MaskedArray` where:
20    /// - Unmasked elements are sorted in ascending order
21    /// - Masked elements come after all unmasked elements
22    ///
23    /// # Errors
24    /// Returns an error only for internal failures.
25    pub fn sort(&self) -> FerrayResult<MaskedArray<T, Ix1>> {
26        let mut unmasked: Vec<T> = Vec::new();
27        let mut masked_vals: Vec<T> = Vec::new();
28
29        for (v, m) in self.data().iter().zip(self.mask().iter()) {
30            if *m {
31                masked_vals.push(*v);
32            } else {
33                unmasked.push(*v);
34            }
35        }
36
37        // Sort unmasked elements using partial_cmp to handle floats
38        unmasked.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
39
40        let unmasked_count = unmasked.len();
41        let total = unmasked_count + masked_vals.len();
42
43        // Build result: sorted unmasked followed by masked
44        let mut result_data = Vec::with_capacity(total);
45        result_data.extend_from_slice(&unmasked);
46        result_data.extend_from_slice(&masked_vals);
47
48        let mut result_mask = Vec::with_capacity(total);
49        result_mask.extend(std::iter::repeat_n(false, unmasked_count));
50        result_mask.extend(std::iter::repeat_n(true, masked_vals.len()));
51
52        let data_arr = Array::from_vec(Ix1::new([total]), result_data)?;
53        let mask_arr = Array::from_vec(Ix1::new([total]), result_mask)?;
54        MaskedArray::new(data_arr, mask_arr)
55    }
56
57    /// Return the indices that would sort the masked array, with masked
58    /// elements placed at the end.
59    ///
60    /// Returns a `Vec<usize>` of indices.
61    ///
62    /// # Errors
63    /// Returns an error only for internal failures.
64    pub fn argsort(&self) -> FerrayResult<Vec<usize>> {
65        let vals: Vec<T> = self.data().iter().copied().collect();
66        let masks: Vec<bool> = self.mask().iter().copied().collect();
67
68        // Separate indices into unmasked and masked
69        let mut unmasked_indices: Vec<usize> = Vec::new();
70        let mut masked_indices: Vec<usize> = Vec::new();
71
72        for (i, m) in masks.iter().enumerate() {
73            if *m {
74                masked_indices.push(i);
75            } else {
76                unmasked_indices.push(i);
77            }
78        }
79
80        // Sort unmasked indices by their data values
81        unmasked_indices.sort_by(|a, b| {
82            vals[*a]
83                .partial_cmp(&vals[*b])
84                .unwrap_or(std::cmp::Ordering::Equal)
85        });
86
87        // Concatenate: sorted unmasked indices, then masked indices
88        let mut result = Vec::with_capacity(unmasked_indices.len() + masked_indices.len());
89        result.extend_from_slice(&unmasked_indices);
90        result.extend_from_slice(&masked_indices);
91
92        Ok(result)
93    }
94}