ndarray_stats/
sort.rs

1use indexmap::IndexMap;
2use ndarray::prelude::*;
3use ndarray::Slice;
4use rand::prelude::*;
5use rand::thread_rng;
6
7/// Methods for sorting and partitioning 1-D arrays.
8pub trait Sort1dExt<A> {
9    /// Return the element that would occupy the `i`-th position if
10    /// the array were sorted in increasing order.
11    ///
12    /// The array is shuffled **in place** to retrieve the desired element:
13    /// no copy of the array is allocated.
14    /// After the shuffling, all elements with an index smaller than `i`
15    /// are smaller than the desired element, while all elements with
16    /// an index greater or equal than `i` are greater than or equal
17    /// to the desired element.
18    ///
19    /// No other assumptions should be made on the ordering of the
20    /// elements after this computation.
21    ///
22    /// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)):
23    /// - average case: O(`n`);
24    /// - worst case: O(`n`^2);
25    /// where n is the number of elements in the array.
26    ///
27    /// **Panics** if `i` is greater than or equal to `n`.
28    fn get_from_sorted_mut(&mut self, i: usize) -> A
29    where
30        A: Ord + Clone;
31
32    /// A bulk version of [`get_from_sorted_mut`], optimized to retrieve multiple
33    /// indexes at once.
34    /// It returns an `IndexMap`, with indexes as keys and retrieved elements as
35    /// values.
36    /// The `IndexMap` is sorted with respect to indexes in increasing order:
37    /// this ordering is preserved when you iterate over it (using `iter`/`into_iter`).
38    ///
39    /// **Panics** if any element in `indexes` is greater than or equal to `n`,
40    /// where `n` is the length of the array..
41    ///
42    /// [`get_from_sorted_mut`]: #tymethod.get_from_sorted_mut
43    fn get_many_from_sorted_mut(&mut self, indexes: &ArrayRef1<usize>) -> IndexMap<usize, A>
44    where
45        A: Ord + Clone;
46
47    /// Partitions the array in increasing order based on the value initially
48    /// located at `pivot_index` and returns the new index of the value.
49    ///
50    /// The elements are rearranged in such a way that the value initially
51    /// located at `pivot_index` is moved to the position it would be in an
52    /// array sorted in increasing order. The return value is the new index of
53    /// the value after rearrangement. All elements smaller than the value are
54    /// moved to its left and all elements equal or greater than the value are
55    /// moved to its right. The ordering of the elements in the two partitions
56    /// is undefined.
57    ///
58    /// `self` is shuffled **in place** to operate the desired partition:
59    /// no copy of the array is allocated.
60    ///
61    /// The method uses Hoare's partition algorithm.
62    /// Complexity: O(`n`), where `n` is the number of elements in the array.
63    /// Average number of element swaps: n/6 - 1/3 (see
64    /// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550))
65    ///
66    /// **Panics** if `pivot_index` is greater than or equal to `n`.
67    ///
68    /// # Example
69    ///
70    /// ```
71    /// use ndarray::array;
72    /// use ndarray_stats::Sort1dExt;
73    ///
74    /// let mut data = array![3, 1, 4, 5, 2];
75    /// let pivot_index = 2;
76    /// let pivot_value = data[pivot_index];
77    ///
78    /// // Partition by the value located at `pivot_index`.
79    /// let new_index = data.partition_mut(pivot_index);
80    /// // The pivot value is now located at `new_index`.
81    /// assert_eq!(data[new_index], pivot_value);
82    /// // Elements less than that value are moved to the left.
83    /// for i in 0..new_index {
84    ///     assert!(data[i] < pivot_value);
85    /// }
86    /// // Elements greater than or equal to that value are moved to the right.
87    /// for i in (new_index + 1)..data.len() {
88    ///      assert!(data[i] >= pivot_value);
89    /// }
90    /// ```
91    fn partition_mut(&mut self, pivot_index: usize) -> usize
92    where
93        A: Ord + Clone;
94
95    private_decl! {}
96}
97
98impl<A> Sort1dExt<A> for ArrayRef<A, Ix1> {
99    fn get_from_sorted_mut(&mut self, i: usize) -> A
100    where
101        A: Ord + Clone,
102    {
103        let n = self.len();
104        if n == 1 {
105            self[0].clone()
106        } else {
107            let mut rng = thread_rng();
108            let pivot_index = rng.gen_range(0..n);
109            let partition_index = self.partition_mut(pivot_index);
110            if i < partition_index {
111                self.slice_axis_mut(Axis(0), Slice::from(..partition_index))
112                    .get_from_sorted_mut(i)
113            } else if i == partition_index {
114                self[i].clone()
115            } else {
116                self.slice_axis_mut(Axis(0), Slice::from(partition_index + 1..))
117                    .get_from_sorted_mut(i - (partition_index + 1))
118            }
119        }
120    }
121
122    fn get_many_from_sorted_mut(&mut self, indexes: &ArrayRef1<usize>) -> IndexMap<usize, A>
123    where
124        A: Ord + Clone,
125    {
126        let mut deduped_indexes: Vec<usize> = indexes.to_vec();
127        deduped_indexes.sort_unstable();
128        deduped_indexes.dedup();
129
130        get_many_from_sorted_mut_unchecked(self, &deduped_indexes)
131    }
132
133    fn partition_mut(&mut self, pivot_index: usize) -> usize
134    where
135        A: Ord + Clone,
136    {
137        let pivot_value = self[pivot_index].clone();
138        self.swap(pivot_index, 0);
139        let n = self.len();
140        let mut i = 1;
141        let mut j = n - 1;
142        loop {
143            loop {
144                if i > j {
145                    break;
146                }
147                if self[i] >= pivot_value {
148                    break;
149                }
150                i += 1;
151            }
152            while pivot_value <= self[j] {
153                if j == 1 {
154                    break;
155                }
156                j -= 1;
157            }
158            if i >= j {
159                break;
160            } else {
161                self.swap(i, j);
162                i += 1;
163                j -= 1;
164            }
165        }
166        self.swap(0, i - 1);
167        i - 1
168    }
169
170    private_impl! {}
171}
172
173/// To retrieve multiple indexes from the sorted array in an optimized fashion,
174/// [get_many_from_sorted_mut] first of all sorts and deduplicates the
175/// `indexes` vector.
176///
177/// `get_many_from_sorted_mut_unchecked` does not perform this sorting and
178/// deduplication, assuming that the user has already taken care of it.
179///
180/// Useful when you have to call [get_many_from_sorted_mut] multiple times
181/// using the same indexes.
182///
183/// [get_many_from_sorted_mut]: ../trait.Sort1dExt.html#tymethod.get_many_from_sorted_mut
184pub(crate) fn get_many_from_sorted_mut_unchecked<A>(
185    array: &mut ArrayRef1<A>,
186    indexes: &[usize],
187) -> IndexMap<usize, A>
188where
189    A: Ord + Clone,
190{
191    if indexes.is_empty() {
192        return IndexMap::new();
193    }
194
195    // Since `!indexes.is_empty()` and indexes must be in-bounds, `array` must
196    // be non-empty.
197    let mut values = vec![array[0].clone(); indexes.len()];
198    _get_many_from_sorted_mut_unchecked(array.view_mut(), &mut indexes.to_owned(), &mut values);
199
200    // We convert the vector to a more search-friendly `IndexMap`.
201    indexes.iter().cloned().zip(values.into_iter()).collect()
202}
203
204/// This is the recursive portion of `get_many_from_sorted_mut_unchecked`.
205///
206/// `indexes` is the list of indexes to get. `indexes` is mutable so that it
207/// can be used as scratch space for this routine; the value of `indexes` after
208/// calling this routine should be ignored.
209///
210/// `values` is a pre-allocated slice to use for writing the output. Its
211/// initial element values are ignored.
212fn _get_many_from_sorted_mut_unchecked<A>(
213    mut array: ArrayViewMut1<'_, A>,
214    indexes: &mut [usize],
215    values: &mut [A],
216) where
217    A: Ord + Clone,
218{
219    let n = array.len();
220    debug_assert!(n >= indexes.len()); // because indexes must be unique and in-bounds
221    debug_assert_eq!(indexes.len(), values.len());
222
223    if indexes.is_empty() {
224        // Nothing to do in this case.
225        return;
226    }
227
228    // At this point, `n >= 1` since `indexes.len() >= 1`.
229    if n == 1 {
230        // We can only reach this point if `indexes.len() == 1`, so we only
231        // need to assign the single value, and then we're done.
232        debug_assert_eq!(indexes.len(), 1);
233        values[0] = array[0].clone();
234        return;
235    }
236
237    // We pick a random pivot index: the corresponding element is the pivot value
238    let mut rng = thread_rng();
239    let pivot_index = rng.gen_range(0..n);
240
241    // We partition the array with respect to the pivot value.
242    // The pivot value moves to `array_partition_index`.
243    // Elements strictly smaller than the pivot value have indexes < `array_partition_index`.
244    // Elements greater or equal to the pivot value have indexes > `array_partition_index`.
245    let array_partition_index = array.partition_mut(pivot_index);
246
247    // We use a divide-and-conquer strategy, splitting the indexes we are
248    // searching for (`indexes`) and the corresponding portions of the output
249    // slice (`values`) into pieces with respect to `array_partition_index`.
250    let (found_exact, index_split) = match indexes.binary_search(&array_partition_index) {
251        Ok(index) => (true, index),
252        Err(index) => (false, index),
253    };
254    let (smaller_indexes, other_indexes) = indexes.split_at_mut(index_split);
255    let (smaller_values, other_values) = values.split_at_mut(index_split);
256    let (bigger_indexes, bigger_values) = if found_exact {
257        other_values[0] = array[array_partition_index].clone(); // Write exactly found value.
258        (&mut other_indexes[1..], &mut other_values[1..])
259    } else {
260        (other_indexes, other_values)
261    };
262
263    // We search recursively for the values corresponding to strictly smaller
264    // indexes to the left of `partition_index`.
265    _get_many_from_sorted_mut_unchecked(
266        array.slice_axis_mut(Axis(0), Slice::from(..array_partition_index)),
267        smaller_indexes,
268        smaller_values,
269    );
270
271    // We search recursively for the values corresponding to strictly bigger
272    // indexes to the right of `partition_index`. Since only the right portion
273    // of the array is passed in, the indexes need to be shifted by length of
274    // the removed portion.
275    bigger_indexes
276        .iter_mut()
277        .for_each(|x| *x -= array_partition_index + 1);
278    _get_many_from_sorted_mut_unchecked(
279        array.slice_axis_mut(Axis(0), Slice::from(array_partition_index + 1..)),
280        bigger_indexes,
281        bigger_values,
282    );
283}