Skip to main content

learned_partition_sort/
learned_sort.rs

1//! Core implementation of Learned Partition Sort.
2//!
3//! The algorithm works in three phases:
4//! 1. **Sample**: Analyze a sample of data to learn min/max bounds
5//! 2. **Scatter**: Distribute elements to buckets using calculated indices
6//! 3. **Refine**: Sort each bucket (in parallel with Rayon)
7
8use rayon::prelude::*;
9
10/// Wrapper to allow sending raw pointer across threads.
11/// SAFETY: Users must guarantee non-overlapping access.
12#[derive(Clone, Copy)]
13struct SendPtr<T>(*mut T);
14unsafe impl<T: Send> Send for SendPtr<T> {}
15unsafe impl<T: Sync> Sync for SendPtr<T> {}
16
17impl<T> SendPtr<T> {
18    /// Returns the raw pointer. Use with care.
19    #[inline]
20    fn get(self) -> *mut T {
21        self.0
22    }
23}
24
25/// Hybridization threshold: below this size, use standard sort.
26/// Benchmarks show LPS wins at N >= 100K, so we use 32K as a safe cutoff.
27const HYBRID_THRESHOLD: usize = 32768;
28
29/// Target number of elements per bucket (fits in L2 cache).
30const BUCKET_TARGET_SIZE: usize = 512;
31
32/// Maximum bucket size multiplier before triggering fallback.
33/// If a bucket has more than this × expected size, the distribution model failed.
34const BUCKET_OVERFLOW_FACTOR: usize = 4;
35
36/// Threshold for using insertion sort on small buckets.
37const INSERTION_SORT_THRESHOLD: usize = 32;
38
39/// Sorts a slice using the Learned Partition Sort algorithm.
40///
41/// This algorithm achieves O(N) complexity on well-distributed numerical data
42/// by learning the data distribution and using it to directly calculate
43/// element positions.
44///
45/// # Algorithm
46///
47/// 1. For small inputs (< 8192 elements), falls back to `sort_unstable`
48/// 2. Samples data to find min/max bounds
49/// 3. Distributes elements to buckets based on calculated positions
50/// 4. Sorts each bucket in parallel using Rayon
51///
52/// # Examples
53///
54/// ```
55/// use learned_partition_sort::learned_sort;
56///
57/// let mut data: Vec<i64> = vec![5, 2, 8, 1, 9];
58/// learned_sort(&mut data);
59/// assert_eq!(data, vec![1, 2, 5, 8, 9]);
60/// ```
61pub fn learned_sort<T>(arr: &mut [T])
62where
63    T: Ord + Copy + Send + Sync + Into<i64>,
64{
65    let n = arr.len();
66
67    // Guard: empty or single element
68    if n <= 1 {
69        return;
70    }
71
72    // Hybridization: use standard sort for small arrays
73    if n < HYBRID_THRESHOLD {
74        arr.sort_unstable();
75        return;
76    }
77
78    // Phase 1: Sample to find min/max
79    let (min_val, max_val) = sample_minmax(arr);
80
81    // Handle edge case: all elements are identical
82    if min_val == max_val {
83        return; // Already sorted (all same)
84    }
85
86    // Calculate number of buckets
87    let num_buckets = (n / BUCKET_TARGET_SIZE).max(1);
88
89    // Phase 2: Count elements per bucket
90    let mut counts = count_buckets(arr, min_val, max_val, num_buckets);
91
92    // Convert counts to start offsets via prefix sum
93    let offsets = prefix_sum(&counts);
94
95    // Phase 3: Scatter elements to auxiliary buffer
96    let mut aux = vec![arr[0]; n]; // Allocate auxiliary buffer
97    scatter(arr, &mut aux, &mut counts, &offsets, min_val, max_val, num_buckets);
98
99    // Phase 4: Refine - sort each bucket in parallel
100    refine_buckets(&mut aux, &offsets, num_buckets, n);
101
102    // Copy back to original array
103    arr.copy_from_slice(&aux);
104}
105
106/// Sorts a slice using an in-place variant of Learned Partition Sort.
107///
108/// This version uses **O(num_buckets)** additional memory instead of O(N),
109/// making it suitable for very large arrays (100M+ elements) or memory-constrained
110/// environments where allocating an auxiliary buffer would fail.
111///
112/// # ⚠️ Performance Warning
113///
114/// This function is **~5x slower** than [`learned_sort`] due to cache-unfriendly
115/// cycle-sort memory access patterns. Only use when memory is more important than speed.
116///
117/// # Trade-offs
118///
119/// | Metric | `learned_sort` | `learned_sort_inplace` |
120/// |--------|----------------|------------------------|
121/// | Memory | 2N (data + aux) | N + ~200KB |
122/// | Time @ 100M | 3.3s | 16.3s |
123/// | Throughput | 30 Melem/s | 6 Melem/s |
124///
125/// # When to Use
126///
127/// - ✅ Sorting 100M+ elements on 8GB RAM machines
128/// - ✅ Serverless/embedded with strict memory limits
129/// - ❌ When speed matters more than memory
130///
131/// # Algorithm
132///
133/// Uses cycle-sort style in-place permutation:
134/// 1. Count elements per bucket
135/// 2. Compute bucket start positions  
136/// 3. Follow permutation cycles to move elements to correct buckets in-place
137/// 4. Sort each bucket in parallel
138///
139/// # Examples
140///
141/// ```
142/// use learned_partition_sort::learned_sort_inplace;
143///
144/// let mut data: Vec<i64> = vec![5, 2, 8, 1, 9];
145/// learned_sort_inplace(&mut data);
146/// assert_eq!(data, vec![1, 2, 5, 8, 9]);
147/// ```
148pub fn learned_sort_inplace<T>(arr: &mut [T])
149where
150    T: Ord + Copy + Send + Sync + Into<i64>,
151{
152    let n = arr.len();
153
154    // Guard: empty or single element
155    if n <= 1 {
156        return;
157    }
158
159    // Hybridization: use standard sort for small arrays
160    if n < HYBRID_THRESHOLD {
161        arr.sort_unstable();
162        return;
163    }
164
165    // Phase 1: Sample to find min/max
166    let (min_val, max_val) = sample_minmax(arr);
167
168    // Handle edge case: all elements are identical
169    if min_val == max_val {
170        return; // Already sorted (all same)
171    }
172
173    // Calculate number of buckets
174    let num_buckets = (n / BUCKET_TARGET_SIZE).max(1);
175
176    // Phase 2: Count elements per bucket
177    let counts = count_buckets(arr, min_val, max_val, num_buckets);
178
179    // Convert counts to start offsets via prefix sum
180    let offsets = prefix_sum(&counts);
181
182    // Phase 3: In-place permutation to bucket positions
183    scatter_inplace(arr, &offsets, min_val, max_val, num_buckets);
184
185    // Phase 4: Refine - sort each bucket in parallel
186    refine_buckets(arr, &offsets, num_buckets, n);
187}
188
189/// Computes the bucket index for a given value.
190#[inline]
191fn compute_bucket(val: i64, min_val: i64, scale: f64, num_buckets: usize) -> usize {
192    let idx = ((val - min_val) as f64 * scale) as usize;
193    idx.min(num_buckets - 1)
194}
195
196/// Permutes elements in-place to their bucket positions using optimized cycle-sort.
197///
198/// Uses O(num_buckets) additional memory for write cursors.
199/// Optimized for cache locality by processing buckets sequentially.
200fn scatter_inplace<T>(
201    arr: &mut [T],
202    offsets: &[usize],
203    min_val: i64,
204    max_val: i64,
205    num_buckets: usize,
206) where
207    T: Copy + Into<i64>,
208{
209    let range = (max_val - min_val) as f64;
210    let scale = (num_buckets as f64 - 0.001) / range;
211
212    // Write cursors: next position to write in each bucket
213    let mut write_cursors: Vec<usize> = offsets[..num_buckets].to_vec();
214
215    // Process each bucket region
216    for bucket in 0..num_buckets {
217        let bucket_start = offsets[bucket];
218        let bucket_end = offsets[bucket + 1];
219
220        // Process each position in the bucket region
221        let mut pos = bucket_start;
222        while pos < bucket_end {
223            let current_val: i64 = arr[pos].into();
224            let target_bucket = compute_bucket(current_val, min_val, scale, num_buckets);
225
226            if target_bucket == bucket {
227                // Element already in correct bucket region
228                // Advance write cursor if needed
229                if write_cursors[bucket] <= pos {
230                    write_cursors[bucket] = pos + 1;
231                }
232                pos += 1;
233                continue;
234            }
235
236            // Element needs to move - follow the permutation cycle
237            let mut current = arr[pos];
238            let mut current_bucket = target_bucket;
239
240            loop {
241                // Get destination position
242                let dest_pos = write_cursors[current_bucket];
243                
244                // Advance cursor for this bucket
245                write_cursors[current_bucket] += 1;
246
247                // Swap
248                let next = arr[dest_pos];
249                arr[dest_pos] = current;
250                
251                let next_bucket = compute_bucket(next.into(), min_val, scale, num_buckets);
252
253                // Check if cycle is complete
254                if next_bucket == bucket {
255                    // Put the final element back
256                    arr[pos] = next;
257                    break;
258                }
259
260                current = next;
261                current_bucket = next_bucket;
262            }
263
264            // Update write cursor for current bucket
265            if write_cursors[bucket] <= pos {
266                write_cursors[bucket] = pos + 1;
267            }
268            pos += 1;
269        }
270    }
271}
272
273/// Samples the array to find minimum and maximum values.
274/// Uses full scan for accuracy (sampling 1% can miss outliers).
275#[inline]
276fn sample_minmax<T>(arr: &[T]) -> (i64, i64)
277where
278    T: Ord + Copy + Into<i64>,
279{
280    let mut min_val = arr[0].into();
281    let mut max_val = arr[0].into();
282
283    for &item in arr.iter() {
284        let val: i64 = item.into();
285        if val < min_val {
286            min_val = val;
287        }
288        if val > max_val {
289            max_val = val;
290        }
291    }
292
293    (min_val, max_val)
294}
295
296/// Counts elements per bucket.
297/// Returns a vector where `counts[i]` is the number of elements in bucket `i`.
298#[inline]
299fn count_buckets<T>(arr: &[T], min_val: i64, max_val: i64, num_buckets: usize) -> Vec<usize>
300where
301    T: Copy + Into<i64>,
302{
303    let mut counts = vec![0usize; num_buckets];
304    let range = (max_val - min_val) as f64;
305    let scale = (num_buckets as f64 - 0.001) / range; // Slight reduction to avoid overflow
306
307    for &item in arr.iter() {
308        let val: i64 = item.into();
309        let bucket_idx = ((val - min_val) as f64 * scale) as usize;
310        let bucket_idx = bucket_idx.min(num_buckets - 1); // Safety clamp
311        counts[bucket_idx] += 1;
312    }
313
314    counts
315}
316
317/// Converts counts to start offsets using prefix sum.
318/// `offsets[i]` is the starting index of bucket `i` in the auxiliary array.
319#[inline]
320fn prefix_sum(counts: &[usize]) -> Vec<usize> {
321    let mut offsets = Vec::with_capacity(counts.len() + 1);
322    let mut sum = 0;
323
324    for &count in counts.iter() {
325        offsets.push(sum);
326        sum += count;
327    }
328    offsets.push(sum); // Final offset marks end of last bucket
329
330    offsets
331}
332
333/// Scatters elements from source to auxiliary buffer based on bucket indices.
334/// Uses mutable counts as write pointers.
335///
336/// # Safety
337/// Uses unsafe `get_unchecked_mut` for performance in the hot loop.
338#[inline]
339fn scatter<T>(
340    src: &[T],
341    aux: &mut [T],
342    counts: &mut [usize],
343    offsets: &[usize],
344    min_val: i64,
345    max_val: i64,
346    num_buckets: usize,
347) where
348    T: Copy + Into<i64>,
349{
350    // Reset counts to use as write pointers (start at offset positions)
351    for (i, count) in counts.iter_mut().enumerate() {
352        *count = offsets[i];
353    }
354
355    let range = (max_val - min_val) as f64;
356    let scale = (num_buckets as f64 - 0.001) / range;
357
358    for &item in src.iter() {
359        let val: i64 = item.into();
360        let bucket_idx = ((val - min_val) as f64 * scale) as usize;
361        let bucket_idx = bucket_idx.min(num_buckets - 1);
362
363        let write_pos = counts[bucket_idx];
364        counts[bucket_idx] += 1;
365
366        // SAFETY: write_pos is guaranteed to be within bounds because:
367        // - offsets are computed from prefix sum of counts
368        // - we increment write_pos exactly as many times as there are elements
369        unsafe {
370            *aux.get_unchecked_mut(write_pos) = item;
371        }
372    }
373}
374
375/// Sorts each bucket in parallel using Rayon.
376/// Small buckets use insertion sort, larger ones use `sort_unstable`.
377fn refine_buckets<T>(aux: &mut [T], offsets: &[usize], num_buckets: usize, total_len: usize)
378where
379    T: Ord + Copy + Send + Sync,
380{
381    let expected_bucket_size = total_len / num_buckets;
382    let overflow_threshold = expected_bucket_size * BUCKET_OVERFLOW_FACTOR;
383
384    let ptr = SendPtr(aux.as_mut_ptr());
385
386    // Create bucket ranges
387    let bucket_ranges: Vec<(usize, usize)> = (0..num_buckets)
388        .map(|i| (offsets[i], offsets[i + 1]))
389        .collect();
390
391    // Sort buckets in parallel
392    bucket_ranges.par_iter().for_each(move |&(start, end)| {
393        let bucket_len = end - start;
394        if bucket_len <= 1 {
395            return;
396        }
397
398        // SAFETY: Each bucket range is non-overlapping due to prefix sum construction
399        let bucket_slice = unsafe { std::slice::from_raw_parts_mut(ptr.get().add(start), bucket_len) };
400
401        if bucket_len < INSERTION_SORT_THRESHOLD {
402            insertion_sort(bucket_slice);
403        } else if bucket_len > overflow_threshold {
404            // Distribution model failed for this bucket - use robust fallback
405            bucket_slice.sort_unstable();
406        } else {
407            // Normal bucket - still use sort_unstable (fast for small slices)
408            bucket_slice.sort_unstable();
409        }
410    });
411}
412
413/// Simple insertion sort for very small slices.
414/// O(N²) but extremely fast for N < 32 due to low overhead.
415#[inline]
416fn insertion_sort<T: Ord + Copy>(arr: &mut [T]) {
417    for i in 1..arr.len() {
418        let key = arr[i];
419        let mut j = i;
420        while j > 0 && arr[j - 1] > key {
421            arr[j] = arr[j - 1];
422            j -= 1;
423        }
424        arr[j] = key;
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use rand::prelude::*;
432
433    #[test]
434    fn test_empty_slice() {
435        let mut data: Vec<i64> = vec![];
436        learned_sort(&mut data);
437        assert!(data.is_empty());
438    }
439
440    #[test]
441    fn test_single_element() {
442        let mut data = vec![42i64];
443        learned_sort(&mut data);
444        assert_eq!(data, vec![42]);
445    }
446
447    #[test]
448    fn test_two_elements() {
449        let mut data = vec![5i64, 3];
450        learned_sort(&mut data);
451        assert_eq!(data, vec![3, 5]);
452    }
453
454    #[test]
455    fn test_small_array_uses_fallback() {
456        let mut data: Vec<i64> = (0..100).rev().collect();
457        learned_sort(&mut data);
458        assert_eq!(data, (0..100).collect::<Vec<_>>());
459    }
460
461    #[test]
462    fn test_medium_array() {
463        let mut data: Vec<i64> = (0..1000).rev().collect();
464        learned_sort(&mut data);
465        assert_eq!(data, (0..1000).collect::<Vec<_>>());
466    }
467
468    #[test]
469    fn test_large_uniform_distribution() {
470        let mut rng = rand::thread_rng();
471        let mut data: Vec<i64> = (0..100_000).map(|_| rng.gen_range(0..1_000_000)).collect();
472        let mut expected = data.clone();
473        expected.sort_unstable();
474
475        learned_sort(&mut data);
476        assert_eq!(data, expected);
477    }
478
479    #[test]
480    fn test_sorted_input() {
481        let mut data: Vec<i64> = (0..10_000).collect();
482        let expected = data.clone();
483        learned_sort(&mut data);
484        assert_eq!(data, expected);
485    }
486
487    #[test]
488    fn test_reverse_sorted() {
489        let mut data: Vec<i64> = (0..10_000).rev().collect();
490        let expected: Vec<i64> = (0..10_000).collect();
491        learned_sort(&mut data);
492        assert_eq!(data, expected);
493    }
494
495    #[test]
496    fn test_duplicates() {
497        let mut data: Vec<i64> = vec![5; 10_000];
498        let expected = data.clone();
499        learned_sort(&mut data);
500        assert_eq!(data, expected);
501    }
502
503    #[test]
504    fn test_many_duplicates() {
505        let mut rng = rand::thread_rng();
506        let mut data: Vec<i64> = (0..10_000).map(|_| rng.gen_range(0..10)).collect();
507        let mut expected = data.clone();
508        expected.sort_unstable();
509
510        learned_sort(&mut data);
511        assert_eq!(data, expected);
512    }
513
514    #[test]
515    fn test_negative_numbers() {
516        let mut rng = rand::thread_rng();
517        let mut data: Vec<i64> = (0..10_000).map(|_| rng.gen_range(-500_000..500_000)).collect();
518        let mut expected = data.clone();
519        expected.sort_unstable();
520
521        learned_sort(&mut data);
522        assert_eq!(data, expected);
523    }
524
525    #[test]
526    fn test_i32_type() {
527        let mut rng = rand::thread_rng();
528        let mut data: Vec<i32> = (0..10_000).map(|_| rng.gen_range(0..1_000_000)).collect();
529        let mut expected = data.clone();
530        expected.sort_unstable();
531
532        learned_sort(&mut data);
533        assert_eq!(data, expected);
534    }
535
536    // ============ Tests for learned_sort_inplace ============
537
538    #[test]
539    fn test_inplace_empty_slice() {
540        let mut data: Vec<i64> = vec![];
541        learned_sort_inplace(&mut data);
542        assert!(data.is_empty());
543    }
544
545    #[test]
546    fn test_inplace_single_element() {
547        let mut data = vec![42i64];
548        learned_sort_inplace(&mut data);
549        assert_eq!(data, vec![42]);
550    }
551
552    #[test]
553    fn test_inplace_small_array() {
554        let mut data: Vec<i64> = (0..100).rev().collect();
555        learned_sort_inplace(&mut data);
556        assert_eq!(data, (0..100).collect::<Vec<_>>());
557    }
558
559    #[test]
560    fn test_inplace_large_uniform() {
561        let mut rng = rand::thread_rng();
562        let mut data: Vec<i64> = (0..100_000).map(|_| rng.gen_range(0..1_000_000)).collect();
563        let mut expected = data.clone();
564        expected.sort_unstable();
565
566        learned_sort_inplace(&mut data);
567        assert_eq!(data, expected);
568    }
569
570    #[test]
571    fn test_inplace_duplicates() {
572        let mut data: Vec<i64> = vec![5; 50_000];
573        let expected = data.clone();
574        learned_sort_inplace(&mut data);
575        assert_eq!(data, expected);
576    }
577
578    #[test]
579    fn test_inplace_many_duplicates() {
580        let mut rng = rand::thread_rng();
581        let mut data: Vec<i64> = (0..50_000).map(|_| rng.gen_range(0..10)).collect();
582        let mut expected = data.clone();
583        expected.sort_unstable();
584
585        learned_sort_inplace(&mut data);
586        assert_eq!(data, expected);
587    }
588
589    #[test]
590    fn test_inplace_negative_numbers() {
591        let mut rng = rand::thread_rng();
592        let mut data: Vec<i64> = (0..50_000).map(|_| rng.gen_range(-500_000..500_000)).collect();
593        let mut expected = data.clone();
594        expected.sort_unstable();
595
596        learned_sort_inplace(&mut data);
597        assert_eq!(data, expected);
598    }
599
600    #[test]
601    fn test_inplace_matches_regular() {
602        let mut rng = rand::thread_rng();
603        let original: Vec<i64> = (0..100_000).map(|_| rng.gen_range(0..1_000_000)).collect();
604        
605        let mut data_regular = original.clone();
606        let mut data_inplace = original.clone();
607        
608        learned_sort(&mut data_regular);
609        learned_sort_inplace(&mut data_inplace);
610        
611        assert_eq!(data_regular, data_inplace);
612    }
613}