use crate::error::{Result, ZiporaError};
use std::cmp::Ordering;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[derive(Debug, Clone)]
pub struct SimdConfig {
pub use_avx512: bool,
pub use_avx2: bool,
pub use_bmi2: bool,
pub min_vector_size: usize,
pub prefetch_distance: usize,
}
impl Default for SimdConfig {
fn default() -> Self {
Self {
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
use_avx512: is_x86_feature_detected!("avx512f"),
#[cfg(not(all(target_arch = "x86_64", feature = "avx512")))]
use_avx512: false,
use_avx2: is_x86_feature_detected!("avx2"),
use_bmi2: is_x86_feature_detected!("bmi2"),
min_vector_size: 8,
prefetch_distance: 2,
}
}
}
pub struct SimdComparator {
config: SimdConfig,
}
impl SimdComparator {
pub fn new() -> Self {
Self::with_config(SimdConfig::default())
}
pub fn with_config(config: SimdConfig) -> Self {
Self { config }
}
pub fn compare_i32_slices(&self, left: &[i32], right: &[i32]) -> Result<Vec<Ordering>> {
if left.len() != right.len() {
return Err(ZiporaError::invalid_parameter(
"Input slices must have equal length",
));
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
if self.config.use_avx512 && left.len() >= 16 {
return unsafe { self.compare_i32_simd_avx512(left, right) };
}
if self.config.use_avx2 && left.len() >= self.config.min_vector_size {
self.compare_i32_simd(left, right)
} else {
Ok(left
.iter()
.zip(right.iter())
.map(|(a, b)| a.cmp(b))
.collect())
}
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
#[target_feature(enable = "avx512f")]
unsafe fn compare_i32_simd_avx512(&self, left: &[i32], right: &[i32]) -> Result<Vec<Ordering>> {
let mut results = Vec::with_capacity(left.len());
let chunk_size = 16;
let mut i = 0;
while i + chunk_size <= left.len() {
unsafe {
let left_ptr = left.as_ptr().add(i);
let right_ptr = right.as_ptr().add(i);
if self.config.prefetch_distance > 0
&& i + chunk_size * self.config.prefetch_distance < left.len()
{
_mm_prefetch(left_ptr.add(chunk_size * self.config.prefetch_distance) as *const i8, _MM_HINT_T0);
_mm_prefetch(right_ptr.add(chunk_size * self.config.prefetch_distance) as *const i8, _MM_HINT_T0);
}
let left_vec = _mm512_loadu_si512(left_ptr as *const __m512i);
let right_vec = _mm512_loadu_si512(right_ptr as *const __m512i);
let eq_mask = _mm512_cmpeq_epi32_mask(left_vec, right_vec);
let gt_mask = _mm512_cmpgt_epi32_mask(left_vec, right_vec);
for j in 0..chunk_size {
let ordering = if (eq_mask >> j) & 1 != 0 {
Ordering::Equal
} else if (gt_mask >> j) & 1 != 0 {
Ordering::Greater
} else {
Ordering::Less
};
results.push(ordering);
}
i += chunk_size;
}
}
while i < left.len() {
results.push(left[i].cmp(&right[i]));
i += 1;
}
Ok(results)
}
#[cfg(target_arch = "x86_64")]
fn compare_i32_simd(&self, left: &[i32], right: &[i32]) -> Result<Vec<Ordering>> {
if !is_x86_feature_detected!("avx2") {
return Ok(left
.iter()
.zip(right.iter())
.map(|(a, b)| a.cmp(b))
.collect());
}
let mut results = Vec::with_capacity(left.len());
let chunk_size = 8;
unsafe {
let mut i = 0;
while i + chunk_size <= left.len() {
let left_ptr = left.as_ptr().add(i);
let right_ptr = right.as_ptr().add(i);
if self.config.prefetch_distance > 0
&& i + chunk_size * self.config.prefetch_distance < left.len()
{
let prefetch_left = left_ptr.add(chunk_size * self.config.prefetch_distance);
let prefetch_right = right_ptr.add(chunk_size * self.config.prefetch_distance);
_mm_prefetch(prefetch_left as *const i8, _MM_HINT_T0);
_mm_prefetch(prefetch_right as *const i8, _MM_HINT_T0);
}
let left_vec = _mm256_loadu_si256(left_ptr as *const __m256i);
let right_vec = _mm256_loadu_si256(right_ptr as *const __m256i);
let eq_mask = _mm256_cmpeq_epi32(left_vec, right_vec);
let gt_mask = _mm256_cmpgt_epi32(left_vec, right_vec);
let eq_bits = _mm256_movemask_epi8(eq_mask) as u32;
let gt_bits = _mm256_movemask_epi8(gt_mask) as u32;
for j in 0..chunk_size {
let bit_offset = j * 4; let eq_bit = (eq_bits >> bit_offset) & 0xF;
let gt_bit = (gt_bits >> bit_offset) & 0xF;
let ordering = if eq_bit != 0 {
Ordering::Equal
} else if gt_bit != 0 {
Ordering::Greater
} else {
Ordering::Less
};
results.push(ordering);
}
i += chunk_size;
}
while i < left.len() {
results.push(left[i].cmp(&right[i]));
i += 1;
}
}
Ok(results)
}
#[cfg(not(target_arch = "x86_64"))]
fn compare_i32_simd(&self, left: &[i32], right: &[i32]) -> Result<Vec<Ordering>> {
Ok(left
.iter()
.zip(right.iter())
.map(|(a, b)| a.cmp(b))
.collect())
}
pub fn find_min_i32(&self, values: &[i32]) -> Option<(usize, i32)> {
if values.is_empty() {
return None;
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
if self.config.use_avx512 && values.len() >= 16 {
return unsafe { self.find_min_i32_simd_avx512(values) };
}
if self.config.use_avx2 && values.len() >= self.config.min_vector_size {
self.find_min_i32_simd(values)
} else {
values
.iter()
.enumerate()
.min_by_key(|(_, val)| *val)
.map(|(idx, val)| (idx, *val))
}
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
#[target_feature(enable = "avx512f")]
unsafe fn find_min_i32_simd_avx512(&self, values: &[i32]) -> Option<(usize, i32)> {
if values.is_empty() {
return None;
}
let chunk_size = 16;
let mut global_min = i32::MAX;
let mut global_min_idx = 0;
let mut i = 0;
let mut min_vec = _mm512_set1_epi32(i32::MAX);
while i + chunk_size <= values.len() {
unsafe {
let vec = _mm512_loadu_si512(values.as_ptr().add(i) as *const __m512i);
let new_min = _mm512_min_epi32(min_vec, vec);
if _mm512_cmpeq_epi32_mask(new_min, min_vec) != 0xFFFF {
let chunk_min = _mm512_reduce_min_epi32(vec);
if chunk_min < global_min {
global_min = chunk_min;
let min_broadcast = _mm512_set1_epi32(chunk_min);
let mask = _mm512_cmpeq_epi32_mask(vec, min_broadcast);
global_min_idx = i + mask.trailing_zeros() as usize;
}
min_vec = new_min;
}
i += chunk_size;
}
}
while i < values.len() {
if values[i] < global_min {
global_min = values[i];
global_min_idx = i;
}
i += 1;
}
Some((global_min_idx, global_min))
}
#[cfg(target_arch = "x86_64")]
fn find_min_i32_simd(&self, values: &[i32]) -> Option<(usize, i32)> {
if !is_x86_feature_detected!("avx2") || values.is_empty() {
return values
.iter()
.enumerate()
.min_by_key(|(_, val)| *val)
.map(|(idx, val)| (idx, *val));
}
unsafe {
let chunk_size = 8;
let mut global_min = i32::MAX;
let mut global_min_idx = 0;
let mut i = 0;
while i + chunk_size <= values.len() {
let ptr = values.as_ptr().add(i);
let vec = _mm256_loadu_si256(ptr as *const __m256i);
let mut chunk_values = [0i32; 8];
_mm256_storeu_si256(chunk_values.as_mut_ptr() as *mut __m256i, vec);
for (j, &val) in chunk_values.iter().enumerate() {
if val < global_min {
global_min = val;
global_min_idx = i + j;
}
}
i += chunk_size;
}
while i < values.len() {
if values[i] < global_min {
global_min = values[i];
global_min_idx = i;
}
i += 1;
}
Some((global_min_idx, global_min))
}
}
#[cfg(not(target_arch = "x86_64"))]
fn find_min_i32_simd(&self, values: &[i32]) -> Option<(usize, i32)> {
values
.iter()
.enumerate()
.min_by_key(|(_, val)| *val)
.map(|(idx, val)| (idx, *val))
}
pub fn merge_sorted_i32(&self, left: &[i32], right: &[i32]) -> Vec<i32> {
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
if self.config.use_avx512 && (left.len() + right.len()) >= 32 {
return unsafe { self.merge_sorted_i32_simd_avx512(left, right) };
}
if self.config.use_avx2 && (left.len() + right.len()) >= self.config.min_vector_size * 2 {
self.merge_sorted_i32_simd(left, right)
} else {
self.merge_sorted_i32_scalar(left, right)
}
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
#[target_feature(enable = "avx512f")]
unsafe fn merge_sorted_i32_simd_avx512(&self, left: &[i32], right: &[i32]) -> Vec<i32> {
let mut result = Vec::with_capacity(left.len() + right.len());
let mut left_idx = 0;
let mut right_idx = 0;
while left_idx < left.len() && right_idx < right.len() {
if left[left_idx] <= right[right_idx] {
result.push(left[left_idx]);
left_idx += 1;
} else {
result.push(right[right_idx]);
right_idx += 1;
}
}
unsafe {
if left_idx < left.len() {
self.simd_copy_i32_avx512(&left[left_idx..], &mut result);
}
if right_idx < right.len() {
self.simd_copy_i32_avx512(&right[right_idx..], &mut result);
}
}
result
}
#[cfg(target_arch = "x86_64")]
fn merge_sorted_i32_simd(&self, left: &[i32], right: &[i32]) -> Vec<i32> {
if !is_x86_feature_detected!("avx2") {
return self.merge_sorted_i32_scalar(left, right);
}
let mut result = Vec::with_capacity(left.len() + right.len());
let mut left_idx = 0;
let mut right_idx = 0;
while left_idx < left.len() && right_idx < right.len() {
if left[left_idx] <= right[right_idx] {
result.push(left[left_idx]);
left_idx += 1;
} else {
result.push(right[right_idx]);
right_idx += 1;
}
}
unsafe {
if left_idx < left.len() {
let remaining = &left[left_idx..];
self.simd_copy_i32(remaining, &mut result);
}
if right_idx < right.len() {
let remaining = &right[right_idx..];
self.simd_copy_i32(remaining, &mut result);
}
}
result
}
#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
#[target_feature(enable = "avx512f")]
unsafe fn simd_copy_i32_avx512(&self, src: &[i32], dest: &mut Vec<i32>) {
let chunk_size = 16;
let mut i = 0;
dest.reserve(src.len());
unsafe {
let dest_ptr = dest.as_mut_ptr().add(dest.len());
while i + chunk_size <= src.len() {
let vec = _mm512_loadu_si512(src.as_ptr().add(i) as *const __m512i);
_mm512_storeu_si512(dest_ptr.add(i) as *mut __m512i, vec);
i += chunk_size;
}
dest.set_len(dest.len() + i);
}
while i < src.len() {
dest.push(src[i]);
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
unsafe fn simd_copy_i32(&self, src: &[i32], dest: &mut Vec<i32>) {
let chunk_size = 8;
let mut i = 0;
dest.reserve(src.len());
let dest_ptr = unsafe { dest.as_mut_ptr().add(dest.len()) };
while i + chunk_size <= src.len() {
let src_ptr = unsafe { src.as_ptr().add(i) };
let vec = unsafe { _mm256_loadu_si256(src_ptr as *const __m256i) };
unsafe { _mm256_storeu_si256(dest_ptr.add(i) as *mut __m256i, vec) };
i += chunk_size;
}
unsafe { dest.set_len(dest.len() + i) };
while i < src.len() {
dest.push(src[i]);
i += 1;
}
}
#[cfg(not(target_arch = "x86_64"))]
fn merge_sorted_i32_simd(&self, left: &[i32], right: &[i32]) -> Vec<i32> {
self.merge_sorted_i32_scalar(left, right)
}
fn merge_sorted_i32_scalar(&self, left: &[i32], right: &[i32]) -> Vec<i32> {
let mut result = Vec::with_capacity(left.len() + right.len());
let mut left_iter = left.iter();
let mut right_iter = right.iter();
let mut left_current = left_iter.next();
let mut right_current = right_iter.next();
loop {
match (left_current, right_current) {
(Some(l), Some(r)) => {
if l <= r {
result.push(*l);
left_current = left_iter.next();
} else {
result.push(*r);
right_current = right_iter.next();
}
}
(Some(l), None) => {
result.push(*l);
result.extend(left_iter.copied());
break;
}
(None, Some(r)) => {
result.push(*r);
result.extend(right_iter.copied());
break;
}
(None, None) => break,
}
}
result
}
pub fn simd_available(&self) -> bool {
#[cfg(target_arch = "x86_64")]
{
self.config.use_avx2 && is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}
pub fn config(&self) -> &SimdConfig {
&self.config
}
}
impl Default for SimdComparator {
fn default() -> Self {
Self::new()
}
}
pub struct SimdOperations;
impl SimdOperations {
pub fn parallel_compare_i32(pairs: &[(i32, i32)]) -> Vec<Ordering> {
let comparator = SimdComparator::new();
if pairs.is_empty() {
return Vec::new();
}
let (left, right): (Vec<i32>, Vec<i32>) = pairs.iter().copied().unzip();
comparator
.compare_i32_slices(&left, &right)
.unwrap_or_else(|_| pairs.iter().map(|(a, b)| a.cmp(b)).collect())
}
pub fn find_multiple_mins(arrays: &[&[i32]]) -> Vec<Option<(usize, i32)>> {
let comparator = SimdComparator::new();
arrays
.iter()
.map(|arr| comparator.find_min_i32(arr))
.collect()
}
pub fn merge_multiple_sorted(arrays: Vec<Vec<i32>>) -> Vec<i32> {
if arrays.is_empty() {
return Vec::new();
}
if arrays.len() == 1 {
return arrays.into_iter().next().expect("non-empty arrays");
}
let comparator = SimdComparator::new();
let mut current_arrays = arrays;
while current_arrays.len() > 1 {
let mut next_arrays = Vec::new();
let mut i = 0;
while i + 1 < current_arrays.len() {
let merged =
comparator.merge_sorted_i32(¤t_arrays[i], ¤t_arrays[i + 1]);
next_arrays.push(merged);
i += 2;
}
if i < current_arrays.len() {
next_arrays.push(
current_arrays
.into_iter()
.nth(i)
.expect("valid merge index"),
);
}
current_arrays = next_arrays;
}
current_arrays
.into_iter()
.next()
.expect("single remaining array")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_comparator_creation() {
let comparator = SimdComparator::new();
assert!(comparator.config().min_vector_size > 0);
}
#[test]
fn test_compare_i32_slices() {
let comparator = SimdComparator::new();
let left = vec![1, 5, 3, 8, 2];
let right = vec![2, 4, 3, 6, 1];
let result = comparator.compare_i32_slices(&left, &right).unwrap();
assert_eq!(
result,
vec![
Ordering::Less, Ordering::Greater, Ordering::Equal, Ordering::Greater, Ordering::Greater, ]
);
}
#[test]
fn test_find_min_i32() {
let comparator = SimdComparator::new();
let values = vec![5, 2, 8, 1, 9, 3];
let result = comparator.find_min_i32(&values).unwrap();
assert_eq!(result, (3, 1));
}
#[test]
fn test_find_min_empty() {
let comparator = SimdComparator::new();
let result = comparator.find_min_i32(&[]);
assert!(result.is_none());
}
#[test]
fn test_merge_sorted_i32() {
let comparator = SimdComparator::new();
let left = vec![1, 3, 5, 7];
let right = vec![2, 4, 6, 8];
let result = comparator.merge_sorted_i32(&left, &right);
assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7, 8]);
}
#[test]
fn test_merge_uneven_arrays() {
let comparator = SimdComparator::new();
let left = vec![1, 5, 9];
let right = vec![2, 3, 4, 6, 7, 8];
let result = comparator.merge_sorted_i32(&left, &right);
assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_parallel_compare() {
let pairs = vec![(1, 2), (5, 3), (4, 4), (9, 7)];
let result = SimdOperations::parallel_compare_i32(&pairs);
assert_eq!(
result,
vec![
Ordering::Less,
Ordering::Greater,
Ordering::Equal,
Ordering::Greater,
]
);
}
#[test]
fn test_find_multiple_mins() {
let arr1 = vec![5, 2, 8, 1];
let arr2 = vec![9, 3, 7, 4];
let arr3 = vec![6];
let arrays = vec![&arr1[..], &arr2[..], &arr3[..]];
let result = SimdOperations::find_multiple_mins(&arrays);
assert_eq!(result, vec![Some((3, 1)), Some((1, 3)), Some((0, 6)),]);
}
#[test]
fn test_merge_multiple_sorted() {
let arrays = vec![vec![1, 4, 7], vec![2, 5, 8], vec![3, 6, 9]];
let result = SimdOperations::merge_multiple_sorted(arrays);
assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_simd_available() {
let comparator = SimdComparator::new();
let _available = comparator.simd_available();
}
#[test]
fn test_mismatched_slice_lengths() {
let comparator = SimdComparator::new();
let left = vec![1, 2, 3];
let right = vec![1, 2];
let result = comparator.compare_i32_slices(&left, &right);
assert!(result.is_err());
}
#[test]
fn test_large_array_simd_path() {
let mut config = SimdConfig::default();
config.min_vector_size = 4;
let comparator = SimdComparator::with_config(config);
let left: Vec<i32> = (0..16).collect();
let right: Vec<i32> = (1..17).collect();
let result = comparator.compare_i32_slices(&left, &right).unwrap();
assert!(result.iter().all(|&ord| ord == Ordering::Less));
assert_eq!(result.len(), 16);
}
}