use bitvec::bitvec;
use std::f64::consts::{E, LN_2};
use std::hash::{BuildHasher, Hash, Hasher};
use std::marker::PhantomData;
use twox_hash::RandomXxHashBuilder;
const LN2_SQUARED: f64 = LN_2 * LN_2;
pub struct BloomFilter<T> {
n: u64,
m: u64,
k: u32,
bit_vec: bitvec::vec::BitVec,
build_hasher: RandomXxHashBuilder,
_phantom: PhantomData<T>,
}
impl<T: Hash> BloomFilter<T> {
pub fn new(false_positive_rate: f64, estimated_items: usize) -> Self {
assert!(
false_positive_rate > 0_f64 && false_positive_rate < 1_f64,
"False positive rate must be between 0 and 1 (non-inclusive)"
);
assert!(
estimated_items > 0,
"Number of estimated items must be greater than zero"
);
let num_bits = -(estimated_items as f64) * false_positive_rate.ln() / LN2_SQUARED;
let num_hashes = (num_bits / estimated_items as f64) * LN_2;
let num_bits = num_bits.ceil() as u64;
let num_hashes = num_hashes.ceil() as u32;
BloomFilter {
n: 0,
m: num_bits,
k: num_hashes,
bit_vec: bitvec![0; num_bits as usize],
build_hasher: RandomXxHashBuilder::default(),
_phantom: PhantomData,
}
}
pub fn add(&mut self, item: &T) {
for i in indices_for_hash(split_hash(item, &self.build_hasher), self.m, self.k) {
self.bit_vec.set(i, true);
}
self.n += 1;
}
pub fn might_contain(&self, item: &T) -> bool {
for i in indices_for_hash(split_hash(item, &self.build_hasher), self.m, self.k) {
if !self.bit_vec[i] {
return false;
}
}
true
}
pub fn false_positive_rate(&self) -> f64 {
(1_f64 - E.powf(-1_f64 * f64::from(self.k) * self.n as f64 / self.m as f64))
.powi(self.k as i32)
}
}
fn split_hash<T: Hash>(item: &T, hasher: &impl BuildHasher) -> (u32, u32) {
let mut hasher = hasher.build_hasher();
item.hash(&mut hasher);
let hash = hasher.finish();
(((hash >> 32) as u32), hash as u32)
}
fn indices_for_hash(split_hash: (u32, u32), m: u64, k: u32) -> impl Iterator<Item = usize> {
(0..k).map(move |i| {
(u64::from(split_hash.0.wrapping_add(split_hash.1.wrapping_mul(i))) % m) as usize
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_num_bits_and_hashes() {
let filter = BloomFilter::<&str>::new(0.01_f64, 216553);
assert_eq!(filter.m, 2_075_674);
assert_eq!(filter.k, 7);
}
#[test]
fn test_false_positive_rate_empty() {
let filter = BloomFilter::<&str>::new(0.01_f64, 216553);
assert_eq!(filter.false_positive_rate(), 0_f64);
}
#[test]
fn test_add() {
let mut filter = BloomFilter::new(0.03_f64, 10);
filter.add(&"Hello, world!");
assert!(filter.false_positive_rate() > 0.0);
assert_eq!(filter.might_contain(&"Hello, world!"), true);
assert_eq!(filter.might_contain(&"Dogs are cool!"), false);
}
}