use serde::{Deserialize, Serialize};
use crate::entry::Hash;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BloomFilter {
bits: Vec<u64>,
num_bits: usize,
num_hashes: u32,
count: usize,
}
impl BloomFilter {
pub fn validate(&self) -> Result<(), String> {
if self.num_bits == 0 {
return Err("bloom filter num_bits must be > 0".into());
}
if self.bits.len() * 64 < self.num_bits {
return Err(format!(
"bloom filter bits array too small: {} words for {} bits",
self.bits.len(),
self.num_bits
));
}
if self.num_hashes == 0 || self.num_hashes > 32 {
return Err(format!(
"bloom filter num_hashes {} out of range [1, 32]",
self.num_hashes
));
}
Ok(())
}
pub fn new(expected_items: usize, fp_rate: f64) -> Self {
assert!(expected_items > 0, "expected_items must be > 0");
assert!((0.0..1.0).contains(&fp_rate), "fp_rate must be in (0, 1)");
let n = expected_items as f64;
let ln2 = std::f64::consts::LN_2;
let ln2_sq = ln2 * ln2;
let num_bits = ((-n * fp_rate.ln()) / ln2_sq).ceil() as usize;
let num_bits = num_bits.max(64); let num_hashes = ((num_bits as f64 / n) * ln2).ceil() as u32;
let num_hashes = num_hashes.max(1);
let words = num_bits.div_ceil(64);
Self {
bits: vec![0u64; words],
num_bits,
num_hashes,
count: 0,
}
}
pub fn insert(&mut self, hash: &Hash) {
for idx in self.indices(hash) {
let word = idx / 64;
let bit = idx % 64;
self.bits[word] |= 1u64 << bit;
}
self.count += 1;
}
pub fn contains(&self, hash: &Hash) -> bool {
for idx in self.indices(hash) {
let word = idx / 64;
let bit = idx % 64;
if self.bits[word] & (1u64 << bit) == 0 {
return false;
}
}
true
}
pub fn count(&self) -> usize {
self.count
}
pub fn merge(&mut self, other: &BloomFilter) {
assert_eq!(self.num_bits, other.num_bits, "bloom filter size mismatch");
assert_eq!(
self.num_hashes, other.num_hashes,
"bloom filter hash count mismatch"
);
for (a, b) in self.bits.iter_mut().zip(other.bits.iter()) {
*a |= *b;
}
self.count += other.count;
}
pub fn to_bytes(&self) -> Vec<u8> {
rmp_serde::to_vec(self).expect("bloom filter serialization should not fail")
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
rmp_serde::from_slice(bytes)
}
fn indices(&self, hash: &Hash) -> Vec<usize> {
let h1 = u64::from_le_bytes(hash[0..8].try_into().unwrap());
let h2 = u64::from_le_bytes(hash[8..16].try_into().unwrap());
let m = self.num_bits as u64;
(0..self.num_hashes)
.map(|i| {
let i = i as u64;
let idx = h1
.wrapping_add(i.wrapping_mul(h2))
.wrapping_add(i.wrapping_mul(i));
(idx % m) as usize
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_hash(seed: u8) -> Hash {
let mut h = [0u8; 32];
h[0] = seed;
*blake3::hash(&h).as_bytes()
}
#[test]
fn bloom_insert_and_check() {
let mut bloom = BloomFilter::new(100, 0.01);
let h1 = make_hash(1);
let h2 = make_hash(2);
let h3 = make_hash(3);
bloom.insert(&h1);
bloom.insert(&h2);
assert!(bloom.contains(&h1));
assert!(bloom.contains(&h2));
assert!(!bloom.contains(&h3));
}
#[test]
fn bloom_empty_contains_nothing() {
let bloom = BloomFilter::new(100, 0.01);
for i in 0..=255 {
assert!(!bloom.contains(&make_hash(i)));
}
}
#[test]
fn bloom_false_positive_rate() {
let n = 1000;
let mut bloom = BloomFilter::new(n, 0.01);
for i in 0..n {
let h = *blake3::hash(&(i as u64).to_le_bytes()).as_bytes();
bloom.insert(&h);
}
let test_count = 10_000;
let mut false_positives = 0;
for i in n..(n + test_count) {
let h = *blake3::hash(&(i as u64).to_le_bytes()).as_bytes();
if bloom.contains(&h) {
false_positives += 1;
}
}
let fpr = false_positives as f64 / test_count as f64;
assert!(
fpr < 0.02,
"false positive rate {fpr:.4} exceeds 2% threshold"
);
}
#[test]
fn bloom_merge_union() {
let mut bloom_a = BloomFilter::new(100, 0.01);
let mut bloom_b = BloomFilter::new(100, 0.01);
let h1 = make_hash(1);
let h2 = make_hash(2);
let h3 = make_hash(3);
bloom_a.insert(&h1);
bloom_a.insert(&h2);
bloom_b.insert(&h2);
bloom_b.insert(&h3);
bloom_a.merge(&bloom_b);
assert!(bloom_a.contains(&h1));
assert!(bloom_a.contains(&h2));
assert!(bloom_a.contains(&h3));
}
#[test]
fn bloom_serialization_roundtrip() {
let mut bloom = BloomFilter::new(100, 0.01);
let h1 = make_hash(1);
let h2 = make_hash(2);
bloom.insert(&h1);
bloom.insert(&h2);
let bytes = bloom.to_bytes();
let restored = BloomFilter::from_bytes(&bytes).unwrap();
assert!(restored.contains(&h1));
assert!(restored.contains(&h2));
assert!(!restored.contains(&make_hash(3)));
assert_eq!(restored.count(), 2);
assert_eq!(restored.num_bits, bloom.num_bits);
assert_eq!(restored.num_hashes, bloom.num_hashes);
}
#[test]
fn bloom_count_tracks_inserts() {
let mut bloom = BloomFilter::new(100, 0.01);
assert_eq!(bloom.count(), 0);
bloom.insert(&make_hash(1));
assert_eq!(bloom.count(), 1);
bloom.insert(&make_hash(2));
assert_eq!(bloom.count(), 2);
}
}