blewm/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
3
4use std::hash::{Hash, BuildHasher, RandomState};
5use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
6
7
8const USIZE_BITS: usize = usize::BITS as usize;
9const INDEX_1_SHIFT: u32 = usize::ilog2(USIZE_BITS);
10const INDEX_2_MASK: usize = (1 << INDEX_1_SHIFT) - 1;
11
12/// A fast, concurrent Bloom filter.
13#[derive(Debug)]
14pub struct Filter<H = RandomState> {
15    bits: Box<[AtomicUsize]>,
16    num_hashes: usize,
17    hasher: H,
18}
19
20impl<H> Filter<H> {
21    /// Creates a new concurrent Bloom filter that does not exceed the specified
22    /// false positive rate when at most `capacity` elements are inserted.
23    ///
24    /// ## Panics
25    ///
26    /// Panics if `false_positive_rate` is not in the range `(0, 1)`.
27    pub fn with_hasher(capacity: usize, false_positive_rate: f64, hasher: H) -> Self {
28        assert!(0.0 < false_positive_rate && false_positive_rate < 1.0);
29
30        // compute optimal number of hashes and bits
31        let log2_eps = f64::log2(false_positive_rate);
32        let num_hashes = 1usize.max(f64::round(-log2_eps) as usize);
33        let min_num_bits = f64::round(capacity as f64 * -log2_eps / std::f64::consts::LN_2) as usize;
34        let num_bits = min_num_bits.next_power_of_two();
35        let num_slots = 1usize.max(num_bits / USIZE_BITS);
36        let bits: Box<[_]> = (0..num_slots).map(|_| AtomicUsize::new(0)).collect();
37
38        Filter { bits, num_hashes, hasher }
39    }
40
41    #[inline]
42    pub const fn as_bits(&self) -> &[AtomicUsize] {
43        &self.bits
44    }
45
46    #[inline]
47    pub const fn as_mut_bits(&mut self) -> &mut [AtomicUsize] {
48        &mut self.bits
49    }
50
51    #[inline]
52    pub fn into_bits(self) -> Box<[AtomicUsize]> {
53        self.bits
54    }
55
56    #[inline]
57    pub const fn num_slots(&self) -> usize {
58        self.bits.len()
59    }
60
61    #[inline]
62    pub const fn num_bits(&self) -> usize {
63        self.bits.len() * USIZE_BITS
64    }
65
66    #[inline]
67    pub const fn num_hashes(&self) -> usize {
68        self.num_hashes
69    }
70
71    #[inline]
72    const fn slot_mask(&self) -> usize {
73        self.num_slots() - 1
74    }
75
76    /// Inserts an element based on its hash into the Bloom filter.
77    /// Returns `true` if the element already existed (subject to
78    /// the usual caveats regarding false positives).
79    pub fn insert_hash(&self, mut hash: u64) -> bool {
80        let slot_mask = self.slot_mask();
81        let aux = aux_hash(hash);
82        let mut exists = true;
83
84        for _ in 0..self.num_hashes {
85            let h = hash as usize;
86            let slot_idx = (h >> INDEX_1_SHIFT) & slot_mask;
87            let bit_idx = h & INDEX_2_MASK;
88            let bit_mask = 1 << bit_idx;
89            let slot_bits = self.bits[slot_idx].fetch_or(bit_mask, Relaxed);
90
91            hash = next_hash(hash, aux);
92            exists &= (slot_bits & bit_mask) != 0;
93        }
94
95        exists
96    }
97
98    /// Returns `false` if the hash of an element is not in the set,
99    /// and `true` if it _may_ be in the set with high probability.
100    pub fn contains_hash(&self, mut hash: u64) -> bool {
101        let slot_mask = self.slot_mask();
102        let aux = aux_hash(hash);
103
104        (0..self.num_hashes).all(|_| {
105            let h = hash as usize;
106            let slot_idx = (h >> INDEX_1_SHIFT) & slot_mask;
107            let bit_idx = h & INDEX_2_MASK;
108            let bit_mask = 1 << bit_idx;
109            let slot_bits = self.bits[slot_idx].load(Relaxed);
110
111            hash = next_hash(hash, aux);
112            slot_bits & bit_mask != 0
113        })
114    }
115
116    /// Removes all elements.
117    ///
118    /// ## Examples
119    ///
120    /// ```
121    /// # use blewm::Filter;
122    /// let filter = Filter::new(1000, 0.0001);
123    ///
124    /// for x in 1000..2000 {
125    ///     assert_eq!(filter.insert(x), false);
126    /// }
127    ///
128    /// for x in 1000..2000 {
129    ///     assert!(filter.contains(x));
130    /// }
131    ///
132    /// filter.clear();
133    ///
134    /// for x in 0000..3000 {
135    ///     assert_eq!(filter.contains(x), false);
136    /// }
137    /// ```
138    pub fn clear(&self) {
139        for slot in &self.bits {
140            slot.store(0, Relaxed);
141        }
142    }
143}
144
145impl<H: BuildHasher> Filter<H> {
146    /// Inserts an element into the Bloom filter.
147    /// Returns `true` if the element already existed (subject to
148    /// the usual caveats regarding false positives).
149    ///
150    /// ## Examples
151    ///
152    /// ```
153    /// # use blewm::Filter;
154    /// let filter = Filter::new(1000, 0.0001);
155    ///
156    /// filter.insert(1);
157    /// filter.insert(2);
158    /// filter.insert(3);
159    /// filter.insert(42);
160    ///
161    /// assert!(filter.contains(1));
162    /// assert!(filter.contains(2));
163    /// assert!(filter.contains(3));
164    ///
165    /// for x in 4..1000 {
166    ///     assert_eq!(filter.contains(x), x == 42);
167    /// }
168    /// ```
169    pub fn insert<T: Hash>(&self, element: T) -> bool {
170        let hash = self.hasher.hash_one(element);
171        self.insert_hash(hash)
172    }
173
174    pub fn contains<T: Hash>(&self, element: T) -> bool {
175        let hash = self.hasher.hash_one(element);
176        self.contains_hash(hash)
177    }
178
179    /// Returns the hash of the element, based on which it
180    /// can be inserted to or looked up in the Bloom filter.
181    #[inline]
182    pub fn hash_item<T: Hash>(&self, element: T) -> u64 {
183        self.hasher.hash_one(element)
184    }
185}
186
187impl<H: Default> Filter<H> {
188    /// Creates a new concurrent Bloom filter that does not exceed the specified
189    /// false positive rate when at most `capacity` elements are inserted.
190    ///
191    /// See the documentation of [`Filter::with_hasher`] for details.
192    #[inline]
193    pub fn with_default_hasher(capacity: usize, false_positive_rate: f64) -> Self {
194        Self::with_hasher(capacity, false_positive_rate, H::default())
195    }
196}
197
198impl Filter<RandomState> {
199    /// Creates a new concurrent Bloom filter that does not exceed the specified
200    /// false positive rate when at most `capacity` elements are inserted, using
201    /// the default hasher.
202    ///
203    /// See the documentation of [`Filter::with_hasher`] for details.
204    #[inline]
205    pub fn new(capacity: usize, false_positive_rate: f64) -> Self {
206        Self::with_default_hasher(capacity, false_positive_rate)
207    }
208}
209
210impl<H: Clone> Clone for Filter<H> {
211    fn clone(&self) -> Self {
212        Filter {
213            bits: self.bits.iter().map(|slot| AtomicUsize::new(slot.load(Relaxed))).collect(),
214            num_hashes: self.num_hashes,
215            hasher: self.hasher.clone(),
216        }
217    }
218}
219
220impl<H: PartialEq> PartialEq for Filter<H> {
221    fn eq(&self, other: &Self) -> bool {
222        self.num_hashes == other.num_hashes // necessary because different # of hashers produce distinct patterns
223            && self.hasher == other.hasher // necessary because different hashers may produce different patterns
224            && self.bits.len() == other.bits.len()
225            && self.bits.iter().map(|x| x.load(Relaxed)).eq(other.bits.iter().map(|y| y.load(Relaxed)))
226    }
227}
228
229impl<H: Eq> Eq for Filter<H> {}
230
231impl<H> From<Filter<H>> for Box<[AtomicUsize]> {
232    fn from(filter: Filter<H>) -> Self {
233        filter.into_bits()
234    }
235}
236
237impl<H> AsRef<[AtomicUsize]> for Filter<H> {
238    fn as_ref(&self) -> &[AtomicUsize] {
239        self.as_bits()
240    }
241}
242
243impl<H> AsMut<[AtomicUsize]> for Filter<H> {
244    fn as_mut(&mut self) -> &mut [AtomicUsize] {
245        self.as_mut_bits()
246    }
247}
248
249/// Explicit impl for the immutable reference type,
250/// so that we don't have to have mutable references
251/// (which `Extend` would normally require).
252impl<H, T> Extend<T> for &Filter<H>
253where
254    H: BuildHasher,
255    T: Hash,
256{
257    /// ```
258    /// # use blewm::Filter;
259    /// let filter = Filter::new(1000, 1e-4);
260    /// let mut filter = &filter;
261    ///
262    /// let range = 7386..8386; // chosen by fair dice roll
263    /// for x in range.clone() {
264    ///     assert_eq!(filter.contains(x), false);
265    /// }
266    ///
267    /// let new_elems = [7922, 7685, 8313, 7426, 8118, 7394];
268    /// filter.extend(new_elems);
269    ///
270    /// for x in range {
271    ///     assert_eq!(filter.contains(x), new_elems.contains(&x));
272    /// }
273    ///
274    /// for x in (0..1000).chain(9000..10_000) {
275    ///     assert_eq!(filter.contains(x), false);
276    /// }
277    /// ```
278    fn extend<I>(&mut self, iter: I)
279    where
280        I: IntoIterator<Item = T>
281    {
282        for item in iter {
283            self.insert(item);
284        }
285    }
286}
287
288/// The non-reference impl is done in terms of the impl for `&Filter<H>`.
289impl<H, T> Extend<T> for Filter<H>
290where
291    H: BuildHasher,
292    T: Hash,
293{
294    /// ```
295    /// # use blewm::Filter;
296    /// let mut filter = Filter::new(1000, 1e-4);
297    ///
298    /// let range = 7386..8386; // chosen by fair dice roll
299    /// for x in range.clone() {
300    ///     assert_eq!(filter.contains(x), false);
301    /// }
302    ///
303    /// let new_elems = [7922, 7685, 8313, 7426, 8118, 7394];
304    /// filter.extend(new_elems);
305    ///
306    /// for x in range {
307    ///     assert_eq!(filter.contains(x), new_elems.contains(&x));
308    /// }
309    ///
310    /// for x in (0..1000).chain(9000..10_000) {
311    ///     assert_eq!(filter.contains(x), false);
312    /// }
313    /// ```
314    fn extend<I>(&mut self, iter: I)
315    where
316        I: IntoIterator<Item = T>
317    {
318        let mut this: &Self = &*self;
319        this.extend(iter);
320    }
321}
322
323/// Adapted from the `fastbloom` crate
324#[inline]
325fn aux_hash(hash: u64) -> u64 {
326    hash.wrapping_shr(32)
327        .wrapping_mul(0x517c_c1b7_2722_0a95) // 0xffff_ffff_ffff_ffff / 0x517c_c1b7_2722_0a95 = π
328}
329
330/// Adapted from the `fastbloom` crate
331#[inline]
332fn next_hash(hash: u64, aux: u64) -> u64 {
333    hash.wrapping_add(aux).rotate_left(5)
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use std::thread::{self, JoinHandle};
340    use ahash::RandomState as AHashRandomState;
341
342    fn generic_hasher<H>(desired_fpr: f64) -> JoinHandle<()>
343    where
344        H: Default + BuildHasher
345    {
346        thread::spawn(move || generic_hasher_impl::<H>(desired_fpr))
347    }
348
349    fn generic_hasher_impl<H>(desired_fpr: f64)
350    where
351        H: Default + BuildHasher
352    {
353        for log_cap in [0, 4, 8, 12, 16, 20, 24, 26] {
354            let cap = 1 << log_cap;
355            let len = cap;
356            let filter = Filter::<H>::with_default_hasher(cap, desired_fpr);
357            let mut fp = 0;
358
359            dbg!(log_cap, cap, len, filter.num_slots(), filter.num_bits(), filter.num_hashes());
360
361            for i in 0..len {
362                if i % 2 == 0 {
363                    fp += u64::from(filter.insert(i));
364                }
365            }
366
367            for i in 0..len {
368                if i % 2 == 0 {
369                    // there should be no false negatives, ever
370                    assert!(filter.contains(i));
371                } else {
372                    // count false positives
373                    fp += u64::from(filter.contains(i));
374                }
375            }
376
377            // The observed false positive ratio should be within a reasonable margin
378            // of the desired FPR.
379            let actual_fpr = fp as f64 / (len as f64 / 2.0); // divide by 2 because we omitted odd elements
380
381            dbg!(std::any::type_name::<H>(), actual_fpr, desired_fpr);
382            assert!(actual_fpr <= 3.0 * desired_fpr);
383        }
384    }
385
386    #[test]
387    fn works_with_default_hasher() {
388        let handles: Vec<_> = (1..=9)
389            .map(|neg_log_fpr| {
390                generic_hasher::<RandomState>(f64::powi(10.0, -neg_log_fpr))
391            })
392            .collect();
393
394        // explicitly join handles so we catch panics
395        for h in handles {
396            h.join().unwrap();
397        }
398    }
399
400    #[test]
401    fn works_with_3rd_party_hasher() {
402        let handles: Vec<_> = (1..=9)
403            .map(|neg_log_fpr| {
404                generic_hasher::<AHashRandomState>(f64::powi(10.0, -neg_log_fpr))
405            })
406            .collect();
407
408        // explicitly join handles so we catch panics
409        for h in handles {
410            h.join().unwrap();
411        }
412    }
413
414    #[test]
415    fn multi_threaded() {
416        let num_threads = thread::available_parallelism().unwrap().get();
417        let items_per_thread = 1_000_000;
418        let desired_fpr = 1.0e-6;
419        let filter = Filter::<AHashRandomState>::with_default_hasher(num_threads * items_per_thread, desired_fpr);
420
421        thread::scope(|scope| {
422            let mut handles = Vec::with_capacity(num_threads);
423
424            for i in 0..num_threads {
425                let filter = &filter;
426                let h = scope.spawn(move || {
427                    let start_idx = i * items_per_thread;
428                    let end_idx = start_idx + items_per_thread;
429                    let mut fp = 0;
430
431                    for j in start_idx..end_idx {
432                        if j % 3 == 1 {
433                            fp += usize::from(filter.insert(j));
434                            assert!(filter.contains(j));
435                        }
436                    }
437
438                    for j in start_idx..end_idx {
439                        if j % 3 == 1 {
440                            assert!(filter.contains(j));
441                        } else {
442                            fp += usize::from(filter.contains(j));
443                        }
444                    }
445
446                    // multiply by 2/3 because we inserted every 3rd item
447                    let actual_fpr = fp as f64 / (items_per_thread as f64 * 2.0 / 3.0);
448
449                    dbg!(i, actual_fpr);
450                    assert!(actual_fpr < 3.0 * desired_fpr);
451                });
452                handles.push(h);
453            }
454
455            // explicitly join threads to forward panics
456            for h in handles {
457                h.join().unwrap();
458            }
459        });
460    }
461}