use std::sync::atomic::{AtomicUsize, Ordering};
pub mod key_ptr;
pub use key_ptr::KeyPtr;
pub trait SortableKey: Copy + Default + Ord + Send + Sync {
type KeyType: Copy + Ord + Into<usize> + TryFrom<usize>;
fn extract_key(&self) -> Self::KeyType;
const IS_PRIMITIVE: bool;
fn from_key(k: Self::KeyType) -> Self;
}
impl SortableKey for u32 {
type KeyType = usize;
#[inline(always)]
fn extract_key(&self) -> usize { *self as usize }
const IS_PRIMITIVE: bool = true;
#[inline(always)]
fn from_key(k: usize) -> Self { k as u32 }
}
impl SortableKey for i32 {
type KeyType = usize;
#[inline(always)]
fn extract_key(&self) -> usize { (*self as i64 + 2147483648) as usize }
const IS_PRIMITIVE: bool = true;
#[inline(always)]
fn from_key(k: usize) -> Self { (k as i64 - 2147483648) as i32 }
}
pub fn overclocked_sort<T: SortableKey>(arr: &mut [T]) {
if arr.len() <= 1 { return; }
if arr.windows(2).all(|w| w[0] <= w[1]) {
return;
}
let mut min_val = arr[0].extract_key().into();
let mut max_val = arr[0].extract_key().into();
for item in arr.iter() {
let k = item.extract_key().into();
if k < min_val { min_val = k; }
if k > max_val { max_val = k; }
}
let n = arr.len();
let range = max_val - min_val + 1;
let max_limit = std::cmp::min(n.saturating_mul(2), 50_000_000);
if range <= max_limit {
if T::IS_PRIMITIVE {
counting_sort_primitive(arr, min_val, range);
} else {
counting_sort_records(arr, min_val, range);
}
} else {
let split_bound = min_val + max_limit;
let mut left = 0;
let mut right = n;
while left < right {
if arr[left].extract_key().into() <= split_bound {
left += 1;
} else {
right -= 1;
arr.swap(left, right);
}
}
let (dense_stream, sparse_stream) = arr.split_at_mut(left);
if !dense_stream.is_empty() {
let mut true_dense_max = min_val;
for item in dense_stream.iter() {
let k = item.extract_key().into();
if k > true_dense_max { true_dense_max = k; }
}
let dense_range = true_dense_max - min_val + 1;
if T::IS_PRIMITIVE {
counting_sort_primitive(dense_stream, min_val, dense_range);
} else {
counting_sort_records(dense_stream, min_val, dense_range);
}
}
if !sparse_stream.is_empty() {
sparse_stream.sort_unstable();
}
}
}
fn counting_sort_primitive<T: SortableKey>(arr: &mut [T], min_val: usize, range: usize) {
let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
let chunk_sz = (arr.len() + num_threads - 1) / num_threads;
let mut local_counts = vec![vec![0usize; range]; num_threads];
std::thread::scope(|s| {
for (i, count) in local_counts.iter_mut().enumerate() {
let start = i * chunk_sz;
let end = std::cmp::min(start + chunk_sz, arr.len());
if start >= arr.len() { continue; }
let slice = &arr[start..end];
s.spawn(move || {
for item in slice {
count[item.extract_key().into() - min_val] += 1;
}
});
}
});
let mut global_count = vec![0usize; range];
for counts in &local_counts {
for (i, &c) in counts.iter().enumerate() {
global_count[i] += c;
}
}
let mut bin_offsets = vec![0usize; range];
let mut offset = 0;
for (i, &c) in global_count.iter().enumerate() {
bin_offsets[i] = offset;
offset += c;
}
let arr_ptr = arr.as_mut_ptr() as usize;
let global_count_ref = &global_count;
let bin_offsets_ref = &bin_offsets;
let bins_per_thread = (range + num_threads - 1) / num_threads;
std::thread::scope(|s| {
for t in 0..num_threads {
s.spawn(move || {
let start_bin = t * bins_per_thread;
let end_bin = std::cmp::min(start_bin + bins_per_thread, range);
for i in start_bin..end_bin {
let freq = global_count_ref[i];
if freq == 0 { continue; }
let val_k = min_val + i;
if let Ok(key) = T::KeyType::try_from(val_k) {
let val = T::from_key(key);
let target_offset = bin_offsets_ref[i];
unsafe {
let ptr = (arr_ptr as *mut T).add(target_offset);
for j in 0..freq {
std::ptr::write(ptr.add(j), val);
}
}
}
}
});
}
});
}
fn counting_sort_records<T: SortableKey>(arr: &mut [T], min_val: usize, range: usize) {
let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
let chunk_sz = (arr.len() + num_threads - 1) / num_threads;
let mut local_counts = vec![vec![0usize; range]; num_threads];
std::thread::scope(|s| {
for (i, count) in local_counts.iter_mut().enumerate() {
let start = i * chunk_sz;
let end = std::cmp::min(start + chunk_sz, arr.len());
if start >= arr.len() { continue; }
let slice = &arr[start..end];
s.spawn(move || {
for item in slice {
count[item.extract_key().into() - min_val] += 1;
}
});
}
});
let mut global_offsets = vec![vec![0usize; range]; num_threads];
let mut total_offset = 0;
for val in 0..range {
for t in 0..num_threads {
global_offsets[t][val] = total_offset;
total_offset += local_counts[t][val];
}
}
let mut buffer = vec![T::default(); arr.len()];
let buf_ptr = buffer.as_mut_ptr() as usize;
let global_offsets_ref = &global_offsets;
let arr_ref: &[T] = arr;
std::thread::scope(|s| {
for t in 0..num_threads {
let current_offsets_tpl = global_offsets_ref[t].clone();
s.spawn(move || {
let mut current_offsets = current_offsets_tpl;
let start = t * chunk_sz;
let end = std::cmp::min(start + chunk_sz, arr_ref.len());
if start >= arr_ref.len() { return; }
let slice = &arr_ref[start..end];
for item in slice {
let bucket = item.extract_key().into() - min_val;
let target_pos = current_offsets[bucket];
unsafe {
let ptr = (buf_ptr as *mut T).add(target_pos);
std::ptr::write(ptr, *item);
}
current_offsets[bucket] += 1;
}
});
}
});
arr.copy_from_slice(&buffer);
}
pub fn overclocked_parallel_sort(input: &[i32], _max_val: usize) -> Vec<i32> {
let mut copy = input.to_vec();
overclocked_sort(&mut copy);
copy
}
pub fn overclocked_kp_sort(input: &[KeyPtr], _max_val: usize) -> Vec<KeyPtr> {
let mut copy = input.to_vec();
overclocked_sort(&mut copy);
copy
}