Skip to main content

diskann_vector/contains/
mod.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6cfg_if::cfg_if! {
7    if #[cfg(all(target_arch="x86_64", target_feature = "avx2"))] {
8        mod x86_64;
9        use x86_64::{contains_simd_u32, contains_simd_u64};
10    } else {
11        mod fallback;
12        use fallback::{contains_simd_u32, contains_simd_u64};
13    }
14}
15
16/// A SIMD-accelerated version of
17/// [`std::slice::contains`](https://doc.rust-lang.org/std/primitive.slice.html#method.contains)
18pub trait ContainsSimd: Sized {
19    fn contains_simd(vector: &[Self], target: Self) -> bool;
20}
21
22impl ContainsSimd for u32 {
23    fn contains_simd(vector: &[u32], target: u32) -> bool {
24        contains_simd_u32(vector, target)
25    }
26}
27
28impl ContainsSimd for u64 {
29    fn contains_simd(vector: &[u64], target: u64) -> bool {
30        contains_simd_u64(vector, target)
31    }
32}
33
34#[cfg(test)]
35mod tests {
36    use rand::{
37        distr::{Distribution, StandardUniform, Uniform},
38        Rng, SeedableRng,
39    };
40
41    use super::*;
42
43    /// Test `contains_simd` for dimension of slice from 0 to `max_dim`.
44    ///
45    /// This tests works by initializing a slice with the elements `[0, 1, ... dim - 1]`
46    /// and then searching for each of `[0, 1, ... dim - 1]` (as well as a few higher
47    /// values which are not expected to be in the slice.
48    ///
49    /// This ensures that we can match all possible locations in any length slice up to
50    /// `max_dim`.
51    fn test_contains<T>(max_dim: usize)
52    where
53        T: ContainsSimd + Copy + TryFrom<usize> + PartialEq,
54        <T as TryFrom<usize>>::Error: std::fmt::Debug,
55    {
56        for dim in 0..max_dim {
57            println!("working on dim {dim}");
58            let v: Vec<T> = (0..dim).map(|i| i.try_into().unwrap()).collect();
59
60            // All of these queries should return success.
61            for query in 0..dim {
62                assert!(
63                    T::contains_simd(&v, query.try_into().unwrap()),
64                    "expected query {} to be iota slice of dimension {}",
65                    query,
66                    dim
67                );
68            }
69
70            // None of these should return success.
71            for query in dim..dim + 10 {
72                assert!(
73                    !T::contains_simd(&v, query.try_into().unwrap()),
74                    "expected query {} not to be iota slice of dimension {}",
75                    query,
76                    dim
77                );
78            }
79        }
80    }
81
82    #[test]
83    fn test_contains_u32() {
84        test_contains::<u32>(128);
85    }
86
87    #[test]
88    fn test_contains_u64() {
89        test_contains::<u64>(128);
90    }
91
92    #[test]
93    fn test_contains_simd_u32() {
94        let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8];
95        test_contains_simd::<u32>(vector, vec![9]);
96    }
97
98    #[test]
99    fn test_contains_simd_u64() {
100        let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8];
101        test_contains_simd::<u64>(vector, vec![9]);
102    }
103
104    #[test]
105    fn test_contains_simd_multiple_of_8_u32() {
106        let vector = vec![5, 7, 6, 3, 2, 1, 4, 0];
107        test_contains_simd::<u32>(vector, vec![9, 8]);
108    }
109
110    #[test]
111    fn test_contains_simd_multiple_of_8_u64() {
112        let vector = vec![5, 7, 6, 3, 2, 1, 4, 0];
113        test_contains_simd::<u64>(vector, vec![9, 8]);
114    }
115
116    #[test]
117    fn test_contains_simd_non_multiple_of_8_u32() {
118        let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 11];
119        test_contains_simd::<u32>(vector, vec![9, 8]);
120    }
121
122    #[test]
123    fn test_contains_simd_non_multiple_of_8_u64() {
124        let vector = vec![5, 7, 6, 3, 2, 1, 4, 0, 11];
125        test_contains_simd::<u64>(vector, vec![9, 8]);
126    }
127    fn test_contains_simd<T>(vector: Vec<T>, not_present: Vec<T>)
128    where
129        T: ContainsSimd + Copy,
130    {
131        not_present.iter().for_each(|item| {
132            assert!(!T::contains_simd(&vector, *item));
133        });
134
135        vector.iter().for_each(|item| {
136            assert!(T::contains_simd(&vector, *item));
137        });
138    }
139
140    const NUM_TRIALS: usize = 1000;
141
142    // Fuzz testing.
143    #[test]
144    fn contains_works_when_item_is_present() {
145        // The distribution used select the length of slice being tests.
146        let dim_dist = Uniform::new(1, 1000).unwrap();
147        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
148        for _ in 0..NUM_TRIALS {
149            let v: Vec<_> = (0..dim_dist.sample(&mut rng))
150                .map(|_| StandardUniform {}.sample(&mut rng))
151                .collect();
152            let index_of_item = rng.random_range(0..v.len());
153            let item = v[index_of_item];
154
155            assert!(u32::contains_simd(&v, item));
156        }
157    }
158
159    #[test]
160    fn contains_works_when_item_is_not_present() {
161        // The distribution used select the length of slice being tests.
162        let dim_dist = Uniform::new(1, 1000).unwrap();
163        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
164        for _ in 0..NUM_TRIALS {
165            let v: Vec<_> = (0..dim_dist.sample(&mut rng))
166                .map(|_| StandardUniform {}.sample(&mut rng))
167                .collect();
168
169            let mut item = StandardUniform {}.sample(&mut rng);
170            while v.contains(&item) {
171                item = StandardUniform {}.sample(&mut rng);
172            }
173
174            assert!(!u32::contains_simd(&v, item));
175        }
176    }
177}