simd_itertools/
any.rs

1use crate::LANE_COUNT;
2use multiversion::multiversion;
3use std::slice;
4
5#[multiversion(targets = "simd")]
6fn any_simd_internal<F, T>(v: &[T], f: F) -> bool
7where
8    F: Fn(&T) -> bool,
9{
10    let mut chunks = v.chunks_exact(LANE_COUNT);
11    for chunk in chunks.by_ref() {
12        if chunk.iter().fold(false, |acc, x| acc | f(x)) {
13            return true;
14        }
15    }
16    chunks.remainder().iter().fold(false, |acc, x| acc | f(x))
17}
18
19pub trait AnySimd<'a, T>
20where
21    T: std::cmp::PartialEq,
22{
23    fn any_simd<F>(&self, f: F) -> bool
24    where
25        F: Fn(&T) -> bool;
26}
27
28impl<'a, T> AnySimd<'a, T> for slice::Iter<'a, T>
29where
30    T: std::cmp::PartialEq,
31{
32    fn any_simd<F>(&self, f: F) -> bool
33    where
34        F: Fn(&T) -> bool,
35    {
36        any_simd_internal(self.as_slice(), f)
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43    use rand::distributions::Standard;
44    use rand::prelude::Distribution;
45    use rand::Rng;
46
47    fn test_simd_for_type<T>()
48    where
49        T: rand::distributions::uniform::SampleUniform
50            + PartialEq
51            + Copy
52            + Default
53            + std::cmp::PartialEq
54            + std::cmp::PartialOrd,
55        Standard: Distribution<T>,
56    {
57        for len in 0..5000 {
58            let ops = [
59                |x: &T| *x == T::default(),
60                |x: &T| *x != T::default(),
61                |x: &T| *x < T::default(),
62                |x: &T| *x > T::default(),
63                |x: &T| [T::default()].contains(x),
64            ];
65
66            for op in ops {
67                let mut v: Vec<T> = vec![T::default(); len];
68                let mut rng = rand::thread_rng();
69                for x in v.iter_mut() {
70                    *x = rng.gen()
71                }
72
73                let ans = v.iter().any_simd(op);
74                let correct = v.iter().any(op);
75                assert_eq!(
76                    ans,
77                    correct,
78                    "Failed for length {} and type {:?}",
79                    len,
80                    std::any::type_name::<T>()
81                );
82            }
83        }
84    }
85
86    #[test]
87    fn test_simd_any() {
88        test_simd_for_type::<i8>();
89        test_simd_for_type::<i16>();
90        test_simd_for_type::<i32>();
91        test_simd_for_type::<i64>();
92        test_simd_for_type::<u8>();
93        test_simd_for_type::<u16>();
94        test_simd_for_type::<u32>();
95        test_simd_for_type::<u64>();
96        test_simd_for_type::<usize>();
97        test_simd_for_type::<isize>();
98        test_simd_for_type::<f32>();
99        test_simd_for_type::<f64>();
100    }
101}