1cfg_if::cfg_if! {
7 if #[cfg(all(target_arch="x86_64", target_feature = "avx2"))] {
8 mod x86_64;
9 use x86_64::{contains_simd_u32, contains_simd_u64};
10 } else {
11 mod fallback;
12 use fallback::{contains_simd_u32, contains_simd_u64};
13 }
14}
15
16pub trait ContainsSimd: Sized {
19 fn contains_simd(vector: &[Self], target: Self) -> bool;
20}
21
22impl ContainsSimd for u32 {
23 fn contains_simd(vector: &[u32], target: u32) -> bool {
24 contains_simd_u32(vector, target)
25 }
26}
27
28impl ContainsSimd for u64 {
29 fn contains_simd(vector: &[u64], target: u64) -> bool {
30 contains_simd_u64(vector, target)
31 }
32}
33
34#[cfg(test)]
35mod tests {
36 use rand::{
37 distr::{Distribution, StandardUniform, Uniform},
38 Rng, SeedableRng,
39 };
40
41 use super::*;
42
43 fn test_contains<T>(max_dim: usize)
52 where
53 T: ContainsSimd + Copy + TryFrom<usize> + PartialEq,
54 <T as TryFrom<usize>>::Error: std::fmt::Debug,
55 {
56 for dim in 0..max_dim {
57 println!("working on dim {dim}");
58 let v: Vec<T> = (0..dim).map(|i| i.try_into().unwrap()).collect();
59
60 for query in 0..dim {
62 assert!(
63 T::contains_simd(&v, query.try_into().unwrap()),
64 "expected query {} to be iota slice of dimension {}",
65 query,
66 dim
67 );
68 }
69
70 for query in dim..dim + 10 {
72 assert!(
73 !T::contains_simd(&v, query.try_into().unwrap()),
74 "expected query {} not to be iota slice of dimension {}",
75 query,
76 dim
77 );
78 }
79 }
80 }
81
82 #[test]
83 fn test_contains_u32() {
84 test_contains::<u32>(128);
85 }
86
87 #[test]
88 fn test_contains_u64() {
89 test_contains::<u64>(128);
90 }
91
92 #[test]
93 fn test_contains_simd_u32() {
94 let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8];
95 test_contains_simd::<u32>(vector, vec![9]);
96 }
97
98 #[test]
99 fn test_contains_simd_u64() {
100 let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8];
101 test_contains_simd::<u64>(vector, vec![9]);
102 }
103
104 #[test]
105 fn test_contains_simd_multiple_of_8_u32() {
106 let vector = vec![5, 7, 6, 3, 2, 1, 4, 0];
107 test_contains_simd::<u32>(vector, vec![9, 8]);
108 }
109
110 #[test]
111 fn test_contains_simd_multiple_of_8_u64() {
112 let vector = vec![5, 7, 6, 3, 2, 1, 4, 0];
113 test_contains_simd::<u64>(vector, vec![9, 8]);
114 }
115
116 #[test]
117 fn test_contains_simd_non_multiple_of_8_u32() {
118 let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 11];
119 test_contains_simd::<u32>(vector, vec![9, 8]);
120 }
121
122 #[test]
123 fn test_contains_simd_non_multiple_of_8_u64() {
124 let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 11];
125 test_contains_simd::<u64>(vector, vec![9, 8]);
126 }
127 fn test_contains_simd<T>(vector: Vec<T>, not_present: Vec<T>)
128 where
129 T: ContainsSimd + Copy,
130 {
131 not_present.iter().for_each(|item| {
132 assert!(!T::contains_simd(&vector, *item));
133 });
134
135 vector.iter().for_each(|item| {
136 assert!(T::contains_simd(&vector, *item));
137 });
138 }
139
140 const NUM_TRIALS: usize = 1000;
141
142 #[test]
144 fn contains_works_when_item_is_present() {
145 let dim_dist = Uniform::new(1, 1000).unwrap();
147 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
148 for _ in 0..NUM_TRIALS {
149 let v: Vec<_> = (0..dim_dist.sample(&mut rng))
150 .map(|_| StandardUniform {}.sample(&mut rng))
151 .collect();
152 let index_of_item = rng.random_range(0..v.len());
153 let item = v[index_of_item];
154
155 assert!(u32::contains_simd(&v, item));
156 }
157 }
158
159 #[test]
160 fn contains_works_when_item_is_not_present() {
161 let dim_dist = Uniform::new(1, 1000).unwrap();
163 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
164 for _ in 0..NUM_TRIALS {
165 let v: Vec<_> = (0..dim_dist.sample(&mut rng))
166 .map(|_| StandardUniform {}.sample(&mut rng))
167 .collect();
168
169 let mut item = StandardUniform {}.sample(&mut rng);
170 while v.contains(&item) {
171 item = StandardUniform {}.sample(&mut rng);
172 }
173
174 assert!(!u32::contains_simd(&v, item));
175 }
176 }
177}