use std::time::Instant;
struct Lcg {
state: u64,
}
impl Lcg {
fn next(&mut self) -> i32 {
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
(self.state >> 32) as i32
}
}
pub fn fallback_unstable_sort(arr: &mut [(i32, u32)]) {
arr.sort_unstable_by_key(|&(k, _)| k);
}
const INSERTION_SORT_THRESHOLD: usize = 32;
pub fn wakesort_oop(arr: &mut [(i32, u32)]) {
let len = arr.len();
if len <= 1 {
return;
}
let mut temp = vec![(0, 0); len];
wakesort_recursive_oop(arr, &mut temp, 0, len - 1, 3);
}
fn insertion_sort(arr: &mut [(i32, u32)], left: usize, right: usize) {
for i in (left + 1)..=right {
let key = arr[i];
let mut j = i;
while j > left && arr[j - 1].0 > key.0 {
arr[j] = arr[j - 1];
j -= 1;
}
arr[j] = key;
}
}
fn get_byte(val: i32, byte_pos: i32) -> usize {
let mut uval = val as u32;
if byte_pos == 3 {
uval ^= 0x80000000;
}
((uval >> (byte_pos * 8)) & 0xFF) as usize
}
fn wakesort_recursive_oop(arr: &mut [(i32, u32)], temp: &mut [(i32, u32)], left: usize, right: usize, byte_pos: i32) {
if right - left < INSERTION_SORT_THRESHOLD {
insertion_sort(arr, left, right);
return;
}
if byte_pos < 0 {
return;
}
let mut counts = [0usize; 256];
for i in left..=right {
let b = get_byte(arr[i].0, byte_pos);
counts[b] += 1;
}
let mut offsets = [0usize; 256];
let mut sum = left;
for i in 0..256 {
offsets[i] = sum;
sum += counts[i];
}
let offsets_copy = offsets;
for i in left..=right {
let b = get_byte(arr[i].0, byte_pos);
let pos = offsets[b];
temp[pos - left] = arr[i];
offsets[b] += 1;
}
arr[left..=right].copy_from_slice(&temp[0..right - left + 1]);
for b in 0..256 {
let bucket_size = counts[b];
if bucket_size > 1 {
let current_start = offsets_copy[b];
wakesort_recursive_oop(arr, temp, current_start, current_start + bucket_size - 1, byte_pos - 1);
}
}
}
pub fn wakesort_inplace(arr: &mut [(i32, u32)]) {
let len = arr.len();
if len <= 1 {
return;
}
wakesort_recursive_inplace(arr, 0, len - 1, 3);
}
fn wakesort_recursive_inplace(arr: &mut [(i32, u32)], left: usize, right: usize, byte_pos: i32) {
if right - left < INSERTION_SORT_THRESHOLD {
insertion_sort(arr, left, right);
return;
}
if byte_pos < 0 {
return;
}
let mut counts = [0usize; 256];
for i in left..=right {
let b = get_byte(arr[i].0, byte_pos);
counts[b] += 1;
}
let mut offsets = [0usize; 256];
let mut ends = [0usize; 256];
let mut sum = left;
for i in 0..256 {
offsets[i] = sum;
sum += counts[i];
ends[i] = sum;
}
let starts = offsets;
for b in 0..256 {
while offsets[b] < ends[b] {
let mut val_b = get_byte(arr[offsets[b]].0, byte_pos);
while val_b != b {
let dest_idx = offsets[val_b];
arr.swap(offsets[b], dest_idx);
offsets[val_b] += 1;
val_b = get_byte(arr[offsets[b]].0, byte_pos);
}
offsets[b] += 1;
}
}
for b in 0..256 {
let bucket_size = counts[b];
if bucket_size > 1 {
let current_start = starts[b];
wakesort_recursive_inplace(arr, current_start, current_start + bucket_size - 1, byte_pos - 1);
}
}
}
const FALLBACK_CASES: &[usize] = &[100_000, 1_000_000, 10_000_000];
fn main() {
println!("Benchmarking WakeSort (MSD Radix Sort) vs Unstable Sort (O(N log N))...");
println!("These sizes represent Fallback cases of OverclockedSort where max_val is very large.\n");
let mut rng = Lcg { state: 42 };
for &size in FALLBACK_CASES {
println!("Array size: {}", size);
let mut data = Vec::with_capacity(size);
for i in 0..size {
data.push((rng.next(), i as u32));
}
let mut data1 = data.clone();
let mut data2 = data.clone();
let mut data3 = data.clone();
let t1 = Instant::now();
fallback_unstable_sort(&mut data1);
let d1 = t1.elapsed();
println!(" std::sort_unstable : {:?}", d1);
let t2 = Instant::now();
wakesort_oop(&mut data2);
let d2 = t2.elapsed();
println!(" WakeSort OOP : {:?}", d2);
let t3 = Instant::now();
wakesort_inplace(&mut data3);
let d3 = t3.elapsed();
println!(" WakeSort In-place : {:?}", d3);
fn is_sorted_by_key(arr: &[(i32, u32)]) -> bool {
for i in 1..arr.len() {
if arr[i - 1].0 > arr[i].0 {
return false;
}
}
true
}
assert!(is_sorted_by_key(&data2), "WakeSort OOP is not sorted!");
assert!(is_sorted_by_key(&data3), "WakeSort In-place is not sorted!");
let mut data1_payloads = data1.clone();
data1_payloads.sort_unstable();
let mut data2_payloads = data2.clone();
data2_payloads.sort_unstable();
let mut data3_payloads = data3.clone();
data3_payloads.sort_unstable();
assert!(data1_payloads == data2_payloads, "WakeSort OOP lost elements!");
assert!(data1_payloads == data3_payloads, "WakeSort In-place lost elements!");
println!();
}
}