Skip to main content

fgumi_lib/sort/
radix.rs

1//! Radix sort for coordinate-based keys.
2//!
3//! This module provides O(n) radix sort for coordinate sorting, which is
4//! faster than comparison-based O(n log n) sorts for large arrays.
5//!
6//! # Algorithm
7//!
8//! Uses LSD (Least Significant Digit) radix sort with 8-bit radix (256 buckets).
9//! Adaptive: only sorts the bytes actually needed based on max values in dataset.
10
11use std::cmp::Ordering;
12
13/// Packed coordinate key for radix sorting.
14///
15/// Packs tid (16 bits) + pos (32 bits) + reverse (1 bit) into a u64 for
16/// efficient radix sorting.
17#[derive(Clone, Copy, Eq, PartialEq, Debug)]
18#[repr(transparent)]
19pub struct PackedCoordinateKey(u64);
20
21impl PackedCoordinateKey {
22    /// Create a new packed coordinate key.
23    ///
24    /// Layout: `[tid:16][pos:32][reverse:1][padding:15]`
25    #[inline]
26    #[must_use]
27    #[allow(clippy::cast_sign_loss)]
28    pub fn new(tid: i32, pos: i32, reverse: bool) -> Self {
29        // Handle unmapped (tid=-1) by mapping to max value
30        let tid_bits = if tid < 0 { 0xFFFF_u64 } else { (tid as u64) & 0xFFFF };
31        let pos_bits = if pos < 0 { 0xFFFF_FFFF_u64 } else { (pos as u64) & 0xFFFF_FFFF };
32        let reverse_bit = u64::from(reverse);
33
34        // Pack: [tid:16][pos:32][reverse:1][padding:15]
35        Self((tid_bits << 48) | (pos_bits << 16) | (reverse_bit << 15))
36    }
37
38    /// Get the raw u64 value for sorting.
39    #[inline]
40    #[must_use]
41    pub fn as_u64(self) -> u64 {
42        self.0
43    }
44}
45
46impl PartialOrd for PackedCoordinateKey {
47    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
48        Some(self.cmp(other))
49    }
50}
51
52impl Ord for PackedCoordinateKey {
53    fn cmp(&self, other: &Self) -> Ordering {
54        self.0.cmp(&other.0)
55    }
56}
57
58/// Threshold below which we use insertion sort instead of radix sort.
59const RADIX_THRESHOLD: usize = 256;
60
61// ============================================================================
62// Samtools-style Adaptive Radix Sort
63// ============================================================================
64
65/// Adaptive radix sort for coordinate keys, following samtools approach.
66///
67/// Key optimizations from samtools:
68/// 1. Only sort the bytes actually needed (based on max tid/pos in dataset)
69/// 2. Pack pos+reverse+tid into little-endian bytes for LSD radix sort
70/// 3. Unmapped reads (tid=-1) sort to end
71///
72/// # Arguments
73/// * `entries` - Slice of (`packed_key`, `record_index`) pairs to sort in place
74/// * `nref` - Number of reference sequences (for mapping unmapped tid)
75#[allow(clippy::uninit_vec, unsafe_code)]
76pub fn radix_sort_coordinate_adaptive<T: Clone>(
77    entries: &mut [(u64, T)],
78    max_tid: u32,
79    max_pos: u64,
80) {
81    let n = entries.len();
82    if n < RADIX_THRESHOLD {
83        insertion_sort_by_key(entries, |(k, _)| *k);
84        return;
85    }
86
87    // Calculate bytes needed for pos and tid (like samtools)
88    let pos_bytes = bytes_needed_u64(max_pos);
89    let tid_bytes = bytes_needed_u32(max_tid);
90    let total_bytes = pos_bytes + tid_bytes;
91
92    if total_bytes == 0 {
93        return; // All same value, already sorted
94    }
95
96    // Allocate auxiliary buffer
97    let mut aux: Vec<(u64, T)> = Vec::with_capacity(n);
98    unsafe {
99        aux.set_len(n);
100    }
101
102    let mut src = entries as *mut [(u64, T)];
103    let mut dst = aux.as_mut_slice() as *mut [(u64, T)];
104
105    // LSD radix sort - byte by byte from least significant
106    for byte_idx in 0..total_bytes {
107        let src_slice = unsafe { &*src };
108        let dst_slice = unsafe { &mut *dst };
109
110        // Count occurrences of each byte value
111        let mut counts = [0usize; 256];
112        for (key, _) in src_slice {
113            let byte = ((key >> (byte_idx * 8)) & 0xFF) as usize;
114            counts[byte] += 1;
115        }
116
117        // Convert to cumulative offsets
118        let mut total = 0;
119        for count in &mut counts {
120            let c = *count;
121            *count = total;
122            total += c;
123        }
124
125        // Scatter elements to destination
126        for item in src_slice {
127            let byte = ((item.0 >> (byte_idx * 8)) & 0xFF) as usize;
128            let dest_idx = counts[byte];
129            counts[byte] += 1;
130            dst_slice[dest_idx] = item.clone();
131        }
132
133        // Swap src and dst
134        std::mem::swap(&mut src, &mut dst);
135    }
136
137    // If odd number of passes, copy back to original buffer
138    if total_bytes % 2 == 1 {
139        let src_slice = unsafe { &*src };
140        entries.clone_from_slice(src_slice);
141    }
142}
143
144/// Calculate number of bytes needed to represent a u64 value.
145#[inline]
146fn bytes_needed_u64(val: u64) -> usize {
147    if val == 0 {
148        return 0;
149    }
150    ((64 - val.leading_zeros()) as usize).div_ceil(8)
151}
152
153/// Calculate number of bytes needed to represent a u32 value.
154#[inline]
155fn bytes_needed_u32(val: u32) -> usize {
156    if val == 0 {
157        return 0;
158    }
159    ((32 - val.leading_zeros()) as usize).div_ceil(8)
160}
161
162/// Pack coordinate fields for radix sorting (samtools style).
163///
164/// Returns a key suitable for LSD radix sort where:
165/// - Lower bytes contain pos+reverse (sorted first)
166/// - Upper bytes contain tid (sorted last)
167/// - Unmapped reads (tid=-1) map to nref to sort at end
168#[inline]
169#[must_use]
170#[allow(clippy::cast_sign_loss)]
171pub fn pack_coordinate_for_radix(tid: i32, pos: i32, reverse: bool, nref: u32) -> u64 {
172    // Map unmapped (-1) to nref so they sort to the end
173    let tid_val = if tid < 0 { nref } else { tid as u32 };
174
175    // pos shifted left by 1, with reverse bit in LSB
176    // Add 1 to pos so -1 becomes 0 (unmapped positions sort first within unmapped tid)
177    let pos_val = (((pos + 1) as u64) << 1) | u64::from(reverse);
178
179    // Pack: lower 40 bits = pos+reverse, upper 24 bits = tid
180    // This gives us room for pos up to ~500 billion and tid up to 16 million
181    pos_val | (u64::from(tid_val) << 40)
182}
183
184/// Radix sort for packed coordinate keys.
185///
186/// Uses 8-bit radix (256 buckets) with 8 passes for 64-bit keys.
187/// Falls back to insertion sort for small arrays.
188#[allow(clippy::uninit_vec, unsafe_code)]
189pub fn radix_sort_u64<T: Clone>(entries: &mut [(u64, T)]) {
190    if entries.len() < RADIX_THRESHOLD {
191        // Use insertion sort for small arrays
192        insertion_sort_by_key(entries, |(k, _)| *k);
193        return;
194    }
195
196    let n = entries.len();
197
198    // Allocate auxiliary buffer
199    let mut aux: Vec<(u64, T)> = Vec::with_capacity(n);
200    unsafe {
201        aux.set_len(n);
202    }
203
204    let mut src = entries as *mut [(u64, T)];
205    let mut dst = aux.as_mut_slice() as *mut [(u64, T)];
206
207    // 8 passes, one for each byte (LSB first)
208    for pass in 0..8 {
209        let shift = pass * 8;
210
211        let src_slice = unsafe { &*src };
212        let dst_slice = unsafe { &mut *dst };
213
214        // Count occurrences of each byte value
215        let mut counts = [0usize; 256];
216        for (key, _) in src_slice {
217            let byte = ((key >> shift) & 0xFF) as usize;
218            counts[byte] += 1;
219        }
220
221        // Convert to cumulative offsets
222        let mut total = 0;
223        for count in &mut counts {
224            let c = *count;
225            *count = total;
226            total += c;
227        }
228
229        // Scatter elements to destination
230        for item in src_slice {
231            let byte = ((item.0 >> shift) & 0xFF) as usize;
232            let dest_idx = counts[byte];
233            counts[byte] += 1;
234            dst_slice[dest_idx] = item.clone();
235        }
236
237        // Swap src and dst
238        std::mem::swap(&mut src, &mut dst);
239    }
240
241    // After 8 passes (even number), data is back in original buffer
242}
243
244/// Radix sort for packed coordinate keys with associated data.
245#[allow(unsafe_code)]
246pub fn radix_sort_coordinate<T: Clone>(entries: &mut [(PackedCoordinateKey, T)]) {
247    if entries.len() < RADIX_THRESHOLD {
248        insertion_sort_by_key(entries, |(k, _)| k.0);
249        return;
250    }
251
252    // Convert to u64 keys for radix sort
253    // SAFETY: PackedCoordinateKey is #[repr(transparent)] over u64,
254    // so (PackedCoordinateKey, T) has the same layout as (u64, T)
255    let entries_u64: &mut [(u64, T)] =
256        unsafe { std::slice::from_raw_parts_mut(entries.as_mut_ptr().cast(), entries.len()) };
257
258    radix_sort_u64(entries_u64);
259}
260
261// ============================================================================
262// Fixed-Array Heap (ks_heapadjust style)
263// ============================================================================
264
265/// In-place heap adjustment (sift-down), following samtools `ks_heapadjust`.
266///
267/// This is more efficient than Rust's `BinaryHeap` for the merge phase because:
268/// 1. No allocation on push/pop
269/// 2. Single sift-down operation per element
270/// 3. Works with fixed-size array
271///
272/// # Arguments
273/// * `heap` - The heap array (max-heap by default, use reverse comparator for min-heap)
274/// * `i` - Index to sift down from
275/// * `n` - Heap size (may be less than array length)
276/// * `lt` - Less-than comparator (for max-heap, swap if child > parent)
277#[inline]
278#[allow(unsafe_code)]
279pub fn heap_sift_down<T, F>(heap: &mut [T], mut i: usize, n: usize, lt: &F)
280where
281    F: Fn(&T, &T) -> bool,
282{
283    let tmp = unsafe { std::ptr::read(&raw const heap[i]) };
284
285    loop {
286        let left = 2 * i + 1;
287        if left >= n {
288            break;
289        }
290
291        // Find larger child
292        let right = left + 1;
293        let mut child = left;
294        if right < n && lt(&heap[left], &heap[right]) {
295            child = right;
296        }
297
298        // If tmp >= largest child, we're done
299        if !lt(&tmp, &heap[child]) {
300            break;
301        }
302
303        // Move child up
304        unsafe {
305            std::ptr::copy_nonoverlapping(&raw const heap[child], &raw mut heap[i], 1);
306        }
307        i = child;
308    }
309
310    unsafe {
311        std::ptr::write(&raw mut heap[i], tmp);
312    }
313}
314
315/// Build a heap from an unsorted array (heapify).
316#[inline]
317pub fn heap_make<T, F>(heap: &mut [T], lt: &F)
318where
319    F: Fn(&T, &T) -> bool,
320{
321    let n = heap.len();
322    if n <= 1 {
323        return;
324    }
325
326    // Start from last non-leaf node and sift down
327    for i in (0..n / 2).rev() {
328        heap_sift_down(heap, i, n, lt);
329    }
330}
331
332/// Pop the top element and restore heap property.
333///
334/// Returns the new heap size (n - 1). The popped element is moved to heap[n-1].
335#[inline]
336pub fn heap_pop<T, F>(heap: &mut [T], n: usize, lt: &F) -> usize
337where
338    F: Fn(&T, &T) -> bool,
339{
340    if n == 0 {
341        return 0;
342    }
343    if n == 1 {
344        return 0;
345    }
346
347    // Swap top with last element
348    heap.swap(0, n - 1);
349
350    // Sift down the new top
351    let new_n = n - 1;
352    if new_n > 0 {
353        heap_sift_down(heap, 0, new_n, lt);
354    }
355
356    new_n
357}
358
359/// Replace the top element and restore heap property.
360///
361/// More efficient than pop + push when the heap size stays the same.
362#[inline]
363pub fn heap_replace_top<T, F>(heap: &mut [T], new_value: T, n: usize, lt: &F)
364where
365    F: Fn(&T, &T) -> bool,
366{
367    heap[0] = new_value;
368    heap_sift_down(heap, 0, n, lt);
369}
370
371// ============================================================================
372// Helper Functions
373// ============================================================================
374
375/// Binary insertion sort for small arrays.
376///
377/// Uses binary search to find insertion point, reducing comparisons from
378/// O(n²) to O(n log n) while maintaining O(n²) moves.
379#[inline]
380pub fn insertion_sort_by_key<T, K: Ord, F: Fn(&T) -> K>(arr: &mut [T], key_fn: F) {
381    for i in 1..arr.len() {
382        // Binary search for insertion point
383        let key = key_fn(&arr[i]);
384        let insert_pos = arr[..i].partition_point(|x| key_fn(x) <= key);
385
386        // Rotate to insert element at correct position
387        if insert_pos < i {
388            arr[insert_pos..=i].rotate_right(1);
389        }
390    }
391}
392
393/// Binary insertion sort with comparison function.
394#[inline]
395pub fn binary_insertion_sort<T, F>(arr: &mut [T], compare: F)
396where
397    F: Fn(&T, &T) -> Ordering,
398{
399    for i in 1..arr.len() {
400        // Binary search for insertion point
401        let mut lo = 0;
402        let mut hi = i;
403
404        while lo < hi {
405            let mid = lo + (hi - lo) / 2;
406            if compare(&arr[mid], &arr[i]) == Ordering::Greater {
407                hi = mid;
408            } else {
409                lo = mid + 1;
410            }
411        }
412
413        // Rotate to insert element at correct position
414        if lo < i {
415            arr[lo..=i].rotate_right(1);
416        }
417    }
418}
419
420/// Hybrid sort: uses binary insertion for small arrays, parallel sort for large.
421pub fn hybrid_sort<T: Send, F>(arr: &mut [T], compare: F, parallel: bool)
422where
423    F: Fn(&T, &T) -> Ordering + Sync,
424{
425    const INSERTION_THRESHOLD: usize = 32;
426
427    if arr.len() <= INSERTION_THRESHOLD {
428        binary_insertion_sort(arr, compare);
429    } else if parallel {
430        use rayon::prelude::*;
431        arr.par_sort_unstable_by(|a, b| compare(a, b));
432    } else {
433        arr.sort_unstable_by(|a, b| compare(a, b));
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn test_packed_coordinate_key() {
443        let k1 = PackedCoordinateKey::new(0, 100, false);
444        let k2 = PackedCoordinateKey::new(0, 200, false);
445        let k3 = PackedCoordinateKey::new(1, 100, false);
446        let k4 = PackedCoordinateKey::new(0, 100, true);
447
448        assert!(k1 < k2); // Same tid, pos1 < pos2
449        assert!(k1 < k3); // tid1 < tid2
450        assert!(k1 < k4); // Same tid+pos, forward < reverse
451    }
452
453    #[test]
454    fn test_packed_coordinate_key_unmapped() {
455        let mapped = PackedCoordinateKey::new(0, 100, false);
456        let unmapped = PackedCoordinateKey::new(-1, -1, false);
457
458        assert!(mapped < unmapped); // Unmapped sorts last
459    }
460
461    #[test]
462    fn test_radix_sort_small() {
463        let mut entries: Vec<(u64, i32)> = vec![(5, 50), (3, 30), (8, 80), (1, 10), (4, 40)];
464
465        radix_sort_u64(&mut entries);
466
467        assert_eq!(entries[0], (1, 10));
468        assert_eq!(entries[1], (3, 30));
469        assert_eq!(entries[2], (4, 40));
470        assert_eq!(entries[3], (5, 50));
471        assert_eq!(entries[4], (8, 80));
472    }
473
474    #[test]
475    fn test_radix_sort_large() {
476        let mut entries: Vec<(u64, usize)> = (0..1000).rev().map(|i| (i as u64, i)).collect();
477
478        radix_sort_u64(&mut entries);
479
480        for (i, (key, _)) in entries.iter().enumerate() {
481            assert_eq!(*key, i as u64);
482        }
483    }
484
485    #[test]
486    fn test_insertion_sort() {
487        let mut arr = vec![5, 3, 8, 1, 4, 2, 7, 6];
488        binary_insertion_sort(&mut arr, std::cmp::Ord::cmp);
489        assert_eq!(arr, vec![1, 2, 3, 4, 5, 6, 7, 8]);
490    }
491
492    #[test]
493    fn test_insertion_sort_by_key() {
494        let mut arr: Vec<(i32, &str)> = vec![(5, "five"), (3, "three"), (8, "eight"), (1, "one")];
495        insertion_sort_by_key(&mut arr, |(k, _)| *k);
496        assert_eq!(arr[0].0, 1);
497        assert_eq!(arr[1].0, 3);
498        assert_eq!(arr[2].0, 5);
499        assert_eq!(arr[3].0, 8);
500    }
501
502    #[test]
503    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
504    fn test_hybrid_sort() {
505        let mut arr: Vec<i32> = (0..100).rev().collect();
506        hybrid_sort(&mut arr, std::cmp::Ord::cmp, false);
507        for (i, &v) in arr.iter().enumerate() {
508            assert_eq!(v, i as i32);
509        }
510    }
511
512    #[test]
513    fn test_heap_operations() {
514        // Test min-heap (use > as lt for min-heap)
515        let lt = |a: &i32, b: &i32| *a > *b;
516
517        let mut heap = vec![5, 3, 8, 1, 4, 2, 7, 6];
518        heap_make(&mut heap, &lt);
519
520        // Pop elements should come out in sorted order
521        let mut sorted = Vec::new();
522        let mut n = heap.len();
523        while n > 0 {
524            sorted.push(heap[0]);
525            n = heap_pop(&mut heap, n, &lt);
526        }
527
528        assert_eq!(sorted, vec![1, 2, 3, 4, 5, 6, 7, 8]);
529    }
530
531    #[test]
532    fn test_pack_coordinate_for_radix() {
533        let nref = 100u32;
534
535        // Mapped read
536        let k1 = pack_coordinate_for_radix(0, 100, false, nref);
537        let k2 = pack_coordinate_for_radix(0, 200, false, nref);
538        let k3 = pack_coordinate_for_radix(1, 100, false, nref);
539        let k4 = pack_coordinate_for_radix(0, 100, true, nref);
540
541        assert!(k1 < k2); // Same tid, pos1 < pos2
542        assert!(k1 < k3); // tid1 < tid2
543        assert!(k1 < k4); // Same tid+pos, forward < reverse
544
545        // Unmapped sorts last
546        let unmapped = pack_coordinate_for_radix(-1, -1, false, nref);
547        assert!(k1 < unmapped);
548        assert!(k3 < unmapped);
549    }
550
551    #[test]
552    fn test_bytes_needed() {
553        assert_eq!(bytes_needed_u64(0), 0);
554        assert_eq!(bytes_needed_u64(255), 1);
555        assert_eq!(bytes_needed_u64(256), 2);
556        assert_eq!(bytes_needed_u64(65535), 2);
557        assert_eq!(bytes_needed_u64(65536), 3);
558        assert_eq!(bytes_needed_u64(u64::MAX), 8);
559
560        assert_eq!(bytes_needed_u32(0), 0);
561        assert_eq!(bytes_needed_u32(255), 1);
562        assert_eq!(bytes_needed_u32(256), 2);
563        assert_eq!(bytes_needed_u32(u32::MAX), 4);
564    }
565
566    #[test]
567    fn test_radix_sort_adaptive() {
568        let mut entries: Vec<(u64, usize)> = vec![
569            (pack_coordinate_for_radix(1, 500, false, 100), 0),
570            (pack_coordinate_for_radix(0, 100, true, 100), 1),
571            (pack_coordinate_for_radix(0, 100, false, 100), 2),
572            (pack_coordinate_for_radix(2, 0, false, 100), 3),
573            (pack_coordinate_for_radix(-1, -1, false, 100), 4), // unmapped
574        ];
575
576        // Max tid=2, max pos=(500+1)<<1=1002
577        radix_sort_coordinate_adaptive(&mut entries, 100, 1002);
578
579        // Expected order: (0,100,false), (0,100,true), (1,500,false), (2,0,false), unmapped
580        assert_eq!(entries[0].1, 2); // tid=0, pos=100, rev=false
581        assert_eq!(entries[1].1, 1); // tid=0, pos=100, rev=true
582        assert_eq!(entries[2].1, 0); // tid=1, pos=500, rev=false
583        assert_eq!(entries[3].1, 3); // tid=2, pos=0, rev=false
584        assert_eq!(entries[4].1, 4); // unmapped
585    }
586}