simd_itertools/
find.rs

1use crate::position::PositionSimd;
2use std::slice;
3
4pub trait FindSimd<'a, T>
5where
6    T: std::cmp::PartialEq,
7{
8    fn find_simd<F>(&self, f: F) -> Option<&'a T>
9    where
10        F: Fn(&T) -> bool + 'a;
11}
12
13impl<'a, T> FindSimd<'a, T> for slice::Iter<'a, T>
14where
15    T: std::cmp::PartialEq,
16{
17    fn find_simd<F>(&self, f: F) -> Option<&'a T>
18    where
19        F: Fn(&T) -> bool + 'a,
20    {
21        match self.position_simd(f) {
22            Some(idx) => Some(&self.as_slice()[idx]),
23            None => None,
24        }
25    }
26}
27
28#[cfg(test)]
29mod tests {
30    use super::*;
31    use rand::distributions::Standard;
32    use rand::prelude::Distribution;
33    use rand::Rng;
34    use std::fmt::Debug;
35
36    fn test_simd_for_type<T>()
37    where
38        T: rand::distributions::uniform::SampleUniform
39            + PartialEq
40            + Copy
41            + Default
42            + Debug
43            + std::cmp::PartialEq
44            + std::cmp::PartialOrd,
45        Standard: Distribution<T>,
46    {
47        for len in 0..5000 {
48            let ops = [
49                |x: &T| *x == T::default(),
50                |x: &T| *x != T::default(),
51                |x: &T| *x < T::default(),
52                |x: &T| *x > T::default(),
53                |x: &T| [T::default()].contains(x),
54            ];
55            let ops2 = [
56                |x: &&T| **x == T::default(),
57                |x: &&T| **x != T::default(),
58                |x: &&T| **x < T::default(),
59                |x: &&T| **x > T::default(),
60                |x: &&T| [T::default()].contains(x),
61            ];
62
63            for (op_simd, op_scalar) in ops.iter().zip(ops2) {
64                let mut v: Vec<T> = vec![T::default(); len];
65                let mut rng = rand::thread_rng();
66                for x in v.iter_mut() {
67                    *x = rng.gen()
68                }
69
70                let ans = v.iter().find(op_scalar);
71                let correct = v.iter().find_simd(op_simd);
72                assert_eq!(
73                    ans,
74                    correct,
75                    "Failed for length {} and type {:?}",
76                    len,
77                    std::any::type_name::<T>()
78                );
79            }
80        }
81    }
82
83    #[test]
84    fn test_simd() {
85        test_simd_for_type::<i8>();
86        test_simd_for_type::<i16>();
87        test_simd_for_type::<i32>();
88        test_simd_for_type::<i64>();
89        test_simd_for_type::<u8>();
90        test_simd_for_type::<u16>();
91        test_simd_for_type::<u32>();
92        test_simd_for_type::<u64>();
93        test_simd_for_type::<usize>();
94        test_simd_for_type::<isize>();
95        test_simd_for_type::<f32>();
96        test_simd_for_type::<f64>();
97    }
98}