use super::config::RadixSortConfig;
use super::advanced::RadixSortable;
use crate::algorithms::{Algorithm, AlgorithmStats};
use crate::error::{Result, ZiporaError};
use rayon::prelude::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::{
__m256i, _mm256_and_si256, _mm256_loadu_si256, _mm256_set1_epi32, _mm256_srlv_epi32,
_mm256_storeu_si256,
};
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
use std::arch::x86_64::{
__m512i, _mm512_and_si512, _mm512_loadu_si512, _mm512_set1_epi32, _mm512_srlv_epi32,
_mm512_storeu_si512,
};
pub struct RadixSort {
config: RadixSortConfig,
stats: AlgorithmStats,
}
impl RadixSort {
pub fn new() -> Self {
Self::with_config(RadixSortConfig::default())
}
pub fn with_config(config: RadixSortConfig) -> Self {
Self {
config,
stats: AlgorithmStats {
items_processed: 0,
processing_time_us: 0,
memory_used: 0,
used_parallel: false,
used_simd: false,
},
}
}
pub fn sort_u32(&mut self, data: &mut [u32]) -> Result<()> {
let start_time = std::time::Instant::now();
if data.is_empty() {
return Ok(());
}
let used_parallel =
data.len() >= self.config.parallel_threshold && self.config.use_parallel;
if used_parallel {
self.sort_u32_parallel(data)?;
} else {
self.sort_u32_sequential(data)?;
}
let elapsed = start_time.elapsed();
self.stats = AlgorithmStats {
items_processed: data.len(),
processing_time_us: elapsed.as_micros() as u64,
memory_used: self.estimate_memory_u32(data.len()),
used_parallel,
used_simd: self.config.use_simd,
};
Ok(())
}
pub fn sort_u64(&mut self, data: &mut [u64]) -> Result<()> {
let start_time = std::time::Instant::now();
if data.is_empty() {
return Ok(());
}
let used_parallel =
data.len() >= self.config.parallel_threshold && self.config.use_parallel;
if used_parallel {
self.sort_u64_parallel(data)?;
} else {
self.sort_u64_sequential(data)?;
}
let elapsed = start_time.elapsed();
self.stats = AlgorithmStats {
items_processed: data.len(),
processing_time_us: elapsed.as_micros() as u64,
memory_used: self.estimate_memory_u64(data.len()),
used_parallel,
used_simd: self.config.use_simd,
};
Ok(())
}
pub fn stats(&self) -> &AlgorithmStats {
&self.stats
}
pub fn sort_bytes(&mut self, data: &mut Vec<Vec<u8>>) -> Result<()> {
let start_time = std::time::Instant::now();
if data.is_empty() {
return Ok(());
}
self.sort_bytes_msd(data.as_mut_slice(), 0)?;
let elapsed = start_time.elapsed();
let total_bytes: usize = data.iter().map(|v| v.len()).sum();
self.stats = AlgorithmStats {
items_processed: data.len(),
processing_time_us: elapsed.as_micros() as u64,
memory_used: total_bytes + data.len() * std::mem::size_of::<Vec<u8>>(),
used_parallel: false, used_simd: false,
};
Ok(())
}
fn sort_u32_sequential(&self, data: &mut [u32]) -> Result<()> {
if data.len() <= self.config.use_counting_sort_threshold {
self.counting_sort_u32(data);
return Ok(());
}
let radix = 1usize << self.config.radix_bits;
let mask = (radix - 1) as u32;
let mut buffer = vec![0u32; data.len()];
let mut counts = vec![0usize; radix];
let max_passes = 32_usize.div_ceil(self.config.radix_bits);
for pass in 0..max_passes {
let shift = pass * self.config.radix_bits;
counts.fill(0);
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
{
if self.config.use_simd && data.len() >= 16 && shift < 24 {
if std::arch::is_x86_feature_detected!("avx512f")
&& std::arch::is_x86_feature_detected!("avx512bw")
{
unsafe {
self.count_digits_avx512(data, shift, mask, &mut counts);
}
} else {
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
counts[digit] += 1;
}
}
} else {
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
counts[digit] += 1;
}
}
}
#[cfg(not(all(target_arch = "x86_64", feature = "avx512")))]
{
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
counts[digit] += 1;
}
}
let mut pos = 0;
for count in counts.iter_mut() {
let old_count = *count;
*count = pos;
pos += old_count;
}
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
buffer[counts[digit]] = value;
counts[digit] += 1;
}
data.copy_from_slice(&buffer);
}
Ok(())
}
fn sort_u32_parallel(&self, data: &mut [u32]) -> Result<()> {
if data.len() < 2 * self.config.parallel_threshold {
return self.sort_u32_sequential(data);
}
let num_threads = rayon::current_num_threads();
let chunk_size = data.len().div_ceil(num_threads);
data.par_chunks_mut(chunk_size).for_each(|chunk| {
let temp_sorter = RadixSort::with_config(RadixSortConfig {
use_parallel: false,
..self.config.clone()
});
let _ = temp_sorter.sort_u32_sequential(chunk);
});
self.multiway_merge_u32_chunks(data, chunk_size)?;
Ok(())
}
fn sort_u64_sequential(&self, data: &mut [u64]) -> Result<()> {
let radix = 1usize << self.config.radix_bits;
let mask = (radix - 1) as u64;
let mut buffer = vec![0u64; data.len()];
let mut counts = vec![0usize; radix];
let max_passes = 64_usize.div_ceil(self.config.radix_bits);
for pass in 0..max_passes {
let shift = pass * self.config.radix_bits;
counts.fill(0);
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
counts[digit] += 1;
}
let mut pos = 0;
for count in counts.iter_mut() {
let old_count = *count;
*count = pos;
pos += old_count;
}
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
buffer[counts[digit]] = value;
counts[digit] += 1;
}
data.copy_from_slice(&buffer);
}
Ok(())
}
fn sort_u64_parallel(&self, data: &mut [u64]) -> Result<()> {
if data.len() < 2 * self.config.parallel_threshold {
return self.sort_u64_sequential(data);
}
let num_threads = rayon::current_num_threads();
let chunk_size = data.len().div_ceil(num_threads);
data.par_chunks_mut(chunk_size).for_each(|chunk| {
let temp_sorter = RadixSort::with_config(RadixSortConfig {
use_parallel: false,
..self.config.clone()
});
let _ = temp_sorter.sort_u64_sequential(chunk);
});
self.multiway_merge_u64_chunks(data, chunk_size)?;
Ok(())
}
fn counting_sort_u32(&self, data: &mut [u32]) {
if data.is_empty() {
return;
}
let max_val = *data.iter().max().expect("non-empty input") as usize;
let mut counts = vec![0usize; max_val + 1];
for &value in data.iter() {
counts[value as usize] += 1;
}
let mut index = 0;
for (value, &count) in counts.iter().enumerate() {
for _ in 0..count {
data[index] = value as u32;
index += 1;
}
}
}
fn sort_bytes_msd(&self, data: &mut [Vec<u8>], depth: usize) -> Result<()> {
if data.len() <= 1 {
return Ok(());
}
let mut counts = [0usize; 257];
for item in data.iter() {
let b = if depth < item.len() {
item[depth] as usize + 1
} else {
0
};
counts[b] += 1;
}
let mut offsets = [0usize; 257];
let mut current_pos = 0;
for i in 0..257 {
offsets[i] = current_pos;
current_pos += counts[i];
}
let mut next_free = offsets;
for b in 0..257 {
let end = if b == 256 { data.len() } else { offsets[b + 1] };
while next_free[b] < end {
let pos = next_free[b];
let item_b = if depth < data[pos].len() {
data[pos][depth] as usize + 1
} else {
0
};
if item_b == b {
next_free[b] += 1;
} else {
data.swap(pos, next_free[item_b]);
next_free[item_b] += 1;
}
}
}
for b in 1..257 {
let start = offsets[b];
let end = if b == 256 { data.len() } else { offsets[b + 1] };
if end - start > 1 {
self.sort_bytes_msd(&mut data[start..end], depth + 1)?;
}
}
Ok(())
}
fn estimate_memory_u32(&self, len: usize) -> usize {
let radix = 1usize << self.config.radix_bits;
len * std::mem::size_of::<u32>() + radix * std::mem::size_of::<usize>() }
fn estimate_memory_u64(&self, len: usize) -> usize {
let radix = 1usize << self.config.radix_bits;
len * std::mem::size_of::<u64>() + radix * std::mem::size_of::<usize>() }
fn multiway_merge_u32_chunks(&self, data: &mut [u32], chunk_size: usize) -> Result<()> {
use crate::algorithms::multiway_merge::{MultiWayMerge, SliceSource};
if data.is_empty() || chunk_size >= data.len() {
return Ok(());
}
let mut sources = Vec::new();
for chunk in data.chunks(chunk_size) {
sources.push(SliceSource::new(chunk));
}
let mut merger = MultiWayMerge::new();
let merged = merger.merge(sources)?;
data.copy_from_slice(&merged);
Ok(())
}
fn multiway_merge_u64_chunks(&self, data: &mut [u64], chunk_size: usize) -> Result<()> {
use crate::algorithms::multiway_merge::{MultiWayMerge, SliceSource};
if data.is_empty() || chunk_size >= data.len() {
return Ok(());
}
let mut sources = Vec::new();
for chunk in data.chunks(chunk_size) {
sources.push(SliceSource::new(chunk));
}
let mut merger = MultiWayMerge::new();
let merged = merger.merge(sources)?;
data.copy_from_slice(&merged);
Ok(())
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
#[target_feature(enable = "avx512f")]
unsafe fn count_digits_avx512(
&self,
data: &[u32],
shift: usize,
mask: u32,
counts: &mut [usize],
) {
let mut i = 0;
let shift_vec = _mm512_set1_epi32(shift as i32);
while i + 16 <= data.len() {
let values = unsafe { _mm512_loadu_si512(data[i..].as_ptr() as *const __m512i) };
let shifted = if shift > 0 {
_mm512_srlv_epi32(values, shift_vec) } else {
values
};
let mask_vec = _mm512_set1_epi32(mask as i32);
let digits = _mm512_and_si512(shifted, mask_vec);
let mut digit_array = [0u32; 16];
unsafe { _mm512_storeu_si512(digit_array.as_mut_ptr() as *mut __m512i, digits) };
for digit in digit_array.iter() {
counts[*digit as usize] += 1;
}
i += 16;
}
for &value in &data[i..] {
let digit = ((value >> shift) & mask) as usize;
counts[digit] += 1;
}
}
}
impl Default for RadixSort {
fn default() -> Self {
Self::new()
}
}
impl Algorithm for RadixSort {
type Config = RadixSortConfig;
type Input = Vec<u32>;
type Output = Vec<u32>;
fn execute(&self, config: &Self::Config, mut input: Self::Input) -> Result<Self::Output> {
let mut sorter = Self::with_config(config.clone());
sorter.sort_u32(&mut input)?;
Ok(input)
}
fn stats(&self) -> AlgorithmStats {
self.stats.clone()
}
fn estimate_memory(&self, input_size: usize) -> usize {
self.estimate_memory_u32(input_size)
}
fn supports_parallel(&self) -> bool {
true
}
fn supports_simd(&self) -> bool {
cfg!(feature = "simd")
}
}
pub struct KeyValueRadixSort<K, V> {
config: RadixSortConfig,
_phantom: std::marker::PhantomData<(K, V)>,
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub(super) struct KVPair {
key: u64,
index: usize,
}
impl RadixSortable for KVPair {
fn extract_key(&self) -> u64 {
self.key
}
fn get_byte(&self, position: usize) -> Option<u8> {
if position < 8 {
Some((self.key >> ((7 - position) * 8)) as u8)
} else {
None
}
}
fn max_bytes(&self) -> usize {
8
}
}
impl<K, V> KeyValueRadixSort<K, V>
where
K: Copy + Into<u64>,
V: Clone,
{
pub fn new() -> Self {
Self {
config: RadixSortConfig::default(),
_phantom: std::marker::PhantomData,
}
}
pub fn sort_by_key(&self, data: &mut [(K, V)]) -> Result<()> {
if data.is_empty() {
return Ok(());
}
let mut indices: Vec<KVPair> = data
.iter()
.enumerate()
.map(|(i, (k, _))| KVPair {
key: (*k).into(),
index: i,
})
.collect();
let mut config = super::config::AdvancedRadixSortConfig::default();
config.use_parallel = self.config.use_parallel;
config.parallel_threshold = self.config.parallel_threshold;
config.radix_bits = self.config.radix_bits;
let mut sorter = super::advanced::AdvancedRadixSort::<KVPair>::with_config(config)
.unwrap_or_else(|_| super::advanced::AdvancedRadixSort::new().unwrap());
sorter.sort(&mut indices)?;
let mut targets = vec![0; data.len()];
for (i, ki) in indices.iter().enumerate() {
targets[ki.index] = i;
}
for i in 0..data.len() {
while targets[i] != i {
let alt = targets[i];
data.swap(i, alt);
targets.swap(i, alt);
}
}
Ok(())
}
}
impl<K, V> Default for KeyValueRadixSort<K, V>
where
K: Copy + Into<u64>,
V: Clone,
{
fn default() -> Self {
Self::new()
}
}