1pub mod key_ptr;
2pub use key_ptr::KeyPtr;
3
4pub trait SortableKey: Copy + Default + Ord + Send + Sync {
6 type KeyType: Copy + Ord + Into<usize> + TryFrom<usize>;
7
8 fn extract_key(&self) -> Self::KeyType;
10
11 const IS_PRIMITIVE: bool;
14
15 fn from_key(k: Self::KeyType) -> Self;
17}
18
19impl SortableKey for u32 {
21 type KeyType = usize;
22 #[inline(always)]
23 fn extract_key(&self) -> usize { *self as usize }
24 const IS_PRIMITIVE: bool = true;
25 #[inline(always)]
26 fn from_key(k: usize) -> Self { k as u32 }
27}
28
29impl SortableKey for i32 {
31 type KeyType = usize;
32 #[inline(always)]
33 fn extract_key(&self) -> usize { (*self as i64 + 2147483648) as usize }
34 const IS_PRIMITIVE: bool = true;
35 #[inline(always)]
36 fn from_key(k: usize) -> Self { (k as i64 - 2147483648) as i32 }
37}
38
39pub fn overclocked_sort<T: SortableKey>(arr: &mut [T]) {
41 if arr.len() <= 1 { return; }
42
43 let n = arr.len();
45 let probe_points = std::cmp::min(n, 1024);
46 if probe_points >= 64 {
47 let stride = (n - 1) / (probe_points - 1);
48 let mut prev_probe = arr[0].extract_key().into();
49 let mut probe_desc_breaks = 0usize;
50 for i in 1..probe_points {
51 let idx = i * stride;
52 let k = arr[idx].extract_key().into();
53 if k < prev_probe {
54 probe_desc_breaks += 1;
55 }
56 prev_probe = k;
57 }
58 if probe_desc_breaks == 0 {
59 arr.sort_unstable();
60 return;
61 }
62 if probe_desc_breaks >= probe_points - 2 {
63 arr.sort_unstable();
64 return;
65 }
66 }
67
68 let first_key = arr[0].extract_key().into();
70 let mut prev_key = first_key;
71 let mut min_val = first_key;
72 let mut max_val = first_key;
73 let mut non_decreasing = true;
74 let mut non_increasing = true;
75 let mut desc_breaks = 0usize;
76 for item in arr.iter().skip(1) {
77 let k = item.extract_key().into();
78 if k < prev_key {
79 non_decreasing = false;
80 desc_breaks += 1;
81 }
82 if k > prev_key {
83 non_increasing = false;
84 }
85 prev_key = k;
86 if k < min_val { min_val = k; }
87 if k > max_val { max_val = k; }
88 }
89 if non_decreasing {
90 return;
91 }
92 if non_increasing {
93 arr.reverse();
94 return;
95 }
96
97 let range = max_val - min_val + 1;
98 let max_limit = std::cmp::min(1_000_000, std::cmp::max(1, n / 4));
99
100 let near_sorted_threshold = std::cmp::max(1, n / 16);
102 if range > max_limit && (desc_breaks <= near_sorted_threshold || desc_breaks >= (n - 1).saturating_sub(near_sorted_threshold)) {
103 arr.sort_unstable();
104 return;
105 }
106
107 if range <= max_limit {
109 if T::IS_PRIMITIVE {
111 counting_sort_primitive(arr, min_val, range);
112 } else {
113 counting_sort_records(arr, min_val, range);
114 }
115 } else {
116 fallback_sort(arr);
118 }
119}
120
121fn fallback_sort<T: SortableKey>(arr: &mut [T]) {
122 if arr.len() <= 1 {
123 return;
124 }
125 if arr.len() <= 32 {
126 insertion_sort(arr);
127 return;
128 }
129 radix_sort_by_key(arr);
130}
131
132fn radix_sort_by_key<T: SortableKey>(arr: &mut [T]) {
133 let len = arr.len();
134 let mut current = arr.to_vec();
135 let mut next = vec![T::default(); len];
136 let passes = std::mem::size_of::<usize>();
137
138 for pass in 0..passes {
139 let shift = pass * 8;
140 let mut counts = [0usize; 256];
141
142 for item in current.iter() {
143 let bucket = (item.extract_key().into() >> shift) & 0xFF;
144 counts[bucket] += 1;
145 }
146
147 let mut offset = 0usize;
148 for count in counts.iter_mut() {
149 let current_count = *count;
150 *count = offset;
151 offset += current_count;
152 }
153
154 for item in current.iter() {
155 let bucket = (item.extract_key().into() >> shift) & 0xFF;
156 next[counts[bucket]] = *item;
157 counts[bucket] += 1;
158 }
159
160 std::mem::swap(&mut current, &mut next);
161 }
162
163 arr.copy_from_slice(¤t);
164}
165
166#[allow(dead_code)]
167fn quicksort_recursive<T: SortableKey>(arr: &mut [T]) {
168 const INSERTION_THRESHOLD: usize = 24;
169
170 if arr.len() <= INSERTION_THRESHOLD {
171 insertion_sort(arr);
172 return;
173 }
174
175 let len = arr.len();
176 let mid = len / 2;
177 let pivot = median_of_three(arr[0], arr[mid], arr[len - 1]);
178
179 let mut left = 0usize;
180 let mut right = len - 1;
181
182 loop {
183 while arr[left] < pivot {
184 left += 1;
185 }
186 while arr[right] > pivot {
187 if right == 0 {
188 break;
189 }
190 right -= 1;
191 }
192 if left >= right {
193 break;
194 }
195 arr.swap(left, right);
196 left += 1;
197 if right == 0 {
198 break;
199 }
200 right -= 1;
201 }
202
203 let split_index = right + 1;
204 let (lo, hi) = arr.split_at_mut(split_index);
205 if lo.len() < hi.len() {
206 if !lo.is_empty() {
207 quicksort_recursive(lo);
208 }
209 if !hi.is_empty() {
210 quicksort_recursive(hi);
211 }
212 } else {
213 if !hi.is_empty() {
214 quicksort_recursive(hi);
215 }
216 if !lo.is_empty() {
217 quicksort_recursive(lo);
218 }
219 }
220}
221
222#[allow(dead_code)]
223fn insertion_sort<T: SortableKey>(arr: &mut [T]) {
224 for i in 1..arr.len() {
225 let value = arr[i];
226 let mut j = i;
227 while j > 0 && arr[j - 1] > value {
228 arr[j] = arr[j - 1];
229 j -= 1;
230 }
231 arr[j] = value;
232 }
233}
234
235#[allow(dead_code)]
236#[inline(always)]
237fn median_of_three<T: SortableKey>(a: T, b: T, c: T) -> T {
238 if a < b {
239 if b < c {
240 b
241 } else if a < c {
242 c
243 } else {
244 a
245 }
246 } else if a < c {
247 a
248 } else if b < c {
249 c
250 } else {
251 b
252 }
253}
254
255fn counting_sort_primitive<T: SortableKey>(arr: &mut [T], min_val: usize, range: usize) {
257 let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
258 let chunk_sz = (arr.len() + num_threads - 1) / num_threads;
259
260 let mut local_counts = vec![vec![0usize; range]; num_threads];
261
262 std::thread::scope(|s| {
264 for (i, count) in local_counts.iter_mut().enumerate() {
265 let start = i * chunk_sz;
266 let end = std::cmp::min(start + chunk_sz, arr.len());
267 if start >= arr.len() { continue; }
268 let slice = &arr[start..end];
269 s.spawn(move || {
270 for item in slice {
271 count[item.extract_key().into() - min_val] += 1;
272 }
273 });
274 }
275 });
276
277 let mut global_count = vec![0usize; range];
279 for counts in &local_counts {
280 for (i, &c) in counts.iter().enumerate() {
281 global_count[i] += c;
282 }
283 }
284
285 let mut bin_offsets = vec![0usize; range];
286 let mut offset = 0;
287 for (i, &c) in global_count.iter().enumerate() {
288 bin_offsets[i] = offset;
289 offset += c;
290 }
291
292 let arr_ptr = arr.as_mut_ptr() as usize;
294 let global_count_ref = &global_count;
295 let bin_offsets_ref = &bin_offsets;
296 let bins_per_thread = (range + num_threads - 1) / num_threads;
297
298 std::thread::scope(|s| {
299 for t in 0..num_threads {
300 s.spawn(move || {
301 let start_bin = t * bins_per_thread;
302 let end_bin = std::cmp::min(start_bin + bins_per_thread, range);
303
304 for i in start_bin..end_bin {
305 let freq = global_count_ref[i];
306 if freq == 0 { continue; }
307
308 let val_k = min_val + i;
309 if let Ok(key) = T::KeyType::try_from(val_k) {
310 let val = T::from_key(key);
311 let target_offset = bin_offsets_ref[i];
312
313 unsafe {
314 let ptr = (arr_ptr as *mut T).add(target_offset);
315 for j in 0..freq {
316 std::ptr::write(ptr.add(j), val);
317 }
318 }
319 }
320 }
321 });
322 }
323 });
324}
325
326fn counting_sort_records<T: SortableKey>(arr: &mut [T], min_val: usize, range: usize) {
328 let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
329 let chunk_sz = (arr.len() + num_threads - 1) / num_threads;
330
331 let mut local_counts = vec![vec![0usize; range]; num_threads];
332
333 std::thread::scope(|s| {
335 for (i, count) in local_counts.iter_mut().enumerate() {
336 let start = i * chunk_sz;
337 let end = std::cmp::min(start + chunk_sz, arr.len());
338 if start >= arr.len() { continue; }
339 let slice = &arr[start..end];
340 s.spawn(move || {
341 for item in slice {
342 count[item.extract_key().into() - min_val] += 1;
343 }
344 });
345 }
346 });
347
348 let mut global_offsets = vec![vec![0usize; range]; num_threads];
350 let mut total_offset = 0;
351 for val in 0..range {
352 for t in 0..num_threads {
353 global_offsets[t][val] = total_offset;
354 total_offset += local_counts[t][val];
355 }
356 }
357
358 let mut buffer = vec![T::default(); arr.len()];
360 let buf_ptr = buffer.as_mut_ptr() as usize;
361 let global_offsets_ref = &global_offsets;
362
363 let arr_ref: &[T] = arr; std::thread::scope(|s| {
366 for t in 0..num_threads {
367 let current_offsets_tpl = global_offsets_ref[t].clone();
368 s.spawn(move || {
369 let mut current_offsets = current_offsets_tpl;
370 let start = t * chunk_sz;
371 let end = std::cmp::min(start + chunk_sz, arr_ref.len());
372 if start >= arr_ref.len() { return; }
373 let slice = &arr_ref[start..end];
374
375 for item in slice {
376 let bucket = item.extract_key().into() - min_val;
377 let target_pos = current_offsets[bucket];
378 unsafe {
379 let ptr = (buf_ptr as *mut T).add(target_pos);
380 std::ptr::write(ptr, *item);
381 }
382 current_offsets[bucket] += 1;
383 }
384 });
385 }
386 });
387
388 arr.copy_from_slice(&buffer);
389}
390
391pub fn overclocked_parallel_sort(input: &[i32], _max_val: usize) -> Vec<i32> {
393 let mut copy = input.to_vec();
394 overclocked_sort(&mut copy);
395 copy
396}
397
398pub fn overclocked_kp_sort(input: &[KeyPtr], _max_val: usize) -> Vec<KeyPtr> {
399 let mut copy = input.to_vec();
400 overclocked_sort(&mut copy);
401 copy
402}