overclocked_sort 0.2.0

A hyper-optimized Parallel Counting Sort utilizing L2 Cache-oblivious block sizing, SIMD Auto-vectorization, Prefix-Sum, and Zero-Runtime Dynamic Work Stealing.
Documentation
use std::sync::atomic::{AtomicUsize, Ordering};

pub mod key_ptr;
pub use key_ptr::KeyPtr;

/// A trait to enable High-performance Overclocked Radix / Counting Sort
pub trait SortableKey: Copy + Default + Ord + Send + Sync {
    type KeyType: Copy + Ord + Into<usize> + TryFrom<usize>;
    
    /// Extract the numeric key used for counting
    fn extract_key(&self) -> Self::KeyType;

    /// Whether this type contains attached data (needs Scatter), 
    /// or just the key (can purely counting).
    const IS_PRIMITIVE: bool;

    /// Factory for generating primitives back from key 
    fn from_key(k: Self::KeyType) -> Self;
}

// implementation for u32
impl SortableKey for u32 {
    type KeyType = usize;
    #[inline(always)]
    fn extract_key(&self) -> usize { *self as usize }
    const IS_PRIMITIVE: bool = true;
    #[inline(always)]
    fn from_key(k: usize) -> Self { k as u32 }
}

// implementation for i32
impl SortableKey for i32 {
    type KeyType = usize;
    #[inline(always)]
    fn extract_key(&self) -> usize { (*self as i64 + 2147483648) as usize }
    const IS_PRIMITIVE: bool = true;
    #[inline(always)]
    fn from_key(k: usize) -> Self { (k as i64 - 2147483648) as i32 }
}

/// Generic hybrid multi-purpose sorting function (Version 0.2.0)
pub fn overclocked_sort<T: SortableKey>(arr: &mut [T]) {
    if arr.len() <= 1 { return; }
    
    // Heuristic Check: Are we already sorted?
    if arr.windows(2).all(|w| w[0] <= w[1]) {
        return;
    }
    
    // Quick scan for range
    let mut min_val = arr[0].extract_key().into();
    let mut max_val = arr[0].extract_key().into();
    for item in arr.iter() {
        let k = item.extract_key().into();
        if k < min_val { min_val = k; }
        if k > max_val { max_val = k; }
    }
    
    let n = arr.len();
    let range = max_val - min_val + 1;

    // Define memory-safe limits for Overclocked Counting Sort (e.g. max 50M buckets)
    let max_limit = std::cmp::min(n.saturating_mul(2), 50_000_000);

    if range <= max_limit {
        // Array is fully dense, use Overclocked Sort directly
        if T::IS_PRIMITIVE {
            counting_sort_primitive(arr, min_val, range);
        } else {
            counting_sort_records(arr, min_val, range);
        }
    } else {
        // Hybrid Partitioning (Dual-Stream Setup)
        // Divide into dense_stream (values <= min_val + max_limit) and sparse_stream
        let split_bound = min_val + max_limit;
        let mut left = 0;
        let mut right = n;
        
        while left < right {
            if arr[left].extract_key().into() <= split_bound {
                left += 1;
            } else {
                right -= 1;
                arr.swap(left, right);
            }
        }
        
        let (dense_stream, sparse_stream) = arr.split_at_mut(left);
        
        // 1. Sort the Overclocked Dense stream
        if !dense_stream.is_empty() {
            // Re-eval max to minimize allocated buckets
            let mut true_dense_max = min_val;
            for item in dense_stream.iter() {
                let k = item.extract_key().into();
                if k > true_dense_max { true_dense_max = k; }
            }
            let dense_range = true_dense_max - min_val + 1;
            
            if T::IS_PRIMITIVE {
                counting_sort_primitive(dense_stream, min_val, dense_range);
            } else {
                counting_sort_records(dense_stream, min_val, dense_range);
            }
        }
        
        // 2. Sort the Sparse stream with Unstable Sort
        if !sparse_stream.is_empty() {
            sparse_stream.sort_unstable();
        }
    }
}

// Memory-efficient Overclocked Primitive Counting Sort (Parallelized)
fn counting_sort_primitive<T: SortableKey>(arr: &mut [T], min_val: usize, range: usize) {
    let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
    let chunk_sz = (arr.len() + num_threads - 1) / num_threads;
    
    let mut local_counts = vec![vec![0usize; range]; num_threads];
    
    // 1. Parallel Histogram
    std::thread::scope(|s| {
        for (i, count) in local_counts.iter_mut().enumerate() {
            let start = i * chunk_sz;
            let end = std::cmp::min(start + chunk_sz, arr.len());
            if start >= arr.len() { continue; }
            let slice = &arr[start..end];
            s.spawn(move || {
                for item in slice {
                    count[item.extract_key().into() - min_val] += 1;
                }
            });
        }
    });
    
    // 2. Merge counts to global and compute bin offsets
    let mut global_count = vec![0usize; range];
    for counts in &local_counts {
        for (i, &c) in counts.iter().enumerate() {
            global_count[i] += c;
        }
    }
    
    let mut bin_offsets = vec![0usize; range];
    let mut offset = 0;
    for (i, &c) in global_count.iter().enumerate() {
        bin_offsets[i] = offset;
        offset += c;
    }
    
    // 3. Parallel Fill (Assign ranges of bins to threads)
    let arr_ptr = arr.as_mut_ptr() as usize;
    let global_count_ref = &global_count;
    let bin_offsets_ref = &bin_offsets;
    let bins_per_thread = (range + num_threads - 1) / num_threads;
    
    std::thread::scope(|s| {
        for t in 0..num_threads {
            s.spawn(move || {
                let start_bin = t * bins_per_thread;
                let end_bin = std::cmp::min(start_bin + bins_per_thread, range);
                
                for i in start_bin..end_bin {
                    let freq = global_count_ref[i];
                    if freq == 0 { continue; }
                    
                    let val_k = min_val + i;
                    if let Ok(key) = T::KeyType::try_from(val_k) {
                        let val = T::from_key(key);
                        let target_offset = bin_offsets_ref[i];
                        
                        unsafe {
                            let ptr = (arr_ptr as *mut T).add(target_offset);
                            for j in 0..freq {
                                std::ptr::write(ptr.add(j), val);
                            }
                        }
                    }
                }
            });
        }
    });
}

// Memory-efficient Prefix Sum + Parallel Scatter for structs (KeyPtr)
fn counting_sort_records<T: SortableKey>(arr: &mut [T], min_val: usize, range: usize) {
    let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
    let chunk_sz = (arr.len() + num_threads - 1) / num_threads;
    
    let mut local_counts = vec![vec![0usize; range]; num_threads];
    
    // 1. Parallel Histogram
    std::thread::scope(|s| {
        for (i, count) in local_counts.iter_mut().enumerate() {
            let start = i * chunk_sz;
            let end = std::cmp::min(start + chunk_sz, arr.len());
            if start >= arr.len() { continue; }
            let slice = &arr[start..end];
            s.spawn(move || {
                for item in slice {
                    count[item.extract_key().into() - min_val] += 1;
                }
            });
        }
    });
    
    // 2. Compute Global Offsets per thread
    let mut global_offsets = vec![vec![0usize; range]; num_threads];
    let mut total_offset = 0;
    for val in 0..range {
        for t in 0..num_threads {
            global_offsets[t][val] = total_offset;
            total_offset += local_counts[t][val];
        }
    }
    
    // 3. Parallel Scatter
    let mut buffer = vec![T::default(); arr.len()];
    let buf_ptr = buffer.as_mut_ptr() as usize;
    let global_offsets_ref = &global_offsets;
    
    let arr_ref: &[T] = arr; // Shared read-only reference for threads
    
    std::thread::scope(|s| {
        for t in 0..num_threads {
            let current_offsets_tpl = global_offsets_ref[t].clone();
            s.spawn(move || {
                let mut current_offsets = current_offsets_tpl;
                let start = t * chunk_sz;
                let end = std::cmp::min(start + chunk_sz, arr_ref.len());
                if start >= arr_ref.len() { return; }
                let slice = &arr_ref[start..end];
                
                for item in slice {
                    let bucket = item.extract_key().into() - min_val;
                    let target_pos = current_offsets[bucket];
                    unsafe {
                        let ptr = (buf_ptr as *mut T).add(target_pos);
                        std::ptr::write(ptr, *item);
                    }
                    current_offsets[bucket] += 1;
                }
            });
        }
    });
    
    arr.copy_from_slice(&buffer);
}

// Backward Compatibility Bridges mapping v0.1 functions
pub fn overclocked_parallel_sort(input: &[i32], _max_val: usize) -> Vec<i32> {
    let mut copy = input.to_vec();
    overclocked_sort(&mut copy);
    copy
}

pub fn overclocked_kp_sort(input: &[KeyPtr], _max_val: usize) -> Vec<KeyPtr> {
    let mut copy = input.to_vec();
    overclocked_sort(&mut copy);
    copy
}