use crate::algorithms::{Algorithm, AlgorithmStats};
use crate::error::{Result, ZiporaError};
use crate::memory::{SecureMemoryPool, SecurePoolConfig};
use rayon::prelude::*;
use std::sync::Arc;
use std::time::Instant;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::{
_mm256_and_si256, _mm256_loadu_si256,
_mm256_set1_epi32, _mm256_srlv_epi32, _mm256_storeu_si256, __m256i,
};
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
use std::arch::x86_64::{
__m512i, _mm512_and_si512, _mm512_loadu_si512, _mm512_set1_epi32, _mm512_srlv_epi32,
_mm512_storeu_si512,
};
#[derive(Debug, Clone)]
pub struct RadixSortConfig {
pub use_parallel: bool,
pub parallel_threshold: usize,
pub radix_bits: usize,
pub use_counting_sort_threshold: usize,
pub use_simd: bool,
}
impl Default for RadixSortConfig {
fn default() -> Self {
Self {
use_parallel: true,
parallel_threshold: 10_000,
radix_bits: 8,
use_counting_sort_threshold: 256,
use_simd: cfg!(feature = "simd"),
}
}
}
pub struct RadixSort {
config: RadixSortConfig,
stats: AlgorithmStats,
}
impl RadixSort {
pub fn new() -> Self {
Self::with_config(RadixSortConfig::default())
}
pub fn with_config(config: RadixSortConfig) -> Self {
Self {
config,
stats: AlgorithmStats {
items_processed: 0,
processing_time_us: 0,
memory_used: 0,
used_parallel: false,
used_simd: false,
},
}
}
pub fn sort_u32(&mut self, data: &mut [u32]) -> Result<()> {
let start_time = Instant::now();
if data.is_empty() {
return Ok(());
}
let used_parallel =
data.len() >= self.config.parallel_threshold && self.config.use_parallel;
if used_parallel {
self.sort_u32_parallel(data)?;
} else {
self.sort_u32_sequential(data)?;
}
let elapsed = start_time.elapsed();
self.stats = AlgorithmStats {
items_processed: data.len(),
processing_time_us: elapsed.as_micros() as u64,
memory_used: self.estimate_memory_u32(data.len()),
used_parallel,
used_simd: self.config.use_simd,
};
Ok(())
}
pub fn sort_u64(&mut self, data: &mut [u64]) -> Result<()> {
let start_time = Instant::now();
if data.is_empty() {
return Ok(());
}
let used_parallel =
data.len() >= self.config.parallel_threshold && self.config.use_parallel;
if used_parallel {
self.sort_u64_parallel(data)?;
} else {
self.sort_u64_sequential(data)?;
}
let elapsed = start_time.elapsed();
self.stats = AlgorithmStats {
items_processed: data.len(),
processing_time_us: elapsed.as_micros() as u64,
memory_used: self.estimate_memory_u64(data.len()),
used_parallel,
used_simd: self.config.use_simd,
};
Ok(())
}
pub fn sort_bytes(&mut self, data: &mut Vec<Vec<u8>>) -> Result<()> {
let start_time = Instant::now();
if data.is_empty() {
return Ok(());
}
self.sort_bytes_msd(data, 0)?;
let elapsed = start_time.elapsed();
let total_bytes: usize = data.iter().map(|v| v.len()).sum();
self.stats = AlgorithmStats {
items_processed: data.len(),
processing_time_us: elapsed.as_micros() as u64,
memory_used: total_bytes + data.len() * std::mem::size_of::<Vec<u8>>(),
used_parallel: false, used_simd: false,
};
Ok(())
}
fn sort_u32_sequential(&self, data: &mut [u32]) -> Result<()> {
if data.len() <= self.config.use_counting_sort_threshold {
self.counting_sort_u32(data);
return Ok(());
}
let radix = 1usize << self.config.radix_bits;
let mask = (radix - 1) as u32;
let mut buffer = vec![0u32; data.len()];
let mut counts = vec![0usize; radix];
let max_passes = 32_usize.div_ceil(self.config.radix_bits);
for pass in 0..max_passes {
let shift = pass * self.config.radix_bits;
counts.fill(0);
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
{
if self.config.use_simd && data.len() >= 16 && shift < 24 {
if std::arch::is_x86_feature_detected!("avx512f")
&& std::arch::is_x86_feature_detected!("avx512bw")
{
unsafe {
self.count_digits_avx512(data, shift, mask, &mut counts);
}
} else {
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
counts[digit] += 1;
}
}
} else {
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
counts[digit] += 1;
}
}
}
#[cfg(not(all(target_arch = "x86_64", feature = "avx512")))]
{
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
counts[digit] += 1;
}
}
let mut pos = 0;
for count in counts.iter_mut() {
let old_count = *count;
*count = pos;
pos += old_count;
}
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
buffer[counts[digit]] = value;
counts[digit] += 1;
}
data.copy_from_slice(&buffer);
}
Ok(())
}
fn sort_u32_parallel(&self, data: &mut [u32]) -> Result<()> {
if data.len() < 2 * self.config.parallel_threshold {
return self.sort_u32_sequential(data);
}
let num_threads = rayon::current_num_threads();
let chunk_size = (data.len() + num_threads - 1) / num_threads;
data.par_chunks_mut(chunk_size).for_each(|chunk| {
let temp_sorter = RadixSort::with_config(RadixSortConfig {
use_parallel: false,
..self.config.clone()
});
let _ = temp_sorter.sort_u32_sequential(chunk);
});
self.multiway_merge_u32_chunks(data, chunk_size)?;
Ok(())
}
fn sort_u64_sequential(&self, data: &mut [u64]) -> Result<()> {
let radix = 1usize << self.config.radix_bits;
let mask = (radix - 1) as u64;
let mut buffer = vec![0u64; data.len()];
let mut counts = vec![0usize; radix];
let max_passes = (64 + self.config.radix_bits - 1) / self.config.radix_bits;
for pass in 0..max_passes {
let shift = pass * self.config.radix_bits;
counts.fill(0);
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
counts[digit] += 1;
}
let mut pos = 0;
for count in counts.iter_mut() {
let old_count = *count;
*count = pos;
pos += old_count;
}
for &value in data.iter() {
let digit = ((value >> shift) & mask) as usize;
buffer[counts[digit]] = value;
counts[digit] += 1;
}
data.copy_from_slice(&buffer);
}
Ok(())
}
fn sort_u64_parallel(&self, data: &mut [u64]) -> Result<()> {
if data.len() < 2 * self.config.parallel_threshold {
return self.sort_u64_sequential(data);
}
let num_threads = rayon::current_num_threads();
let chunk_size = (data.len() + num_threads - 1) / num_threads;
data.par_chunks_mut(chunk_size).for_each(|chunk| {
let temp_sorter = RadixSort::with_config(RadixSortConfig {
use_parallel: false,
..self.config.clone()
});
let _ = temp_sorter.sort_u64_sequential(chunk);
});
self.multiway_merge_u64_chunks(data, chunk_size)?;
Ok(())
}
fn counting_sort_u32(&self, data: &mut [u32]) {
if data.is_empty() {
return;
}
let max_val = *data.iter().max().expect("non-empty input") as usize;
let mut counts = vec![0usize; max_val + 1];
for &value in data.iter() {
counts[value as usize] += 1;
}
let mut index = 0;
for (value, &count) in counts.iter().enumerate() {
for _ in 0..count {
data[index] = value as u32;
index += 1;
}
}
}
fn sort_bytes_msd(&self, data: &mut Vec<Vec<u8>>, depth: usize) -> Result<()> {
if data.len() <= 1 {
return Ok(());
}
let mut buckets: Vec<Vec<Vec<u8>>> = vec![Vec::new(); 257];
for item in data.drain(..) {
let bucket_index = if depth < item.len() {
item[depth] as usize + 1 } else {
0 };
buckets[bucket_index].push(item);
}
for (i, mut bucket) in buckets.into_iter().enumerate() {
if bucket.len() > 1 && i > 0 {
self.sort_bytes_msd(&mut bucket, depth + 1)?;
}
data.extend(bucket);
}
Ok(())
}
fn estimate_memory_u32(&self, len: usize) -> usize {
let radix = 1usize << self.config.radix_bits;
len * std::mem::size_of::<u32>() + radix * std::mem::size_of::<usize>() }
fn estimate_memory_u64(&self, len: usize) -> usize {
let radix = 1usize << self.config.radix_bits;
len * std::mem::size_of::<u64>() + radix * std::mem::size_of::<usize>() }
fn multiway_merge_u32_chunks(&self, data: &mut [u32], chunk_size: usize) -> Result<()> {
use crate::algorithms::multiway_merge::{MultiWayMerge, VectorSource};
let mut sources = Vec::new();
let mut chunks_vec = Vec::new();
for chunk in data.chunks(chunk_size) {
chunks_vec.push(chunk.to_vec());
}
for chunk in chunks_vec {
sources.push(VectorSource::new(chunk));
}
let mut merger = MultiWayMerge::new();
let merged = merger.merge(sources)?;
data.copy_from_slice(&merged);
Ok(())
}
fn multiway_merge_u64_chunks(&self, data: &mut [u64], chunk_size: usize) -> Result<()> {
use crate::algorithms::multiway_merge::{MultiWayMerge, VectorSource};
let mut sources = Vec::new();
let mut chunks_vec = Vec::new();
for chunk in data.chunks(chunk_size) {
chunks_vec.push(chunk.to_vec());
}
for chunk in chunks_vec {
sources.push(VectorSource::new(chunk));
}
let mut merger = MultiWayMerge::new();
let merged = merger.merge(sources)?;
data.copy_from_slice(&merged);
Ok(())
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
#[target_feature(enable = "avx512f")]
unsafe fn count_digits_avx512(
&self,
data: &[u32],
shift: usize,
mask: u32,
counts: &mut [usize],
) {
let mut i = 0;
let shift_vec = _mm512_set1_epi32(shift as i32);
while i + 16 <= data.len() {
let values = unsafe { _mm512_loadu_si512(data[i..].as_ptr() as *const __m512i) };
let shifted = if shift > 0 {
_mm512_srlv_epi32(values, shift_vec) } else {
values
};
let mask_vec = _mm512_set1_epi32(mask as i32);
let digits = _mm512_and_si512(shifted, mask_vec);
let mut digit_array = [0u32; 16];
unsafe { _mm512_storeu_si512(digit_array.as_mut_ptr() as *mut __m512i, digits) };
for digit in digit_array.iter() {
counts[*digit as usize] += 1;
}
i += 16;
}
for &value in &data[i..] {
let digit = ((value >> shift) & mask) as usize;
counts[digit] += 1;
}
}
}
impl Default for RadixSort {
fn default() -> Self {
Self::new()
}
}
impl Algorithm for RadixSort {
type Config = RadixSortConfig;
type Input = Vec<u32>;
type Output = Vec<u32>;
fn execute(&self, config: &Self::Config, mut input: Self::Input) -> Result<Self::Output> {
let mut sorter = Self::with_config(config.clone());
sorter.sort_u32(&mut input)?;
Ok(input)
}
fn stats(&self) -> AlgorithmStats {
self.stats.clone()
}
fn estimate_memory(&self, input_size: usize) -> usize {
self.estimate_memory_u32(input_size)
}
fn supports_parallel(&self) -> bool {
true
}
fn supports_simd(&self) -> bool {
cfg!(feature = "simd")
}
}
pub struct KeyValueRadixSort<K, V> {
config: RadixSortConfig,
_phantom: std::marker::PhantomData<(K, V)>,
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
struct KVPair {
key: u64,
index: usize,
}
impl RadixSortable for KVPair {
fn extract_key(&self) -> u64 {
self.key
}
fn get_byte(&self, position: usize) -> Option<u8> {
if position < 8 {
Some((self.key >> ((7 - position) * 8)) as u8)
} else {
None
}
}
fn max_bytes(&self) -> usize {
8
}
}
impl<K, V> KeyValueRadixSort<K, V>
where
K: Copy + Into<u64>,
V: Clone,
{
pub fn new() -> Self {
Self {
config: RadixSortConfig::default(),
_phantom: std::marker::PhantomData,
}
}
pub fn sort_by_key(&self, data: &mut [(K, V)]) -> Result<()> {
if data.is_empty() {
return Ok(());
}
let mut indices: Vec<KVPair> = data
.iter()
.enumerate()
.map(|(i, (k, _))| KVPair { key: (*k).into(), index: i })
.collect();
let mut config = AdvancedRadixSortConfig::default();
config.use_parallel = self.config.use_parallel;
config.parallel_threshold = self.config.parallel_threshold;
config.radix_bits = self.config.radix_bits;
let mut sorter = AdvancedRadixSort::<KVPair>::with_config(config)
.unwrap_or_else(|_| AdvancedRadixSort::new().unwrap());
sorter.sort(&mut indices)?;
let original_data: Vec<(K, V)> = data.to_vec();
for (i, ki) in indices.iter().enumerate() {
data[i] = original_data[ki.index].clone();
}
Ok(())
}
}
impl<K, V> Default for KeyValueRadixSort<K, V>
where
K: Copy + Into<u64>,
V: Clone,
{
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CpuFeatures {
pub avx2: bool,
pub bmi2: bool,
pub popcnt: bool,
pub avx512f: bool,
pub avx512bw: bool,
}
impl CpuFeatures {
pub fn detect() -> Self {
#[cfg(target_arch = "x86_64")]
{
Self {
avx2: std::arch::is_x86_feature_detected!("avx2"),
bmi2: std::arch::is_x86_feature_detected!("bmi2"),
popcnt: std::arch::is_x86_feature_detected!("popcnt"),
avx512f: std::arch::is_x86_feature_detected!("avx512f"),
avx512bw: std::arch::is_x86_feature_detected!("avx512bw"),
}
}
#[cfg(not(target_arch = "x86_64"))]
{
Self {
avx2: false,
bmi2: false,
popcnt: false,
avx512f: false,
avx512bw: false,
}
}
}
pub fn has_advanced_simd(&self) -> bool {
self.avx2 && self.bmi2
}
pub fn has_avx512(&self) -> bool {
self.avx512f && self.avx512bw
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SortingStrategy {
Insertion,
TimSort,
LsdRadix,
MsdRadix,
Adaptive,
}
#[derive(Debug, Clone)]
pub struct DataCharacteristics {
pub size: usize,
pub is_nearly_sorted: bool,
pub is_string_data: bool,
pub estimated_entropy: f64,
pub max_key_bits: usize,
}
impl DataCharacteristics {
pub fn analyze_integers<T>(data: &[T]) -> Self
where
T: Copy + Into<u64> + Ord,
{
let size = data.len();
let mut inversions = 0usize;
let mut sorted_runs = 0usize;
let mut current_run_length = 1usize;
for i in 1..data.len() {
if data[i] >= data[i - 1] {
current_run_length += 1;
} else {
inversions += 1;
if current_run_length > 1 {
sorted_runs += 1;
}
current_run_length = 1;
}
}
if current_run_length > 1 {
sorted_runs += 1;
}
let inversion_threshold = if size <= 10 {
std::cmp::max(1, size / 5) } else {
size / 10 };
let is_nearly_sorted = inversions < inversion_threshold && (sorted_runs >= 1 || inversions == 0);
let mut max_val = 0u64;
for &item in data {
max_val = max_val.max(item.into());
}
let max_key_bits = if max_val == 0 { 1 } else { 64 - max_val.leading_zeros() as usize };
let estimated_entropy = if size > 0 { (size as f64).log2() } else { 0.0 };
Self {
size,
is_nearly_sorted,
is_string_data: false,
estimated_entropy,
max_key_bits,
}
}
pub fn analyze_strings(data: &[Vec<u8>]) -> Self {
let size = data.len();
let mut inversions = 0usize;
for i in 1..data.len() {
if data[i] < data[i - 1] {
inversions += 1;
}
}
let is_nearly_sorted = inversions < std::cmp::max(1, size / 10);
let max_length = data.iter().map(|s| s.len()).max().unwrap_or(0);
let estimated_entropy = if size > 0 { (size as f64).log2() } else { 0.0 };
Self {
size,
is_nearly_sorted,
is_string_data: true,
estimated_entropy,
max_key_bits: max_length * 8, }
}
}
#[derive(Debug, Clone)]
pub struct AdvancedRadixSortConfig {
pub use_secure_memory: bool,
pub adaptive_strategy: bool,
pub force_strategy: Option<SortingStrategy>,
pub use_parallel: bool,
pub parallel_threshold: usize,
pub num_threads: usize,
pub radix_bits: usize,
pub insertion_sort_threshold: usize,
pub counting_sort_threshold: usize,
pub use_simd: bool,
pub use_work_stealing: bool,
pub prefetch_distance: usize,
pub cache_alignment: usize,
pub memory_budget: usize,
pub enable_profiling: bool,
}
impl Default for AdvancedRadixSortConfig {
fn default() -> Self {
Self {
use_secure_memory: true,
adaptive_strategy: true,
force_strategy: None,
use_parallel: true,
parallel_threshold: 10_000,
num_threads: 0, radix_bits: 8,
insertion_sort_threshold: 100,
counting_sort_threshold: 1024,
use_simd: cfg!(feature = "simd"),
use_work_stealing: true,
prefetch_distance: 2,
cache_alignment: 64,
memory_budget: 64 * 1024 * 1024, enable_profiling: false,
}
}
}
#[derive(Debug, Clone)]
pub struct AdvancedAlgorithmStats {
pub basic_stats: AlgorithmStats,
pub strategy_used: SortingStrategy,
pub cpu_features_used: CpuFeatures,
pub estimated_cache_misses: u64,
pub peak_memory_bytes: usize,
pub threads_used: usize,
pub phase_times: PhaseTimes,
}
#[derive(Debug, Clone)]
pub struct PhaseTimes {
pub analysis_time_us: u64,
pub allocation_time_us: u64,
pub sorting_time_us: u64,
pub merging_time_us: u64,
pub cleanup_time_us: u64,
}
pub trait RadixSortable: Clone + Copy + Send + Sync + Ord {
fn extract_key(&self) -> u64;
fn get_byte(&self, position: usize) -> Option<u8>;
fn max_bytes(&self) -> usize;
fn supports_parallel() -> bool {
true
}
}
impl RadixSortable for u32 {
fn extract_key(&self) -> u64 {
*self as u64
}
fn get_byte(&self, position: usize) -> Option<u8> {
if position < 4 {
Some(((*self >> (8 * (3 - position))) & 0xFF) as u8)
} else {
None
}
}
fn max_bytes(&self) -> usize {
4
}
}
impl RadixSortable for u64 {
fn extract_key(&self) -> u64 {
*self
}
fn get_byte(&self, position: usize) -> Option<u8> {
if position < 8 {
Some(((*self >> (8 * (7 - position))) & 0xFF) as u8)
} else {
None
}
}
fn max_bytes(&self) -> usize {
8
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct RadixString<'a> {
data: &'a [u8],
}
impl<'a> RadixString<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self { data }
}
#[inline]
pub fn as_slice(&self) -> &[u8] {
self.data
}
}
impl<'a> RadixSortable for RadixString<'a> {
fn extract_key(&self) -> u64 {
let mut key = 0u64;
for (i, &byte) in self.data.iter().take(8).enumerate() {
key |= (byte as u64) << (8 * (7 - i));
}
key
}
fn get_byte(&self, position: usize) -> Option<u8> {
self.data.get(position).copied()
}
fn max_bytes(&self) -> usize {
self.data.len()
}
fn supports_parallel() -> bool {
true }
}
pub struct AdvancedRadixSort<T: RadixSortable> {
config: AdvancedRadixSortConfig,
stats: AdvancedAlgorithmStats,
memory_pool: Option<Arc<SecureMemoryPool>>,
cpu_features: CpuFeatures,
_phantom: std::marker::PhantomData<T>,
}
impl<T: RadixSortable> AdvancedRadixSort<T> {
pub fn new() -> Result<Self> {
Self::with_config(AdvancedRadixSortConfig::default())
}
pub fn with_config(config: AdvancedRadixSortConfig) -> Result<Self> {
let cpu_features = CpuFeatures::detect();
let memory_pool = if config.use_secure_memory {
let pool_config = if config.memory_budget <= 64 * 1024 {
SecurePoolConfig::small_secure()
} else if config.memory_budget <= 1024 * 1024 {
SecurePoolConfig::medium_secure()
} else {
SecurePoolConfig::large_secure()
};
Some(SecureMemoryPool::new(pool_config)?)
} else {
None
};
Ok(Self {
config,
stats: AdvancedAlgorithmStats {
basic_stats: AlgorithmStats {
items_processed: 0,
processing_time_us: 0,
memory_used: 0,
used_parallel: false,
used_simd: false,
},
strategy_used: SortingStrategy::Adaptive,
cpu_features_used: cpu_features.clone(),
estimated_cache_misses: 0,
peak_memory_bytes: 0,
threads_used: 0,
phase_times: PhaseTimes {
analysis_time_us: 0,
allocation_time_us: 0,
sorting_time_us: 0,
merging_time_us: 0,
cleanup_time_us: 0,
},
},
memory_pool,
cpu_features,
_phantom: std::marker::PhantomData,
})
}
pub fn with_memory_pool(
config: AdvancedRadixSortConfig,
memory_pool: Arc<SecureMemoryPool>,
) -> Self {
let cpu_features = CpuFeatures::detect();
Self {
config,
stats: AdvancedAlgorithmStats {
basic_stats: AlgorithmStats {
items_processed: 0,
processing_time_us: 0,
memory_used: 0,
used_parallel: false,
used_simd: false,
},
strategy_used: SortingStrategy::Adaptive,
cpu_features_used: cpu_features.clone(),
estimated_cache_misses: 0,
peak_memory_bytes: 0,
threads_used: 0,
phase_times: PhaseTimes {
analysis_time_us: 0,
allocation_time_us: 0,
sorting_time_us: 0,
merging_time_us: 0,
cleanup_time_us: 0,
},
},
memory_pool: Some(memory_pool),
cpu_features,
_phantom: std::marker::PhantomData,
}
}
pub fn sort(&mut self, data: &mut [T]) -> Result<()> {
let total_start = Instant::now();
if data.is_empty() {
return Ok(());
}
let analysis_start = Instant::now();
let strategy = self.select_strategy(data)?;
self.stats.phase_times.analysis_time_us = analysis_start.elapsed().as_micros() as u64;
let sorting_start = Instant::now();
match strategy {
SortingStrategy::Insertion => self.insertion_sort(data)?,
SortingStrategy::TimSort => self.tim_sort(data)?,
SortingStrategy::LsdRadix => self.lsd_radix_sort(data)?,
SortingStrategy::MsdRadix => self.msd_radix_sort(data, 0)?,
SortingStrategy::Adaptive => {
return Err(ZiporaError::invalid_data("Invalid adaptive strategy selection"));
}
}
self.stats.phase_times.sorting_time_us = sorting_start.elapsed().as_micros() as u64;
let total_elapsed = total_start.elapsed();
self.stats.basic_stats.items_processed = data.len();
self.stats.basic_stats.processing_time_us = total_elapsed.as_micros() as u64;
self.stats.strategy_used = strategy;
Ok(())
}
fn select_strategy(&self, data: &[T]) -> Result<SortingStrategy> {
if let Some(strategy) = self.config.force_strategy {
return Ok(strategy);
}
if !self.config.adaptive_strategy {
return Ok(SortingStrategy::LsdRadix);
}
let size = data.len();
if size <= self.config.insertion_sort_threshold {
return Ok(SortingStrategy::Insertion);
}
if self.is_nearly_sorted(data) {
return Ok(SortingStrategy::TimSort);
}
Ok(SortingStrategy::LsdRadix)
}
fn is_nearly_sorted(&self, data: &[T]) -> bool {
if data.len() < 2 {
return true;
}
let mut inversions = 0usize;
let sample_size = std::cmp::min(1000, data.len());
for i in 1..sample_size {
if data[i].extract_key() < data[i - 1].extract_key() {
inversions += 1;
}
}
inversions < sample_size / 10
}
pub fn stats(&self) -> &AdvancedAlgorithmStats {
&self.stats
}
pub fn estimate_memory(&self, input_size: usize) -> usize {
let radix_size = 1usize << self.config.radix_bits;
let base_memory = input_size * std::mem::size_of::<T>() + radix_size * std::mem::size_of::<usize>();
let parallel_memory = if self.config.use_parallel && input_size >= self.config.parallel_threshold {
let num_threads = if self.config.num_threads > 0 {
self.config.num_threads
} else {
rayon::current_num_threads()
};
base_memory * num_threads / 4 } else {
0
};
base_memory + parallel_memory
}
fn insertion_sort(&mut self, data: &mut [T]) -> Result<()> {
for i in 1..data.len() {
let key = data[i].clone();
let key_value = key.extract_key();
let mut j = i;
while j > 0 && data[j - 1].extract_key() > key_value {
data[j] = data[j - 1].clone();
j -= 1;
}
data[j] = key;
}
self.stats.basic_stats.used_parallel = false;
self.stats.basic_stats.used_simd = false;
Ok(())
}
fn tim_sort(&mut self, data: &mut [T]) -> Result<()> {
data.sort_unstable_by_key(|item| item.extract_key());
self.stats.basic_stats.used_parallel = false;
self.stats.basic_stats.used_simd = false;
Ok(())
}
fn lsd_radix_sort(&mut self, data: &mut [T]) -> Result<()> {
if data.is_empty() {
return Ok(());
}
let should_use_parallel = data.len() >= self.config.parallel_threshold &&
self.config.use_parallel &&
T::supports_parallel();
if should_use_parallel {
self.lsd_radix_sort_parallel(data)
} else {
self.lsd_radix_sort_sequential(data)
}
}
fn lsd_radix_sort_sequential(&mut self, data: &mut [T]) -> Result<()> {
let radix = 1usize << self.config.radix_bits;
let mask = (radix - 1) as u64;
let buffer = if let Some(ref pool) = self.memory_pool {
vec![data[0].clone(); data.len()]
} else {
vec![data[0].clone(); data.len()]
};
let mut buffer = buffer;
let mut counts = vec![0usize; radix];
let max_key = data.iter()
.map(|item| item.extract_key())
.max()
.unwrap_or(0);
let max_passes = if max_key == 0 { 1 } else {
(64 - max_key.leading_zeros() as usize + self.config.radix_bits - 1) / self.config.radix_bits
};
for pass in 0..max_passes {
let shift = pass * self.config.radix_bits;
counts.fill(0);
if self.config.use_simd && self.cpu_features.has_advanced_simd() && data.len() >= 16 {
self.count_digits_simd(data, shift, mask, &mut counts)?;
} else {
for item in data.iter() {
let key = item.extract_key();
let digit = ((key >> shift) & mask) as usize;
counts[digit] += 1;
}
}
let mut pos = 0;
for count in counts.iter_mut() {
let old_count = *count;
*count = pos;
pos += old_count;
}
for item in data.iter() {
let key = item.extract_key();
let digit = ((key >> shift) & mask) as usize;
buffer[counts[digit]] = item.clone();
counts[digit] += 1;
}
data.copy_from_slice(&buffer);
}
self.stats.basic_stats.used_parallel = false;
self.stats.basic_stats.used_simd = self.config.use_simd && self.cpu_features.has_advanced_simd();
Ok(())
}
fn lsd_radix_sort_parallel(&mut self, data: &mut [T]) -> Result<()> {
let num_threads = if self.config.num_threads > 0 {
self.config.num_threads
} else {
rayon::current_num_threads()
};
if data.len() < 2 * self.config.parallel_threshold {
return self.lsd_radix_sort_sequential(data);
}
let chunk_size = (data.len() + num_threads - 1) / num_threads;
data.par_chunks_mut(chunk_size).for_each(|chunk| {
if let Ok(mut temp_sorter) = AdvancedRadixSort::with_config(AdvancedRadixSortConfig {
use_parallel: false,
..self.config.clone()
}) {
let _ = temp_sorter.lsd_radix_sort_sequential(chunk);
}
});
self.multiway_merge_chunks(data, chunk_size)?;
self.stats.basic_stats.used_parallel = true;
self.stats.basic_stats.used_simd = self.config.use_simd && self.cpu_features.has_advanced_simd();
self.stats.threads_used = num_threads;
Ok(())
}
fn msd_radix_sort(&mut self, data: &mut [T], depth: usize) -> Result<()> {
if data.len() <= 1 {
return Ok(());
}
if data.len() <= self.config.insertion_sort_threshold || depth > 64 {
return self.insertion_sort(data);
}
let mut buckets: Vec<Vec<T>> = vec![Vec::new(); 257];
for &item in data.iter() {
let bucket_index = if let Some(byte) = item.get_byte(depth) {
byte as usize + 1 } else {
0 };
buckets[bucket_index].push(item);
}
let mut offset = 0;
for (i, mut bucket) in buckets.into_iter().enumerate() {
if bucket.len() > 1 && i > 0 {
self.msd_radix_sort(&mut bucket, depth + 1)?;
}
for item in bucket {
data[offset] = item;
offset += 1;
}
}
self.stats.basic_stats.used_parallel = false; self.stats.basic_stats.used_simd = false;
Ok(())
}
fn count_digits_simd(&self, data: &[T], shift: usize, mask: u64, counts: &mut [usize]) -> Result<()> {
#[cfg(target_arch = "x86_64")]
{
if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("bmi2") {
unsafe {
self.count_digits_avx2_bmi2(data, shift, mask, counts)?;
}
return Ok(());
}
}
for item in data.iter() {
let key = item.extract_key();
let digit = ((key >> shift) & mask) as usize;
counts[digit] += 1;
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,bmi2")]
unsafe fn count_digits_avx2_bmi2(&self, data: &[T], shift: usize, mask: u64, counts: &mut [usize]) -> Result<()> {
let mut keys: Vec<u64> = data.iter().map(|item| item.extract_key()).collect();
let mut i = 0;
let shift_vec = _mm256_set1_epi32(shift as i32);
let mask_vec = _mm256_set1_epi32(mask as i32);
while i + 8 <= keys.len() {
let keys_u32: [u32; 8] = [
keys[i] as u32,
keys[i + 1] as u32,
keys[i + 2] as u32,
keys[i + 3] as u32,
keys[i + 4] as u32,
keys[i + 5] as u32,
keys[i + 6] as u32,
keys[i + 7] as u32,
];
let values = unsafe { _mm256_loadu_si256(keys_u32.as_ptr() as *const __m256i) };
let shifted = if shift > 0 {
unsafe { _mm256_srlv_epi32(values, shift_vec) }
} else {
values
};
let digits = unsafe { _mm256_and_si256(shifted, mask_vec) };
let mut digit_array = [0u32; 8];
unsafe { _mm256_storeu_si256(digit_array.as_mut_ptr() as *mut __m256i, digits) };
for &digit in &digit_array {
if (digit as usize) < counts.len() {
counts[digit as usize] += 1;
}
}
i += 8;
}
for &key in &keys[i..] {
let digit = ((key >> shift) & mask) as usize;
if digit < counts.len() {
counts[digit] += 1;
}
}
Ok(())
}
fn multiway_merge_chunks(&self, data: &mut [T], chunk_size: usize) -> Result<()> {
data.sort_unstable();
Ok(())
}
}
impl<T: RadixSortable> Default for AdvancedRadixSort<T> {
fn default() -> Self {
Self::new().unwrap_or_else(|e| {
panic!("AdvancedRadixSort creation failed in Default: {}. \
This indicates severe memory pressure.", e)
})
}
}
impl<T: RadixSortable> Algorithm for AdvancedRadixSort<T> {
type Config = AdvancedRadixSortConfig;
type Input = Vec<T>;
type Output = Vec<T>;
fn execute(&self, config: &Self::Config, mut input: Self::Input) -> Result<Self::Output> {
let mut sorter = Self::with_config(config.clone())?;
sorter.sort(&mut input)?;
Ok(input)
}
fn stats(&self) -> AlgorithmStats {
self.stats.basic_stats.clone()
}
fn estimate_memory(&self, input_size: usize) -> usize {
self.estimate_memory(input_size)
}
fn supports_parallel(&self) -> bool {
T::supports_parallel()
}
fn supports_simd(&self) -> bool {
self.config.use_simd
}
}
pub type AdvancedU32RadixSort = AdvancedRadixSort<u32>;
pub type AdvancedU64RadixSort = AdvancedRadixSort<u64>;
pub type AdvancedStringRadixSort<'a> = AdvancedRadixSort<RadixString<'a>>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_radix_sort_u32_empty() {
let mut sorter = RadixSort::new();
let mut data: Vec<u32> = vec![];
let result = sorter.sort_u32(&mut data);
assert!(result.is_ok());
assert!(data.is_empty());
}
#[test]
fn test_radix_sort_u32_simple() {
let mut sorter = RadixSort::new();
let mut data = vec![5, 2, 8, 1, 9, 3, 7, 4, 6];
let result = sorter.sort_u32(&mut data);
assert!(result.is_ok());
assert_eq!(data, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
let stats = sorter.stats();
assert_eq!(stats.items_processed, 9);
}
#[test]
fn test_radix_sort_u32_large_numbers() {
let mut sorter = RadixSort::new();
let mut data = vec![u32::MAX, 1000000, 500000, 0, 999999];
let result = sorter.sort_u32(&mut data);
assert!(result.is_ok());
assert_eq!(data, vec![0, 500000, 999999, 1000000, u32::MAX]);
}
#[test]
fn test_radix_sort_u64() {
let mut sorter = RadixSort::new();
let mut data = vec![5u64, 2, 8, 1, 9, 3, 7, 4, 6];
let result = sorter.sort_u64(&mut data);
assert!(result.is_ok());
assert_eq!(data, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_radix_sort_bytes() {
let mut sorter = RadixSort::new();
let mut data = vec![
b"banana".to_vec(),
b"apple".to_vec(),
b"cherry".to_vec(),
b"date".to_vec(),
];
let result = sorter.sort_bytes(&mut data);
assert!(result.is_ok());
assert_eq!(
data,
vec![
b"apple".to_vec(),
b"banana".to_vec(),
b"cherry".to_vec(),
b"date".to_vec(),
]
);
}
#[test]
fn test_radix_sort_bytes_different_lengths() {
let mut sorter = RadixSort::new();
let mut data = vec![b"a".to_vec(), b"abc".to_vec(), b"ab".to_vec(), b"".to_vec()];
let result = sorter.sort_bytes(&mut data);
assert!(result.is_ok());
assert_eq!(
data,
vec![b"".to_vec(), b"a".to_vec(), b"ab".to_vec(), b"abc".to_vec(),]
);
}
#[test]
fn test_radix_sort_config() {
let config = RadixSortConfig {
use_parallel: false,
parallel_threshold: 100,
radix_bits: 4,
use_counting_sort_threshold: 10,
use_simd: false,
};
let mut sorter = RadixSort::with_config(config);
let mut data = vec![5u32, 2, 8, 1, 9];
let result = sorter.sort_u32(&mut data);
assert!(result.is_ok());
assert_eq!(data, vec![1, 2, 5, 8, 9]);
assert!(!sorter.stats().used_parallel);
}
#[test]
fn test_counting_sort_threshold() {
let config = RadixSortConfig {
use_counting_sort_threshold: 100,
..Default::default()
};
let mut sorter = RadixSort::with_config(config);
let mut data = vec![3u32, 1, 4, 1, 5, 9, 2, 6];
let result = sorter.sort_u32(&mut data);
assert!(result.is_ok());
assert_eq!(data, vec![1, 1, 2, 3, 4, 5, 6, 9]);
}
#[test]
fn test_key_value_radix_sort() {
let sorter = KeyValueRadixSort::<u32, String>::new();
let mut data = vec![
(5, "five".to_string()),
(2, "two".to_string()),
(8, "eight".to_string()),
(1, "one".to_string()),
];
let result = sorter.sort_by_key(&mut data);
assert!(result.is_ok());
let expected = vec![
(1, "one".to_string()),
(2, "two".to_string()),
(5, "five".to_string()),
(8, "eight".to_string()),
];
assert_eq!(data, expected);
}
#[test]
fn test_algorithm_trait() {
let sorter = RadixSort::new();
assert!(sorter.supports_parallel());
let memory_estimate = sorter.estimate_memory(1000);
assert!(memory_estimate > 1000 * std::mem::size_of::<u32>());
let input = vec![3u32, 1, 4, 1, 5];
let config = RadixSortConfig::default();
let result = sorter.execute(&config, input);
assert!(result.is_ok());
let sorted = result.unwrap();
assert_eq!(sorted, vec![1, 1, 3, 4, 5]);
}
#[test]
fn test_advanced_radix_sort_u32_simple() {
let mut sorter = AdvancedU32RadixSort::new().unwrap();
let mut data = vec![5, 2, 8, 1, 9, 3, 7, 4, 6];
let result = sorter.sort(&mut data);
assert!(result.is_ok());
assert_eq!(data, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
let stats = sorter.stats();
assert_eq!(stats.basic_stats.items_processed, 9);
assert_eq!(stats.strategy_used, SortingStrategy::Insertion); }
#[test]
fn test_advanced_radix_sort_u32_large() {
let mut sorter = AdvancedU32RadixSort::with_config(AdvancedRadixSortConfig {
insertion_sort_threshold: 50,
..Default::default()
}).unwrap();
let mut data: Vec<u32> = (0..1000).rev().collect(); let result = sorter.sort(&mut data);
assert!(result.is_ok());
let expected: Vec<u32> = (0..1000).collect();
assert_eq!(data, expected);
let stats = sorter.stats();
assert_eq!(stats.basic_stats.items_processed, 1000);
assert!(matches!(stats.strategy_used, SortingStrategy::TimSort | SortingStrategy::LsdRadix));
}
#[test]
fn test_advanced_radix_sort_u64() {
let mut sorter = AdvancedU64RadixSort::new().unwrap();
let mut data = vec![u64::MAX, 1000000, 500000, 0, 999999];
let result = sorter.sort(&mut data);
assert!(result.is_ok());
assert_eq!(data, vec![0, 500000, 999999, 1000000, u64::MAX]);
}
#[test]
fn test_advanced_radix_sort_strings() {
let mut sorter = AdvancedStringRadixSort::new().unwrap();
let strings = vec![b"banana".as_slice(), b"apple".as_slice(), b"cherry".as_slice(), b"date".as_slice()];
let mut data: Vec<RadixString> = strings.iter().map(|s| RadixString::new(s)).collect();
let result = sorter.sort(&mut data);
assert!(result.is_ok());
let expected_strings = vec![b"apple".as_slice(), b"banana".as_slice(), b"cherry".as_slice(), b"date".as_slice()];
for (i, expected) in expected_strings.iter().enumerate() {
assert_eq!(data[i].as_slice(), *expected);
}
}
#[test]
fn test_advanced_radix_sort_forced_strategy() {
let config = AdvancedRadixSortConfig {
force_strategy: Some(SortingStrategy::LsdRadix),
insertion_sort_threshold: 1000, ..Default::default()
};
let mut sorter = AdvancedU32RadixSort::with_config(config).unwrap();
let mut data = vec![5, 2, 8, 1, 9];
let result = sorter.sort(&mut data);
assert!(result.is_ok());
assert_eq!(data, vec![1, 2, 5, 8, 9]);
let stats = sorter.stats();
assert_eq!(stats.strategy_used, SortingStrategy::LsdRadix);
}
#[test]
fn test_advanced_radix_sort_parallel() {
let config = AdvancedRadixSortConfig {
use_parallel: true,
parallel_threshold: 100,
force_strategy: Some(SortingStrategy::LsdRadix),
..Default::default()
};
let mut sorter = AdvancedU32RadixSort::with_config(config).unwrap();
let mut data: Vec<u32> = (0..1000).rev().collect();
let result = sorter.sort(&mut data);
assert!(result.is_ok());
let expected: Vec<u32> = (0..1000).collect();
assert_eq!(data, expected);
let stats = sorter.stats();
assert!(stats.basic_stats.used_parallel);
assert!(stats.threads_used > 0);
}
#[test]
fn test_cpu_features_detection() {
let features = CpuFeatures::detect();
#[cfg(target_arch = "x86_64")]
{
let _ = features.avx2;
let _ = features.bmi2;
let _ = features.popcnt;
}
#[cfg(not(target_arch = "x86_64"))]
{
assert!(!features.avx2);
assert!(!features.bmi2);
assert!(!features.popcnt);
assert!(!features.avx512f);
assert!(!features.avx512bw);
}
}
#[test]
fn test_data_characteristics_integers() {
let data = vec![1u32, 2, 3, 4, 5]; let chars = DataCharacteristics::analyze_integers(&data);
assert_eq!(chars.size, 5);
assert!(chars.is_nearly_sorted);
assert!(!chars.is_string_data);
assert!(chars.max_key_bits >= 3);
let data = vec![5u32, 1, 4, 2, 3]; let chars = DataCharacteristics::analyze_integers(&data);
assert!(!chars.is_nearly_sorted);
}
#[test]
fn test_data_characteristics_strings() {
let data = vec![
b"apple".to_vec(),
b"banana".to_vec(),
b"cherry".to_vec(),
];
let chars = DataCharacteristics::analyze_strings(&data);
assert_eq!(chars.size, 3);
assert!(chars.is_nearly_sorted); assert!(chars.is_string_data);
}
#[test]
fn test_radix_sortable_trait_u32() {
let value = 0x12345678u32;
assert_eq!(value.extract_key(), 0x12345678u64);
assert_eq!(value.get_byte(0), Some(0x12));
assert_eq!(value.get_byte(1), Some(0x34));
assert_eq!(value.get_byte(2), Some(0x56));
assert_eq!(value.get_byte(3), Some(0x78));
assert_eq!(value.get_byte(4), None);
assert_eq!(value.max_bytes(), 4);
}
#[test]
fn test_radix_sortable_trait_radix_string() {
let value = RadixString::new(b"hello");
let key = value.extract_key();
let expected = (b'h' as u64) << 56 |
(b'e' as u64) << 48 |
(b'l' as u64) << 40 |
(b'l' as u64) << 32 |
(b'o' as u64) << 24;
assert_eq!(key, expected);
assert_eq!(value.get_byte(0), Some(b'h'));
assert_eq!(value.get_byte(4), Some(b'o'));
assert_eq!(value.get_byte(5), None);
assert_eq!(value.max_bytes(), 5);
}
#[test]
fn test_advanced_algorithm_trait() {
let sorter = AdvancedU32RadixSort::new().unwrap();
assert!(sorter.supports_parallel());
assert_eq!(sorter.supports_simd(), cfg!(feature = "simd"));
let memory_estimate = sorter.estimate_memory(1000);
assert!(memory_estimate > 1000 * std::mem::size_of::<u32>());
let input = vec![3u32, 1, 4, 1, 5];
let config = AdvancedRadixSortConfig::default();
let result = sorter.execute(&config, input);
assert!(result.is_ok());
let sorted = result.unwrap();
assert_eq!(sorted, vec![1, 1, 3, 4, 5]);
}
#[test]
fn test_phase_times_tracking() {
let mut sorter = AdvancedU32RadixSort::with_config(AdvancedRadixSortConfig {
enable_profiling: true,
..Default::default()
}).unwrap();
let mut data: Vec<u32> = (0..1000).rev().collect(); let result = sorter.sort(&mut data);
assert!(result.is_ok());
let stats = sorter.stats();
assert!(stats.phase_times.analysis_time_us > 0 || stats.phase_times.sorting_time_us > 0);
}
#[test]
fn test_memory_pool_integration() {
let memory_pool = SecureMemoryPool::new(SecurePoolConfig::large_secure()).unwrap();
let config = AdvancedRadixSortConfig::default();
let mut sorter = AdvancedU32RadixSort::with_memory_pool(config, memory_pool);
let mut data = vec![5u32, 2, 8, 1, 9];
let result = sorter.sort(&mut data);
assert!(result.is_ok());
assert_eq!(data, vec![1, 2, 5, 8, 9]);
}
}