Skip to main content

overclocked_sort/
lib.rs

1pub mod key_ptr;
2pub use key_ptr::KeyPtr;
3
4/// A trait to enable High-performance Overclocked Radix / Counting Sort
5pub trait SortableKey: Copy + Default + Ord + Send + Sync {
6    type KeyType: Copy + Ord + Into<usize> + TryFrom<usize>;
7    
8    /// Extract the numeric key used for counting
9    fn extract_key(&self) -> Self::KeyType;
10
11    /// Whether this type contains attached data (needs Scatter), 
12    /// or just the key (can purely counting).
13    const IS_PRIMITIVE: bool;
14
15    /// Factory for generating primitives back from key 
16    fn from_key(k: Self::KeyType) -> Self;
17}
18
19// implementation for u32
20impl SortableKey for u32 {
21    type KeyType = usize;
22    #[inline(always)]
23    fn extract_key(&self) -> usize { *self as usize }
24    const IS_PRIMITIVE: bool = true;
25    #[inline(always)]
26    fn from_key(k: usize) -> Self { k as u32 }
27}
28
29// implementation for i32
30impl SortableKey for i32 {
31    type KeyType = usize;
32    #[inline(always)]
33    fn extract_key(&self) -> usize { (*self as i64 + 2147483648) as usize }
34    const IS_PRIMITIVE: bool = true;
35    #[inline(always)]
36    fn from_key(k: usize) -> Self { (k as i64 - 2147483648) as i32 }
37}
38
39/// Generic hybrid multi-purpose sorting function (Version 0.2.1)
40pub fn overclocked_sort<T: SortableKey>(arr: &mut [T]) {
41    if arr.len() <= 1 { return; }
42
43    // Lightweight probe: delegate clearly monotonic/all-equal patterns to std's PDQ path.
44    let n = arr.len();
45    let probe_points = std::cmp::min(n, 1024);
46    if probe_points >= 64 {
47        let stride = (n - 1) / (probe_points - 1);
48        let mut prev_probe = arr[0].extract_key().into();
49        let mut probe_desc_breaks = 0usize;
50        for i in 1..probe_points {
51            let idx = i * stride;
52            let k = arr[idx].extract_key().into();
53            if k < prev_probe {
54                probe_desc_breaks += 1;
55            }
56            prev_probe = k;
57        }
58        if probe_desc_breaks == 0 {
59            arr.sort_unstable();
60            return;
61        }
62        if probe_desc_breaks >= probe_points - 2 {
63            arr.sort_unstable();
64            return;
65        }
66    }
67    
68    // Single pass: detect monotonic patterns while collecting exact min/max range.
69    let first_key = arr[0].extract_key().into();
70    let mut prev_key = first_key;
71    let mut min_val = first_key;
72    let mut max_val = first_key;
73    let mut non_decreasing = true;
74    let mut non_increasing = true;
75    let mut desc_breaks = 0usize;
76    for item in arr.iter().skip(1) {
77        let k = item.extract_key().into();
78        if k < prev_key {
79            non_decreasing = false;
80            desc_breaks += 1;
81        }
82        if k > prev_key {
83            non_increasing = false;
84        }
85        prev_key = k;
86        if k < min_val { min_val = k; }
87        if k > max_val { max_val = k; }
88    }
89    if non_decreasing {
90        return;
91    }
92    if non_increasing {
93        arr.reverse();
94        return;
95    }
96
97    let range = max_val - min_val + 1;
98    let max_limit = std::cmp::min(1_000_000, std::cmp::max(1, n / 4));
99
100    // Pattern-defeating quicksort in std is very strong for near-sorted sparse-ish runs.
101    let near_sorted_threshold = std::cmp::max(1, n / 16);
102    if range > max_limit && (desc_breaks <= near_sorted_threshold || desc_breaks >= (n - 1).saturating_sub(near_sorted_threshold)) {
103        arr.sort_unstable();
104        return;
105    }
106
107    // Keep counting only in tight ranges to avoid large histogram overhead on edge patterns.
108    if range <= max_limit {
109        // Array is fully dense, use Overclocked Sort directly
110        if T::IS_PRIMITIVE {
111            counting_sort_primitive(arr, min_val, range);
112        } else {
113            counting_sort_records(arr, min_val, range);
114        }
115    } else {
116        // Sparse inputs go straight to the internal comparison sort.
117        fallback_sort(arr);
118    }
119}
120
121fn fallback_sort<T: SortableKey>(arr: &mut [T]) {
122    if arr.len() <= 1 {
123        return;
124    }
125    if arr.len() <= 32 {
126        insertion_sort(arr);
127        return;
128    }
129    radix_sort_by_key(arr);
130}
131
132fn radix_sort_by_key<T: SortableKey>(arr: &mut [T]) {
133    let len = arr.len();
134    let mut current = arr.to_vec();
135    let mut next = vec![T::default(); len];
136    let passes = std::mem::size_of::<usize>();
137
138    for pass in 0..passes {
139        let shift = pass * 8;
140        let mut counts = [0usize; 256];
141
142        for item in current.iter() {
143            let bucket = (item.extract_key().into() >> shift) & 0xFF;
144            counts[bucket] += 1;
145        }
146
147        let mut offset = 0usize;
148        for count in counts.iter_mut() {
149            let current_count = *count;
150            *count = offset;
151            offset += current_count;
152        }
153
154        for item in current.iter() {
155            let bucket = (item.extract_key().into() >> shift) & 0xFF;
156            next[counts[bucket]] = *item;
157            counts[bucket] += 1;
158        }
159
160        std::mem::swap(&mut current, &mut next);
161    }
162
163    arr.copy_from_slice(&current);
164}
165
166#[allow(dead_code)]
167fn quicksort_recursive<T: SortableKey>(arr: &mut [T]) {
168    const INSERTION_THRESHOLD: usize = 24;
169
170    if arr.len() <= INSERTION_THRESHOLD {
171        insertion_sort(arr);
172        return;
173    }
174
175    let len = arr.len();
176    let mid = len / 2;
177    let pivot = median_of_three(arr[0], arr[mid], arr[len - 1]);
178
179    let mut left = 0usize;
180    let mut right = len - 1;
181
182    loop {
183        while arr[left] < pivot {
184            left += 1;
185        }
186        while arr[right] > pivot {
187            if right == 0 {
188                break;
189            }
190            right -= 1;
191        }
192        if left >= right {
193            break;
194        }
195        arr.swap(left, right);
196        left += 1;
197        if right == 0 {
198            break;
199        }
200        right -= 1;
201    }
202
203    let split_index = right + 1;
204    let (lo, hi) = arr.split_at_mut(split_index);
205    if lo.len() < hi.len() {
206        if !lo.is_empty() {
207            quicksort_recursive(lo);
208        }
209        if !hi.is_empty() {
210            quicksort_recursive(hi);
211        }
212    } else {
213        if !hi.is_empty() {
214            quicksort_recursive(hi);
215        }
216        if !lo.is_empty() {
217            quicksort_recursive(lo);
218        }
219    }
220}
221
222#[allow(dead_code)]
223fn insertion_sort<T: SortableKey>(arr: &mut [T]) {
224    for i in 1..arr.len() {
225        let value = arr[i];
226        let mut j = i;
227        while j > 0 && arr[j - 1] > value {
228            arr[j] = arr[j - 1];
229            j -= 1;
230        }
231        arr[j] = value;
232    }
233}
234
235#[allow(dead_code)]
236#[inline(always)]
237fn median_of_three<T: SortableKey>(a: T, b: T, c: T) -> T {
238    if a < b {
239        if b < c {
240            b
241        } else if a < c {
242            c
243        } else {
244            a
245        }
246    } else if a < c {
247        a
248    } else if b < c {
249        c
250    } else {
251        b
252    }
253}
254
255// Memory-efficient Overclocked Primitive Counting Sort (Parallelized)
256fn counting_sort_primitive<T: SortableKey>(arr: &mut [T], min_val: usize, range: usize) {
257    let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
258    let chunk_sz = (arr.len() + num_threads - 1) / num_threads;
259    
260    let mut local_counts = vec![vec![0usize; range]; num_threads];
261    
262    // 1. Parallel Histogram
263    std::thread::scope(|s| {
264        for (i, count) in local_counts.iter_mut().enumerate() {
265            let start = i * chunk_sz;
266            let end = std::cmp::min(start + chunk_sz, arr.len());
267            if start >= arr.len() { continue; }
268            let slice = &arr[start..end];
269            s.spawn(move || {
270                for item in slice {
271                    count[item.extract_key().into() - min_val] += 1;
272                }
273            });
274        }
275    });
276    
277    // 2. Merge counts to global and compute bin offsets
278    let mut global_count = vec![0usize; range];
279    for counts in &local_counts {
280        for (i, &c) in counts.iter().enumerate() {
281            global_count[i] += c;
282        }
283    }
284    
285    let mut bin_offsets = vec![0usize; range];
286    let mut offset = 0;
287    for (i, &c) in global_count.iter().enumerate() {
288        bin_offsets[i] = offset;
289        offset += c;
290    }
291    
292    // 3. Parallel Fill (Assign ranges of bins to threads)
293    let arr_ptr = arr.as_mut_ptr() as usize;
294    let global_count_ref = &global_count;
295    let bin_offsets_ref = &bin_offsets;
296    let bins_per_thread = (range + num_threads - 1) / num_threads;
297    
298    std::thread::scope(|s| {
299        for t in 0..num_threads {
300            s.spawn(move || {
301                let start_bin = t * bins_per_thread;
302                let end_bin = std::cmp::min(start_bin + bins_per_thread, range);
303                
304                for i in start_bin..end_bin {
305                    let freq = global_count_ref[i];
306                    if freq == 0 { continue; }
307                    
308                    let val_k = min_val + i;
309                    if let Ok(key) = T::KeyType::try_from(val_k) {
310                        let val = T::from_key(key);
311                        let target_offset = bin_offsets_ref[i];
312                        
313                        unsafe {
314                            let ptr = (arr_ptr as *mut T).add(target_offset);
315                            for j in 0..freq {
316                                std::ptr::write(ptr.add(j), val);
317                            }
318                        }
319                    }
320                }
321            });
322        }
323    });
324}
325
326// Memory-efficient Prefix Sum + Parallel Scatter for structs (KeyPtr)
327fn counting_sort_records<T: SortableKey>(arr: &mut [T], min_val: usize, range: usize) {
328    let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
329    let chunk_sz = (arr.len() + num_threads - 1) / num_threads;
330    
331    let mut local_counts = vec![vec![0usize; range]; num_threads];
332    
333    // 1. Parallel Histogram
334    std::thread::scope(|s| {
335        for (i, count) in local_counts.iter_mut().enumerate() {
336            let start = i * chunk_sz;
337            let end = std::cmp::min(start + chunk_sz, arr.len());
338            if start >= arr.len() { continue; }
339            let slice = &arr[start..end];
340            s.spawn(move || {
341                for item in slice {
342                    count[item.extract_key().into() - min_val] += 1;
343                }
344            });
345        }
346    });
347    
348    // 2. Compute Global Offsets per thread
349    let mut global_offsets = vec![vec![0usize; range]; num_threads];
350    let mut total_offset = 0;
351    for val in 0..range {
352        for t in 0..num_threads {
353            global_offsets[t][val] = total_offset;
354            total_offset += local_counts[t][val];
355        }
356    }
357    
358    // 3. Parallel Scatter
359    let mut buffer = vec![T::default(); arr.len()];
360    let buf_ptr = buffer.as_mut_ptr() as usize;
361    let global_offsets_ref = &global_offsets;
362    
363    let arr_ref: &[T] = arr; // Shared read-only reference for threads
364    
365    std::thread::scope(|s| {
366        for t in 0..num_threads {
367            let current_offsets_tpl = global_offsets_ref[t].clone();
368            s.spawn(move || {
369                let mut current_offsets = current_offsets_tpl;
370                let start = t * chunk_sz;
371                let end = std::cmp::min(start + chunk_sz, arr_ref.len());
372                if start >= arr_ref.len() { return; }
373                let slice = &arr_ref[start..end];
374                
375                for item in slice {
376                    let bucket = item.extract_key().into() - min_val;
377                    let target_pos = current_offsets[bucket];
378                    unsafe {
379                        let ptr = (buf_ptr as *mut T).add(target_pos);
380                        std::ptr::write(ptr, *item);
381                    }
382                    current_offsets[bucket] += 1;
383                }
384            });
385        }
386    });
387    
388    arr.copy_from_slice(&buffer);
389}
390
391// Backward Compatibility Bridges mapping v0.1 functions
392pub fn overclocked_parallel_sort(input: &[i32], _max_val: usize) -> Vec<i32> {
393    let mut copy = input.to_vec();
394    overclocked_sort(&mut copy);
395    copy
396}
397
398pub fn overclocked_kp_sort(input: &[KeyPtr], _max_val: usize) -> Vec<KeyPtr> {
399    let mut copy = input.to_vec();
400    overclocked_sort(&mut copy);
401    copy
402}