1use crate::position::PositionSimd;
2use std::slice;
3
4pub trait FindSimd<'a, T>
5where
6 T: std::cmp::PartialEq,
7{
8 fn find_simd<F>(&self, f: F) -> Option<&'a T>
9 where
10 F: Fn(&T) -> bool + 'a;
11}
12
13impl<'a, T> FindSimd<'a, T> for slice::Iter<'a, T>
14where
15 T: std::cmp::PartialEq,
16{
17 fn find_simd<F>(&self, f: F) -> Option<&'a T>
18 where
19 F: Fn(&T) -> bool + 'a,
20 {
21 match self.position_simd(f) {
22 Some(idx) => Some(&self.as_slice()[idx]),
23 None => None,
24 }
25 }
26}
27
28#[cfg(test)]
29mod tests {
30 use super::*;
31 use rand::distributions::Standard;
32 use rand::prelude::Distribution;
33 use rand::Rng;
34 use std::fmt::Debug;
35
36 fn test_simd_for_type<T>()
37 where
38 T: rand::distributions::uniform::SampleUniform
39 + PartialEq
40 + Copy
41 + Default
42 + Debug
43 + std::cmp::PartialEq
44 + std::cmp::PartialOrd,
45 Standard: Distribution<T>,
46 {
47 for len in 0..5000 {
48 let ops = [
49 |x: &T| *x == T::default(),
50 |x: &T| *x != T::default(),
51 |x: &T| *x < T::default(),
52 |x: &T| *x > T::default(),
53 |x: &T| [T::default()].contains(x),
54 ];
55 let ops2 = [
56 |x: &&T| **x == T::default(),
57 |x: &&T| **x != T::default(),
58 |x: &&T| **x < T::default(),
59 |x: &&T| **x > T::default(),
60 |x: &&T| [T::default()].contains(x),
61 ];
62
63 for (op_simd, op_scalar) in ops.iter().zip(ops2) {
64 let mut v: Vec<T> = vec![T::default(); len];
65 let mut rng = rand::thread_rng();
66 for x in v.iter_mut() {
67 *x = rng.gen()
68 }
69
70 let ans = v.iter().find(op_scalar);
71 let correct = v.iter().find_simd(op_simd);
72 assert_eq!(
73 ans,
74 correct,
75 "Failed for length {} and type {:?}",
76 len,
77 std::any::type_name::<T>()
78 );
79 }
80 }
81 }
82
83 #[test]
84 fn test_simd() {
85 test_simd_for_type::<i8>();
86 test_simd_for_type::<i16>();
87 test_simd_for_type::<i32>();
88 test_simd_for_type::<i64>();
89 test_simd_for_type::<u8>();
90 test_simd_for_type::<u16>();
91 test_simd_for_type::<u32>();
92 test_simd_for_type::<u64>();
93 test_simd_for_type::<usize>();
94 test_simd_for_type::<isize>();
95 test_simd_for_type::<f32>();
96 test_simd_for_type::<f64>();
97 }
98}