cfg_if::cfg_if! {
if #[cfg(all(target_arch="x86_64", target_feature = "avx2"))] {
mod x86_64;
use x86_64::{contains_simd_u32, contains_simd_u64};
} else {
mod fallback;
use fallback::{contains_simd_u32, contains_simd_u64};
}
}
pub trait ContainsSimd: Sized {
fn contains_simd(vector: &[Self], target: Self) -> bool;
}
impl ContainsSimd for u32 {
fn contains_simd(vector: &[u32], target: u32) -> bool {
contains_simd_u32(vector, target)
}
}
impl ContainsSimd for u64 {
fn contains_simd(vector: &[u64], target: u64) -> bool {
contains_simd_u64(vector, target)
}
}
#[cfg(test)]
mod tests {
use rand::{
distr::{Distribution, StandardUniform, Uniform},
Rng, SeedableRng,
};
use super::*;
fn test_contains<T>(max_dim: usize)
where
T: ContainsSimd + Copy + TryFrom<usize> + PartialEq,
<T as TryFrom<usize>>::Error: std::fmt::Debug,
{
for dim in 0..max_dim {
println!("working on dim {dim}");
let v: Vec<T> = (0..dim).map(|i| i.try_into().unwrap()).collect();
for query in 0..dim {
assert!(
T::contains_simd(&v, query.try_into().unwrap()),
"expected query {} to be iota slice of dimension {}",
query,
dim
);
}
for query in dim..dim + 10 {
assert!(
!T::contains_simd(&v, query.try_into().unwrap()),
"expected query {} not to be iota slice of dimension {}",
query,
dim
);
}
}
}
#[test]
fn test_contains_u32() {
test_contains::<u32>(128);
}
#[test]
fn test_contains_u64() {
test_contains::<u64>(128);
}
#[test]
fn test_contains_simd_u32() {
let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8];
test_contains_simd::<u32>(vector, vec![9]);
}
#[test]
fn test_contains_simd_u64() {
let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8];
test_contains_simd::<u64>(vector, vec![9]);
}
#[test]
fn test_contains_simd_multiple_of_8_u32() {
let vector = vec![5, 7, 6, 3, 2, 1, 4, 0];
test_contains_simd::<u32>(vector, vec![9, 8]);
}
#[test]
fn test_contains_simd_multiple_of_8_u64() {
let vector = vec![5, 7, 6, 3, 2, 1, 4, 0];
test_contains_simd::<u64>(vector, vec![9, 8]);
}
#[test]
fn test_contains_simd_non_multiple_of_8_u32() {
let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 11];
test_contains_simd::<u32>(vector, vec![9, 8]);
}
#[test]
fn test_contains_simd_non_multiple_of_8_u64() {
let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 11];
test_contains_simd::<u64>(vector, vec![9, 8]);
}
fn test_contains_simd<T>(vector: Vec<T>, not_present: Vec<T>)
where
T: ContainsSimd + Copy,
{
not_present.iter().for_each(|item| {
assert!(!T::contains_simd(&vector, *item));
});
vector.iter().for_each(|item| {
assert!(T::contains_simd(&vector, *item));
});
}
const NUM_TRIALS: usize = 1000;
#[test]
fn contains_works_when_item_is_present() {
let dim_dist = Uniform::new(1, 1000).unwrap();
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
for _ in 0..NUM_TRIALS {
let v: Vec<_> = (0..dim_dist.sample(&mut rng))
.map(|_| StandardUniform {}.sample(&mut rng))
.collect();
let index_of_item = rng.random_range(0..v.len());
let item = v[index_of_item];
assert!(u32::contains_simd(&v, item));
}
}
#[test]
fn contains_works_when_item_is_not_present() {
let dim_dist = Uniform::new(1, 1000).unwrap();
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
for _ in 0..NUM_TRIALS {
let v: Vec<_> = (0..dim_dist.sample(&mut rng))
.map(|_| StandardUniform {}.sample(&mut rng))
.collect();
let mut item = StandardUniform {}.sample(&mut rng);
while v.contains(&item) {
item = StandardUniform {}.sample(&mut rng);
}
assert!(!u32::contains_simd(&v, item));
}
}
}