1use crate::LANE_COUNT;
2use multiversion::multiversion;
3use std::slice;
4
5#[multiversion(targets = "simd")]
6fn any_simd_internal<F, T>(v: &[T], f: F) -> bool
7where
8 F: Fn(&T) -> bool,
9{
10 let mut chunks = v.chunks_exact(LANE_COUNT);
11 for chunk in chunks.by_ref() {
12 if chunk.iter().fold(false, |acc, x| acc | f(x)) {
13 return true;
14 }
15 }
16 chunks.remainder().iter().fold(false, |acc, x| acc | f(x))
17}
18
19pub trait AnySimd<'a, T>
20where
21 T: std::cmp::PartialEq,
22{
23 fn any_simd<F>(&self, f: F) -> bool
24 where
25 F: Fn(&T) -> bool;
26}
27
28impl<'a, T> AnySimd<'a, T> for slice::Iter<'a, T>
29where
30 T: std::cmp::PartialEq,
31{
32 fn any_simd<F>(&self, f: F) -> bool
33 where
34 F: Fn(&T) -> bool,
35 {
36 any_simd_internal(self.as_slice(), f)
37 }
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43 use rand::distributions::Standard;
44 use rand::prelude::Distribution;
45 use rand::Rng;
46
47 fn test_simd_for_type<T>()
48 where
49 T: rand::distributions::uniform::SampleUniform
50 + PartialEq
51 + Copy
52 + Default
53 + std::cmp::PartialEq
54 + std::cmp::PartialOrd,
55 Standard: Distribution<T>,
56 {
57 for len in 0..5000 {
58 let ops = [
59 |x: &T| *x == T::default(),
60 |x: &T| *x != T::default(),
61 |x: &T| *x < T::default(),
62 |x: &T| *x > T::default(),
63 |x: &T| [T::default()].contains(x),
64 ];
65
66 for op in ops {
67 let mut v: Vec<T> = vec![T::default(); len];
68 let mut rng = rand::thread_rng();
69 for x in v.iter_mut() {
70 *x = rng.gen()
71 }
72
73 let ans = v.iter().any_simd(op);
74 let correct = v.iter().any(op);
75 assert_eq!(
76 ans,
77 correct,
78 "Failed for length {} and type {:?}",
79 len,
80 std::any::type_name::<T>()
81 );
82 }
83 }
84 }
85
86 #[test]
87 fn test_simd_any() {
88 test_simd_for_type::<i8>();
89 test_simd_for_type::<i16>();
90 test_simd_for_type::<i32>();
91 test_simd_for_type::<i64>();
92 test_simd_for_type::<u8>();
93 test_simd_for_type::<u16>();
94 test_simd_for_type::<u32>();
95 test_simd_for_type::<u64>();
96 test_simd_for_type::<usize>();
97 test_simd_for_type::<isize>();
98 test_simd_for_type::<f32>();
99 test_simd_for_type::<f64>();
100 }
101}