use multiversion::multiversion;
use std::simd::{cmp::SimdPartialEq, LaneCount, Simd, SupportedLaneCount};
#[multiversion(targets(
"x86_64+avx512f",
"x86_64+avx2",
"x86_64+sse4.1",
"aarch64+sve",
"aarch64+neon"
))]
pub fn simd_search_i64(keys: &[Option<i64>], key: i64) -> Option<usize> {
if keys.is_empty() {
return None;
}
if keys[0] == Some(key) {
return Some(0);
}
search_simd::<8>(keys, key)
.or_else(|| search_simd::<4>(keys, key))
.or_else(|| search_simd::<2>(keys, key))
.unwrap_or_else(|| search_scalar(keys, key))
}
#[inline]
#[allow(clippy::option_option)] fn search_simd<const N: usize>(keys: &[Option<i64>], key: i64) -> Option<Option<usize>>
where
LaneCount<N>: SupportedLaneCount,
{
let len = keys.len();
if len < N {
return None; }
let key_vec = Simd::<i64, N>::splat(key);
let mut i = 0;
while i + N <= len {
let mut values = [i64::MAX; N];
for j in 0..N {
values[j] = keys[i + j].unwrap_or(i64::MAX);
}
let vec = Simd::<i64, N>::from_array(values);
let mask = vec.simd_eq(key_vec);
if mask.any() {
for j in 0..N {
if keys[i + j] == Some(key) {
return Some(Some(i + j));
}
}
}
i += N;
}
for (j, &k) in keys[i..len].iter().enumerate() {
if k == Some(key) {
return Some(Some(i + j));
}
}
Some(None) }
#[inline]
fn search_scalar(keys: &[Option<i64>], key: i64) -> Option<usize> {
keys.iter().position(|&k| k == Some(key))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_search_empty() {
let keys: &[Option<i64>] = &[];
assert_eq!(simd_search_i64(keys, 42), None);
}
#[test]
fn test_simd_search_not_found() {
let keys = vec![Some(1), Some(2), Some(3), None, Some(5)];
assert_eq!(simd_search_i64(&keys, 42), None);
}
#[test]
fn test_simd_search_first_element() {
let keys = vec![Some(42), Some(2), Some(3)];
assert_eq!(simd_search_i64(&keys, 42), Some(0));
}
#[test]
fn test_simd_search_middle() {
let keys = vec![Some(1), Some(2), Some(42), Some(4)];
assert_eq!(simd_search_i64(&keys, 42), Some(2));
}
#[test]
fn test_simd_search_last() {
let keys = vec![Some(1), Some(2), Some(3), Some(42)];
assert_eq!(simd_search_i64(&keys, 42), Some(3));
}
#[test]
fn test_simd_search_with_gaps() {
let keys = vec![Some(1), None, Some(42), None, Some(5)];
assert_eq!(simd_search_i64(&keys, 42), Some(2));
}
#[test]
fn test_simd_search_long_array() {
let mut keys = vec![None; 20];
keys[15] = Some(42);
assert_eq!(simd_search_i64(&keys, 42), Some(15));
}
#[test]
fn test_simd_search_very_long_array() {
let mut keys = vec![Some(0); 100];
keys[77] = Some(42);
assert_eq!(simd_search_i64(&keys, 42), Some(77));
}
#[test]
fn test_simd_search_all_gaps() {
let keys = vec![None, None, None, None];
assert_eq!(simd_search_i64(&keys, 42), None);
}
#[test]
fn test_simd_search_consistency() {
let test_cases = vec![
vec![Some(1), Some(2), Some(3), Some(4), Some(5)],
vec![None, Some(1), None, Some(2), None],
vec![Some(10), Some(20), Some(30), Some(40)],
vec![Some(1); 10],
vec![Some(1); 100], ];
for keys in test_cases {
for search_key in [1, 2, 5, 10, 20, 42, 100] {
let simd_result = simd_search_i64(&keys, search_key);
let linear_result = keys.iter().position(|&k| k == Some(search_key));
assert_eq!(
simd_result, linear_result,
"Mismatch for key={} in {:?}",
search_key, keys
);
}
}
}
}