use bit_vec::BitVec;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
#[derive(Clone)]
pub struct BloomFilter {
bits: BitVec,
k_num_hashes: u32,
}
impl BloomFilter {
pub fn new(num_elements: usize, false_positive_rate: f64) -> Self {
let (num_bits, k_num_hashes) =
Self::calculate_optimal_params(num_elements, false_positive_rate);
Self {
bits: BitVec::from_elem(num_bits, false),
k_num_hashes,
}
}
fn calculate_optimal_params(num_elements: usize, fpr: f64) -> (usize, u32) {
let num_bits =
(-(num_elements as f64) * fpr.ln() / (std::f64::consts::LN_2.powi(2))).ceil() as usize;
let mut k =
((num_bits as f64 / num_elements as f64) * std::f64::consts::LN_2).ceil() as u32;
if k == 0 {
k = 1;
}
(num_bits, k)
}
fn hash_key(key: &[u8]) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
pub fn set(&mut self, key: &[u8]) {
let hash = Self::hash_key(key);
let h1 = (hash & 0xFFFFFFFF) as u32;
let h2 = (hash >> 32) as u32;
let m = self.bits.len();
if m == 0 {
return;
}
for i in 0..self.k_num_hashes {
let bit_idx = (h1 as u64 + (i as u64).wrapping_mul(h2 as u64)) as usize % m;
self.bits.set(bit_idx, true);
}
}
pub fn contains(&self, key: &[u8]) -> bool {
let hash = Self::hash_key(key);
let h1 = (hash & 0xFFFFFFFF) as u32;
let h2 = (hash >> 32) as u32;
let m = self.bits.len();
if m == 0 {
return false;
}
for i in 0..self.k_num_hashes {
let bit_idx = (h1 as u64 + (i as u64).wrapping_mul(h2 as u64)) as usize % m;
if !self.bits.get(bit_idx).unwrap_or(false) {
return false;
}
}
true
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.extend_from_slice(&self.k_num_hashes.to_le_bytes());
bytes.extend_from_slice(&(self.bits.len() as u64).to_le_bytes());
bytes.extend_from_slice(&self.bits.to_bytes());
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 12 {
return None;
}
let k_num_hashes = u32::from_le_bytes(bytes[0..4].try_into().ok()?);
let num_bits = u64::from_le_bytes(bytes[4..12].try_into().ok()?) as usize;
let mut bits = BitVec::from_bytes(&bytes[12..]);
bits.truncate(num_bits);
Some(Self { bits, k_num_hashes })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bloom_filter_basic() {
let mut bf = BloomFilter::new(100, 0.01);
bf.set(b"apple");
bf.set(b"banana");
bf.set(b"grape");
assert!(bf.contains(b"apple"));
assert!(bf.contains(b"banana"));
assert!(bf.contains(b"grape"));
assert!(!bf.contains(b"strawberry"));
assert!(!bf.contains(b"missing"));
}
#[test]
fn test_bloom_filter_false_positives() {
let mut bf = BloomFilter::new(1000, 0.1);
for i in 0..1000 {
let key = format!("key{}", i);
bf.set(key.as_bytes());
}
let mut false_positives = 0;
let tests = 10000;
for i in 1000..(1000 + tests) {
let key = format!("key{}", i);
if bf.contains(key.as_bytes()) {
false_positives += 1;
}
}
let actual_fpr = false_positives as f64 / tests as f64;
assert!(
actual_fpr < 0.15,
"FPR was significantly higher than estimated: {}",
actual_fpr
);
}
}