ndarray_stats/sort.rs
1use indexmap::IndexMap;
2use ndarray::prelude::*;
3use ndarray::{Data, DataMut, Slice};
4use rand::prelude::*;
5use rand::thread_rng;
6
7/// Methods for sorting and partitioning 1-D arrays.
8pub trait Sort1dExt<A, S>
9where
10 S: Data<Elem = A>,
11{
12 /// Return the element that would occupy the `i`-th position if
13 /// the array were sorted in increasing order.
14 ///
15 /// The array is shuffled **in place** to retrieve the desired element:
16 /// no copy of the array is allocated.
17 /// After the shuffling, all elements with an index smaller than `i`
18 /// are smaller than the desired element, while all elements with
19 /// an index greater or equal than `i` are greater than or equal
20 /// to the desired element.
21 ///
22 /// No other assumptions should be made on the ordering of the
23 /// elements after this computation.
24 ///
25 /// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)):
26 /// - average case: O(`n`);
27 /// - worst case: O(`n`^2);
28 /// where n is the number of elements in the array.
29 ///
30 /// **Panics** if `i` is greater than or equal to `n`.
31 fn get_from_sorted_mut(&mut self, i: usize) -> A
32 where
33 A: Ord + Clone,
34 S: DataMut;
35
36 /// A bulk version of [`get_from_sorted_mut`], optimized to retrieve multiple
37 /// indexes at once.
38 /// It returns an `IndexMap`, with indexes as keys and retrieved elements as
39 /// values.
40 /// The `IndexMap` is sorted with respect to indexes in increasing order:
41 /// this ordering is preserved when you iterate over it (using `iter`/`into_iter`).
42 ///
43 /// **Panics** if any element in `indexes` is greater than or equal to `n`,
44 /// where `n` is the length of the array..
45 ///
46 /// [`get_from_sorted_mut`]: #tymethod.get_from_sorted_mut
47 fn get_many_from_sorted_mut<S2>(&mut self, indexes: &ArrayBase<S2, Ix1>) -> IndexMap<usize, A>
48 where
49 A: Ord + Clone,
50 S: DataMut,
51 S2: Data<Elem = usize>;
52
53 /// Partitions the array in increasing order based on the value initially
54 /// located at `pivot_index` and returns the new index of the value.
55 ///
56 /// The elements are rearranged in such a way that the value initially
57 /// located at `pivot_index` is moved to the position it would be in an
58 /// array sorted in increasing order. The return value is the new index of
59 /// the value after rearrangement. All elements smaller than the value are
60 /// moved to its left and all elements equal or greater than the value are
61 /// moved to its right. The ordering of the elements in the two partitions
62 /// is undefined.
63 ///
64 /// `self` is shuffled **in place** to operate the desired partition:
65 /// no copy of the array is allocated.
66 ///
67 /// The method uses Hoare's partition algorithm.
68 /// Complexity: O(`n`), where `n` is the number of elements in the array.
69 /// Average number of element swaps: n/6 - 1/3 (see
70 /// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550))
71 ///
72 /// **Panics** if `pivot_index` is greater than or equal to `n`.
73 ///
74 /// # Example
75 ///
76 /// ```
77 /// use ndarray::array;
78 /// use ndarray_stats::Sort1dExt;
79 ///
80 /// let mut data = array![3, 1, 4, 5, 2];
81 /// let pivot_index = 2;
82 /// let pivot_value = data[pivot_index];
83 ///
84 /// // Partition by the value located at `pivot_index`.
85 /// let new_index = data.partition_mut(pivot_index);
86 /// // The pivot value is now located at `new_index`.
87 /// assert_eq!(data[new_index], pivot_value);
88 /// // Elements less than that value are moved to the left.
89 /// for i in 0..new_index {
90 /// assert!(data[i] < pivot_value);
91 /// }
92 /// // Elements greater than or equal to that value are moved to the right.
93 /// for i in (new_index + 1)..data.len() {
94 /// assert!(data[i] >= pivot_value);
95 /// }
96 /// ```
97 fn partition_mut(&mut self, pivot_index: usize) -> usize
98 where
99 A: Ord + Clone,
100 S: DataMut;
101
102 private_decl! {}
103}
104
105impl<A, S> Sort1dExt<A, S> for ArrayBase<S, Ix1>
106where
107 S: Data<Elem = A>,
108{
109 fn get_from_sorted_mut(&mut self, i: usize) -> A
110 where
111 A: Ord + Clone,
112 S: DataMut,
113 {
114 let n = self.len();
115 if n == 1 {
116 self[0].clone()
117 } else {
118 let mut rng = thread_rng();
119 let pivot_index = rng.gen_range(0..n);
120 let partition_index = self.partition_mut(pivot_index);
121 if i < partition_index {
122 self.slice_axis_mut(Axis(0), Slice::from(..partition_index))
123 .get_from_sorted_mut(i)
124 } else if i == partition_index {
125 self[i].clone()
126 } else {
127 self.slice_axis_mut(Axis(0), Slice::from(partition_index + 1..))
128 .get_from_sorted_mut(i - (partition_index + 1))
129 }
130 }
131 }
132
133 fn get_many_from_sorted_mut<S2>(&mut self, indexes: &ArrayBase<S2, Ix1>) -> IndexMap<usize, A>
134 where
135 A: Ord + Clone,
136 S: DataMut,
137 S2: Data<Elem = usize>,
138 {
139 let mut deduped_indexes: Vec<usize> = indexes.to_vec();
140 deduped_indexes.sort_unstable();
141 deduped_indexes.dedup();
142
143 get_many_from_sorted_mut_unchecked(self, &deduped_indexes)
144 }
145
146 fn partition_mut(&mut self, pivot_index: usize) -> usize
147 where
148 A: Ord + Clone,
149 S: DataMut,
150 {
151 let pivot_value = self[pivot_index].clone();
152 self.swap(pivot_index, 0);
153 let n = self.len();
154 let mut i = 1;
155 let mut j = n - 1;
156 loop {
157 loop {
158 if i > j {
159 break;
160 }
161 if self[i] >= pivot_value {
162 break;
163 }
164 i += 1;
165 }
166 while pivot_value <= self[j] {
167 if j == 1 {
168 break;
169 }
170 j -= 1;
171 }
172 if i >= j {
173 break;
174 } else {
175 self.swap(i, j);
176 i += 1;
177 j -= 1;
178 }
179 }
180 self.swap(0, i - 1);
181 i - 1
182 }
183
184 private_impl! {}
185}
186
187/// To retrieve multiple indexes from the sorted array in an optimized fashion,
188/// [get_many_from_sorted_mut] first of all sorts and deduplicates the
189/// `indexes` vector.
190///
191/// `get_many_from_sorted_mut_unchecked` does not perform this sorting and
192/// deduplication, assuming that the user has already taken care of it.
193///
194/// Useful when you have to call [get_many_from_sorted_mut] multiple times
195/// using the same indexes.
196///
197/// [get_many_from_sorted_mut]: ../trait.Sort1dExt.html#tymethod.get_many_from_sorted_mut
198pub(crate) fn get_many_from_sorted_mut_unchecked<A, S>(
199 array: &mut ArrayBase<S, Ix1>,
200 indexes: &[usize],
201) -> IndexMap<usize, A>
202where
203 A: Ord + Clone,
204 S: DataMut<Elem = A>,
205{
206 if indexes.is_empty() {
207 return IndexMap::new();
208 }
209
210 // Since `!indexes.is_empty()` and indexes must be in-bounds, `array` must
211 // be non-empty.
212 let mut values = vec![array[0].clone(); indexes.len()];
213 _get_many_from_sorted_mut_unchecked(array.view_mut(), &mut indexes.to_owned(), &mut values);
214
215 // We convert the vector to a more search-friendly `IndexMap`.
216 indexes.iter().cloned().zip(values.into_iter()).collect()
217}
218
219/// This is the recursive portion of `get_many_from_sorted_mut_unchecked`.
220///
221/// `indexes` is the list of indexes to get. `indexes` is mutable so that it
222/// can be used as scratch space for this routine; the value of `indexes` after
223/// calling this routine should be ignored.
224///
225/// `values` is a pre-allocated slice to use for writing the output. Its
226/// initial element values are ignored.
227fn _get_many_from_sorted_mut_unchecked<A>(
228 mut array: ArrayViewMut1<'_, A>,
229 indexes: &mut [usize],
230 values: &mut [A],
231) where
232 A: Ord + Clone,
233{
234 let n = array.len();
235 debug_assert!(n >= indexes.len()); // because indexes must be unique and in-bounds
236 debug_assert_eq!(indexes.len(), values.len());
237
238 if indexes.is_empty() {
239 // Nothing to do in this case.
240 return;
241 }
242
243 // At this point, `n >= 1` since `indexes.len() >= 1`.
244 if n == 1 {
245 // We can only reach this point if `indexes.len() == 1`, so we only
246 // need to assign the single value, and then we're done.
247 debug_assert_eq!(indexes.len(), 1);
248 values[0] = array[0].clone();
249 return;
250 }
251
252 // We pick a random pivot index: the corresponding element is the pivot value
253 let mut rng = thread_rng();
254 let pivot_index = rng.gen_range(0..n);
255
256 // We partition the array with respect to the pivot value.
257 // The pivot value moves to `array_partition_index`.
258 // Elements strictly smaller than the pivot value have indexes < `array_partition_index`.
259 // Elements greater or equal to the pivot value have indexes > `array_partition_index`.
260 let array_partition_index = array.partition_mut(pivot_index);
261
262 // We use a divide-and-conquer strategy, splitting the indexes we are
263 // searching for (`indexes`) and the corresponding portions of the output
264 // slice (`values`) into pieces with respect to `array_partition_index`.
265 let (found_exact, index_split) = match indexes.binary_search(&array_partition_index) {
266 Ok(index) => (true, index),
267 Err(index) => (false, index),
268 };
269 let (smaller_indexes, other_indexes) = indexes.split_at_mut(index_split);
270 let (smaller_values, other_values) = values.split_at_mut(index_split);
271 let (bigger_indexes, bigger_values) = if found_exact {
272 other_values[0] = array[array_partition_index].clone(); // Write exactly found value.
273 (&mut other_indexes[1..], &mut other_values[1..])
274 } else {
275 (other_indexes, other_values)
276 };
277
278 // We search recursively for the values corresponding to strictly smaller
279 // indexes to the left of `partition_index`.
280 _get_many_from_sorted_mut_unchecked(
281 array.slice_axis_mut(Axis(0), Slice::from(..array_partition_index)),
282 smaller_indexes,
283 smaller_values,
284 );
285
286 // We search recursively for the values corresponding to strictly bigger
287 // indexes to the right of `partition_index`. Since only the right portion
288 // of the array is passed in, the indexes need to be shifted by length of
289 // the removed portion.
290 bigger_indexes
291 .iter_mut()
292 .for_each(|x| *x -= array_partition_index + 1);
293 _get_many_from_sorted_mut_unchecked(
294 array.slice_axis_mut(Axis(0), Slice::from(array_partition_index + 1..)),
295 bigger_indexes,
296 bigger_values,
297 );
298}