1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
//! Derivative work of [`core::slice::sort`] licensed under `MIT OR Apache-2.0`.
//!
//! [`core::slice::sort`]: https://doc.rust-lang.org/src/core/slice/sort.rs.html

use crate::partition::CopyOnDrop;
use core::{mem, mem::ManuallyDrop, ptr};
use ndarray::{s, ArrayViewMut1, IndexLonger};

/// Sorts a slice using insertion sort, which is *O*(*n*^2) worst-case.
pub fn insertion_sort<T, F>(mut v: ArrayViewMut1<'_, T>, is_less: &mut F)
where
	F: FnMut(&T, &T) -> bool,
{
	for i in 1..v.len() {
		shift_tail(v.slice_mut(s![..i + 1]), is_less);
	}
}

/// Partially sorts a slice by shifting several out-of-order elements around.
///
/// Returns `true` if the slice is sorted at the end. This function is *O*(*n*) worst-case.
#[cold]
pub fn partial_insertion_sort<T, F>(mut v: ArrayViewMut1<'_, T>, is_less: &mut F) -> bool
where
	F: FnMut(&T, &T) -> bool,
{
	// Maximum number of adjacent out-of-order pairs that will get shifted.
	const MAX_STEPS: usize = 5;
	// If the slice is shorter than this, don't shift any elements.
	const SHORTEST_SHIFTING: usize = 50;

	let len = v.len();
	let mut i = 1;

	for _ in 0..MAX_STEPS {
		// SAFETY: We already explicitly did the bound checking with `i < len`.
		// All our subsequent indexing is only in the range `0 <= index < len`
		unsafe {
			let v = v.view();
			// Find the next pair of adjacent out-of-order elements.
			while i < len && !is_less(v.uget(i), v.uget(i - 1)) {
				i += 1;
			}
		}

		// Are we done?
		if i == len {
			return true;
		}

		// Don't shift elements on short arrays, that has a performance cost.
		if len < SHORTEST_SHIFTING {
			return false;
		}

		// Swap the found pair of elements. This puts them in correct order.
		v.swap(i - 1, i);

		// Shift the smaller element to the left.
		shift_tail(v.slice_mut(s![..i]), is_less);
		// Shift the greater element to the right.
		shift_head(v.slice_mut(s![i..]), is_less);
	}

	// Didn't manage to sort the slice in the limited number of steps.
	false
}

/// Shifts the first element to the right until it encounters a greater or equal element.
fn shift_head<T, F>(mut v: ArrayViewMut1<'_, T>, is_less: &mut F)
where
	F: FnMut(&T, &T) -> bool,
{
	let len = v.len();
	// SAFETY: The unsafe operations below involves indexing without a bounds check (by offsetting a
	// pointer) and copying memory (`ptr::copy_nonoverlapping`).
	//
	// a. Indexing:
	//  1. We checked the size of the array to >=2.
	//  2. All the indexing that we will do is always between {0 <= index < len} at most.
	//
	// b. Memory copying
	//  1. We are obtaining pointers to references which are guaranteed to be valid.
	//  2. They cannot overlap because we obtain pointers to difference indices of the slice.
	//     Namely, `i` and `i-1`.
	//  3. If the slice is properly aligned, the elements are properly aligned.
	//     It is the caller's responsibility to make sure the slice is properly aligned.
	//
	// See comments below for further detail.
	unsafe {
		let w = v.view();
		// If the first two elements are out-of-order...
		if len >= 2 && is_less(w.uget(1), w.uget(0)) {
			// Read the first element into a stack-allocated variable. If a following comparison
			// operation panics, `hole` will get dropped and automatically write the element back
			// into the slice.
			let tmp = mem::ManuallyDrop::new(ptr::read(w.uget(0)));
			let src = v.view().index(1) as *const T;
			let dst = v.view_mut().index(0) as *mut T;
			let mut hole = CopyOnDrop {
				src: &*tmp,
				dest: src as *mut T,
			};
			ptr::copy_nonoverlapping(src, dst, 1);

			for i in 2..len {
				let w = v.view();
				if !is_less(w.uget(i), &*tmp) {
					break;
				}

				// Move `i`-th element one place to the left, thus shifting the hole to the right.
				ptr::copy_nonoverlapping(w.uget(i), v.view_mut().uget(i - 1), 1);
				hole.dest = v.view_mut().uget(i) as *mut T;
			}
			// `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
		}
	}
}

/// Shifts the last element to the left until it encounters a smaller or equal element.
fn shift_tail<T, F>(mut v: ArrayViewMut1<'_, T>, is_less: &mut F)
where
	F: FnMut(&T, &T) -> bool,
{
	let len = v.len();
	// SAFETY: The unsafe operations below involves indexing without a bound check (by offsetting a
	// pointer) and copying memory (`ptr::copy_nonoverlapping`).
	//
	// a. Indexing:
	//  1. We checked the size of the array to >= 2.
	//  2. All the indexing that we will do is always between `0 <= index < len-1` at most.
	//
	// b. Memory copying
	//  1. We are obtaining pointers to references which are guaranteed to be valid.
	//  2. They cannot overlap because we obtain pointers to difference indices of the slice.
	//     Namely, `i` and `i+1`.
	//  3. If the slice is properly aligned, the elements are properly aligned.
	//     It is the caller's responsibility to make sure the slice is properly aligned.
	//
	// See comments below for further detail.
	unsafe {
		// If the last two elements are out-of-order...
		if len >= 2 {
			let w = v.view();
			if is_less(w.uget(len - 1), w.uget(len - 2)) {
				// Read the last element into a stack-allocated variable. If a following comparison
				// operation panics, `hole` will get dropped and automatically write the element back
				// into the slice.
				let tmp = ManuallyDrop::new(ptr::read(w.uget(len - 1)));
				let mut hole = CopyOnDrop {
					src: &*tmp,
					dest: v.view_mut().index(len - 2),
				};
				let src = v.view().index(len - 2) as *const T;
				let dst = v.view_mut().index(len - 1) as *mut T;
				ptr::copy_nonoverlapping(src, dst, 1);

				for i in (0..len - 2).rev() {
					let src = v.view_mut().index(i) as *mut T;
					if !is_less(&*tmp, &*src) {
						break;
					}

					// Move `i`-th element one place to the right, thus shifting the hole to the left.
					let dst = v.view_mut().index(i + 1) as *mut T;
					ptr::copy_nonoverlapping(src, dst, 1);
					hole.dest = src;
				}
				// `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
			}
		}
	}
}

#[cfg(feature = "std")]
#[cfg(test)]
mod test {
	use super::insertion_sort;
	use ndarray::Array1;
	use quickcheck_macros::quickcheck;

	#[quickcheck]
	fn sorted(xs: Vec<u32>) {
		let mut array = Array1::from_vec(xs);
		insertion_sort(array.view_mut(), &mut u32::lt);
		for i in 1..array.len() {
			assert!(array[i - 1] <= array[i]);
		}
	}
}