pub mod key_ptr;
pub use key_ptr::KeyPtr;
pub trait SortableKey: Copy + Default + Ord + Send + Sync {
type KeyType: Copy + Ord + Into<usize> + TryFrom<usize>;
fn extract_key(&self) -> Self::KeyType;
const IS_PRIMITIVE: bool;
fn from_key(k: Self::KeyType) -> Self;
}
impl SortableKey for u32 {
type KeyType = usize;
#[inline(always)]
fn extract_key(&self) -> usize { *self as usize }
const IS_PRIMITIVE: bool = true;
#[inline(always)]
fn from_key(k: usize) -> Self { k as u32 }
}
impl SortableKey for i32 {
type KeyType = usize;
#[inline(always)]
fn extract_key(&self) -> usize { (*self as i64 + 2147483648) as usize }
const IS_PRIMITIVE: bool = true;
#[inline(always)]
fn from_key(k: usize) -> Self { (k as i64 - 2147483648) as i32 }
}
pub fn overclocked_sort<T: SortableKey>(arr: &mut [T]) {
if arr.len() <= 1 { return; }
let n = arr.len();
let probe_points = std::cmp::min(n, 1024);
if probe_points >= 64 {
let stride = (n - 1) / (probe_points - 1);
let mut prev_probe = arr[0].extract_key().into();
let mut probe_desc_breaks = 0usize;
for i in 1..probe_points {
let idx = i * stride;
let k = arr[idx].extract_key().into();
if k < prev_probe {
probe_desc_breaks += 1;
}
prev_probe = k;
}
if probe_desc_breaks == 0 {
arr.sort_unstable();
return;
}
if probe_desc_breaks >= probe_points - 2 {
arr.sort_unstable();
return;
}
}
let first_key = arr[0].extract_key().into();
let mut prev_key = first_key;
let mut min_val = first_key;
let mut max_val = first_key;
let mut non_decreasing = true;
let mut non_increasing = true;
let mut desc_breaks = 0usize;
for item in arr.iter().skip(1) {
let k = item.extract_key().into();
if k < prev_key {
non_decreasing = false;
desc_breaks += 1;
}
if k > prev_key {
non_increasing = false;
}
prev_key = k;
if k < min_val { min_val = k; }
if k > max_val { max_val = k; }
}
if non_decreasing {
return;
}
if non_increasing {
arr.reverse();
return;
}
let range = max_val - min_val + 1;
let max_limit = std::cmp::min(1_000_000, std::cmp::max(1, n / 4));
let near_sorted_threshold = std::cmp::max(1, n / 16);
if range > max_limit && (desc_breaks <= near_sorted_threshold || desc_breaks >= (n - 1).saturating_sub(near_sorted_threshold)) {
arr.sort_unstable();
return;
}
if range <= max_limit {
if T::IS_PRIMITIVE {
counting_sort_primitive(arr, min_val, range);
} else {
counting_sort_records(arr, min_val, range);
}
} else {
fallback_sort(arr);
}
}
fn fallback_sort<T: SortableKey>(arr: &mut [T]) {
if arr.len() <= 1 {
return;
}
if arr.len() <= 32 {
insertion_sort(arr);
return;
}
radix_sort_by_key(arr);
}
fn radix_sort_by_key<T: SortableKey>(arr: &mut [T]) {
let len = arr.len();
let mut current = arr.to_vec();
let mut next = vec![T::default(); len];
let passes = std::mem::size_of::<usize>();
for pass in 0..passes {
let shift = pass * 8;
let mut counts = [0usize; 256];
for item in current.iter() {
let bucket = (item.extract_key().into() >> shift) & 0xFF;
counts[bucket] += 1;
}
let mut offset = 0usize;
for count in counts.iter_mut() {
let current_count = *count;
*count = offset;
offset += current_count;
}
for item in current.iter() {
let bucket = (item.extract_key().into() >> shift) & 0xFF;
next[counts[bucket]] = *item;
counts[bucket] += 1;
}
std::mem::swap(&mut current, &mut next);
}
arr.copy_from_slice(¤t);
}
#[allow(dead_code)]
fn quicksort_recursive<T: SortableKey>(arr: &mut [T]) {
const INSERTION_THRESHOLD: usize = 24;
if arr.len() <= INSERTION_THRESHOLD {
insertion_sort(arr);
return;
}
let len = arr.len();
let mid = len / 2;
let pivot = median_of_three(arr[0], arr[mid], arr[len - 1]);
let mut left = 0usize;
let mut right = len - 1;
loop {
while arr[left] < pivot {
left += 1;
}
while arr[right] > pivot {
if right == 0 {
break;
}
right -= 1;
}
if left >= right {
break;
}
arr.swap(left, right);
left += 1;
if right == 0 {
break;
}
right -= 1;
}
let split_index = right + 1;
let (lo, hi) = arr.split_at_mut(split_index);
if lo.len() < hi.len() {
if !lo.is_empty() {
quicksort_recursive(lo);
}
if !hi.is_empty() {
quicksort_recursive(hi);
}
} else {
if !hi.is_empty() {
quicksort_recursive(hi);
}
if !lo.is_empty() {
quicksort_recursive(lo);
}
}
}
#[allow(dead_code)]
fn insertion_sort<T: SortableKey>(arr: &mut [T]) {
for i in 1..arr.len() {
let value = arr[i];
let mut j = i;
while j > 0 && arr[j - 1] > value {
arr[j] = arr[j - 1];
j -= 1;
}
arr[j] = value;
}
}
#[allow(dead_code)]
#[inline(always)]
fn median_of_three<T: SortableKey>(a: T, b: T, c: T) -> T {
if a < b {
if b < c {
b
} else if a < c {
c
} else {
a
}
} else if a < c {
a
} else if b < c {
c
} else {
b
}
}
fn counting_sort_primitive<T: SortableKey>(arr: &mut [T], min_val: usize, range: usize) {
let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
let chunk_sz = (arr.len() + num_threads - 1) / num_threads;
let mut local_counts = vec![vec![0usize; range]; num_threads];
std::thread::scope(|s| {
for (i, count) in local_counts.iter_mut().enumerate() {
let start = i * chunk_sz;
let end = std::cmp::min(start + chunk_sz, arr.len());
if start >= arr.len() { continue; }
let slice = &arr[start..end];
s.spawn(move || {
for item in slice {
count[item.extract_key().into() - min_val] += 1;
}
});
}
});
let mut global_count = vec![0usize; range];
for counts in &local_counts {
for (i, &c) in counts.iter().enumerate() {
global_count[i] += c;
}
}
let mut bin_offsets = vec![0usize; range];
let mut offset = 0;
for (i, &c) in global_count.iter().enumerate() {
bin_offsets[i] = offset;
offset += c;
}
let arr_ptr = arr.as_mut_ptr() as usize;
let global_count_ref = &global_count;
let bin_offsets_ref = &bin_offsets;
let bins_per_thread = (range + num_threads - 1) / num_threads;
std::thread::scope(|s| {
for t in 0..num_threads {
s.spawn(move || {
let start_bin = t * bins_per_thread;
let end_bin = std::cmp::min(start_bin + bins_per_thread, range);
for i in start_bin..end_bin {
let freq = global_count_ref[i];
if freq == 0 { continue; }
let val_k = min_val + i;
if let Ok(key) = T::KeyType::try_from(val_k) {
let val = T::from_key(key);
let target_offset = bin_offsets_ref[i];
unsafe {
let ptr = (arr_ptr as *mut T).add(target_offset);
for j in 0..freq {
std::ptr::write(ptr.add(j), val);
}
}
}
}
});
}
});
}
fn counting_sort_records<T: SortableKey>(arr: &mut [T], min_val: usize, range: usize) {
let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
let chunk_sz = (arr.len() + num_threads - 1) / num_threads;
let mut local_counts = vec![vec![0usize; range]; num_threads];
std::thread::scope(|s| {
for (i, count) in local_counts.iter_mut().enumerate() {
let start = i * chunk_sz;
let end = std::cmp::min(start + chunk_sz, arr.len());
if start >= arr.len() { continue; }
let slice = &arr[start..end];
s.spawn(move || {
for item in slice {
count[item.extract_key().into() - min_val] += 1;
}
});
}
});
let mut global_offsets = vec![vec![0usize; range]; num_threads];
let mut total_offset = 0;
for val in 0..range {
for t in 0..num_threads {
global_offsets[t][val] = total_offset;
total_offset += local_counts[t][val];
}
}
let mut buffer = vec![T::default(); arr.len()];
let buf_ptr = buffer.as_mut_ptr() as usize;
let global_offsets_ref = &global_offsets;
let arr_ref: &[T] = arr;
std::thread::scope(|s| {
for t in 0..num_threads {
let current_offsets_tpl = global_offsets_ref[t].clone();
s.spawn(move || {
let mut current_offsets = current_offsets_tpl;
let start = t * chunk_sz;
let end = std::cmp::min(start + chunk_sz, arr_ref.len());
if start >= arr_ref.len() { return; }
let slice = &arr_ref[start..end];
for item in slice {
let bucket = item.extract_key().into() - min_val;
let target_pos = current_offsets[bucket];
unsafe {
let ptr = (buf_ptr as *mut T).add(target_pos);
std::ptr::write(ptr, *item);
}
current_offsets[bucket] += 1;
}
});
}
});
arr.copy_from_slice(&buffer);
}
pub fn overclocked_parallel_sort(input: &[i32], _max_val: usize) -> Vec<i32> {
let mut copy = input.to_vec();
overclocked_sort(&mut copy);
copy
}
pub fn overclocked_kp_sort(input: &[KeyPtr], _max_val: usize) -> Vec<KeyPtr> {
let mut copy = input.to_vec();
overclocked_sort(&mut copy);
copy
}