use crate::sort::bam_fields;
use crate::sort::keys::{RawCoordinateKey, RawSortKey, SortContext};
use crate::sort::radix::bytes_needed_u64;
use crate::sort::segmented_buf::SegmentedBuf;
use fgumi_raw_bam::{RawRecord, RawRecordView};
use std::cmp::Ordering;
use std::io::{Read, Write};
pub trait ProbeableBuffer {
fn memory_usage(&self) -> usize;
fn allocated_capacity(&self) -> usize;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn num_segments(&self) -> usize;
}
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
pub struct PackedCoordinateKey(pub u64);
impl PackedCoordinateKey {
#[inline]
#[must_use]
#[allow(clippy::cast_sign_loss)]
pub fn new(tid: i32, pos: i32, reverse: bool, _nref: u32) -> Self {
if tid < 0 {
return Self::unmapped();
}
let tid = tid as u32;
#[allow(clippy::cast_lossless)] let key = (u64::from(tid) << 34)
| ((i64::from(pos) as u64).wrapping_add(1) << 1)
| u64::from(reverse);
Self(key)
}
#[inline]
#[must_use]
pub fn unmapped() -> Self {
Self(u64::MAX)
}
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct RecordRef {
pub sort_key: u64,
pub offset: u64,
pub len: u32,
padding: u32,
}
impl PartialEq for RecordRef {
fn eq(&self, other: &Self) -> bool {
self.sort_key == other.sort_key
}
}
impl Eq for RecordRef {}
impl PartialOrd for RecordRef {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for RecordRef {
fn cmp(&self, other: &Self) -> Ordering {
self.sort_key.cmp(&other.sort_key)
}
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
struct InlineHeader {
sort_key: u64,
record_len: u32,
padding: u32,
}
const HEADER_SIZE: usize = std::mem::size_of::<InlineHeader>();
pub struct RecordBuffer {
data: SegmentedBuf,
refs: Vec<RecordRef>,
nref: u32,
}
const SORT_SEGMENT_SIZE: usize = 256 * 1024 * 1024;
macro_rules! par_sort_into_chunks_impl {
($self:expr, $threads:expr, $sort_fn:ident, $key_fn:expr) => {{
use rayon::prelude::*;
let n = $self.refs.len();
if $threads <= 1 || n < RADIX_THRESHOLD * 2 || n <= 10_000 {
$sort_fn(&mut $self.refs);
let chunk = $self
.refs
.iter()
.map(|r| ($key_fn(r), RawRecord::from($self.get_record(r).to_vec())))
.collect();
return vec![chunk];
}
let chunk_size = n.div_ceil($threads);
$self.refs.par_chunks_mut(chunk_size).for_each(|chunk| {
$sort_fn(chunk);
});
$self
.refs
.chunks(chunk_size)
.map(|chunk| {
chunk
.iter()
.map(|r| ($key_fn(r), RawRecord::from($self.get_record(r).to_vec())))
.collect()
})
.collect()
}};
}
impl RecordBuffer {
#[must_use]
pub fn with_capacity(estimated_records: usize, estimated_bytes: usize, nref: u32) -> Self {
Self {
data: SegmentedBuf::with_capacity(
estimated_bytes + estimated_records * HEADER_SIZE,
SORT_SEGMENT_SIZE,
),
refs: Vec::with_capacity(estimated_records),
nref,
}
}
#[inline]
pub fn push_coordinate(&mut self, record: &[u8]) -> anyhow::Result<()> {
const MIN_BAM_RECORD_LEN: usize = 16;
anyhow::ensure!(
record.len() >= MIN_BAM_RECORD_LEN,
"BAM record is truncated: need at least {} bytes to extract coordinate fields, got {}",
MIN_BAM_RECORD_LEN,
record.len(),
);
let total_bytes = HEADER_SIZE + record.len();
anyhow::ensure!(
total_bytes <= SORT_SEGMENT_SIZE,
"BAM record of {} bytes (+ {} byte header) exceeds segment size of {} bytes; \
this is likely a malformed BAM file",
record.len(),
HEADER_SIZE,
SORT_SEGMENT_SIZE,
);
let len = u32::try_from(record.len())
.map_err(|_| anyhow::anyhow!("record length {} exceeds u32::MAX", record.len()))?;
let sort_key = extract_coordinate_key_inline(record, self.nref);
let offset = self.data.reserve_contiguous(total_bytes) as u64;
let header = InlineHeader { sort_key, record_len: len, padding: 0 };
self.data.extend_in_place(&header.sort_key.to_le_bytes());
self.data.extend_in_place(&header.record_len.to_le_bytes());
self.data.extend_in_place(&header.padding.to_le_bytes());
self.data.extend_in_place(record);
self.refs.push(RecordRef { sort_key, offset, len, padding: 0 });
Ok(())
}
pub fn sort(&mut self) {
radix_sort_record_refs(&mut self.refs);
}
pub fn par_sort(&mut self) {
parallel_radix_sort_record_refs(&mut self.refs);
}
pub fn par_sort_into_chunks(
&mut self,
threads: usize,
) -> Vec<Vec<(RawCoordinateKey, RawRecord)>> {
par_sort_into_chunks_impl!(self, threads, radix_sort_record_refs, |r: &RecordRef| {
RawCoordinateKey { sort_key: r.sort_key }
})
}
#[inline]
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn get_record(&self, r: &RecordRef) -> &[u8] {
self.data.slice(r.offset as usize + HEADER_SIZE, r.len as usize)
}
pub fn iter_sorted(&self) -> impl Iterator<Item = &[u8]> {
self.refs.iter().map(|r| self.get_record(r))
}
#[must_use]
pub fn refs(&self) -> &[RecordRef] {
&self.refs
}
#[must_use]
pub fn memory_usage(&self) -> usize {
self.data.len() + self.refs.len() * std::mem::size_of::<RecordRef>()
}
#[must_use]
pub fn allocated_capacity(&self) -> usize {
self.data.allocated_capacity() + self.refs.capacity() * std::mem::size_of::<RecordRef>()
}
#[must_use]
pub fn num_segments(&self) -> usize {
self.data.num_segments()
}
#[must_use]
pub fn len(&self) -> usize {
self.refs.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.refs.is_empty()
}
pub fn clear(&mut self) {
self.data.clear();
self.refs.clear();
}
#[must_use]
pub fn nref(&self) -> u32 {
self.nref
}
}
impl ProbeableBuffer for RecordBuffer {
fn memory_usage(&self) -> usize {
self.memory_usage()
}
fn allocated_capacity(&self) -> usize {
self.allocated_capacity()
}
fn len(&self) -> usize {
self.len()
}
fn num_segments(&self) -> usize {
self.num_segments()
}
}
#[inline]
#[must_use]
pub fn extract_coordinate_key_inline(bam: &[u8], nref: u32) -> u64 {
let tid = bam_fields::ref_id(bam);
let pos = bam_fields::pos(bam);
let reverse = RawRecordView::new(bam).flags() & bam_fields::flags::REVERSE != 0;
if tid < 0 {
PackedCoordinateKey::unmapped().0
} else {
PackedCoordinateKey::new(tid, pos, reverse, nref).0
}
}
#[repr(C)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, bytemuck::Pod, bytemuck::Zeroable)]
pub struct TemplateKey {
pub primary: u64,
pub secondary: u64,
pub cb_hash: u64,
pub tertiary: u64,
pub name_hash_upper: u64,
}
impl TemplateKey {
#[allow(clippy::too_many_arguments, clippy::cast_sign_loss)]
#[must_use]
pub fn new(
tid1: i32,
pos1: i32,
neg1: bool,
tid2: i32,
pos2: i32,
neg2: bool,
cb_hash: u64,
library: u32,
mi: (u64, bool),
name_hash: u64,
is_upper: bool,
) -> Self {
let tid1_packed = if tid1 == i32::MAX { 0xFFFF_u64 } else { (tid1.max(0) as u64) & 0xFFFF };
let tid2_packed = if tid2 == i32::MAX { 0xFFFF_u64 } else { (tid2.max(0) as u64) & 0xFFFF };
let pos1_packed = u64::from((pos1 as u32) ^ 0x8000_0000) & 0xFFFF_FFFF;
let pos2_packed = u64::from((pos2 as u32) ^ 0x8000_0000) & 0xFFFF_FFFF;
let p1 = (tid1_packed << 48) | (tid2_packed << 32) | pos1_packed;
let p2 = (pos2_packed << 32) | (u64::from(!neg1) << 1) | u64::from(!neg2);
let p3 = ((u64::from(library) & 0xFFFF) << 48)
| ((mi.0 & 0xFFFF_FFFF_FFFF) << 1)
| u64::from(!mi.1);
let p4 = (name_hash << 1) | u64::from(is_upper);
Self { primary: p1, secondary: p2, cb_hash, tertiary: p3, name_hash_upper: p4 }
}
#[must_use]
pub fn unmapped(name_hash: u64, cb_hash: u64, is_read2: bool) -> Self {
Self {
primary: u64::MAX,
secondary: u64::MAX,
cb_hash,
tertiary: 0,
name_hash_upper: (name_hash << 1) | u64::from(is_read2),
}
}
#[inline]
#[must_use]
pub fn zeroed() -> Self {
Self { primary: 0, secondary: 0, cb_hash: 0, tertiary: 0, name_hash_upper: 0 }
}
}
impl Default for TemplateKey {
fn default() -> Self {
Self::zeroed()
}
}
impl TemplateKey {
#[inline]
#[must_use]
pub fn to_bytes(&self) -> [u8; 40] {
let mut buf = [0u8; 40];
buf[0..8].copy_from_slice(&self.primary.to_le_bytes());
buf[8..16].copy_from_slice(&self.secondary.to_le_bytes());
buf[16..24].copy_from_slice(&self.cb_hash.to_le_bytes());
buf[24..32].copy_from_slice(&self.tertiary.to_le_bytes());
buf[32..40].copy_from_slice(&self.name_hash_upper.to_le_bytes());
buf
}
#[inline]
#[must_use]
pub fn from_bytes(buf: &[u8; 40]) -> Self {
Self {
primary: u64::from_le_bytes([
buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
]),
secondary: u64::from_le_bytes([
buf[8], buf[9], buf[10], buf[11], buf[12], buf[13], buf[14], buf[15],
]),
cb_hash: u64::from_le_bytes([
buf[16], buf[17], buf[18], buf[19], buf[20], buf[21], buf[22], buf[23],
]),
tertiary: u64::from_le_bytes([
buf[24], buf[25], buf[26], buf[27], buf[28], buf[29], buf[30], buf[31],
]),
name_hash_upper: u64::from_le_bytes([
buf[32], buf[33], buf[34], buf[35], buf[36], buf[37], buf[38], buf[39],
]),
}
}
}
impl PartialOrd for TemplateKey {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for TemplateKey {
fn cmp(&self, other: &Self) -> Ordering {
self.primary
.cmp(&other.primary)
.then_with(|| self.secondary.cmp(&other.secondary))
.then_with(|| self.cb_hash.cmp(&other.cb_hash))
.then_with(|| self.tertiary.cmp(&other.tertiary))
.then_with(|| self.name_hash_upper.cmp(&other.name_hash_upper))
}
}
impl TemplateKey {
#[inline]
#[must_use]
pub fn core_cmp(&self, other: &Self) -> Ordering {
self.primary
.cmp(&other.primary)
.then_with(|| self.secondary.cmp(&other.secondary))
.then_with(|| self.cb_hash.cmp(&other.cb_hash))
.then_with(|| self.tertiary.cmp(&other.tertiary))
}
}
impl RawSortKey for TemplateKey {
const SERIALIZED_SIZE: Option<usize> = Some(40);
fn extract(_bam: &[u8], _ctx: &SortContext) -> Self {
unreachable!(
"TemplateKey::extract() should not be called directly. \
Use extract_template_key_inline() with LibraryLookup instead."
)
}
#[inline]
fn write_to<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_all(&self.to_bytes())
}
#[inline]
fn read_from<R: Read>(reader: &mut R) -> std::io::Result<Self> {
let mut buf = [0u8; 40];
reader.read_exact(&mut buf)?;
Ok(Self::from_bytes(&buf))
}
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct TemplateInlineHeader {
pub key: TemplateKey,
pub record_len: u32,
pub padding: u32,
}
pub const TEMPLATE_HEADER_SIZE: usize = 48; const _: () = assert!(
std::mem::size_of::<TemplateInlineHeader>() == TEMPLATE_HEADER_SIZE,
"TEMPLATE_HEADER_SIZE must match size_of::<TemplateInlineHeader>()"
);
impl TemplateInlineHeader {
#[inline]
#[must_use]
pub fn to_bytes(&self) -> [u8; TEMPLATE_HEADER_SIZE] {
let mut buf = [0u8; TEMPLATE_HEADER_SIZE];
buf[0..8].copy_from_slice(&self.key.primary.to_le_bytes());
buf[8..16].copy_from_slice(&self.key.secondary.to_le_bytes());
buf[16..24].copy_from_slice(&self.key.cb_hash.to_le_bytes());
buf[24..32].copy_from_slice(&self.key.tertiary.to_le_bytes());
buf[32..40].copy_from_slice(&self.key.name_hash_upper.to_le_bytes());
buf[40..44].copy_from_slice(&self.record_len.to_le_bytes());
buf[44..48].copy_from_slice(&self.padding.to_le_bytes());
buf
}
#[inline]
#[must_use]
pub fn read_from(data: &[u8]) -> Self {
let primary = u64::from_le_bytes([
data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
]);
let secondary = u64::from_le_bytes([
data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
]);
let cb_hash = u64::from_le_bytes([
data[16], data[17], data[18], data[19], data[20], data[21], data[22], data[23],
]);
let tertiary = u64::from_le_bytes([
data[24], data[25], data[26], data[27], data[28], data[29], data[30], data[31],
]);
let name_hash_upper = u64::from_le_bytes([
data[32], data[33], data[34], data[35], data[36], data[37], data[38], data[39],
]);
let record_len = u32::from_le_bytes([data[40], data[41], data[42], data[43]]);
Self {
key: TemplateKey { primary, secondary, cb_hash, tertiary, name_hash_upper },
record_len,
padding: 0,
}
}
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct TemplateRecordRef {
pub key: TemplateKey,
pub offset: u64,
pub len: u32,
pub padding: u32,
}
impl PartialEq for TemplateRecordRef {
fn eq(&self, other: &Self) -> bool {
self.offset == other.offset
}
}
impl Eq for TemplateRecordRef {}
impl PartialOrd for TemplateRecordRef {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for TemplateRecordRef {
fn cmp(&self, other: &Self) -> Ordering {
self.offset.cmp(&other.offset)
}
}
pub struct TemplateRecordBuffer {
data: SegmentedBuf,
refs: Vec<TemplateRecordRef>,
}
impl TemplateRecordBuffer {
#[must_use]
pub fn with_capacity(estimated_records: usize, estimated_bytes: usize) -> Self {
let header_bytes = estimated_records * TEMPLATE_HEADER_SIZE;
Self {
data: SegmentedBuf::with_capacity(estimated_bytes + header_bytes, SORT_SEGMENT_SIZE),
refs: Vec::with_capacity(estimated_records),
}
}
#[inline]
pub fn push(&mut self, record: &[u8], key: TemplateKey) -> anyhow::Result<()> {
let total_bytes = TEMPLATE_HEADER_SIZE + record.len();
anyhow::ensure!(
total_bytes <= SORT_SEGMENT_SIZE,
"BAM record of {} bytes (+ {} byte header) exceeds segment size of {} bytes; \
this is likely a malformed BAM file",
record.len(),
TEMPLATE_HEADER_SIZE,
SORT_SEGMENT_SIZE,
);
let record_len = u32::try_from(record.len())
.map_err(|_| anyhow::anyhow!("record length {} exceeds u32::MAX", record.len()))?;
let offset = self.data.reserve_contiguous(total_bytes) as u64;
let header = TemplateInlineHeader { key, record_len, padding: 0 };
self.data.extend_in_place(&header.to_bytes());
self.data.extend_in_place(record);
self.refs.push(TemplateRecordRef { key, offset, len: record_len, padding: 0 });
Ok(())
}
pub fn sort(&mut self) {
radix_sort_template_refs(&mut self.refs);
}
pub fn par_sort(&mut self) {
parallel_radix_sort_template_refs(&mut self.refs);
}
#[inline]
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn get_record(&self, r: &TemplateRecordRef) -> &[u8] {
self.data.slice(r.offset as usize + TEMPLATE_HEADER_SIZE, r.len as usize)
}
pub fn iter_sorted(&self) -> impl Iterator<Item = &[u8]> {
self.refs.iter().map(|r| self.get_record(r))
}
#[must_use]
pub fn refs(&self) -> &[TemplateRecordRef] {
&self.refs
}
#[must_use]
pub fn memory_usage(&self) -> usize {
self.data.len() + self.refs.len() * std::mem::size_of::<TemplateRecordRef>()
}
#[must_use]
pub fn allocated_capacity(&self) -> usize {
self.data.allocated_capacity()
+ self.refs.capacity() * std::mem::size_of::<TemplateRecordRef>()
}
#[must_use]
pub fn num_segments(&self) -> usize {
self.data.num_segments()
}
#[must_use]
pub fn len(&self) -> usize {
self.refs.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.refs.is_empty()
}
#[inline]
#[must_use]
pub fn get_key(&self, r: &TemplateRecordRef) -> TemplateKey {
r.key
}
pub fn iter_sorted_keyed(&self) -> impl Iterator<Item = (TemplateKey, &[u8])> {
self.refs.iter().map(|r| (self.get_key(r), self.get_record(r)))
}
pub fn clear(&mut self) {
self.data.clear();
self.refs.clear();
}
pub fn par_sort_into_chunks(&mut self, threads: usize) -> Vec<Vec<(TemplateKey, RawRecord)>> {
par_sort_into_chunks_impl!(
self,
threads,
radix_sort_template_refs,
|r: &TemplateRecordRef| r.key
)
}
}
impl ProbeableBuffer for TemplateRecordBuffer {
fn memory_usage(&self) -> usize {
self.memory_usage()
}
fn allocated_capacity(&self) -> usize {
self.allocated_capacity()
}
fn len(&self) -> usize {
self.len()
}
fn num_segments(&self) -> usize {
self.num_segments()
}
}
const RADIX_THRESHOLD: usize = 256;
#[allow(clippy::uninit_vec, unsafe_code)]
pub fn radix_sort_record_refs(refs: &mut [RecordRef]) {
let n = refs.len();
if n < RADIX_THRESHOLD {
insertion_sort_refs(refs);
return;
}
let max_key = refs.iter().map(|r| r.sort_key).max().unwrap_or(0);
let bytes_needed =
if max_key == 0 { 0 } else { ((64 - max_key.leading_zeros()) as usize).div_ceil(8) };
if bytes_needed == 0 {
return; }
let mut aux: Vec<RecordRef> = Vec::with_capacity(n);
unsafe {
aux.set_len(n);
}
let mut src = refs as *mut [RecordRef];
let mut dst = aux.as_mut_slice() as *mut [RecordRef];
for byte_idx in 0..bytes_needed {
let src_slice = unsafe { &*src };
let dst_slice = unsafe { &mut *dst };
let mut counts = [0usize; 256];
for r in src_slice {
let byte = ((r.sort_key >> (byte_idx * 8)) & 0xFF) as usize;
counts[byte] += 1;
}
let mut total = 0;
for count in &mut counts {
let c = *count;
*count = total;
total += c;
}
for r in src_slice {
let byte = ((r.sort_key >> (byte_idx * 8)) & 0xFF) as usize;
let dest_idx = counts[byte];
counts[byte] += 1;
dst_slice[dest_idx] = *r;
}
std::mem::swap(&mut src, &mut dst);
}
if bytes_needed % 2 == 1 {
let src_slice = unsafe { &*src };
refs.copy_from_slice(src_slice);
}
}
pub fn parallel_radix_sort_record_refs(refs: &mut [RecordRef]) {
use rayon::prelude::*;
let n = refs.len();
if n < RADIX_THRESHOLD * 2 {
radix_sort_record_refs(refs);
return;
}
let n_threads = rayon::current_num_threads();
if n_threads > 1 && n > 10_000 {
let chunk_size = n.div_ceil(n_threads);
refs.par_chunks_mut(chunk_size).for_each(|chunk| {
radix_sort_record_refs(chunk);
});
let chunk_boundaries: Vec<_> = refs
.chunks(chunk_size)
.scan(0, |pos, chunk| {
let start = *pos;
*pos += chunk.len();
Some(start..*pos)
})
.collect();
merge_sorted_chunks(refs, &chunk_boundaries);
} else {
radix_sort_record_refs(refs);
}
}
fn merge_sorted_chunks(refs: &mut [RecordRef], chunk_ranges: &[std::ops::Range<usize>]) {
use crate::sort::radix::{heap_make, heap_sift_down};
struct HeapEntry {
key: u64,
chunk_idx: usize,
pos: usize,
}
if chunk_ranges.len() <= 1 {
return;
}
let n = refs.len();
let mut result: Vec<RecordRef> = Vec::with_capacity(n);
let mut heap: Vec<HeapEntry> = Vec::with_capacity(chunk_ranges.len());
for (chunk_idx, range) in chunk_ranges.iter().enumerate() {
if !range.is_empty() {
heap.push(HeapEntry { key: refs[range.start].sort_key, chunk_idx, pos: range.start });
}
}
if heap.is_empty() {
return;
}
let lt = |a: &HeapEntry, b: &HeapEntry| -> bool { (a.key, a.chunk_idx) > (b.key, b.chunk_idx) };
heap_make(&mut heap, <);
let mut heap_size = heap.len();
while heap_size > 0 {
let entry = &heap[0];
result.push(refs[entry.pos]);
let chunk_idx = entry.chunk_idx;
let next_pos = entry.pos + 1;
let range = &chunk_ranges[chunk_idx];
if next_pos < range.end {
heap[0] = HeapEntry { key: refs[next_pos].sort_key, chunk_idx, pos: next_pos };
heap_sift_down(&mut heap, 0, heap_size, <);
} else {
heap_size -= 1;
if heap_size > 0 {
heap.swap(0, heap_size);
heap_sift_down(&mut heap, 0, heap_size, <);
}
}
}
refs.copy_from_slice(&result);
}
fn insertion_sort_refs(refs: &mut [RecordRef]) {
for i in 1..refs.len() {
let key = refs[i].sort_key;
let insert_pos = refs[..i].partition_point(|r| r.sort_key <= key);
if insert_pos < i {
refs[insert_pos..=i].rotate_right(1);
}
}
}
#[allow(clippy::uninit_vec, unsafe_code)]
pub fn radix_sort_template_refs(refs: &mut [TemplateRecordRef]) {
let n = refs.len();
if n < RADIX_THRESHOLD {
insertion_sort_template_refs(refs);
return;
}
let mut aux: Vec<TemplateRecordRef> = Vec::with_capacity(n);
unsafe {
aux.set_len(n);
}
let max_primary = refs.iter().map(|r| r.key.primary).max().unwrap_or(0);
let bytes_needed = bytes_needed_u64(max_primary);
if bytes_needed > 0 {
radix_sort_template_field(refs, &mut aux, |r| r.key.primary, bytes_needed);
}
sub_sort_runs(refs, &mut aux, |r| r.key.primary, &REMAINING_FIELDS_AFTER_PRIMARY);
}
const REMAINING_FIELDS_AFTER_PRIMARY: [fn(&TemplateRecordRef) -> u64; 4] =
[|r| r.key.secondary, |r| r.key.cb_hash, |r| r.key.tertiary, |r| r.key.name_hash_upper];
const SUB_SORT_INSERTION_THRESHOLD: usize = 64;
fn sub_sort_runs<F>(
refs: &mut [TemplateRecordRef],
aux: &mut [TemplateRecordRef],
run_field: F,
remaining_fields: &[fn(&TemplateRecordRef) -> u64],
) where
F: Fn(&TemplateRecordRef) -> u64,
{
if remaining_fields.is_empty() {
return;
}
let n = refs.len();
let mut start = 0;
while start < n {
let val = run_field(&refs[start]);
let mut end = start + 1;
while end < n && run_field(&refs[end]) == val {
end += 1;
}
let run = &mut refs[start..end];
let run_len = run.len();
if run_len > 1 {
if run_len <= SUB_SORT_INSERTION_THRESHOLD {
insertion_sort_template_refs(run);
} else {
let next_field = remaining_fields[0];
let max_val = run.iter().map(next_field).max().unwrap_or(0);
let bytes_needed = bytes_needed_u64(max_val);
if bytes_needed > 0 {
let run_aux = &mut aux[start..end];
radix_sort_template_field(run, run_aux, next_field, bytes_needed);
}
if remaining_fields.len() > 1 {
sub_sort_runs(run, &mut aux[start..end], next_field, &remaining_fields[1..]);
}
}
}
start = end;
}
}
#[allow(clippy::uninit_vec, unsafe_code)]
fn radix_sort_template_field<F>(
refs: &mut [TemplateRecordRef],
aux: &mut [TemplateRecordRef],
get_field: F,
bytes_needed: usize,
) where
F: Fn(&TemplateRecordRef) -> u64,
{
let n = refs.len();
let mut src = refs as *mut [TemplateRecordRef];
let mut dst = aux as *mut [TemplateRecordRef];
for byte_idx in 0..bytes_needed {
let src_slice = unsafe { &*src };
let dst_slice = unsafe { &mut *dst };
let mut counts = [0usize; 256];
for r in src_slice {
let byte = ((get_field(r) >> (byte_idx * 8)) & 0xFF) as usize;
counts[byte] += 1;
}
let mut total = 0;
for count in &mut counts {
let c = *count;
*count = total;
total += c;
}
for item in src_slice.iter().take(n) {
let byte = ((get_field(item) >> (byte_idx * 8)) & 0xFF) as usize;
let dest_idx = counts[byte];
counts[byte] += 1;
dst_slice[dest_idx] = *item;
}
std::mem::swap(&mut src, &mut dst);
}
if bytes_needed % 2 == 1 {
let src_slice = unsafe { &*src };
refs.copy_from_slice(src_slice);
}
}
pub fn parallel_radix_sort_template_refs(refs: &mut [TemplateRecordRef]) {
use rayon::prelude::*;
let n = refs.len();
if n < RADIX_THRESHOLD * 2 {
radix_sort_template_refs(refs);
return;
}
let n_threads = rayon::current_num_threads();
if n_threads > 1 && n > 10_000 {
let chunk_size = n.div_ceil(n_threads);
refs.par_chunks_mut(chunk_size).for_each(|chunk| {
radix_sort_template_refs(chunk);
});
let chunk_boundaries: Vec<_> = refs
.chunks(chunk_size)
.scan(0, |pos, chunk| {
let start = *pos;
*pos += chunk.len();
Some(start..*pos)
})
.collect();
merge_sorted_template_chunks(refs, &chunk_boundaries);
} else {
radix_sort_template_refs(refs);
}
}
fn merge_sorted_template_chunks(
refs: &mut [TemplateRecordRef],
chunk_ranges: &[std::ops::Range<usize>],
) {
use crate::sort::radix::{heap_make, heap_sift_down};
struct HeapEntry {
key: TemplateKey,
chunk_idx: usize,
pos: usize,
}
if chunk_ranges.len() <= 1 {
return;
}
let n = refs.len();
let mut result: Vec<TemplateRecordRef> = Vec::with_capacity(n);
let mut heap: Vec<HeapEntry> = Vec::with_capacity(chunk_ranges.len());
for (chunk_idx, range) in chunk_ranges.iter().enumerate() {
if !range.is_empty() {
heap.push(HeapEntry { key: refs[range.start].key, chunk_idx, pos: range.start });
}
}
if heap.is_empty() {
return;
}
let lt = |a: &HeapEntry, b: &HeapEntry| -> bool {
match a.key.cmp(&b.key) {
std::cmp::Ordering::Greater => true,
std::cmp::Ordering::Less => false,
std::cmp::Ordering::Equal => a.chunk_idx > b.chunk_idx,
}
};
heap_make(&mut heap, <);
let mut heap_size = heap.len();
while heap_size > 0 {
let entry = &heap[0];
result.push(refs[entry.pos]);
let chunk_idx = entry.chunk_idx;
let next_pos = entry.pos + 1;
let range = &chunk_ranges[chunk_idx];
if next_pos < range.end {
heap[0] = HeapEntry { key: refs[next_pos].key, chunk_idx, pos: next_pos };
heap_sift_down(&mut heap, 0, heap_size, <);
} else {
heap_size -= 1;
if heap_size > 0 {
heap.swap(0, heap_size);
heap_sift_down(&mut heap, 0, heap_size, <);
}
}
}
refs.copy_from_slice(&result);
}
fn insertion_sort_template_refs(refs: &mut [TemplateRecordRef]) {
for i in 1..refs.len() {
let key = &refs[i].key;
let insert_pos = refs[..i].partition_point(|r| r.key <= *key);
if insert_pos < i {
refs[insert_pos..=i].rotate_right(1);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_packed_coord_key_ordering() {
assert!(
PackedCoordinateKey::new(0, 100, false, 10)
< PackedCoordinateKey::new(1, 100, false, 10)
);
assert!(
PackedCoordinateKey::new(0, 100, false, 10)
< PackedCoordinateKey::new(0, 200, false, 10)
);
assert!(
PackedCoordinateKey::new(0, 100, false, 10)
< PackedCoordinateKey::new(0, 100, true, 10)
);
assert!(PackedCoordinateKey::new(9, 1_000_000, true, 10) < PackedCoordinateKey::unmapped());
}
#[test]
fn test_template_key_ordering() {
let k1 = TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (1, true), 0, false);
let k2 = TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (2, true), 0, false);
assert!(k1 < k2);
let ka = TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (1, true), 0, false);
let kb = TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (1, false), 0, false);
assert!(ka < kb);
let lower = TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (1, true), 12345, false);
let upper = TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (1, true), 12345, true);
assert!(lower < upper, "is_upper=false should sort before is_upper=true");
let first_hash_lo =
TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (1, true), 100, false);
let first_hash_hi =
TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (1, true), 100, true);
let second_hash =
TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (1, true), 200, false);
assert!(first_hash_lo < second_hash);
assert!(first_hash_hi < second_hash);
}
#[test]
fn test_radix_sort_record_refs() {
let mut refs = vec![
RecordRef { sort_key: 100, offset: 0, len: 10, padding: 0 },
RecordRef { sort_key: 50, offset: 100, len: 10, padding: 0 },
RecordRef { sort_key: 200, offset: 200, len: 10, padding: 0 },
RecordRef { sort_key: 50, offset: 300, len: 10, padding: 0 }, RecordRef { sort_key: 1, offset: 400, len: 10, padding: 0 },
];
radix_sort_record_refs(&mut refs);
assert_eq!(refs[0].sort_key, 1);
assert_eq!(refs[1].sort_key, 50);
assert_eq!(refs[2].sort_key, 50);
assert_eq!(refs[3].sort_key, 100);
assert_eq!(refs[4].sort_key, 200);
assert_eq!(refs[1].offset, 100);
assert_eq!(refs[2].offset, 300);
}
#[test]
#[allow(clippy::cast_sign_loss)]
fn test_radix_sort_large() {
let mut refs: Vec<RecordRef> = (0..1000)
.map(|i| RecordRef {
sort_key: (999 - i) as u64, offset: i as u64 * 100,
len: 10,
padding: 0,
})
.collect();
radix_sort_record_refs(&mut refs);
for (i, r) in refs.iter().enumerate() {
assert_eq!(r.sort_key, i as u64, "Expected sort_key {i} at index {i}");
}
}
#[test]
fn test_radix_sort_empty() {
let mut refs: Vec<RecordRef> = Vec::new();
radix_sort_record_refs(&mut refs);
assert!(refs.is_empty());
}
#[test]
fn test_radix_sort_single() {
let mut refs = vec![RecordRef { sort_key: 42, offset: 0, len: 10, padding: 0 }];
radix_sort_record_refs(&mut refs);
assert_eq!(refs.len(), 1);
assert_eq!(refs[0].sort_key, 42);
}
#[test]
#[allow(clippy::cast_sign_loss)]
fn test_radix_sort_all_same_keys() {
let mut refs: Vec<RecordRef> = (0..100)
.map(|i| RecordRef { sort_key: 42, offset: i as u64 * 100, len: 10, padding: 0 })
.collect();
radix_sort_record_refs(&mut refs);
for (i, r) in refs.iter().enumerate() {
assert_eq!(r.sort_key, 42);
assert_eq!(r.offset, i as u64 * 100, "Stability violated at index {i}");
}
}
#[test]
#[allow(clippy::cast_sign_loss)]
fn test_radix_sort_all_zero_keys() {
let mut refs: Vec<RecordRef> = (0..50)
.map(|i| RecordRef { sort_key: 0, offset: i as u64 * 100, len: 10, padding: 0 })
.collect();
radix_sort_record_refs(&mut refs);
for (i, r) in refs.iter().enumerate() {
assert_eq!(r.sort_key, 0);
assert_eq!(r.offset, i as u64 * 100);
}
}
#[test]
fn test_radix_sort_max_key() {
let mut refs = vec![
RecordRef { sort_key: u64::MAX, offset: 0, len: 10, padding: 0 },
RecordRef { sort_key: 0, offset: 100, len: 10, padding: 0 },
RecordRef { sort_key: u64::MAX / 2, offset: 200, len: 10, padding: 0 },
];
radix_sort_record_refs(&mut refs);
assert_eq!(refs[0].sort_key, 0);
assert_eq!(refs[1].sort_key, u64::MAX / 2);
assert_eq!(refs[2].sort_key, u64::MAX);
}
#[test]
#[allow(clippy::cast_sign_loss)]
fn test_parallel_radix_sort() {
let mut refs: Vec<RecordRef> = (0..50_000)
.map(|i| RecordRef {
sort_key: (49_999 - i) as u64, offset: i as u64 * 100,
len: 10,
padding: 0,
})
.collect();
parallel_radix_sort_record_refs(&mut refs);
for (i, r) in refs.iter().enumerate() {
assert_eq!(r.sort_key, i as u64, "Expected sort_key {i} at index {i}");
}
}
#[test]
#[allow(clippy::cast_sign_loss)]
fn test_parallel_radix_sort_stability() {
let mut refs: Vec<RecordRef> = (0..20_000)
.map(|i| RecordRef {
sort_key: (i / 100) as u64, offset: i as u64, len: 10,
padding: 0,
})
.collect();
parallel_radix_sort_record_refs(&mut refs);
for i in 1..refs.len() {
assert!(refs[i - 1].sort_key <= refs[i].sort_key, "Not sorted at index {i}");
if refs[i - 1].sort_key == refs[i].sort_key {
assert!(refs[i - 1].offset < refs[i].offset, "Stability violated at index {i}");
}
}
}
#[test]
#[allow(clippy::cast_sign_loss)]
fn test_radix_sort_template_refs_stability() {
let key = TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (1, true), 12345, false);
let mut refs: Vec<TemplateRecordRef> = (0..500)
.map(|i| TemplateRecordRef {
key, offset: i as u64, len: 10,
padding: 0,
})
.collect();
radix_sort_template_refs(&mut refs);
for i in 1..refs.len() {
assert!(
refs[i - 1].offset < refs[i].offset,
"Template radix sort stability violated at index {}: offset {} should be < {}",
i,
refs[i - 1].offset,
refs[i].offset
);
}
}
#[test]
#[allow(clippy::cast_sign_loss)]
fn test_parallel_radix_sort_template_refs_stability() {
let mut refs: Vec<TemplateRecordRef> = (0..20_000)
.map(|i| {
let group = i / 100;
let key = TemplateKey::new(
0,
100,
false,
0,
200,
false,
0,
0,
(1, true),
group as u64, false,
);
TemplateRecordRef {
key,
offset: i as u64, len: 10,
padding: 0,
}
})
.collect();
parallel_radix_sort_template_refs(&mut refs);
for i in 1..refs.len() {
let prev_key = &refs[i - 1].key;
let curr_key = &refs[i].key;
assert!(prev_key <= curr_key, "Not sorted at index {i}");
if prev_key == curr_key {
assert!(
refs[i - 1].offset < refs[i].offset,
"Parallel template sort stability violated at index {}: offset {} should be < {}",
i,
refs[i - 1].offset,
refs[i].offset
);
}
}
}
#[test]
fn test_template_record_buffer_sort_stability() {
let mut buffer = TemplateRecordBuffer::with_capacity(100, 10000);
let key = TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (1, true), 12345, false);
for i in 0..100u8 {
let mut record = vec![
0, 0, 0, 0, 100, 0, 0, 0, 2, 0, 0, 0, 0, 0, 99, 0, 1, 0, 0, 0, 0, 0, 0, 0, 200, 0, 0, 0, 0, 0, 0, 0, b'A', 0, i, 0xFF, ];
while record.len() < 40 {
record.push(0);
}
buffer.push(&record, key).expect("push should succeed in tests");
}
buffer.sort();
let mut prev_seq_byte = None;
for rec in buffer.iter_sorted() {
let seq_byte = rec.get(34).copied().unwrap_or(0);
if let Some(prev) = prev_seq_byte {
assert!(
prev < seq_byte,
"TemplateRecordBuffer::sort() stability violated: {prev} should be < {seq_byte}"
);
}
prev_seq_byte = Some(seq_byte);
}
}
#[test]
fn test_template_key_cb_hash_ordering() {
let k1 = TemplateKey::new(0, 100, false, 0, 200, false, 10, 0, (1, true), 0, false);
let k2 = TemplateKey::new(0, 100, false, 0, 200, false, 20, 0, (1, true), 0, false);
assert!(k1 < k2, "lower cb_hash should sort before higher cb_hash");
}
#[test]
fn test_template_key_cb_hash_zero_sorts_first() {
let k_no_cb = TemplateKey::new(0, 100, false, 0, 200, false, 0, 0, (1, true), 0, false);
let k_cb = TemplateKey::new(0, 100, false, 0, 200, false, 42, 0, (1, true), 0, false);
assert!(k_no_cb < k_cb, "cb_hash=0 should sort before non-zero cb_hash");
}
#[test]
fn test_template_key_cb_hash_between_secondary_and_tertiary() {
let k1 = TemplateKey::new(0, 100, false, 0, 200, false, 10, 0, (1, true), 0, false);
let k2 = TemplateKey::new(0, 100, false, 0, 200, false, 20, 0, (0, true), 0, false);
assert!(k1 < k2, "cb_hash should be compared before tertiary (library/MI)");
}
#[test]
fn test_template_key_serialization_with_cb_hash() {
let key =
TemplateKey::new(1, 500, true, 2, 600, false, 0xDEAD_BEEF, 3, (7, true), 999, false);
let bytes = key.to_bytes();
assert_eq!(bytes.len(), 40);
let roundtrip = TemplateKey::from_bytes(&bytes);
assert_eq!(key, roundtrip, "serialization roundtrip should preserve cb_hash");
assert_eq!(roundtrip.cb_hash, 0xDEAD_BEEF);
}
#[test]
fn test_template_key_unmapped_with_cb_hash() {
let key = TemplateKey::unmapped(12345, 0xCAFE, false);
assert_eq!(key.primary, u64::MAX);
assert_eq!(key.secondary, u64::MAX);
assert_eq!(key.cb_hash, 0xCAFE, "unmapped should preserve cb_hash");
assert_eq!(key.tertiary, 0);
}
#[test]
fn test_template_key_core_cmp_includes_cb_hash() {
let k1 = TemplateKey::new(0, 100, false, 0, 200, false, 10, 0, (1, true), 0, false);
let k2 = TemplateKey::new(0, 100, false, 0, 200, false, 20, 0, (1, true), 0, false);
assert_eq!(k1.core_cmp(&k2), std::cmp::Ordering::Less, "core_cmp should include cb_hash");
let k3 = TemplateKey::new(0, 100, false, 0, 200, false, 10, 0, (1, true), 100, false);
let k4 = TemplateKey::new(0, 100, false, 0, 200, false, 10, 0, (1, true), 200, false);
assert_eq!(k3.core_cmp(&k4), std::cmp::Ordering::Equal, "core_cmp should ignore name_hash");
}
#[test]
fn test_template_key_zeroed_has_zero_cb_hash() {
let key = TemplateKey::zeroed();
assert_eq!(key.cb_hash, 0);
}
#[test]
fn test_template_key_default_has_zero_cb_hash() {
let key = TemplateKey::default();
assert_eq!(key.cb_hash, 0);
}
fn make_bam_record(index: u16) -> Vec<u8> {
let mut record = vec![
0, 0, 0, 0, 100, 0, 0, 0, 2, 0, 0, 0, 0, 0, 99, 0, 1, 0, 0, 0, 0, 0, 0, 0, 200, 0, 0, 0, 0, 0, 0, 0, b'A', 0, ];
record.push((index & 0xFF) as u8);
record.push((index >> 8) as u8);
record.push(0xFF); while record.len() < 40 {
record.push(0);
}
record
}
#[test]
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
fn test_par_sort_into_chunks_single_threaded_fallback() {
let n = 100;
let mut buffer = TemplateRecordBuffer::with_capacity(n, n * 50);
for i in 0..n {
let key = TemplateKey::new(
0,
(n - i) as i32,
false,
0,
200,
false,
0,
0,
(1, true),
0,
false,
);
buffer.push(&make_bam_record(i as u16), key).expect("push should succeed in tests");
}
let chunks = buffer.par_sort_into_chunks(1);
assert_eq!(chunks.len(), 1, "single-threaded should produce exactly 1 chunk");
assert_eq!(chunks[0].len(), n, "the single chunk should contain all records");
for i in 1..chunks[0].len() {
assert!(
chunks[0][i - 1].0 <= chunks[0][i].0,
"single chunk should be sorted at index {i}"
);
}
}
#[test]
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
fn test_par_sort_into_chunks_parallel_path() {
let n: usize = 10_500; let pool = rayon::ThreadPoolBuilder::new()
.num_threads(4)
.build()
.expect("failed to build rayon thread pool");
let chunks = pool.install(|| {
let mut buffer = TemplateRecordBuffer::with_capacity(n, n * 50);
for i in 0..n {
let key = TemplateKey::new(
0,
(n - i) as i32,
false,
0,
200,
false,
0,
0,
(1, true),
0,
false,
);
buffer.push(&make_bam_record(i as u16), key).expect("push should succeed in tests");
}
buffer.par_sort_into_chunks(4)
});
assert!(
chunks.len() > 1,
"expected multiple chunks from parallel path, got {}",
chunks.len()
);
for (ci, chunk) in chunks.iter().enumerate() {
for i in 1..chunk.len() {
assert!(chunk[i - 1].0 <= chunk[i].0, "chunk {ci} not sorted at index {i}");
}
}
let total: usize = chunks.iter().map(Vec::len).sum();
assert_eq!(total, n, "total records across chunks should equal input count");
}
fn make_coordinate_bam_record(tid: i32, pos: i32) -> Vec<u8> {
let mut record = Vec::with_capacity(40);
record.extend_from_slice(&tid.to_le_bytes()); record.extend_from_slice(&pos.to_le_bytes()); record.push(2); record.push(0); record.extend_from_slice(&0u16.to_le_bytes()); record.extend_from_slice(&0u16.to_le_bytes()); record.extend_from_slice(&0u16.to_le_bytes()); record.extend_from_slice(&1u32.to_le_bytes()); record.extend_from_slice(&(-1i32).to_le_bytes()); record.extend_from_slice(&(-1i32).to_le_bytes()); record.extend_from_slice(&0i32.to_le_bytes()); record.push(b'A'); record.push(0); while record.len() < 40 {
record.push(0);
}
record
}
#[test]
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
fn test_par_sort_into_chunks_coordinate_single_threaded() {
let nref = 10u32;
let n = 100;
let mut buffer = RecordBuffer::with_capacity(n, n * 50, nref);
for i in 0..n {
let pos = (n - i) as i32;
buffer
.push_coordinate(&make_coordinate_bam_record(0, pos))
.expect("push_coordinate should succeed in tests");
}
let chunks = buffer.par_sort_into_chunks(1);
assert_eq!(chunks.len(), 1, "single-threaded should produce exactly 1 chunk");
assert_eq!(chunks[0].len(), n, "the single chunk should contain all records");
for i in 1..chunks[0].len() {
assert!(
chunks[0][i - 1].0 <= chunks[0][i].0,
"single chunk should be sorted at index {i}"
);
}
}
#[test]
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
fn test_par_sort_into_chunks_coordinate_parallel() {
let nref = 10u32;
let n: usize = 10_500; let pool = rayon::ThreadPoolBuilder::new()
.num_threads(4)
.build()
.expect("failed to build rayon thread pool");
let chunks = pool.install(|| {
let mut buffer = RecordBuffer::with_capacity(n, n * 50, nref);
for i in 0..n {
let pos = (n - i) as i32;
buffer
.push_coordinate(&make_coordinate_bam_record(0, pos))
.expect("push_coordinate should succeed in tests");
}
buffer.par_sort_into_chunks(4)
});
assert!(
chunks.len() > 1,
"expected multiple chunks from parallel path, got {}",
chunks.len()
);
for (ci, chunk) in chunks.iter().enumerate() {
for i in 1..chunk.len() {
assert!(chunk[i - 1].0 <= chunk[i].0, "chunk {ci} not sorted at index {i}");
}
}
let total: usize = chunks.iter().map(Vec::len).sum();
assert_eq!(total, n, "total records across chunks should equal input count");
}
mod proptest_msd {
use super::*;
use proptest::{prop_assert_eq, proptest};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
fn make_ref(key: TemplateKey, offset: u64) -> TemplateRecordRef {
TemplateRecordRef { key, offset, len: 10, padding: 0 }
}
fn hash_pair(a: u64, b: u64) -> u64 {
let mut h = DefaultHasher::new();
(a, b).hash(&mut h);
h.finish()
}
proptest! {
#[test]
fn msd_sort_matches_reference(
n_primaries in 1_usize..=4,
seed in proptest::num::u64::ANY,
) {
let primaries: Vec<u64> = (0..n_primaries)
.map(|i| hash_pair(seed, i as u64))
.collect();
let n = 300;
let mut refs: Vec<TemplateRecordRef> = Vec::with_capacity(n);
for i in 0..n {
let h = hash_pair(seed, (i + n_primaries) as u64);
let primary = primaries[i % n_primaries];
let key = TemplateKey {
primary,
secondary: h,
cb_hash: h.wrapping_mul(2_654_435_761),
tertiary: h.wrapping_mul(40503),
name_hash_upper: h.rotate_left(17),
};
refs.push(make_ref(key, i as u64));
}
let mut expected = refs.clone();
expected.sort_by(|a, b| a.key.cmp(&b.key));
radix_sort_template_refs(&mut refs);
for i in 0..n {
prop_assert_eq!(
refs[i].key, expected[i].key,
"Mismatch at index {}: MSD key {:?} != reference {:?}",
i, refs[i].key, expected[i].key
);
}
}
#[test]
fn msd_sort_matches_reference_random_keys(
seed in proptest::num::u64::ANY,
) {
let n = 500;
let mut refs: Vec<TemplateRecordRef> = Vec::with_capacity(n);
for i in 0..n {
let h = hash_pair(seed, i as u64);
let key = TemplateKey {
primary: h,
secondary: h.wrapping_mul(6_364_136_223_846_793_005),
cb_hash: h.wrapping_mul(2_654_435_761),
tertiary: h.wrapping_mul(40503),
name_hash_upper: h.rotate_left(17),
};
refs.push(make_ref(key, i as u64));
}
let mut expected = refs.clone();
expected.sort_by(|a, b| a.key.cmp(&b.key));
radix_sort_template_refs(&mut refs);
for i in 0..n {
prop_assert_eq!(
refs[i].key, expected[i].key,
"Mismatch at index {}: MSD key {:?} != reference {:?}",
i, refs[i].key, expected[i].key
);
}
}
}
}
}