bloomur/
filter.rs

1use smallvec::SmallVec;
2
3use super::{hasher::SimMurmur, BloomHasher};
4
5use core::f64::consts::LN_2;
6use std::vec::Vec;
7
8const CACHE_LINE_SIZE: usize = 64;
9const CACHE_LINE_BITS: usize = CACHE_LINE_SIZE * 8;
10
11#[inline]
12const fn calculate_probes(bits_per_key: usize) -> u32 {
13  // We intentionally round down to reduce probing cost a little bit
14  let mut n = (bits_per_key as f64 * 0.69) as u32; // 0.69 ~= ln(2)
15  if n < 1 {
16    n = 1
17  }
18
19  if n > 30 {
20    n = 30
21  }
22
23  n
24}
25
26/// Returns the bits per key required by bloomfilter based on
27/// the false positive rate.
28#[inline]
29pub fn bits_per_key(num_entries: usize, fp: f64) -> usize {
30  use libm::{ceil, log, pow};
31  let size = -1.0 * num_entries as f64 * log(fp) / pow(LN_2, 2.0);
32  ceil(LN_2 * size / num_entries as f64) as usize
33}
34
35/// A bloom filter builder.
36#[derive(Debug, Clone)]
37pub struct Filter<const N: usize = 128, S = SimMurmur> {
38  bits_per_key: usize,
39
40  num_hashes: usize,
41
42  last_hash: u32,
43
44  // We store the hashes in blocks.
45  blocks: SmallVec<[Vec<u32>; 2]>,
46
47  hasher: S,
48}
49
50impl<const N: usize> Filter<N> {
51  /// Creates a new filter builder.
52  ///
53  /// ## Example
54  ///
55  /// ```rust
56  /// use bloomur::Filter;
57  ///
58  /// let f = Filter::<512>::new(1000, 0.01);
59  /// ```
60  #[inline]
61  pub fn new(num_entries: usize, fp: f64) -> Self {
62    let bpk = bits_per_key(num_entries, fp);
63    Self {
64      bits_per_key: bpk,
65      num_hashes: 0,
66      last_hash: 0,
67      blocks: SmallVec::new_const(),
68      hasher: SimMurmur::new(),
69    }
70  }
71
72  /// Creates a new filter builder.
73  ///
74  /// ## Example
75  ///
76  /// ```rust
77  /// use bloomur::Filter;
78  ///
79  /// let f = Filter::<512>::with_bits_per_key(10);
80  /// ```
81  #[inline]
82  pub const fn with_bits_per_key(bits_per_key: usize) -> Self {
83    Self {
84      bits_per_key,
85      num_hashes: 0,
86      last_hash: 0,
87      blocks: SmallVec::new_const(),
88      hasher: SimMurmur::new(),
89    }
90  }
91}
92
93impl<const N: usize, S> Filter<N, S> {
94  /// Creates a new filter builder.
95  ///
96  /// ## Example
97  ///
98  /// ```rust
99  /// use bloomur::{Filter, hasher::SimMurmur};
100  ///
101  /// let f = Filter::<512, SimMurmur>::with_hasher(1000, 0.01, SimMurmur::new());
102  /// ```
103  #[inline]
104  pub fn with_hasher(num_entries: usize, fp: f64, hasher: S) -> Self {
105    let bpk = bits_per_key(num_entries, fp);
106    Self {
107      bits_per_key: bpk,
108      num_hashes: 0,
109      last_hash: 0,
110      blocks: SmallVec::new_const(),
111      hasher,
112    }
113  }
114
115  /// Creates a new filter builder.
116  ///
117  /// ## Example
118  ///
119  /// ```rust
120  /// use bloomur::{Filter, hasher::SimMurmur};
121  ///
122  /// let f = Filter::<512, SimMurmur>::with_bits_per_key_and_hasher(10, SimMurmur::new());
123  /// ```
124  #[inline]
125  pub const fn with_bits_per_key_and_hasher(bits_per_key: usize, hasher: S) -> Self {
126    Self {
127      bits_per_key,
128      num_hashes: 0,
129      last_hash: 0,
130      blocks: SmallVec::new_const(),
131      hasher,
132    }
133  }
134}
135
136impl<const N: usize, S> Filter<N, S>
137where
138  S: BloomHasher,
139{
140  /// Adds a key to the filter.
141  pub fn insert(&mut self, key: &[u8]) {
142    let h = self.hasher.hash_one(key);
143    if self.num_hashes != 0 && h == self.last_hash {
144      return;
145    }
146
147    let ofs = self.num_hashes % N;
148    if ofs == 0 {
149      // Time for a new block
150      self.blocks.push(std::vec![0; N]);
151    }
152
153    self
154      .blocks
155      .last_mut()
156      .expect("blocks cannot be empty")
157      .insert(ofs, h);
158    self.last_hash = h;
159    self.num_hashes += 1;
160  }
161
162  /// Returns the length of the final filter.
163  #[inline]
164  pub const fn filter_length(&self) -> usize {
165    let n_lines = self.n_lines();
166    // +5: 4 bytes for n_lines and 1 byte for n_probes
167    n_lines * CACHE_LINE_SIZE + 5
168  }
169
170  const fn n_lines(&self) -> usize {
171    let mut n_lines = 0;
172    if self.num_hashes != 0 {
173      n_lines = (self.num_hashes * self.bits_per_key).div_ceil(CACHE_LINE_BITS);
174      // Make n_lines an odd number to make sure more bits are involved when
175      // determining which block.
176      if n_lines % 2 == 0 {
177        n_lines += 1;
178      }
179    }
180
181    // +5: 4 bytes for n_lines and 1 byte for n_probes
182    n_lines
183  }
184
185  /// Finalize to the given buffer.
186  ///
187  /// ## Returns
188  ///
189  /// - Returns `Ok(usize)` the number of bytes written to the buffer.
190  /// - Returns `Err(usize)` when the buf does not large enough to hold the filter, the number of bytes required to write the filter.
191  ///
192  /// ## Example
193  ///
194  /// ```rust
195  /// use bloomur::Filter;
196  ///
197  /// let mut f = Filter::<512>::with_bits_per_key(10);
198  /// f.insert(b"hello");
199  /// f.insert(b"world");
200  ///
201  /// let mut buf = vec![0; f.filter_length()];
202  /// let written = f.finalize_to(&mut buf).unwrap();
203  /// ```
204  pub fn finalize_to(self, buf: &mut [u8]) -> Result<usize, usize> {
205    let n_lines = self.n_lines();
206    let n_bytes = n_lines * CACHE_LINE_SIZE;
207    let written = n_bytes + 5;
208    if buf.len() < written {
209      return Err(written);
210    }
211
212    self.finalize_in(n_lines, n_bytes, buf);
213    Ok(written)
214  }
215
216  /// Finalizes the filter.
217  pub fn finalize(self) -> std::vec::Vec<u8> {
218    let n_lines = self.n_lines();
219    let n_bytes = n_lines * CACHE_LINE_SIZE;
220    // +5: 4 bytes for n_lines and 1 byte for n_probes
221    let mut filter = std::vec![0; n_bytes + 5];
222    self.finalize_in(n_lines, n_bytes, &mut filter);
223    filter
224  }
225
226  fn finalize_in(mut self, n_lines: usize, n_bytes: usize, filter: &mut [u8]) {
227    if n_lines != 0 {
228      let n_probes = calculate_probes(self.bits_per_key);
229      let num_blocks = self.blocks.len();
230      for (bidx, b) in self.blocks.iter_mut().enumerate() {
231        let mut length = N;
232        if bidx == num_blocks - 1 && self.num_hashes % N != 0 {
233          length = self.num_hashes % N;
234        }
235
236        for h in &mut b[..length] {
237          let delta = h.rotate_left(15); // rotate right 17 bits
238          let b = (*h % n_lines as u32) * CACHE_LINE_BITS as u32;
239
240          for _ in 0..n_probes {
241            let bit_pos = b + (*h % CACHE_LINE_BITS as u32);
242            filter[(bit_pos / 8) as usize] |= 1 << (bit_pos % 8);
243            *h = h.wrapping_add(delta);
244          }
245        }
246      }
247
248      filter[n_bytes] = n_probes as u8;
249      filter[n_bytes + 1..n_bytes + 5].copy_from_slice((n_lines as u32).to_le_bytes().as_slice());
250    }
251  }
252}
253
254#[cfg(test)]
255mod tests {
256  #[cfg(feature = "xxhash3")]
257  use crate::hasher::Xxh3;
258  #[cfg(feature = "xxhash32")]
259  use crate::hasher::Xxh32;
260
261  use super::*;
262  use crate::FrozenFilter;
263
264  fn new_filter<'a, S: BloomHasher + Default>(
265    bits_per_key: usize,
266    keys: impl Iterator<Item = &'a [u8]>,
267  ) -> std::vec::Vec<u8> {
268    let mut builder =
269      Filter::<512, S>::with_bits_per_key_and_hasher(bits_per_key, Default::default());
270    for key in keys {
271      builder.insert(key);
272    }
273
274    builder.finalize()
275  }
276
277  fn filter_to_string(src: &[u8]) -> String {
278    let mut buf = String::new();
279
280    for (i, x) in src.iter().enumerate() {
281      if i > 0 {
282        if i % 8 == 0 {
283          buf.push('\n');
284        } else {
285          buf.push_str("  ");
286        }
287      }
288
289      for j in 0..8 {
290        if *x & (1 << (7 - j)) != 0 {
291          buf.push('1');
292        } else {
293          buf.push('.');
294        }
295      }
296    }
297
298    buf.push('\n');
299    buf
300  }
301
302  fn small_bloomfilter<S: BloomHasher + Default>(f: &[u8]) {
303    let m = &[
304      ("hello", true),
305      ("world", true),
306      ("x", false),
307      ("foo", false),
308    ];
309
310    let f = FrozenFilter::with_hasher(f, S::default());
311    for (key, want) in m {
312      let got = f.may_contain(key.as_bytes());
313      assert_eq!(got, *want);
314    }
315  }
316
317  #[test]
318  fn small_bloomfilter_simmurur() {
319    let f = new_filter::<SimMurmur>(10, [b"hello", b"world"].iter().map(|e| e.as_slice()));
320
321    let want = r###"
322........  ........  ........  .......1  ........  ........  ........  ........
323........  .1......  ........  .1......  ........  ........  ........  ........
324...1....  ........  ........  ........  ........  ........  ........  ........
325........  ........  ........  ........  ........  ........  ........  ...1....
326........  ........  ........  ........  .....1..  ........  ........  ........
327.......1  ........  ........  ........  ........  ........  .1......  ........
328........  ........  ........  ........  ........  ...1....  ........  ........
329.......1  ........  ........  ........  .1...1..  ........  ........  ........
330.....11.  .......1  ........  ........  ........
331"###;
332
333    let want = want.trim_start();
334    let got = filter_to_string(&f);
335    for i in 0..want.len() {
336      let goti = got.as_bytes()[i];
337      let wanti = want.as_bytes()[i];
338      assert_eq!(goti, wanti, "idx={i}");
339    }
340
341    small_bloomfilter::<SimMurmur>(&f);
342  }
343
344  #[test]
345  #[cfg(feature = "xxhash32")]
346  fn small_bloomfilter_xxhash32() {
347    let f = new_filter::<Xxh32>(10, [b"hello", b"world"].iter().map(|e| e.as_slice()));
348    small_bloomfilter::<Xxh32>(&f);
349  }
350
351  #[test]
352  #[cfg(feature = "xxhash3")]
353  fn small_bloomfilter_xxh3() {
354    let f = new_filter::<Xxh3>(10, [b"hello", b"world"].iter().map(|e| e.as_slice()));
355    small_bloomfilter::<Xxh3>(&f);
356  }
357
358  fn bloom_filter_in<S: BloomHasher + Default>() {
359    let next_length = |x: usize| -> usize {
360      if x < 10 {
361        return x + 1;
362      }
363
364      if x < 100 {
365        return x + 10;
366      }
367
368      if x < 1000 {
369        return x + 100;
370      }
371
372      x + 1000
373    };
374
375    let le32 = |i: usize| -> [u8; 4] {
376      let mut buf = [0; 4];
377      buf[0] = (i as u32) as u8;
378      buf[1] = ((i as u32) >> 8) as u8;
379      buf[2] = ((i as u32) >> 16) as u8;
380      buf[3] = ((i as u32) >> 24) as u8;
381      buf
382    };
383
384    let (mut n_mediocre_filters, mut n_good_filters) = (0, 0);
385
386    'l: loop {
387      let mut length = 1;
388
389      while length <= 10_000 {
390        let keys = (0..length).map(&le32).collect::<std::vec::Vec<_>>();
391
392        let f = new_filter::<S>(10, keys.iter().map(|b| b.as_slice()));
393        // The size of the table bloom filter is measured in multiples of the
394        // cache line size. The '+2' contribution captures the rounding up in the
395        // length division plus preferring an odd number of cache lines. As such,
396        // this formula isn't exact, but the exact formula is hard to read.
397        let max_len = 5 + ((length * 10) / CACHE_LINE_BITS + 2) * CACHE_LINE_SIZE;
398        if f.len() > max_len {
399          #[cfg(feature = "std")]
400          std::eprintln!(
401            "length={}: f.len()={} > max len {}",
402            length,
403            f.len(),
404            max_len
405          );
406          continue;
407        }
408
409        let f = FrozenFilter::with_hasher(f.as_slice(), S::default());
410        // All added keys must match.
411        for key in keys.iter() {
412          if !f.may_contain(key) {
413            #[cfg(feature = "std")]
414            std::eprintln!("length={}: did not contain key {:?}", length, key);
415            continue 'l;
416          }
417        }
418
419        // Check false positive rate.
420        let mut n_false_positive = 0f64;
421        for i in 0..10_000 {
422          if f.may_contain(le32((1e9f64 + i as f64) as usize).as_slice()) {
423            n_false_positive += 1f64;
424          }
425        }
426
427        if n_false_positive > 200f64 {
428          #[cfg(feature = "std")]
429          std::eprintln!(
430            "length={}: n_false_positive={} > 0.02 * 10_000",
431            length,
432            n_false_positive
433          );
434          continue;
435        }
436
437        if n_false_positive > 125f64 {
438          n_mediocre_filters += 1;
439        } else {
440          n_good_filters += 1;
441        }
442
443        length = next_length(length);
444      }
445
446      break;
447    }
448
449    if n_mediocre_filters > n_good_filters / 5 {
450      #[cfg(feature = "std")]
451      eprintln!(
452        "{} mediocre filters buf only {} good filters",
453        n_mediocre_filters, n_good_filters
454      );
455    }
456  }
457
458  #[test]
459  fn bloom_filter_sim_murur() {
460    bloom_filter_in::<SimMurmur>();
461  }
462
463  #[test]
464  #[cfg(feature = "xxhash32")]
465  fn bloom_filter_xxh32() {
466    bloom_filter_in::<Xxh32>();
467  }
468
469  #[test]
470  #[cfg(feature = "xxhash3")]
471  fn bloom_filter_xxh3() {
472    bloom_filter_in::<Xxh3>();
473  }
474}