concurrent_bloom/
bloom.rs

1use {
2  core::fmt, 
3  fnv::FnvHasher, 
4  rand::{rng, Rng}, 
5  serde::{Deserialize, Serialize}, 
6  std::{
7    cmp, 
8    hash::Hasher, 
9    marker::PhantomData, 
10    sync::atomic::{AtomicU64, Ordering},
11  }
12};
13
14fn hash<T: AsRef<[u8]>>(input: T, h_key: u64) -> u64 {
15  let mut hasher = FnvHasher::with_key(h_key);
16  hasher.write(input.as_ref());
17  hasher.finish()
18}
19
20#[derive(Serialize, Deserialize, Default)]
21pub struct Bloom<T: AsRef<[u8]>> {
22  n_bits: u64,
23  n_bits_set: AtomicU64,
24  hash_keys: Vec<u64>,
25  bits: Vec<AtomicU64>,
26  _marker: PhantomData<T>
27}
28
29impl<T: AsRef<[u8]>> fmt::Debug for Bloom<T> {
30  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31    write!(
32      f,
33      "Bloom {{ num_hash_keys: {}, num_bits: {}, num_bit_sets: {}, bits: ",
34      self.hash_keys.len(),
35      self.n_bits,
36      self.num_bits_set(),
37    )?;
38    let first = self.bits[0].load(Ordering::Relaxed);
39    let first_10_bits = first.reverse_bits() >> 54; // Reverse only 10 bits
40    write!(f, "{:010b}..", first_10_bits)?;
41    write!(f, " }}")
42  }
43}
44
45impl<T: AsRef<[u8]>> Bloom<T> {
46/// Creates a thread-safe Bloom filter with an optimal bit size and number of hash functions  
47/// based on the expected number of items and the desired false positive rate.
48pub fn new(n_items: usize, false_rate: f64) ->Self{
49    let n_items = cmp::max(1, n_items);
50    let mut m = (-(n_items as f64)*false_rate.ln()/(2f64.ln()*2f64.ln())).ceil();
51    m = cmp::max(1, m as u64) as f64; // make sure m >= 1
52    let k = (2f64.ln())*m/(n_items as f64).round();
53    let length = (m as u64 + 63)/64; // calculate the length of the AtomicU64 vector
54    let mut r = rng();
55    let hash_keys: Vec<u64> = (0..k as usize).map(|_| r.random()).collect();
56    Bloom { 
57      n_bits: length*64, 
58      n_bits_set: AtomicU64::new(0),
59      hash_keys,
60      bits: (0..length).map(|_| AtomicU64::new(0)).collect(),
61      _marker: PhantomData,
62    }
63  }
64  /// Computes the `u64` index and bitmask for a given input and hash key.
65  /// This is used to set or check the bit corresponding to the input.
66  fn bit_pos(&self, input: &T, h_key: u64) -> (usize, u64) {
67    let p = hash(input, h_key) % self.n_bits;
68    let idx = p>>6;
69    let mask  = 1u64 << (p &63);
70    (idx as usize, mask)
71  }
72  /// Sets the bit corresponding to the given input and hash key in the Bloom filter.
73  fn set_bit(&self, input: &T, h_key: u64) -> bool{
74    let (idx, mask) = self.bit_pos(input, h_key);
75    let prev = self.bits[idx].fetch_or(mask, Ordering::Relaxed);
76    let is_new = prev &mask == 0;
77    if is_new {
78      self.n_bits_set.fetch_add(1, Ordering::Relaxed);
79    }
80    is_new
81  }
82
83  /// Checks the bit corresponding to the given input and hash key in the Bloom filter.
84  fn check_bit(&self, input: &T, h_key: u64) -> bool{
85    let (idx, mask) = self.bit_pos(input, h_key);
86    let bit = self.bits[idx].load(Ordering::Relaxed) & mask;
87    bit > 0
88  }
89  /// Adds an item to Bloom filter
90  pub fn insert(&self, item: &T) {
91    for h_key in &self.hash_keys {
92      self.set_bit(item, *h_key);
93    }
94  }
95  /// Checks if an item is in Bloom filter
96  pub fn contains(&self, item: &T) -> bool {
97    for h_key in &self.hash_keys {
98      if !self.check_bit(item, *h_key) {
99        return false;
100      }
101    }
102    true
103  }
104  // clear all the bits
105  pub fn reset(&self) {
106    for n in &self.bits{
107      n.store(0, Ordering::Relaxed);
108    }
109  }
110  // get the number of bits set
111  pub fn num_bits_set(&self) ->u64 {
112    self.n_bits_set.load(Ordering::Relaxed)
113  }
114}
115
116#[cfg(test)]
117mod test {
118    use {
119      super::Bloom, 
120      rand::{rng, rngs::ThreadRng, Rng}, 
121      rayon::prelude::*, 
122      std::sync::atomic::{AtomicU64, Ordering},
123    };
124  #[test]
125  fn test_bloom_constructor() {
126    let bloom: Bloom<String> = Bloom::new(0, 0.1);
127    assert_eq!(bloom.n_bits, 64);
128    assert_eq!(bloom.hash_keys.len(), 3);
129
130    let bloom: Bloom<String> = Bloom::new(10, 0.1);
131    assert_eq!(bloom.n_bits, 64);
132    assert_eq!(bloom.hash_keys.len(), 3);
133
134    let bloom: Bloom<String> = Bloom::new(100, 0.1);
135    assert_eq!(bloom.n_bits, 512);
136    assert_eq!(bloom.hash_keys.len(), 3);
137  }
138  #[test]
139  fn test_bloom_hash_keys_randomness() {
140    let mut bloom1: Bloom<String> = Bloom::new(10, 0.1);
141    let mut bloom2: Bloom<String> = Bloom::new(10, 0.1);
142    assert_eq!(bloom1.hash_keys.len(), bloom2.hash_keys.len());
143    bloom1.hash_keys.sort_unstable();
144    bloom2.hash_keys.sort_unstable();
145    assert_ne!(bloom1.hash_keys, bloom2.hash_keys);
146  }
147  const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
148                         abcdefghijklmnopqrstuvwxyz\
149                         0123456789";
150  fn random_string(rng: &mut ThreadRng) -> String {
151    let length = rng.random_range(1..64);
152    (0..length).map(|_| CHARSET[rng.random_range(0..CHARSET.len())] as char).collect()
153  }
154  #[test]
155  fn test_bloom_insert_contains() {
156    let bloom: Bloom<String> = Bloom::new(2100, 0.1);
157    println!("{:?}", bloom);
158    assert_eq!(10112, bloom.n_bits);
159    assert_eq!(3, bloom.hash_keys.len());
160    let mut r = rng();
161    let items: Vec<String> = (0..2000).map(|_| random_string(&mut r)).collect();
162
163    let false_positives = AtomicU64::new(0);
164    items.par_iter().for_each(|item| {
165      if bloom.contains(item) {
166        false_positives.fetch_add(1, Ordering::Relaxed);
167      }
168      bloom.insert(item);
169      assert!(bloom.contains(item));
170    });
171    assert!(
172      false_positives.load(Ordering::Relaxed) < 200, 
173      "false_positive: {}", false_positives.load(Ordering::Relaxed));
174    // test false_positives more intensively
175    false_positives.store(0, Ordering::Relaxed);
176    (0..10000).for_each(|_| {
177      let item = random_string(&mut r);
178      if bloom.contains(&item) {
179        false_positives.fetch_add(1, Ordering::Relaxed);
180      }
181    });
182    assert!(false_positives.load(Ordering::Relaxed) < 2000, 
183    "false_positive: {}", false_positives.load(Ordering::Relaxed));
184  }
185}