use rayon::prelude::*;
#[derive(Clone, Copy)]
struct SendPtr<T>(*mut T);
unsafe impl<T: Send> Send for SendPtr<T> {}
unsafe impl<T: Sync> Sync for SendPtr<T> {}
impl<T> SendPtr<T> {
#[inline]
fn get(self) -> *mut T {
self.0
}
}
const HYBRID_THRESHOLD: usize = 32768;
const BUCKET_TARGET_SIZE: usize = 512;
const BUCKET_OVERFLOW_FACTOR: usize = 4;
const INSERTION_SORT_THRESHOLD: usize = 32;
pub fn learned_sort<T>(arr: &mut [T])
where
T: Ord + Copy + Send + Sync + Into<i64>,
{
let n = arr.len();
if n <= 1 {
return;
}
if n < HYBRID_THRESHOLD {
arr.sort_unstable();
return;
}
let (min_val, max_val) = sample_minmax(arr);
if min_val == max_val {
return; }
let num_buckets = (n / BUCKET_TARGET_SIZE).max(1);
let mut counts = count_buckets(arr, min_val, max_val, num_buckets);
let offsets = prefix_sum(&counts);
let mut aux = vec![arr[0]; n]; scatter(arr, &mut aux, &mut counts, &offsets, min_val, max_val, num_buckets);
refine_buckets(&mut aux, &offsets, num_buckets, n);
arr.copy_from_slice(&aux);
}
pub fn learned_sort_inplace<T>(arr: &mut [T])
where
T: Ord + Copy + Send + Sync + Into<i64>,
{
let n = arr.len();
if n <= 1 {
return;
}
if n < HYBRID_THRESHOLD {
arr.sort_unstable();
return;
}
let (min_val, max_val) = sample_minmax(arr);
if min_val == max_val {
return; }
let num_buckets = (n / BUCKET_TARGET_SIZE).max(1);
let counts = count_buckets(arr, min_val, max_val, num_buckets);
let offsets = prefix_sum(&counts);
scatter_inplace(arr, &offsets, min_val, max_val, num_buckets);
refine_buckets(arr, &offsets, num_buckets, n);
}
#[inline]
fn compute_bucket(val: i64, min_val: i64, scale: f64, num_buckets: usize) -> usize {
let idx = ((val - min_val) as f64 * scale) as usize;
idx.min(num_buckets - 1)
}
fn scatter_inplace<T>(
arr: &mut [T],
offsets: &[usize],
min_val: i64,
max_val: i64,
num_buckets: usize,
) where
T: Copy + Into<i64>,
{
let range = (max_val - min_val) as f64;
let scale = (num_buckets as f64 - 0.001) / range;
let mut write_cursors: Vec<usize> = offsets[..num_buckets].to_vec();
for bucket in 0..num_buckets {
let bucket_start = offsets[bucket];
let bucket_end = offsets[bucket + 1];
let mut pos = bucket_start;
while pos < bucket_end {
let current_val: i64 = arr[pos].into();
let target_bucket = compute_bucket(current_val, min_val, scale, num_buckets);
if target_bucket == bucket {
if write_cursors[bucket] <= pos {
write_cursors[bucket] = pos + 1;
}
pos += 1;
continue;
}
let mut current = arr[pos];
let mut current_bucket = target_bucket;
loop {
let dest_pos = write_cursors[current_bucket];
write_cursors[current_bucket] += 1;
let next = arr[dest_pos];
arr[dest_pos] = current;
let next_bucket = compute_bucket(next.into(), min_val, scale, num_buckets);
if next_bucket == bucket {
arr[pos] = next;
break;
}
current = next;
current_bucket = next_bucket;
}
if write_cursors[bucket] <= pos {
write_cursors[bucket] = pos + 1;
}
pos += 1;
}
}
}
#[inline]
fn sample_minmax<T>(arr: &[T]) -> (i64, i64)
where
T: Ord + Copy + Into<i64>,
{
let mut min_val = arr[0].into();
let mut max_val = arr[0].into();
for &item in arr.iter() {
let val: i64 = item.into();
if val < min_val {
min_val = val;
}
if val > max_val {
max_val = val;
}
}
(min_val, max_val)
}
#[inline]
fn count_buckets<T>(arr: &[T], min_val: i64, max_val: i64, num_buckets: usize) -> Vec<usize>
where
T: Copy + Into<i64>,
{
let mut counts = vec![0usize; num_buckets];
let range = (max_val - min_val) as f64;
let scale = (num_buckets as f64 - 0.001) / range;
for &item in arr.iter() {
let val: i64 = item.into();
let bucket_idx = ((val - min_val) as f64 * scale) as usize;
let bucket_idx = bucket_idx.min(num_buckets - 1); counts[bucket_idx] += 1;
}
counts
}
#[inline]
fn prefix_sum(counts: &[usize]) -> Vec<usize> {
let mut offsets = Vec::with_capacity(counts.len() + 1);
let mut sum = 0;
for &count in counts.iter() {
offsets.push(sum);
sum += count;
}
offsets.push(sum);
offsets
}
#[inline]
fn scatter<T>(
src: &[T],
aux: &mut [T],
counts: &mut [usize],
offsets: &[usize],
min_val: i64,
max_val: i64,
num_buckets: usize,
) where
T: Copy + Into<i64>,
{
for (i, count) in counts.iter_mut().enumerate() {
*count = offsets[i];
}
let range = (max_val - min_val) as f64;
let scale = (num_buckets as f64 - 0.001) / range;
for &item in src.iter() {
let val: i64 = item.into();
let bucket_idx = ((val - min_val) as f64 * scale) as usize;
let bucket_idx = bucket_idx.min(num_buckets - 1);
let write_pos = counts[bucket_idx];
counts[bucket_idx] += 1;
unsafe {
*aux.get_unchecked_mut(write_pos) = item;
}
}
}
fn refine_buckets<T>(aux: &mut [T], offsets: &[usize], num_buckets: usize, total_len: usize)
where
T: Ord + Copy + Send + Sync,
{
let expected_bucket_size = total_len / num_buckets;
let overflow_threshold = expected_bucket_size * BUCKET_OVERFLOW_FACTOR;
let ptr = SendPtr(aux.as_mut_ptr());
let bucket_ranges: Vec<(usize, usize)> = (0..num_buckets)
.map(|i| (offsets[i], offsets[i + 1]))
.collect();
bucket_ranges.par_iter().for_each(move |&(start, end)| {
let bucket_len = end - start;
if bucket_len <= 1 {
return;
}
let bucket_slice = unsafe { std::slice::from_raw_parts_mut(ptr.get().add(start), bucket_len) };
if bucket_len < INSERTION_SORT_THRESHOLD {
insertion_sort(bucket_slice);
} else if bucket_len > overflow_threshold {
bucket_slice.sort_unstable();
} else {
bucket_slice.sort_unstable();
}
});
}
#[inline]
fn insertion_sort<T: Ord + Copy>(arr: &mut [T]) {
for i in 1..arr.len() {
let key = arr[i];
let mut j = i;
while j > 0 && arr[j - 1] > key {
arr[j] = arr[j - 1];
j -= 1;
}
arr[j] = key;
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
#[test]
fn test_empty_slice() {
let mut data: Vec<i64> = vec![];
learned_sort(&mut data);
assert!(data.is_empty());
}
#[test]
fn test_single_element() {
let mut data = vec![42i64];
learned_sort(&mut data);
assert_eq!(data, vec![42]);
}
#[test]
fn test_two_elements() {
let mut data = vec![5i64, 3];
learned_sort(&mut data);
assert_eq!(data, vec![3, 5]);
}
#[test]
fn test_small_array_uses_fallback() {
let mut data: Vec<i64> = (0..100).rev().collect();
learned_sort(&mut data);
assert_eq!(data, (0..100).collect::<Vec<_>>());
}
#[test]
fn test_medium_array() {
let mut data: Vec<i64> = (0..1000).rev().collect();
learned_sort(&mut data);
assert_eq!(data, (0..1000).collect::<Vec<_>>());
}
#[test]
fn test_large_uniform_distribution() {
let mut rng = rand::thread_rng();
let mut data: Vec<i64> = (0..100_000).map(|_| rng.gen_range(0..1_000_000)).collect();
let mut expected = data.clone();
expected.sort_unstable();
learned_sort(&mut data);
assert_eq!(data, expected);
}
#[test]
fn test_sorted_input() {
let mut data: Vec<i64> = (0..10_000).collect();
let expected = data.clone();
learned_sort(&mut data);
assert_eq!(data, expected);
}
#[test]
fn test_reverse_sorted() {
let mut data: Vec<i64> = (0..10_000).rev().collect();
let expected: Vec<i64> = (0..10_000).collect();
learned_sort(&mut data);
assert_eq!(data, expected);
}
#[test]
fn test_duplicates() {
let mut data: Vec<i64> = vec![5; 10_000];
let expected = data.clone();
learned_sort(&mut data);
assert_eq!(data, expected);
}
#[test]
fn test_many_duplicates() {
let mut rng = rand::thread_rng();
let mut data: Vec<i64> = (0..10_000).map(|_| rng.gen_range(0..10)).collect();
let mut expected = data.clone();
expected.sort_unstable();
learned_sort(&mut data);
assert_eq!(data, expected);
}
#[test]
fn test_negative_numbers() {
let mut rng = rand::thread_rng();
let mut data: Vec<i64> = (0..10_000).map(|_| rng.gen_range(-500_000..500_000)).collect();
let mut expected = data.clone();
expected.sort_unstable();
learned_sort(&mut data);
assert_eq!(data, expected);
}
#[test]
fn test_i32_type() {
let mut rng = rand::thread_rng();
let mut data: Vec<i32> = (0..10_000).map(|_| rng.gen_range(0..1_000_000)).collect();
let mut expected = data.clone();
expected.sort_unstable();
learned_sort(&mut data);
assert_eq!(data, expected);
}
#[test]
fn test_inplace_empty_slice() {
let mut data: Vec<i64> = vec![];
learned_sort_inplace(&mut data);
assert!(data.is_empty());
}
#[test]
fn test_inplace_single_element() {
let mut data = vec![42i64];
learned_sort_inplace(&mut data);
assert_eq!(data, vec![42]);
}
#[test]
fn test_inplace_small_array() {
let mut data: Vec<i64> = (0..100).rev().collect();
learned_sort_inplace(&mut data);
assert_eq!(data, (0..100).collect::<Vec<_>>());
}
#[test]
fn test_inplace_large_uniform() {
let mut rng = rand::thread_rng();
let mut data: Vec<i64> = (0..100_000).map(|_| rng.gen_range(0..1_000_000)).collect();
let mut expected = data.clone();
expected.sort_unstable();
learned_sort_inplace(&mut data);
assert_eq!(data, expected);
}
#[test]
fn test_inplace_duplicates() {
let mut data: Vec<i64> = vec![5; 50_000];
let expected = data.clone();
learned_sort_inplace(&mut data);
assert_eq!(data, expected);
}
#[test]
fn test_inplace_many_duplicates() {
let mut rng = rand::thread_rng();
let mut data: Vec<i64> = (0..50_000).map(|_| rng.gen_range(0..10)).collect();
let mut expected = data.clone();
expected.sort_unstable();
learned_sort_inplace(&mut data);
assert_eq!(data, expected);
}
#[test]
fn test_inplace_negative_numbers() {
let mut rng = rand::thread_rng();
let mut data: Vec<i64> = (0..50_000).map(|_| rng.gen_range(-500_000..500_000)).collect();
let mut expected = data.clone();
expected.sort_unstable();
learned_sort_inplace(&mut data);
assert_eq!(data, expected);
}
#[test]
fn test_inplace_matches_regular() {
let mut rng = rand::thread_rng();
let original: Vec<i64> = (0..100_000).map(|_| rng.gen_range(0..1_000_000)).collect();
let mut data_regular = original.clone();
let mut data_inplace = original.clone();
learned_sort(&mut data_regular);
learned_sort_inplace(&mut data_inplace);
assert_eq!(data_regular, data_inplace);
}
}