ndarray_stats/sort.rs
1use indexmap::IndexMap;
2use ndarray::prelude::*;
3use ndarray::Slice;
4use rand::prelude::*;
5use rand::thread_rng;
6
7/// Methods for sorting and partitioning 1-D arrays.
8pub trait Sort1dExt<A> {
9 /// Return the element that would occupy the `i`-th position if
10 /// the array were sorted in increasing order.
11 ///
12 /// The array is shuffled **in place** to retrieve the desired element:
13 /// no copy of the array is allocated.
14 /// After the shuffling, all elements with an index smaller than `i`
15 /// are smaller than the desired element, while all elements with
16 /// an index greater or equal than `i` are greater than or equal
17 /// to the desired element.
18 ///
19 /// No other assumptions should be made on the ordering of the
20 /// elements after this computation.
21 ///
22 /// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)):
23 /// - average case: O(`n`);
24 /// - worst case: O(`n`^2);
25 /// where n is the number of elements in the array.
26 ///
27 /// **Panics** if `i` is greater than or equal to `n`.
28 fn get_from_sorted_mut(&mut self, i: usize) -> A
29 where
30 A: Ord + Clone;
31
32 /// A bulk version of [`get_from_sorted_mut`], optimized to retrieve multiple
33 /// indexes at once.
34 /// It returns an `IndexMap`, with indexes as keys and retrieved elements as
35 /// values.
36 /// The `IndexMap` is sorted with respect to indexes in increasing order:
37 /// this ordering is preserved when you iterate over it (using `iter`/`into_iter`).
38 ///
39 /// **Panics** if any element in `indexes` is greater than or equal to `n`,
40 /// where `n` is the length of the array..
41 ///
42 /// [`get_from_sorted_mut`]: #tymethod.get_from_sorted_mut
43 fn get_many_from_sorted_mut(&mut self, indexes: &ArrayRef1<usize>) -> IndexMap<usize, A>
44 where
45 A: Ord + Clone;
46
47 /// Partitions the array in increasing order based on the value initially
48 /// located at `pivot_index` and returns the new index of the value.
49 ///
50 /// The elements are rearranged in such a way that the value initially
51 /// located at `pivot_index` is moved to the position it would be in an
52 /// array sorted in increasing order. The return value is the new index of
53 /// the value after rearrangement. All elements smaller than the value are
54 /// moved to its left and all elements equal or greater than the value are
55 /// moved to its right. The ordering of the elements in the two partitions
56 /// is undefined.
57 ///
58 /// `self` is shuffled **in place** to operate the desired partition:
59 /// no copy of the array is allocated.
60 ///
61 /// The method uses Hoare's partition algorithm.
62 /// Complexity: O(`n`), where `n` is the number of elements in the array.
63 /// Average number of element swaps: n/6 - 1/3 (see
64 /// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550))
65 ///
66 /// **Panics** if `pivot_index` is greater than or equal to `n`.
67 ///
68 /// # Example
69 ///
70 /// ```
71 /// use ndarray::array;
72 /// use ndarray_stats::Sort1dExt;
73 ///
74 /// let mut data = array![3, 1, 4, 5, 2];
75 /// let pivot_index = 2;
76 /// let pivot_value = data[pivot_index];
77 ///
78 /// // Partition by the value located at `pivot_index`.
79 /// let new_index = data.partition_mut(pivot_index);
80 /// // The pivot value is now located at `new_index`.
81 /// assert_eq!(data[new_index], pivot_value);
82 /// // Elements less than that value are moved to the left.
83 /// for i in 0..new_index {
84 /// assert!(data[i] < pivot_value);
85 /// }
86 /// // Elements greater than or equal to that value are moved to the right.
87 /// for i in (new_index + 1)..data.len() {
88 /// assert!(data[i] >= pivot_value);
89 /// }
90 /// ```
91 fn partition_mut(&mut self, pivot_index: usize) -> usize
92 where
93 A: Ord + Clone;
94
95 private_decl! {}
96}
97
98impl<A> Sort1dExt<A> for ArrayRef<A, Ix1> {
99 fn get_from_sorted_mut(&mut self, i: usize) -> A
100 where
101 A: Ord + Clone,
102 {
103 let n = self.len();
104 if n == 1 {
105 self[0].clone()
106 } else {
107 let mut rng = thread_rng();
108 let pivot_index = rng.gen_range(0..n);
109 let partition_index = self.partition_mut(pivot_index);
110 if i < partition_index {
111 self.slice_axis_mut(Axis(0), Slice::from(..partition_index))
112 .get_from_sorted_mut(i)
113 } else if i == partition_index {
114 self[i].clone()
115 } else {
116 self.slice_axis_mut(Axis(0), Slice::from(partition_index + 1..))
117 .get_from_sorted_mut(i - (partition_index + 1))
118 }
119 }
120 }
121
122 fn get_many_from_sorted_mut(&mut self, indexes: &ArrayRef1<usize>) -> IndexMap<usize, A>
123 where
124 A: Ord + Clone,
125 {
126 let mut deduped_indexes: Vec<usize> = indexes.to_vec();
127 deduped_indexes.sort_unstable();
128 deduped_indexes.dedup();
129
130 get_many_from_sorted_mut_unchecked(self, &deduped_indexes)
131 }
132
133 fn partition_mut(&mut self, pivot_index: usize) -> usize
134 where
135 A: Ord + Clone,
136 {
137 let pivot_value = self[pivot_index].clone();
138 self.swap(pivot_index, 0);
139 let n = self.len();
140 let mut i = 1;
141 let mut j = n - 1;
142 loop {
143 loop {
144 if i > j {
145 break;
146 }
147 if self[i] >= pivot_value {
148 break;
149 }
150 i += 1;
151 }
152 while pivot_value <= self[j] {
153 if j == 1 {
154 break;
155 }
156 j -= 1;
157 }
158 if i >= j {
159 break;
160 } else {
161 self.swap(i, j);
162 i += 1;
163 j -= 1;
164 }
165 }
166 self.swap(0, i - 1);
167 i - 1
168 }
169
170 private_impl! {}
171}
172
173/// To retrieve multiple indexes from the sorted array in an optimized fashion,
174/// [get_many_from_sorted_mut] first of all sorts and deduplicates the
175/// `indexes` vector.
176///
177/// `get_many_from_sorted_mut_unchecked` does not perform this sorting and
178/// deduplication, assuming that the user has already taken care of it.
179///
180/// Useful when you have to call [get_many_from_sorted_mut] multiple times
181/// using the same indexes.
182///
183/// [get_many_from_sorted_mut]: ../trait.Sort1dExt.html#tymethod.get_many_from_sorted_mut
184pub(crate) fn get_many_from_sorted_mut_unchecked<A>(
185 array: &mut ArrayRef1<A>,
186 indexes: &[usize],
187) -> IndexMap<usize, A>
188where
189 A: Ord + Clone,
190{
191 if indexes.is_empty() {
192 return IndexMap::new();
193 }
194
195 // Since `!indexes.is_empty()` and indexes must be in-bounds, `array` must
196 // be non-empty.
197 let mut values = vec![array[0].clone(); indexes.len()];
198 _get_many_from_sorted_mut_unchecked(array.view_mut(), &mut indexes.to_owned(), &mut values);
199
200 // We convert the vector to a more search-friendly `IndexMap`.
201 indexes.iter().cloned().zip(values.into_iter()).collect()
202}
203
204/// This is the recursive portion of `get_many_from_sorted_mut_unchecked`.
205///
206/// `indexes` is the list of indexes to get. `indexes` is mutable so that it
207/// can be used as scratch space for this routine; the value of `indexes` after
208/// calling this routine should be ignored.
209///
210/// `values` is a pre-allocated slice to use for writing the output. Its
211/// initial element values are ignored.
212fn _get_many_from_sorted_mut_unchecked<A>(
213 mut array: ArrayViewMut1<'_, A>,
214 indexes: &mut [usize],
215 values: &mut [A],
216) where
217 A: Ord + Clone,
218{
219 let n = array.len();
220 debug_assert!(n >= indexes.len()); // because indexes must be unique and in-bounds
221 debug_assert_eq!(indexes.len(), values.len());
222
223 if indexes.is_empty() {
224 // Nothing to do in this case.
225 return;
226 }
227
228 // At this point, `n >= 1` since `indexes.len() >= 1`.
229 if n == 1 {
230 // We can only reach this point if `indexes.len() == 1`, so we only
231 // need to assign the single value, and then we're done.
232 debug_assert_eq!(indexes.len(), 1);
233 values[0] = array[0].clone();
234 return;
235 }
236
237 // We pick a random pivot index: the corresponding element is the pivot value
238 let mut rng = thread_rng();
239 let pivot_index = rng.gen_range(0..n);
240
241 // We partition the array with respect to the pivot value.
242 // The pivot value moves to `array_partition_index`.
243 // Elements strictly smaller than the pivot value have indexes < `array_partition_index`.
244 // Elements greater or equal to the pivot value have indexes > `array_partition_index`.
245 let array_partition_index = array.partition_mut(pivot_index);
246
247 // We use a divide-and-conquer strategy, splitting the indexes we are
248 // searching for (`indexes`) and the corresponding portions of the output
249 // slice (`values`) into pieces with respect to `array_partition_index`.
250 let (found_exact, index_split) = match indexes.binary_search(&array_partition_index) {
251 Ok(index) => (true, index),
252 Err(index) => (false, index),
253 };
254 let (smaller_indexes, other_indexes) = indexes.split_at_mut(index_split);
255 let (smaller_values, other_values) = values.split_at_mut(index_split);
256 let (bigger_indexes, bigger_values) = if found_exact {
257 other_values[0] = array[array_partition_index].clone(); // Write exactly found value.
258 (&mut other_indexes[1..], &mut other_values[1..])
259 } else {
260 (other_indexes, other_values)
261 };
262
263 // We search recursively for the values corresponding to strictly smaller
264 // indexes to the left of `partition_index`.
265 _get_many_from_sorted_mut_unchecked(
266 array.slice_axis_mut(Axis(0), Slice::from(..array_partition_index)),
267 smaller_indexes,
268 smaller_values,
269 );
270
271 // We search recursively for the values corresponding to strictly bigger
272 // indexes to the right of `partition_index`. Since only the right portion
273 // of the array is passed in, the indexes need to be shifted by length of
274 // the removed portion.
275 bigger_indexes
276 .iter_mut()
277 .for_each(|x| *x -= array_partition_index + 1);
278 _get_many_from_sorted_mut_unchecked(
279 array.slice_axis_mut(Axis(0), Slice::from(array_partition_index + 1..)),
280 bigger_indexes,
281 bigger_values,
282 );
283}