kn0sys_ndarray_stats/sort.rs
1use indexmap::IndexMap;
2use ndarray::prelude::*;
3use ndarray::{Data, DataMut, Slice};
4use rand::prelude::*;
5
6/// Methods for sorting and partitioning 1-D arrays.
7pub trait Sort1dExt<A, S>
8where
9 S: Data<Elem = A>,
10{
11 /// Return the element that would occupy the `i`-th position if
12 /// the array were sorted in increasing order.
13 ///
14 /// The array is shuffled **in place** to retrieve the desired element:
15 /// no copy of the array is allocated.
16 /// After the shuffling, all elements with an index smaller than `i`
17 /// are smaller than the desired element, while all elements with
18 /// an index greater or equal than `i` are greater than or equal
19 /// to the desired element.
20 ///
21 /// No other assumptions should be made on the ordering of the
22 /// elements after this computation.
23 ///
24 /// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)):
25 /// - average case: O(`n`);
26 /// - worst case: O(`n`^2);
27 /// where n is the number of elements in the array.
28 ///
29 /// **Panics** if `i` is greater than or equal to `n`.
30 fn get_from_sorted_mut(&mut self, i: usize) -> A
31 where
32 A: Ord + Clone,
33 S: DataMut;
34
35 /// A bulk version of [`get_from_sorted_mut`], optimized to retrieve multiple
36 /// indexes at once.
37 /// It returns an `IndexMap`, with indexes as keys and retrieved elements as
38 /// values.
39 /// The `IndexMap` is sorted with respect to indexes in increasing order:
40 /// this ordering is preserved when you iterate over it (using `iter`/`into_iter`).
41 ///
42 /// **Panics** if any element in `indexes` is greater than or equal to `n`,
43 /// where `n` is the length of the array..
44 ///
45 /// [`get_from_sorted_mut`]: #tymethod.get_from_sorted_mut
46 fn get_many_from_sorted_mut<S2>(&mut self, indexes: &ArrayBase<S2, Ix1>) -> IndexMap<usize, A>
47 where
48 A: Ord + Clone,
49 S: DataMut,
50 S2: Data<Elem = usize>;
51
52 /// Partitions the array in increasing order based on the value initially
53 /// located at `pivot_index` and returns the new index of the value.
54 ///
55 /// The elements are rearranged in such a way that the value initially
56 /// located at `pivot_index` is moved to the position it would be in an
57 /// array sorted in increasing order. The return value is the new index of
58 /// the value after rearrangement. All elements smaller than the value are
59 /// moved to its left and all elements equal or greater than the value are
60 /// moved to its right. The ordering of the elements in the two partitions
61 /// is undefined.
62 ///
63 /// `self` is shuffled **in place** to operate the desired partition:
64 /// no copy of the array is allocated.
65 ///
66 /// The method uses Hoare's partition algorithm.
67 /// Complexity: O(`n`), where `n` is the number of elements in the array.
68 /// Average number of element swaps: n/6 - 1/3 (see
69 /// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550))
70 ///
71 /// **Panics** if `pivot_index` is greater than or equal to `n`.
72 ///
73 /// # Example
74 ///
75 /// ```
76 /// use ndarray::array;
77 /// use kn0sys_ndarray_stats::Sort1dExt;
78 ///
79 /// let mut data = array![3, 1, 4, 5, 2];
80 /// let pivot_index = 2;
81 /// let pivot_value = data[pivot_index];
82 ///
83 /// // Partition by the value located at `pivot_index`.
84 /// let new_index = data.partition_mut(pivot_index);
85 /// // The pivot value is now located at `new_index`.
86 /// assert_eq!(data[new_index], pivot_value);
87 /// // Elements less than that value are moved to the left.
88 /// for i in 0..new_index {
89 /// assert!(data[i] < pivot_value);
90 /// }
91 /// // Elements greater than or equal to that value are moved to the right.
92 /// for i in (new_index + 1)..data.len() {
93 /// assert!(data[i] >= pivot_value);
94 /// }
95 /// ```
96 fn partition_mut(&mut self, pivot_index: usize) -> usize
97 where
98 A: Ord + Clone,
99 S: DataMut;
100
101 private_decl! {}
102}
103
104impl<A, S> Sort1dExt<A, S> for ArrayBase<S, Ix1>
105where
106 S: Data<Elem = A>,
107{
108 fn get_from_sorted_mut(&mut self, i: usize) -> A
109 where
110 A: Ord + Clone,
111 S: DataMut,
112 {
113 let n = self.len();
114 if n == 1 {
115 self[0].clone()
116 } else {
117 let mut rng = rand::rng();
118 let pivot_index = rng.random_range(0..n);
119 let partition_index = self.partition_mut(pivot_index);
120 if i < partition_index {
121 self.slice_axis_mut(Axis(0), Slice::from(..partition_index))
122 .get_from_sorted_mut(i)
123 } else if i == partition_index {
124 self[i].clone()
125 } else {
126 self.slice_axis_mut(Axis(0), Slice::from(partition_index + 1..))
127 .get_from_sorted_mut(i - (partition_index + 1))
128 }
129 }
130 }
131
132 fn get_many_from_sorted_mut<S2>(&mut self, indexes: &ArrayBase<S2, Ix1>) -> IndexMap<usize, A>
133 where
134 A: Ord + Clone,
135 S: DataMut,
136 S2: Data<Elem = usize>,
137 {
138 let mut deduped_indexes: Vec<usize> = indexes.to_vec();
139 deduped_indexes.sort_unstable();
140 deduped_indexes.dedup();
141
142 get_many_from_sorted_mut_unchecked(self, &deduped_indexes)
143 }
144
145 fn partition_mut(&mut self, pivot_index: usize) -> usize
146 where
147 A: Ord + Clone,
148 S: DataMut,
149 {
150 let pivot_value = self[pivot_index].clone();
151 self.swap(pivot_index, 0);
152 let n = self.len();
153 let mut i = 1;
154 let mut j = n - 1;
155 loop {
156 loop {
157 if i > j {
158 break;
159 }
160 if self[i] >= pivot_value {
161 break;
162 }
163 i += 1;
164 }
165 while pivot_value <= self[j] {
166 if j == 1 {
167 break;
168 }
169 j -= 1;
170 }
171 if i >= j {
172 break;
173 } else {
174 self.swap(i, j);
175 i += 1;
176 j -= 1;
177 }
178 }
179 self.swap(0, i - 1);
180 i - 1
181 }
182
183 private_impl! {}
184}
185
186/// To retrieve multiple indexes from the sorted array in an optimized fashion,
187/// [get_many_from_sorted_mut] first of all sorts and deduplicates the
188/// `indexes` vector.
189///
190/// `get_many_from_sorted_mut_unchecked` does not perform this sorting and
191/// deduplication, assuming that the user has already taken care of it.
192///
193/// Useful when you have to call [get_many_from_sorted_mut] multiple times
194/// using the same indexes.
195///
196/// [get_many_from_sorted_mut]: ../trait.Sort1dExt.html#tymethod.get_many_from_sorted_mut
197pub(crate) fn get_many_from_sorted_mut_unchecked<A, S>(
198 array: &mut ArrayBase<S, Ix1>,
199 indexes: &[usize],
200) -> IndexMap<usize, A>
201where
202 A: Ord + Clone,
203 S: DataMut<Elem = A>,
204{
205 if indexes.is_empty() {
206 return IndexMap::new();
207 }
208
209 // Since `!indexes.is_empty()` and indexes must be in-bounds, `array` must
210 // be non-empty.
211 let mut values = vec![array[0].clone(); indexes.len()];
212 _get_many_from_sorted_mut_unchecked(array.view_mut(), &mut indexes.to_owned(), &mut values);
213
214 // We convert the vector to a more search-friendly `IndexMap`.
215 indexes.iter().cloned().zip(values.into_iter()).collect()
216}
217
218/// This is the recursive portion of `get_many_from_sorted_mut_unchecked`.
219///
220/// `indexes` is the list of indexes to get. `indexes` is mutable so that it
221/// can be used as scratch space for this routine; the value of `indexes` after
222/// calling this routine should be ignored.
223///
224/// `values` is a pre-allocated slice to use for writing the output. Its
225/// initial element values are ignored.
226fn _get_many_from_sorted_mut_unchecked<A>(
227 mut array: ArrayViewMut1<'_, A>,
228 indexes: &mut [usize],
229 values: &mut [A],
230) where
231 A: Ord + Clone,
232{
233 let n = array.len();
234 debug_assert!(n >= indexes.len()); // because indexes must be unique and in-bounds
235 debug_assert_eq!(indexes.len(), values.len());
236
237 if indexes.is_empty() {
238 // Nothing to do in this case.
239 return;
240 }
241
242 // At this point, `n >= 1` since `indexes.len() >= 1`.
243 if n == 1 {
244 // We can only reach this point if `indexes.len() == 1`, so we only
245 // need to assign the single value, and then we're done.
246 debug_assert_eq!(indexes.len(), 1);
247 values[0] = array[0].clone();
248 return;
249 }
250
251 // We pick a random pivot index: the corresponding element is the pivot value
252 let mut rng = rand::rng();
253 let pivot_index = rng.random_range(0..n);
254
255 // We partition the array with respect to the pivot value.
256 // The pivot value moves to `array_partition_index`.
257 // Elements strictly smaller than the pivot value have indexes < `array_partition_index`.
258 // Elements greater or equal to the pivot value have indexes > `array_partition_index`.
259 let array_partition_index = array.partition_mut(pivot_index);
260
261 // We use a divide-and-conquer strategy, splitting the indexes we are
262 // searching for (`indexes`) and the corresponding portions of the output
263 // slice (`values`) into pieces with respect to `array_partition_index`.
264 let (found_exact, index_split) = match indexes.binary_search(&array_partition_index) {
265 Ok(index) => (true, index),
266 Err(index) => (false, index),
267 };
268 let (smaller_indexes, other_indexes) = indexes.split_at_mut(index_split);
269 let (smaller_values, other_values) = values.split_at_mut(index_split);
270 let (bigger_indexes, bigger_values) = if found_exact {
271 other_values[0] = array[array_partition_index].clone(); // Write exactly found value.
272 (&mut other_indexes[1..], &mut other_values[1..])
273 } else {
274 (other_indexes, other_values)
275 };
276
277 // We search recursively for the values corresponding to strictly smaller
278 // indexes to the left of `partition_index`.
279 _get_many_from_sorted_mut_unchecked(
280 array.slice_axis_mut(Axis(0), Slice::from(..array_partition_index)),
281 smaller_indexes,
282 smaller_values,
283 );
284
285 // We search recursively for the values corresponding to strictly bigger
286 // indexes to the right of `partition_index`. Since only the right portion
287 // of the array is passed in, the indexes need to be shifted by length of
288 // the removed portion.
289 bigger_indexes
290 .iter_mut()
291 .for_each(|x| *x -= array_partition_index + 1);
292 _get_many_from_sorted_mut_unchecked(
293 array.slice_axis_mut(Axis(0), Slice::from(array_partition_index + 1..)),
294 bigger_indexes,
295 bigger_values,
296 );
297}