use crate::{
common::{BloomParams, HashIndexIterator},
utils::HexFieldDebug,
};
use bitvec::{prelude::Lsb0, view::BitView};
use std::fmt::Debug;
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct BloomFilter {
k_hashes: usize,
bytes: Box<[u8]>,
}
impl BloomFilter {
pub fn new_from_fpr(n_elems: u64, fpr: f64) -> Self {
let params = BloomParams::new_from_fpr(n_elems, fpr);
let bits = Box::from(vec![0u8; params.byte_size].as_ref());
Self {
k_hashes: params.k_hashes,
bytes: bits,
}
}
pub fn new_from_fpr_po2(n_elems: u64, fpr: f64) -> Self {
let params = BloomParams::new_from_fpr_po2(n_elems, fpr);
let bits = Box::from(vec![0u8; params.byte_size].as_ref());
Self {
k_hashes: params.k_hashes,
bytes: bits,
}
}
pub fn new_from_size(bloom_bytes: usize, n_elems: u64) -> Self {
let params = BloomParams::new_from_size(bloom_bytes, n_elems);
let bits = Box::from(vec![0u8; params.byte_size].as_ref());
Self {
k_hashes: params.k_hashes,
bytes: bits,
}
}
pub fn new_with(k_hashes: usize, bytes: Box<[u8]>) -> Self {
Self { k_hashes, bytes }
}
pub fn get_bloom_params(&self) -> BloomParams {
BloomParams {
k_hashes: self.k_hashes,
byte_size: self.bytes.len(),
}
}
pub fn false_positive_rate_at(&self, n_elems: u64) -> f64 {
self.get_bloom_params().false_positive_rate_at(n_elems)
}
pub fn current_false_positive_rate(&self) -> f64 {
let m = (self.bytes.len() * 8) as f64;
let m_set = self.count_ones() as f64;
let load = m_set / m;
load.powi(self.hash_count() as i32)
}
pub fn count_ones(&self) -> usize {
self.bytes.view_bits::<Lsb0>().count_ones()
}
pub fn insert(&mut self, item: &impl AsRef<[u8]>) {
for i in self.hash_indices(item) {
self.bytes.view_bits_mut::<Lsb0>().set(i, true);
}
}
pub fn contains(&self, item: &impl AsRef<[u8]>) -> bool {
for i in self.hash_indices(item) {
if !self.bytes.view_bits::<Lsb0>()[i] {
return false;
}
}
true
}
pub fn hash_count(&self) -> usize {
self.k_hashes
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
pub fn hash_indices<'a>(&self, item: &'a impl AsRef<[u8]>) -> impl Iterator<Item = usize> + 'a {
HashIndexIterator::new(item, self.bytes.len() * 8).take(self.hash_count())
}
}
impl Debug for BloomFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BloomFilter")
.field("k_hashes", &self.k_hashes)
.field("bytes", &HexFieldDebug(&self.bytes))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::BloomFilter;
#[test]
fn serialization_round_trip() {
let mut filter = BloomFilter::new_from_fpr(100, 0.001);
filter.insert(b"Hello");
filter.insert(b"World!");
let serialized_bytes = filter.as_bytes();
let serialized_k = filter.hash_count();
let deserialized = BloomFilter::new_with(serialized_k, Box::from(serialized_bytes));
assert!(deserialized.contains(b"Hello"));
assert!(!deserialized.contains(b"abc"));
assert_eq!(deserialized, filter);
}
#[test]
fn empty_bloom_filter() {
let filter = BloomFilter::new_with(3, Box::new([]));
assert!(filter.contains(&[1, 2, 3]));
}
}
#[cfg(test)]
mod proptests {
use super::BloomFilter;
use proptest::prop_assert;
use test_strategy::proptest;
#[proptest]
fn inserted_always_contained(items: Vec<u64>, #[strategy(100usize..10_000)] size: usize) {
let capacity = std::cmp::max(items.len() as u64, 1);
let mut filter = BloomFilter::new_from_size(size, capacity);
for item in items.iter() {
filter.insert(&item.to_le_bytes());
}
for item in items.iter() {
prop_assert!(filter.contains(&item.to_le_bytes()));
}
}
#[proptest]
fn false_positive_rate_as_predicted(
#[strategy(100u64..1_000)] n_elems: u64,
#[strategy(100.0..10_000.0)] inv_fpr: f64,
) {
let fpr = 1.0 / inv_fpr;
let mut filter = BloomFilter::new_from_fpr(n_elems, fpr);
for i in 0..n_elems {
filter.insert(&i.to_le_bytes());
}
let measurements = 100_000;
let mut false_positives = 0;
for i in n_elems..n_elems + measurements {
if filter.contains(&i.to_le_bytes()) {
false_positives += 1;
}
}
let computed_fpr = false_positives as f64 / measurements as f64;
prop_assert!((computed_fpr - fpr).abs() < 1.5e-3);
}
}