1use rayon::prelude::*;
9
10#[derive(Clone, Copy)]
13struct SendPtr<T>(*mut T);
14unsafe impl<T: Send> Send for SendPtr<T> {}
15unsafe impl<T: Sync> Sync for SendPtr<T> {}
16
17impl<T> SendPtr<T> {
18 #[inline]
20 fn get(self) -> *mut T {
21 self.0
22 }
23}
24
25const HYBRID_THRESHOLD: usize = 32768;
28
29const BUCKET_TARGET_SIZE: usize = 512;
31
32const BUCKET_OVERFLOW_FACTOR: usize = 4;
35
36const INSERTION_SORT_THRESHOLD: usize = 32;
38
39pub fn learned_sort<T>(arr: &mut [T])
62where
63 T: Ord + Copy + Send + Sync + Into<i64>,
64{
65 let n = arr.len();
66
67 if n <= 1 {
69 return;
70 }
71
72 if n < HYBRID_THRESHOLD {
74 arr.sort_unstable();
75 return;
76 }
77
78 let (min_val, max_val) = sample_minmax(arr);
80
81 if min_val == max_val {
83 return; }
85
86 let num_buckets = (n / BUCKET_TARGET_SIZE).max(1);
88
89 let mut counts = count_buckets(arr, min_val, max_val, num_buckets);
91
92 let offsets = prefix_sum(&counts);
94
95 let mut aux = vec![arr[0]; n]; scatter(arr, &mut aux, &mut counts, &offsets, min_val, max_val, num_buckets);
98
99 refine_buckets(&mut aux, &offsets, num_buckets, n);
101
102 arr.copy_from_slice(&aux);
104}
105
106pub fn learned_sort_inplace<T>(arr: &mut [T])
149where
150 T: Ord + Copy + Send + Sync + Into<i64>,
151{
152 let n = arr.len();
153
154 if n <= 1 {
156 return;
157 }
158
159 if n < HYBRID_THRESHOLD {
161 arr.sort_unstable();
162 return;
163 }
164
165 let (min_val, max_val) = sample_minmax(arr);
167
168 if min_val == max_val {
170 return; }
172
173 let num_buckets = (n / BUCKET_TARGET_SIZE).max(1);
175
176 let counts = count_buckets(arr, min_val, max_val, num_buckets);
178
179 let offsets = prefix_sum(&counts);
181
182 scatter_inplace(arr, &offsets, min_val, max_val, num_buckets);
184
185 refine_buckets(arr, &offsets, num_buckets, n);
187}
188
189#[inline]
191fn compute_bucket(val: i64, min_val: i64, scale: f64, num_buckets: usize) -> usize {
192 let idx = ((val - min_val) as f64 * scale) as usize;
193 idx.min(num_buckets - 1)
194}
195
196fn scatter_inplace<T>(
201 arr: &mut [T],
202 offsets: &[usize],
203 min_val: i64,
204 max_val: i64,
205 num_buckets: usize,
206) where
207 T: Copy + Into<i64>,
208{
209 let range = (max_val - min_val) as f64;
210 let scale = (num_buckets as f64 - 0.001) / range;
211
212 let mut write_cursors: Vec<usize> = offsets[..num_buckets].to_vec();
214
215 for bucket in 0..num_buckets {
217 let bucket_start = offsets[bucket];
218 let bucket_end = offsets[bucket + 1];
219
220 let mut pos = bucket_start;
222 while pos < bucket_end {
223 let current_val: i64 = arr[pos].into();
224 let target_bucket = compute_bucket(current_val, min_val, scale, num_buckets);
225
226 if target_bucket == bucket {
227 if write_cursors[bucket] <= pos {
230 write_cursors[bucket] = pos + 1;
231 }
232 pos += 1;
233 continue;
234 }
235
236 let mut current = arr[pos];
238 let mut current_bucket = target_bucket;
239
240 loop {
241 let dest_pos = write_cursors[current_bucket];
243
244 write_cursors[current_bucket] += 1;
246
247 let next = arr[dest_pos];
249 arr[dest_pos] = current;
250
251 let next_bucket = compute_bucket(next.into(), min_val, scale, num_buckets);
252
253 if next_bucket == bucket {
255 arr[pos] = next;
257 break;
258 }
259
260 current = next;
261 current_bucket = next_bucket;
262 }
263
264 if write_cursors[bucket] <= pos {
266 write_cursors[bucket] = pos + 1;
267 }
268 pos += 1;
269 }
270 }
271}
272
273#[inline]
276fn sample_minmax<T>(arr: &[T]) -> (i64, i64)
277where
278 T: Ord + Copy + Into<i64>,
279{
280 let mut min_val = arr[0].into();
281 let mut max_val = arr[0].into();
282
283 for &item in arr.iter() {
284 let val: i64 = item.into();
285 if val < min_val {
286 min_val = val;
287 }
288 if val > max_val {
289 max_val = val;
290 }
291 }
292
293 (min_val, max_val)
294}
295
296#[inline]
299fn count_buckets<T>(arr: &[T], min_val: i64, max_val: i64, num_buckets: usize) -> Vec<usize>
300where
301 T: Copy + Into<i64>,
302{
303 let mut counts = vec![0usize; num_buckets];
304 let range = (max_val - min_val) as f64;
305 let scale = (num_buckets as f64 - 0.001) / range; for &item in arr.iter() {
308 let val: i64 = item.into();
309 let bucket_idx = ((val - min_val) as f64 * scale) as usize;
310 let bucket_idx = bucket_idx.min(num_buckets - 1); counts[bucket_idx] += 1;
312 }
313
314 counts
315}
316
317#[inline]
320fn prefix_sum(counts: &[usize]) -> Vec<usize> {
321 let mut offsets = Vec::with_capacity(counts.len() + 1);
322 let mut sum = 0;
323
324 for &count in counts.iter() {
325 offsets.push(sum);
326 sum += count;
327 }
328 offsets.push(sum); offsets
331}
332
333#[inline]
339fn scatter<T>(
340 src: &[T],
341 aux: &mut [T],
342 counts: &mut [usize],
343 offsets: &[usize],
344 min_val: i64,
345 max_val: i64,
346 num_buckets: usize,
347) where
348 T: Copy + Into<i64>,
349{
350 for (i, count) in counts.iter_mut().enumerate() {
352 *count = offsets[i];
353 }
354
355 let range = (max_val - min_val) as f64;
356 let scale = (num_buckets as f64 - 0.001) / range;
357
358 for &item in src.iter() {
359 let val: i64 = item.into();
360 let bucket_idx = ((val - min_val) as f64 * scale) as usize;
361 let bucket_idx = bucket_idx.min(num_buckets - 1);
362
363 let write_pos = counts[bucket_idx];
364 counts[bucket_idx] += 1;
365
366 unsafe {
370 *aux.get_unchecked_mut(write_pos) = item;
371 }
372 }
373}
374
375fn refine_buckets<T>(aux: &mut [T], offsets: &[usize], num_buckets: usize, total_len: usize)
378where
379 T: Ord + Copy + Send + Sync,
380{
381 let expected_bucket_size = total_len / num_buckets;
382 let overflow_threshold = expected_bucket_size * BUCKET_OVERFLOW_FACTOR;
383
384 let ptr = SendPtr(aux.as_mut_ptr());
385
386 let bucket_ranges: Vec<(usize, usize)> = (0..num_buckets)
388 .map(|i| (offsets[i], offsets[i + 1]))
389 .collect();
390
391 bucket_ranges.par_iter().for_each(move |&(start, end)| {
393 let bucket_len = end - start;
394 if bucket_len <= 1 {
395 return;
396 }
397
398 let bucket_slice = unsafe { std::slice::from_raw_parts_mut(ptr.get().add(start), bucket_len) };
400
401 if bucket_len < INSERTION_SORT_THRESHOLD {
402 insertion_sort(bucket_slice);
403 } else if bucket_len > overflow_threshold {
404 bucket_slice.sort_unstable();
406 } else {
407 bucket_slice.sort_unstable();
409 }
410 });
411}
412
413#[inline]
416fn insertion_sort<T: Ord + Copy>(arr: &mut [T]) {
417 for i in 1..arr.len() {
418 let key = arr[i];
419 let mut j = i;
420 while j > 0 && arr[j - 1] > key {
421 arr[j] = arr[j - 1];
422 j -= 1;
423 }
424 arr[j] = key;
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use rand::prelude::*;
432
433 #[test]
434 fn test_empty_slice() {
435 let mut data: Vec<i64> = vec![];
436 learned_sort(&mut data);
437 assert!(data.is_empty());
438 }
439
440 #[test]
441 fn test_single_element() {
442 let mut data = vec![42i64];
443 learned_sort(&mut data);
444 assert_eq!(data, vec![42]);
445 }
446
447 #[test]
448 fn test_two_elements() {
449 let mut data = vec![5i64, 3];
450 learned_sort(&mut data);
451 assert_eq!(data, vec![3, 5]);
452 }
453
454 #[test]
455 fn test_small_array_uses_fallback() {
456 let mut data: Vec<i64> = (0..100).rev().collect();
457 learned_sort(&mut data);
458 assert_eq!(data, (0..100).collect::<Vec<_>>());
459 }
460
461 #[test]
462 fn test_medium_array() {
463 let mut data: Vec<i64> = (0..1000).rev().collect();
464 learned_sort(&mut data);
465 assert_eq!(data, (0..1000).collect::<Vec<_>>());
466 }
467
468 #[test]
469 fn test_large_uniform_distribution() {
470 let mut rng = rand::thread_rng();
471 let mut data: Vec<i64> = (0..100_000).map(|_| rng.gen_range(0..1_000_000)).collect();
472 let mut expected = data.clone();
473 expected.sort_unstable();
474
475 learned_sort(&mut data);
476 assert_eq!(data, expected);
477 }
478
479 #[test]
480 fn test_sorted_input() {
481 let mut data: Vec<i64> = (0..10_000).collect();
482 let expected = data.clone();
483 learned_sort(&mut data);
484 assert_eq!(data, expected);
485 }
486
487 #[test]
488 fn test_reverse_sorted() {
489 let mut data: Vec<i64> = (0..10_000).rev().collect();
490 let expected: Vec<i64> = (0..10_000).collect();
491 learned_sort(&mut data);
492 assert_eq!(data, expected);
493 }
494
495 #[test]
496 fn test_duplicates() {
497 let mut data: Vec<i64> = vec![5; 10_000];
498 let expected = data.clone();
499 learned_sort(&mut data);
500 assert_eq!(data, expected);
501 }
502
503 #[test]
504 fn test_many_duplicates() {
505 let mut rng = rand::thread_rng();
506 let mut data: Vec<i64> = (0..10_000).map(|_| rng.gen_range(0..10)).collect();
507 let mut expected = data.clone();
508 expected.sort_unstable();
509
510 learned_sort(&mut data);
511 assert_eq!(data, expected);
512 }
513
514 #[test]
515 fn test_negative_numbers() {
516 let mut rng = rand::thread_rng();
517 let mut data: Vec<i64> = (0..10_000).map(|_| rng.gen_range(-500_000..500_000)).collect();
518 let mut expected = data.clone();
519 expected.sort_unstable();
520
521 learned_sort(&mut data);
522 assert_eq!(data, expected);
523 }
524
525 #[test]
526 fn test_i32_type() {
527 let mut rng = rand::thread_rng();
528 let mut data: Vec<i32> = (0..10_000).map(|_| rng.gen_range(0..1_000_000)).collect();
529 let mut expected = data.clone();
530 expected.sort_unstable();
531
532 learned_sort(&mut data);
533 assert_eq!(data, expected);
534 }
535
536 #[test]
539 fn test_inplace_empty_slice() {
540 let mut data: Vec<i64> = vec![];
541 learned_sort_inplace(&mut data);
542 assert!(data.is_empty());
543 }
544
545 #[test]
546 fn test_inplace_single_element() {
547 let mut data = vec![42i64];
548 learned_sort_inplace(&mut data);
549 assert_eq!(data, vec![42]);
550 }
551
552 #[test]
553 fn test_inplace_small_array() {
554 let mut data: Vec<i64> = (0..100).rev().collect();
555 learned_sort_inplace(&mut data);
556 assert_eq!(data, (0..100).collect::<Vec<_>>());
557 }
558
559 #[test]
560 fn test_inplace_large_uniform() {
561 let mut rng = rand::thread_rng();
562 let mut data: Vec<i64> = (0..100_000).map(|_| rng.gen_range(0..1_000_000)).collect();
563 let mut expected = data.clone();
564 expected.sort_unstable();
565
566 learned_sort_inplace(&mut data);
567 assert_eq!(data, expected);
568 }
569
570 #[test]
571 fn test_inplace_duplicates() {
572 let mut data: Vec<i64> = vec![5; 50_000];
573 let expected = data.clone();
574 learned_sort_inplace(&mut data);
575 assert_eq!(data, expected);
576 }
577
578 #[test]
579 fn test_inplace_many_duplicates() {
580 let mut rng = rand::thread_rng();
581 let mut data: Vec<i64> = (0..50_000).map(|_| rng.gen_range(0..10)).collect();
582 let mut expected = data.clone();
583 expected.sort_unstable();
584
585 learned_sort_inplace(&mut data);
586 assert_eq!(data, expected);
587 }
588
589 #[test]
590 fn test_inplace_negative_numbers() {
591 let mut rng = rand::thread_rng();
592 let mut data: Vec<i64> = (0..50_000).map(|_| rng.gen_range(-500_000..500_000)).collect();
593 let mut expected = data.clone();
594 expected.sort_unstable();
595
596 learned_sort_inplace(&mut data);
597 assert_eq!(data, expected);
598 }
599
600 #[test]
601 fn test_inplace_matches_regular() {
602 let mut rng = rand::thread_rng();
603 let original: Vec<i64> = (0..100_000).map(|_| rng.gen_range(0..1_000_000)).collect();
604
605 let mut data_regular = original.clone();
606 let mut data_inplace = original.clone();
607
608 learned_sort(&mut data_regular);
609 learned_sort_inplace(&mut data_inplace);
610
611 assert_eq!(data_regular, data_inplace);
612 }
613}