ndarray_slice/merge_sort.rs
1//! Derivative work of [`core::slice::sort`] licensed under `MIT OR Apache-2.0`.
2//!
3//! [`core::slice::sort`]: https://doc.rust-lang.org/src/core/slice/sort.rs.html
4
5#![cfg(feature = "alloc")]
6
7use crate::insertion_sort::insertion_sort_shift_left;
8use crate::partition::reverse;
9use core::{cmp, mem, ptr};
10use ndarray::{ArrayView1, ArrayViewMut1, IndexLonger, s};
11
12/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
13/// stores the result into `v[..]`.
14///
15/// # Safety
16///
17/// The two slices must be non-empty and `mid` must be in bounds. Buffer `buf` must be long enough
18/// to hold a copy of the shorter slice. Also, `T` must not be a zero-sized type.
19#[warn(unsafe_op_in_unsafe_fn)]
20unsafe fn merge<T, F>(v: ArrayViewMut1<'_, T>, mid: usize, buf: *mut T, is_less: &mut F)
21where
22 F: FnMut(&T, &T) -> bool,
23{
24 let len = v.len();
25 //let v = 0;//v.as_mut_ptr();
26
27 // SAFETY: mid and len must be in-bounds of v.
28 //let (v_mid, v_end) = (mid, len);//unsafe { (v.add(mid), v.add(len)) };
29
30 // The merge process first copies the shorter run into `buf`. Then it traces the newly copied
31 // run and the longer run forwards (or backwards), comparing their next unconsumed elements and
32 // copying the lesser (or greater) one into `v`.
33 //
34 // As soon as the shorter run is fully consumed, the process is done. If the longer run gets
35 // consumed first, then we must copy whatever is left of the shorter run into the remaining
36 // hole in `v`.
37 //
38 // Intermediate state of the process is always tracked by `hole`, which serves two purposes:
39 // 1. Protects integrity of `v` from panics in `is_less`.
40 // 2. Fills the remaining hole in `v` if the longer run gets consumed first.
41 //
42 // Panic safety:
43 //
44 // If `is_less` panics at any point during the process, `hole` will get dropped and fill the
45 // hole in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
46 // object it initially held exactly once.
47 let mut hole;
48
49 if mid <= len - mid {
50 // The left run is shorter.
51
52 //let src = v.view_mut().index(0);
53 // SAFETY: buf must have enough capacity for `v[..mid]`.
54 unsafe {
55 for i in 0..mid {
56 ptr::copy_nonoverlapping(&v[i], buf.add(i), 1);
57 }
58 hole = MergeHole {
59 buf,
60 start: 0,
61 end: mid,
62 dest: 0,
63 v,
64 };
65 }
66
67 // Initially, these pointers point to the beginnings of their arrays.
68 let left = &mut hole.start;
69 let mut right = mid; //v_mid
70 let out = &mut hole.dest;
71
72 while *left < hole.end && right < len {
73 // Consume the lesser side.
74 // If equal, prefer the left run to maintain stability.
75
76 // SAFETY: left and right must be valid and part of v same for out.
77 unsafe {
78 let w = hole.v.view();
79 let is_l = is_less(w.uget(right), &*hole.buf.add(*left));
80 let to_copy = if is_l {
81 w.uget(right)
82 } else {
83 &*hole.buf.add(*left)
84 };
85 ptr::copy_nonoverlapping(to_copy, hole.v.view_mut().index(*out), 1);
86 *out += 1;
87 if is_l {
88 right += 1;
89 } else {
90 *left += 1;
91 }
92 }
93 }
94 } else {
95 // The right run is shorter.
96
97 // SAFETY: buf must have enough capacity for `v[mid..]`.
98 unsafe {
99 for i in 0..len - mid {
100 ptr::copy_nonoverlapping(&v[mid + i], buf.add(i), 1);
101 }
102 hole = MergeHole {
103 buf,
104 start: 0,
105 end: len - mid,
106 dest: mid,
107 v,
108 };
109 }
110
111 // Initially, these pointers point past the ends of their arrays.
112 let left = &mut hole.dest;
113 let right = &mut hole.end;
114 let mut out = len; //v_end;
115
116 while 0 < *left && 0 < *right {
117 // Consume the greater side.
118 // If equal, prefer the right run to maintain stability.
119
120 // SAFETY: left and right must be valid and part of v same for out.
121 unsafe {
122 let w = hole.v.view();
123 let is_l = is_less(&*hole.buf.add(*right - 1), w.uget(*left - 1));
124 if is_l {
125 *left -= 1;
126 } else {
127 *right -= 1;
128 }
129 let to_copy = if is_l {
130 w.uget(*left)
131 } else {
132 &*hole.buf.add(*right)
133 };
134 out -= 1;
135 ptr::copy_nonoverlapping(to_copy, hole.v.view_mut().index(out), 1);
136 }
137 }
138 }
139 // Finally, `hole` gets dropped. If the shorter run was not fully consumed, whatever remains of
140 // it will now be copied into the hole in `v`.
141
142 // When dropped, copies the range `start..end` into `dest..`.
143 struct MergeHole<'a, T> {
144 buf: *mut T,
145 start: usize,
146 end: usize,
147
148 v: ArrayViewMut1<'a, T>,
149 dest: usize,
150 }
151 //impl<'a, T> MergeHole<'a, T> {
152 // unsafe fn buf_get_and_increment(&mut self, idx: &mut usize) -> *mut T {
153 // let old = self.buf.add(*idx);
154
155 // // SAFETY: ptr.add(1) must still be a valid pointer and part of `v`.
156 // *idx = *idx + 1;//unsafe { ptr.add(1) };
157 // old
158 // }
159
160 // unsafe fn buf_decrement_and_get(&mut self, idx: &mut usize) -> *mut T {
161 // // SAFETY: ptr.sub(1) must still be a valid pointer and part of `v`.
162 // *idx = *idx - 1;//unsafe { ptr.sub(1) };
163 // self.buf.add(*idx)
164 // }
165
166 // unsafe fn out_get_and_increment(&mut self, idx: &mut usize) -> *mut T {
167 // let old = self.v.view_mut().index(*idx);
168
169 // // SAFETY: ptr.add(1) must still be a valid pointer and part of `v`.
170 // *idx = *idx + 1;//unsafe { ptr.add(1) };
171 // old
172 // }
173
174 // unsafe fn out_decrement_and_get(&mut self, idx: &mut usize) -> *mut T {
175 // // SAFETY: ptr.sub(1) must still be a valid pointer and part of `v`.
176 // *idx = *idx - 1;//unsafe { ptr.sub(1) };
177 // self.v.view_mut().index(*idx)
178 // }
179 //}
180
181 impl<T> Drop for MergeHole<'_, T> {
182 fn drop(&mut self) {
183 // SAFETY: `T` is not a zero-sized type, and these are pointers into a slice's elements.
184 unsafe {
185 let len = self.end - self.start; //self.end.sub_ptr(self.start);
186 for i in 0..len {
187 let src = self.buf.add(self.start + i);
188 let dst = self.v.view_mut().index(self.dest + i);
189 ptr::copy_nonoverlapping(src, dst, 1);
190 }
191 }
192 }
193 }
194}
195
196/// This merge sort borrows some (but not all) ideas from TimSort, which used to be described in
197/// detail [here](https://github.com/python/cpython/blob/main/Objects/listsort.txt). However Python
198/// has switched to a Powersort based implementation.
199///
200/// The algorithm identifies strictly descending and non-descending subsequences, which are called
201/// natural runs. There is a stack of pending runs yet to be merged. Each newly found run is pushed
202/// onto the stack, and then some pairs of adjacent runs are merged until these two invariants are
203/// satisfied:
204///
205/// 1. for every `i` in `1..runs.len()`: `runs[i - 1].len > runs[i].len`
206/// 2. for every `i` in `2..runs.len()`: `runs[i - 2].len > runs[i - 1].len + runs[i].len`
207///
208/// The invariants ensure that the total running time is *O*(*n* \* log(*n*)) worst-case.
209pub fn merge_sort<T, CmpF, ElemAllocF, ElemDeallocF, RunAllocF, RunDeallocF>(
210 mut v: ArrayViewMut1<'_, T>,
211 is_less: &mut CmpF,
212 elem_alloc_fn: ElemAllocF,
213 elem_dealloc_fn: ElemDeallocF,
214 run_alloc_fn: RunAllocF,
215 run_dealloc_fn: RunDeallocF,
216) where
217 CmpF: FnMut(&T, &T) -> bool,
218 ElemAllocF: Fn(usize) -> *mut T,
219 ElemDeallocF: Fn(*mut T, usize),
220 RunAllocF: Fn(usize) -> *mut TimSortRun,
221 RunDeallocF: Fn(*mut TimSortRun, usize),
222{
223 // Slices of up to this length get sorted using insertion sort.
224 const MAX_INSERTION: usize = 20;
225
226 // The caller should have already checked that.
227 debug_assert!(mem::size_of::<T>() > 0);
228
229 let len = v.len();
230
231 // Short arrays get sorted in-place via insertion sort to avoid allocations.
232 if len <= MAX_INSERTION {
233 if len >= 2 {
234 insertion_sort_shift_left(v, 1, is_less);
235 }
236 return;
237 }
238
239 // Allocate a buffer to use as scratch memory. We keep the length 0 so we can keep in it
240 // shallow copies of the contents of `v` without risking the dtors running on copies if
241 // `is_less` panics. When merging two sorted runs, this buffer holds a copy of the shorter run,
242 // which will always have length at most `len / 2`.
243 let buf = BufGuard::new(len / 2, elem_alloc_fn, elem_dealloc_fn);
244 let buf_ptr = buf.buf_ptr.as_ptr();
245
246 let mut runs = RunVec::new(run_alloc_fn, run_dealloc_fn);
247
248 let mut end = 0;
249 let mut start = 0;
250
251 // Scan forward. Memory pre-fetching prefers forward scanning vs backwards scanning, and the
252 // code-gen is usually better. For the most sensitive types such as integers, these are merged
253 // bidirectionally at once. So there is no benefit in scanning backwards.
254 while end < len {
255 let (streak_end, was_reversed) = find_streak(v.slice(s![start..]), is_less);
256 end += streak_end;
257 if was_reversed {
258 reverse(v.slice_mut(s![start..end]));
259 }
260
261 // Insert some more elements into the run if it's too short. Insertion sort is faster than
262 // merge sort on short sequences, so this significantly improves performance.
263 end = provide_sorted_batch(v.view_mut(), start, end, is_less);
264
265 // Push this run onto the stack.
266 runs.push(TimSortRun {
267 start,
268 len: end - start,
269 });
270 start = end;
271
272 // Merge some pairs of adjacent runs to satisfy the invariants.
273 while let Some(r) = collapse(runs.as_slice(), len) {
274 let left = runs[r];
275 let right = runs[r + 1];
276 let merge_slice = v.slice_mut(s![left.start..right.start + right.len]);
277 // SAFETY: `buf_ptr` must hold enough capacity for the shorter of the two sides, and
278 // neither side may be on length 0.
279 unsafe {
280 merge(merge_slice, left.len, buf_ptr, is_less);
281 }
282 runs[r + 1] = TimSortRun {
283 start: left.start,
284 len: left.len + right.len,
285 };
286 runs.remove(r);
287 }
288 }
289
290 // Finally, exactly one run must remain in the stack.
291 debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len);
292
293 // Examines the stack of runs and identifies the next pair of runs to merge. More specifically,
294 // if `Some(r)` is returned, that means `runs[r]` and `runs[r + 1]` must be merged next. If the
295 // algorithm should continue building a new run instead, `None` is returned.
296 //
297 // TimSort is infamous for its buggy implementations, as described here:
298 // http://envisage-project.eu/timsort-specification-and-verification/
299 //
300 // The gist of the story is: we must enforce the invariants on the top four runs on the stack.
301 // Enforcing them on just top three is not sufficient to ensure that the invariants will still
302 // hold for *all* runs in the stack.
303 //
304 // This function correctly checks invariants for the top four runs. Additionally, if the top
305 // run starts at index 0, it will always demand a merge operation until the stack is fully
306 // collapsed, in order to complete the sort.
307 #[inline]
308 fn collapse(runs: &[TimSortRun], stop: usize) -> Option<usize> {
309 let n = runs.len();
310 if n >= 2
311 && (runs[n - 1].start + runs[n - 1].len == stop
312 || runs[n - 2].len <= runs[n - 1].len
313 || (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len)
314 || (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len))
315 {
316 if n >= 3 && runs[n - 3].len < runs[n - 1].len {
317 Some(n - 3)
318 } else {
319 Some(n - 2)
320 }
321 } else {
322 None
323 }
324 }
325
326 // Extremely basic versions of Vec.
327 // Their use is super limited and by having the code here, it allows reuse between the sort
328 // implementations.
329 struct BufGuard<T, ElemDeallocF>
330 where
331 ElemDeallocF: Fn(*mut T, usize),
332 {
333 buf_ptr: ptr::NonNull<T>,
334 capacity: usize,
335 elem_dealloc_fn: ElemDeallocF,
336 }
337
338 impl<T, ElemDeallocF> BufGuard<T, ElemDeallocF>
339 where
340 ElemDeallocF: Fn(*mut T, usize),
341 {
342 fn new<ElemAllocF>(
343 len: usize,
344 elem_alloc_fn: ElemAllocF,
345 elem_dealloc_fn: ElemDeallocF,
346 ) -> Self
347 where
348 ElemAllocF: Fn(usize) -> *mut T,
349 {
350 Self {
351 buf_ptr: ptr::NonNull::new(elem_alloc_fn(len)).unwrap(),
352 capacity: len,
353 elem_dealloc_fn,
354 }
355 }
356 }
357
358 impl<T, ElemDeallocF> Drop for BufGuard<T, ElemDeallocF>
359 where
360 ElemDeallocF: Fn(*mut T, usize),
361 {
362 fn drop(&mut self) {
363 (self.elem_dealloc_fn)(self.buf_ptr.as_ptr(), self.capacity);
364 }
365 }
366
367 struct RunVec<RunAllocF, RunDeallocF>
368 where
369 RunAllocF: Fn(usize) -> *mut TimSortRun,
370 RunDeallocF: Fn(*mut TimSortRun, usize),
371 {
372 buf_ptr: ptr::NonNull<TimSortRun>,
373 capacity: usize,
374 len: usize,
375 run_alloc_fn: RunAllocF,
376 run_dealloc_fn: RunDeallocF,
377 }
378
379 impl<RunAllocF, RunDeallocF> RunVec<RunAllocF, RunDeallocF>
380 where
381 RunAllocF: Fn(usize) -> *mut TimSortRun,
382 RunDeallocF: Fn(*mut TimSortRun, usize),
383 {
384 fn new(run_alloc_fn: RunAllocF, run_dealloc_fn: RunDeallocF) -> Self {
385 // Most slices can be sorted with at most 16 runs in-flight.
386 const START_RUN_CAPACITY: usize = 16;
387
388 Self {
389 buf_ptr: ptr::NonNull::new(run_alloc_fn(START_RUN_CAPACITY)).unwrap(),
390 capacity: START_RUN_CAPACITY,
391 len: 0,
392 run_alloc_fn,
393 run_dealloc_fn,
394 }
395 }
396
397 fn push(&mut self, val: TimSortRun) {
398 if self.len == self.capacity {
399 let old_capacity = self.capacity;
400 let old_buf_ptr = self.buf_ptr.as_ptr();
401
402 self.capacity *= 2;
403 self.buf_ptr = ptr::NonNull::new((self.run_alloc_fn)(self.capacity)).unwrap();
404
405 // SAFETY: buf_ptr new and old were correctly allocated and old_buf_ptr has
406 // old_capacity valid elements.
407 unsafe {
408 ptr::copy_nonoverlapping(old_buf_ptr, self.buf_ptr.as_ptr(), old_capacity);
409 }
410
411 (self.run_dealloc_fn)(old_buf_ptr, old_capacity);
412 }
413
414 // SAFETY: The invariant was just checked.
415 unsafe {
416 self.buf_ptr.as_ptr().add(self.len).write(val);
417 }
418 self.len += 1;
419 }
420
421 fn remove(&mut self, index: usize) {
422 if index >= self.len {
423 panic!("Index out of bounds");
424 }
425
426 // SAFETY: buf_ptr needs to be valid and len invariant upheld.
427 unsafe {
428 // the place we are taking from.
429 let ptr = self.buf_ptr.as_ptr().add(index);
430
431 // Shift everything down to fill in that spot.
432 ptr::copy(ptr.add(1), ptr, self.len - index - 1);
433 }
434 self.len -= 1;
435 }
436
437 fn as_slice(&self) -> &[TimSortRun] {
438 // SAFETY: Safe as long as buf_ptr is valid and len invariant was upheld.
439 unsafe { &*ptr::slice_from_raw_parts(self.buf_ptr.as_ptr(), self.len) }
440 }
441
442 fn len(&self) -> usize {
443 self.len
444 }
445 }
446
447 impl<RunAllocF, RunDeallocF> core::ops::Index<usize> for RunVec<RunAllocF, RunDeallocF>
448 where
449 RunAllocF: Fn(usize) -> *mut TimSortRun,
450 RunDeallocF: Fn(*mut TimSortRun, usize),
451 {
452 type Output = TimSortRun;
453
454 fn index(&self, index: usize) -> &Self::Output {
455 if index < self.len {
456 // SAFETY: buf_ptr and len invariant must be upheld.
457 unsafe {
458 return &*(self.buf_ptr.as_ptr().add(index));
459 }
460 }
461
462 panic!("Index out of bounds");
463 }
464 }
465
466 impl<RunAllocF, RunDeallocF> core::ops::IndexMut<usize> for RunVec<RunAllocF, RunDeallocF>
467 where
468 RunAllocF: Fn(usize) -> *mut TimSortRun,
469 RunDeallocF: Fn(*mut TimSortRun, usize),
470 {
471 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
472 if index < self.len {
473 // SAFETY: buf_ptr and len invariant must be upheld.
474 unsafe {
475 return &mut *(self.buf_ptr.as_ptr().add(index));
476 }
477 }
478
479 panic!("Index out of bounds");
480 }
481 }
482
483 impl<RunAllocF, RunDeallocF> Drop for RunVec<RunAllocF, RunDeallocF>
484 where
485 RunAllocF: Fn(usize) -> *mut TimSortRun,
486 RunDeallocF: Fn(*mut TimSortRun, usize),
487 {
488 fn drop(&mut self) {
489 // As long as TimSortRun is Copy we don't need to drop them individually but just the
490 // whole allocation.
491 (self.run_dealloc_fn)(self.buf_ptr.as_ptr(), self.capacity);
492 }
493 }
494}
495
496/// Internal type used by merge_sort.
497#[derive(Clone, Copy, Debug)]
498pub struct TimSortRun {
499 len: usize,
500 start: usize,
501}
502
503/// Takes a range as denoted by start and end, that is already sorted and extends it to the right if
504/// necessary with sorts optimized for smaller ranges such as insertion sort.
505fn provide_sorted_batch<T, F>(
506 mut v: ArrayViewMut1<'_, T>,
507 start: usize,
508 mut end: usize,
509 is_less: &mut F,
510) -> usize
511where
512 F: FnMut(&T, &T) -> bool,
513{
514 let len = v.len();
515 assert!(end >= start && end <= len);
516
517 // This value is a balance between least comparisons and best performance, as
518 // influenced by for example cache locality.
519 const MIN_INSERTION_RUN: usize = 10;
520
521 // Insert some more elements into the run if it's too short. Insertion sort is faster than
522 // merge sort on short sequences, so this significantly improves performance.
523 let start_end_diff = end - start;
524
525 if start_end_diff < MIN_INSERTION_RUN && end < len {
526 // v[start_found..end] are elements that are already sorted in the input. We want to extend
527 // the sorted region to the left, so we push up MIN_INSERTION_RUN - 1 to the right. Which is
528 // more efficient that trying to push those already sorted elements to the left.
529 end = cmp::min(start + MIN_INSERTION_RUN, len);
530 let presorted_start = cmp::max(start_end_diff, 1);
531
532 insertion_sort_shift_left(v.slice_mut(s![start..end]), presorted_start, is_less);
533 }
534
535 end
536}
537
538// Finds a streak of presorted elements starting at the beginning of the slice. Returns the first
539/// value that is not part of said streak, and a bool denoting whether the streak was reversed.
540/// Streaks can be increasing or decreasing.
541fn find_streak<T, F>(v: ArrayView1<'_, T>, is_less: &mut F) -> (usize, bool)
542where
543 F: FnMut(&T, &T) -> bool,
544{
545 let len = v.len();
546
547 if len < 2 {
548 return (len, false);
549 }
550
551 let mut end = 2;
552
553 // SAFETY: See below specific.
554 unsafe {
555 // SAFETY: We checked that len >= 2, so 0 and 1 are valid indices.
556 let assume_reverse = is_less(v.uget(1), v.uget(0));
557
558 // SAFETY: We know end >= 2 and check end < len.
559 // From that follows that accessing v at end and end - 1 is safe.
560 if assume_reverse {
561 while end < len && is_less(v.uget(end), v.uget(end - 1)) {
562 end += 1;
563 }
564
565 (end, true)
566 } else {
567 while end < len && !is_less(v.uget(end), v.uget(end - 1)) {
568 end += 1;
569 }
570 (end, false)
571 }
572 }
573}