ndarray_slice/
merge_sort.rs

1//! Derivative work of [`core::slice::sort`] licensed under `MIT OR Apache-2.0`.
2//!
3//! [`core::slice::sort`]: https://doc.rust-lang.org/src/core/slice/sort.rs.html
4
5#![cfg(feature = "alloc")]
6
7use crate::insertion_sort::insertion_sort_shift_left;
8use crate::partition::reverse;
9use core::{cmp, mem, ptr};
10use ndarray::{ArrayView1, ArrayViewMut1, IndexLonger, s};
11
12/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
13/// stores the result into `v[..]`.
14///
15/// # Safety
16///
17/// The two slices must be non-empty and `mid` must be in bounds. Buffer `buf` must be long enough
18/// to hold a copy of the shorter slice. Also, `T` must not be a zero-sized type.
19#[warn(unsafe_op_in_unsafe_fn)]
20unsafe fn merge<T, F>(v: ArrayViewMut1<'_, T>, mid: usize, buf: *mut T, is_less: &mut F)
21where
22	F: FnMut(&T, &T) -> bool,
23{
24	let len = v.len();
25	//let v = 0;//v.as_mut_ptr();
26
27	// SAFETY: mid and len must be in-bounds of v.
28	//let (v_mid, v_end) = (mid, len);//unsafe { (v.add(mid), v.add(len)) };
29
30	// The merge process first copies the shorter run into `buf`. Then it traces the newly copied
31	// run and the longer run forwards (or backwards), comparing their next unconsumed elements and
32	// copying the lesser (or greater) one into `v`.
33	//
34	// As soon as the shorter run is fully consumed, the process is done. If the longer run gets
35	// consumed first, then we must copy whatever is left of the shorter run into the remaining
36	// hole in `v`.
37	//
38	// Intermediate state of the process is always tracked by `hole`, which serves two purposes:
39	// 1. Protects integrity of `v` from panics in `is_less`.
40	// 2. Fills the remaining hole in `v` if the longer run gets consumed first.
41	//
42	// Panic safety:
43	//
44	// If `is_less` panics at any point during the process, `hole` will get dropped and fill the
45	// hole in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
46	// object it initially held exactly once.
47	let mut hole;
48
49	if mid <= len - mid {
50		// The left run is shorter.
51
52		//let src = v.view_mut().index(0);
53		// SAFETY: buf must have enough capacity for `v[..mid]`.
54		unsafe {
55			for i in 0..mid {
56				ptr::copy_nonoverlapping(&v[i], buf.add(i), 1);
57			}
58			hole = MergeHole {
59				buf,
60				start: 0,
61				end: mid,
62				dest: 0,
63				v,
64			};
65		}
66
67		// Initially, these pointers point to the beginnings of their arrays.
68		let left = &mut hole.start;
69		let mut right = mid; //v_mid
70		let out = &mut hole.dest;
71
72		while *left < hole.end && right < len {
73			// Consume the lesser side.
74			// If equal, prefer the left run to maintain stability.
75
76			// SAFETY: left and right must be valid and part of v same for out.
77			unsafe {
78				let w = hole.v.view();
79				let is_l = is_less(w.uget(right), &*hole.buf.add(*left));
80				let to_copy = if is_l {
81					w.uget(right)
82				} else {
83					&*hole.buf.add(*left)
84				};
85				ptr::copy_nonoverlapping(to_copy, hole.v.view_mut().index(*out), 1);
86				*out += 1;
87				if is_l {
88					right += 1;
89				} else {
90					*left += 1;
91				}
92			}
93		}
94	} else {
95		// The right run is shorter.
96
97		// SAFETY: buf must have enough capacity for `v[mid..]`.
98		unsafe {
99			for i in 0..len - mid {
100				ptr::copy_nonoverlapping(&v[mid + i], buf.add(i), 1);
101			}
102			hole = MergeHole {
103				buf,
104				start: 0,
105				end: len - mid,
106				dest: mid,
107				v,
108			};
109		}
110
111		// Initially, these pointers point past the ends of their arrays.
112		let left = &mut hole.dest;
113		let right = &mut hole.end;
114		let mut out = len; //v_end;
115
116		while 0 < *left && 0 < *right {
117			// Consume the greater side.
118			// If equal, prefer the right run to maintain stability.
119
120			// SAFETY: left and right must be valid and part of v same for out.
121			unsafe {
122				let w = hole.v.view();
123				let is_l = is_less(&*hole.buf.add(*right - 1), w.uget(*left - 1));
124				if is_l {
125					*left -= 1;
126				} else {
127					*right -= 1;
128				}
129				let to_copy = if is_l {
130					w.uget(*left)
131				} else {
132					&*hole.buf.add(*right)
133				};
134				out -= 1;
135				ptr::copy_nonoverlapping(to_copy, hole.v.view_mut().index(out), 1);
136			}
137		}
138	}
139	// Finally, `hole` gets dropped. If the shorter run was not fully consumed, whatever remains of
140	// it will now be copied into the hole in `v`.
141
142	// When dropped, copies the range `start..end` into `dest..`.
143	struct MergeHole<'a, T> {
144		buf: *mut T,
145		start: usize,
146		end: usize,
147
148		v: ArrayViewMut1<'a, T>,
149		dest: usize,
150	}
151	//impl<'a, T> MergeHole<'a, T> {
152	//    unsafe fn buf_get_and_increment(&mut self, idx: &mut usize) -> *mut T {
153	//        let old = self.buf.add(*idx);
154
155	//        // SAFETY: ptr.add(1) must still be a valid pointer and part of `v`.
156	//        *idx = *idx + 1;//unsafe { ptr.add(1) };
157	//        old
158	//    }
159
160	//    unsafe fn buf_decrement_and_get(&mut self, idx: &mut usize) -> *mut T {
161	//        // SAFETY: ptr.sub(1) must still be a valid pointer and part of `v`.
162	//        *idx = *idx - 1;//unsafe { ptr.sub(1) };
163	//        self.buf.add(*idx)
164	//    }
165
166	//    unsafe fn out_get_and_increment(&mut self, idx: &mut usize) -> *mut T {
167	//        let old = self.v.view_mut().index(*idx);
168
169	//        // SAFETY: ptr.add(1) must still be a valid pointer and part of `v`.
170	//        *idx = *idx + 1;//unsafe { ptr.add(1) };
171	//        old
172	//    }
173
174	//    unsafe fn out_decrement_and_get(&mut self, idx: &mut usize) -> *mut T {
175	//        // SAFETY: ptr.sub(1) must still be a valid pointer and part of `v`.
176	//        *idx = *idx - 1;//unsafe { ptr.sub(1) };
177	//        self.v.view_mut().index(*idx)
178	//    }
179	//}
180
181	impl<T> Drop for MergeHole<'_, T> {
182		fn drop(&mut self) {
183			// SAFETY: `T` is not a zero-sized type, and these are pointers into a slice's elements.
184			unsafe {
185				let len = self.end - self.start; //self.end.sub_ptr(self.start);
186				for i in 0..len {
187					let src = self.buf.add(self.start + i);
188					let dst = self.v.view_mut().index(self.dest + i);
189					ptr::copy_nonoverlapping(src, dst, 1);
190				}
191			}
192		}
193	}
194}
195
196/// This merge sort borrows some (but not all) ideas from TimSort, which used to be described in
197/// detail [here](https://github.com/python/cpython/blob/main/Objects/listsort.txt). However Python
198/// has switched to a Powersort based implementation.
199///
200/// The algorithm identifies strictly descending and non-descending subsequences, which are called
201/// natural runs. There is a stack of pending runs yet to be merged. Each newly found run is pushed
202/// onto the stack, and then some pairs of adjacent runs are merged until these two invariants are
203/// satisfied:
204///
205/// 1. for every `i` in `1..runs.len()`: `runs[i - 1].len > runs[i].len`
206/// 2. for every `i` in `2..runs.len()`: `runs[i - 2].len > runs[i - 1].len + runs[i].len`
207///
208/// The invariants ensure that the total running time is *O*(*n* \* log(*n*)) worst-case.
209pub fn merge_sort<T, CmpF, ElemAllocF, ElemDeallocF, RunAllocF, RunDeallocF>(
210	mut v: ArrayViewMut1<'_, T>,
211	is_less: &mut CmpF,
212	elem_alloc_fn: ElemAllocF,
213	elem_dealloc_fn: ElemDeallocF,
214	run_alloc_fn: RunAllocF,
215	run_dealloc_fn: RunDeallocF,
216) where
217	CmpF: FnMut(&T, &T) -> bool,
218	ElemAllocF: Fn(usize) -> *mut T,
219	ElemDeallocF: Fn(*mut T, usize),
220	RunAllocF: Fn(usize) -> *mut TimSortRun,
221	RunDeallocF: Fn(*mut TimSortRun, usize),
222{
223	// Slices of up to this length get sorted using insertion sort.
224	const MAX_INSERTION: usize = 20;
225
226	// The caller should have already checked that.
227	debug_assert!(mem::size_of::<T>() > 0);
228
229	let len = v.len();
230
231	// Short arrays get sorted in-place via insertion sort to avoid allocations.
232	if len <= MAX_INSERTION {
233		if len >= 2 {
234			insertion_sort_shift_left(v, 1, is_less);
235		}
236		return;
237	}
238
239	// Allocate a buffer to use as scratch memory. We keep the length 0 so we can keep in it
240	// shallow copies of the contents of `v` without risking the dtors running on copies if
241	// `is_less` panics. When merging two sorted runs, this buffer holds a copy of the shorter run,
242	// which will always have length at most `len / 2`.
243	let buf = BufGuard::new(len / 2, elem_alloc_fn, elem_dealloc_fn);
244	let buf_ptr = buf.buf_ptr.as_ptr();
245
246	let mut runs = RunVec::new(run_alloc_fn, run_dealloc_fn);
247
248	let mut end = 0;
249	let mut start = 0;
250
251	// Scan forward. Memory pre-fetching prefers forward scanning vs backwards scanning, and the
252	// code-gen is usually better. For the most sensitive types such as integers, these are merged
253	// bidirectionally at once. So there is no benefit in scanning backwards.
254	while end < len {
255		let (streak_end, was_reversed) = find_streak(v.slice(s![start..]), is_less);
256		end += streak_end;
257		if was_reversed {
258			reverse(v.slice_mut(s![start..end]));
259		}
260
261		// Insert some more elements into the run if it's too short. Insertion sort is faster than
262		// merge sort on short sequences, so this significantly improves performance.
263		end = provide_sorted_batch(v.view_mut(), start, end, is_less);
264
265		// Push this run onto the stack.
266		runs.push(TimSortRun {
267			start,
268			len: end - start,
269		});
270		start = end;
271
272		// Merge some pairs of adjacent runs to satisfy the invariants.
273		while let Some(r) = collapse(runs.as_slice(), len) {
274			let left = runs[r];
275			let right = runs[r + 1];
276			let merge_slice = v.slice_mut(s![left.start..right.start + right.len]);
277			// SAFETY: `buf_ptr` must hold enough capacity for the shorter of the two sides, and
278			// neither side may be on length 0.
279			unsafe {
280				merge(merge_slice, left.len, buf_ptr, is_less);
281			}
282			runs[r + 1] = TimSortRun {
283				start: left.start,
284				len: left.len + right.len,
285			};
286			runs.remove(r);
287		}
288	}
289
290	// Finally, exactly one run must remain in the stack.
291	debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len);
292
293	// Examines the stack of runs and identifies the next pair of runs to merge. More specifically,
294	// if `Some(r)` is returned, that means `runs[r]` and `runs[r + 1]` must be merged next. If the
295	// algorithm should continue building a new run instead, `None` is returned.
296	//
297	// TimSort is infamous for its buggy implementations, as described here:
298	// http://envisage-project.eu/timsort-specification-and-verification/
299	//
300	// The gist of the story is: we must enforce the invariants on the top four runs on the stack.
301	// Enforcing them on just top three is not sufficient to ensure that the invariants will still
302	// hold for *all* runs in the stack.
303	//
304	// This function correctly checks invariants for the top four runs. Additionally, if the top
305	// run starts at index 0, it will always demand a merge operation until the stack is fully
306	// collapsed, in order to complete the sort.
307	#[inline]
308	fn collapse(runs: &[TimSortRun], stop: usize) -> Option<usize> {
309		let n = runs.len();
310		if n >= 2
311			&& (runs[n - 1].start + runs[n - 1].len == stop
312				|| runs[n - 2].len <= runs[n - 1].len
313				|| (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len)
314				|| (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len))
315		{
316			if n >= 3 && runs[n - 3].len < runs[n - 1].len {
317				Some(n - 3)
318			} else {
319				Some(n - 2)
320			}
321		} else {
322			None
323		}
324	}
325
326	// Extremely basic versions of Vec.
327	// Their use is super limited and by having the code here, it allows reuse between the sort
328	// implementations.
329	struct BufGuard<T, ElemDeallocF>
330	where
331		ElemDeallocF: Fn(*mut T, usize),
332	{
333		buf_ptr: ptr::NonNull<T>,
334		capacity: usize,
335		elem_dealloc_fn: ElemDeallocF,
336	}
337
338	impl<T, ElemDeallocF> BufGuard<T, ElemDeallocF>
339	where
340		ElemDeallocF: Fn(*mut T, usize),
341	{
342		fn new<ElemAllocF>(
343			len: usize,
344			elem_alloc_fn: ElemAllocF,
345			elem_dealloc_fn: ElemDeallocF,
346		) -> Self
347		where
348			ElemAllocF: Fn(usize) -> *mut T,
349		{
350			Self {
351				buf_ptr: ptr::NonNull::new(elem_alloc_fn(len)).unwrap(),
352				capacity: len,
353				elem_dealloc_fn,
354			}
355		}
356	}
357
358	impl<T, ElemDeallocF> Drop for BufGuard<T, ElemDeallocF>
359	where
360		ElemDeallocF: Fn(*mut T, usize),
361	{
362		fn drop(&mut self) {
363			(self.elem_dealloc_fn)(self.buf_ptr.as_ptr(), self.capacity);
364		}
365	}
366
367	struct RunVec<RunAllocF, RunDeallocF>
368	where
369		RunAllocF: Fn(usize) -> *mut TimSortRun,
370		RunDeallocF: Fn(*mut TimSortRun, usize),
371	{
372		buf_ptr: ptr::NonNull<TimSortRun>,
373		capacity: usize,
374		len: usize,
375		run_alloc_fn: RunAllocF,
376		run_dealloc_fn: RunDeallocF,
377	}
378
379	impl<RunAllocF, RunDeallocF> RunVec<RunAllocF, RunDeallocF>
380	where
381		RunAllocF: Fn(usize) -> *mut TimSortRun,
382		RunDeallocF: Fn(*mut TimSortRun, usize),
383	{
384		fn new(run_alloc_fn: RunAllocF, run_dealloc_fn: RunDeallocF) -> Self {
385			// Most slices can be sorted with at most 16 runs in-flight.
386			const START_RUN_CAPACITY: usize = 16;
387
388			Self {
389				buf_ptr: ptr::NonNull::new(run_alloc_fn(START_RUN_CAPACITY)).unwrap(),
390				capacity: START_RUN_CAPACITY,
391				len: 0,
392				run_alloc_fn,
393				run_dealloc_fn,
394			}
395		}
396
397		fn push(&mut self, val: TimSortRun) {
398			if self.len == self.capacity {
399				let old_capacity = self.capacity;
400				let old_buf_ptr = self.buf_ptr.as_ptr();
401
402				self.capacity *= 2;
403				self.buf_ptr = ptr::NonNull::new((self.run_alloc_fn)(self.capacity)).unwrap();
404
405				// SAFETY: buf_ptr new and old were correctly allocated and old_buf_ptr has
406				// old_capacity valid elements.
407				unsafe {
408					ptr::copy_nonoverlapping(old_buf_ptr, self.buf_ptr.as_ptr(), old_capacity);
409				}
410
411				(self.run_dealloc_fn)(old_buf_ptr, old_capacity);
412			}
413
414			// SAFETY: The invariant was just checked.
415			unsafe {
416				self.buf_ptr.as_ptr().add(self.len).write(val);
417			}
418			self.len += 1;
419		}
420
421		fn remove(&mut self, index: usize) {
422			if index >= self.len {
423				panic!("Index out of bounds");
424			}
425
426			// SAFETY: buf_ptr needs to be valid and len invariant upheld.
427			unsafe {
428				// the place we are taking from.
429				let ptr = self.buf_ptr.as_ptr().add(index);
430
431				// Shift everything down to fill in that spot.
432				ptr::copy(ptr.add(1), ptr, self.len - index - 1);
433			}
434			self.len -= 1;
435		}
436
437		fn as_slice(&self) -> &[TimSortRun] {
438			// SAFETY: Safe as long as buf_ptr is valid and len invariant was upheld.
439			unsafe { &*ptr::slice_from_raw_parts(self.buf_ptr.as_ptr(), self.len) }
440		}
441
442		fn len(&self) -> usize {
443			self.len
444		}
445	}
446
447	impl<RunAllocF, RunDeallocF> core::ops::Index<usize> for RunVec<RunAllocF, RunDeallocF>
448	where
449		RunAllocF: Fn(usize) -> *mut TimSortRun,
450		RunDeallocF: Fn(*mut TimSortRun, usize),
451	{
452		type Output = TimSortRun;
453
454		fn index(&self, index: usize) -> &Self::Output {
455			if index < self.len {
456				// SAFETY: buf_ptr and len invariant must be upheld.
457				unsafe {
458					return &*(self.buf_ptr.as_ptr().add(index));
459				}
460			}
461
462			panic!("Index out of bounds");
463		}
464	}
465
466	impl<RunAllocF, RunDeallocF> core::ops::IndexMut<usize> for RunVec<RunAllocF, RunDeallocF>
467	where
468		RunAllocF: Fn(usize) -> *mut TimSortRun,
469		RunDeallocF: Fn(*mut TimSortRun, usize),
470	{
471		fn index_mut(&mut self, index: usize) -> &mut Self::Output {
472			if index < self.len {
473				// SAFETY: buf_ptr and len invariant must be upheld.
474				unsafe {
475					return &mut *(self.buf_ptr.as_ptr().add(index));
476				}
477			}
478
479			panic!("Index out of bounds");
480		}
481	}
482
483	impl<RunAllocF, RunDeallocF> Drop for RunVec<RunAllocF, RunDeallocF>
484	where
485		RunAllocF: Fn(usize) -> *mut TimSortRun,
486		RunDeallocF: Fn(*mut TimSortRun, usize),
487	{
488		fn drop(&mut self) {
489			// As long as TimSortRun is Copy we don't need to drop them individually but just the
490			// whole allocation.
491			(self.run_dealloc_fn)(self.buf_ptr.as_ptr(), self.capacity);
492		}
493	}
494}
495
496/// Internal type used by merge_sort.
497#[derive(Clone, Copy, Debug)]
498pub struct TimSortRun {
499	len: usize,
500	start: usize,
501}
502
503/// Takes a range as denoted by start and end, that is already sorted and extends it to the right if
504/// necessary with sorts optimized for smaller ranges such as insertion sort.
505fn provide_sorted_batch<T, F>(
506	mut v: ArrayViewMut1<'_, T>,
507	start: usize,
508	mut end: usize,
509	is_less: &mut F,
510) -> usize
511where
512	F: FnMut(&T, &T) -> bool,
513{
514	let len = v.len();
515	assert!(end >= start && end <= len);
516
517	// This value is a balance between least comparisons and best performance, as
518	// influenced by for example cache locality.
519	const MIN_INSERTION_RUN: usize = 10;
520
521	// Insert some more elements into the run if it's too short. Insertion sort is faster than
522	// merge sort on short sequences, so this significantly improves performance.
523	let start_end_diff = end - start;
524
525	if start_end_diff < MIN_INSERTION_RUN && end < len {
526		// v[start_found..end] are elements that are already sorted in the input. We want to extend
527		// the sorted region to the left, so we push up MIN_INSERTION_RUN - 1 to the right. Which is
528		// more efficient that trying to push those already sorted elements to the left.
529		end = cmp::min(start + MIN_INSERTION_RUN, len);
530		let presorted_start = cmp::max(start_end_diff, 1);
531
532		insertion_sort_shift_left(v.slice_mut(s![start..end]), presorted_start, is_less);
533	}
534
535	end
536}
537
538// Finds a streak of presorted elements starting at the beginning of the slice. Returns the first
539/// value that is not part of said streak, and a bool denoting whether the streak was reversed.
540/// Streaks can be increasing or decreasing.
541fn find_streak<T, F>(v: ArrayView1<'_, T>, is_less: &mut F) -> (usize, bool)
542where
543	F: FnMut(&T, &T) -> bool,
544{
545	let len = v.len();
546
547	if len < 2 {
548		return (len, false);
549	}
550
551	let mut end = 2;
552
553	// SAFETY: See below specific.
554	unsafe {
555		// SAFETY: We checked that len >= 2, so 0 and 1 are valid indices.
556		let assume_reverse = is_less(v.uget(1), v.uget(0));
557
558		// SAFETY: We know end >= 2 and check end < len.
559		// From that follows that accessing v at end and end - 1 is safe.
560		if assume_reverse {
561			while end < len && is_less(v.uget(end), v.uget(end - 1)) {
562				end += 1;
563			}
564
565			(end, true)
566		} else {
567			while end < len && !is_less(v.uget(end), v.uget(end - 1)) {
568				end += 1;
569			}
570			(end, false)
571		}
572	}
573}