use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use twox_hash::XxHash64;
#[derive(Debug)]
pub struct BloomFilter {
bits: Vec<bool>,
num_hashes: usize,
count: usize,
}
impl BloomFilter {
pub fn new(expected_elements: usize, false_positive_rate: f64) -> Self {
let num_bits = (-(expected_elements as f64) * false_positive_rate.ln()
/ (2.0_f64.ln().powi(2)))
.ceil() as usize;
let num_hashes =
((num_bits as f64 / expected_elements as f64) * 2.0_f64.ln()).ceil() as usize;
let num_hashes = num_hashes.max(1);
Self {
bits: vec![false; num_bits],
num_hashes,
count: 0,
}
}
pub fn insert<T: Hash + ?Sized>(&mut self, item: &T) {
for i in 0..self.num_hashes {
let hash = self.hash(item, i);
let index = (hash % self.bits.len() as u64) as usize;
self.bits[index] = true;
}
self.count += 1;
}
#[inline]
pub fn contains<T: Hash + ?Sized>(&self, item: &T) -> bool {
for i in 0..self.num_hashes {
let hash = self.hash(item, i);
let index = (hash % self.bits.len() as u64) as usize;
if !self.bits[index] {
return false; }
}
true }
pub fn len(&self) -> usize {
self.count
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn size_bytes(&self) -> usize {
self.bits.len() + std::mem::size_of::<Self>()
}
#[allow(dead_code)]
pub fn false_positive_rate(&self) -> f64 {
if self.count == 0 {
return 0.0;
}
let k = self.num_hashes as f64;
let n = self.count as f64;
let m = self.bits.len() as f64;
(1.0 - (-k * n / m).exp()).powf(k)
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.extend_from_slice(&(self.bits.len() as u64).to_le_bytes());
bytes.extend_from_slice(&(self.num_hashes as u32).to_le_bytes());
bytes.extend_from_slice(&(self.count as u64).to_le_bytes());
let mut packed_bits = Vec::new();
for chunk in self.bits.chunks(8) {
let mut byte = 0u8;
for (i, &bit) in chunk.iter().enumerate() {
if bit {
byte |= 1 << i;
}
}
packed_bits.push(byte);
}
bytes.extend_from_slice(&packed_bits);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 20 {
return None; }
let num_bits = u64::from_le_bytes(bytes[0..8].try_into().ok()?) as usize;
let num_hashes = u32::from_le_bytes(bytes[8..12].try_into().ok()?) as usize;
let count = u64::from_le_bytes(bytes[12..20].try_into().ok()?) as usize;
let packed_bits = &bytes[20..];
let mut bits = Vec::with_capacity(num_bits);
for byte in packed_bits {
for i in 0..8 {
if bits.len() < num_bits {
bits.push((byte & (1 << i)) != 0);
}
}
}
if bits.len() != num_bits {
return None;
}
Some(Self {
bits,
num_hashes,
count,
})
}
#[inline]
fn hash<T: Hash + ?Sized>(&self, item: &T, seed: usize) -> u64 {
if seed.is_multiple_of(2) {
let mut hasher = DefaultHasher::new();
seed.hash(&mut hasher);
item.hash(&mut hasher);
hasher.finish()
} else {
let mut hasher = XxHash64::with_seed(seed as u64);
item.hash(&mut hasher);
hasher.finish()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_operations() {
let mut bf = BloomFilter::new(1000, 0.01);
bf.insert(&"hello");
bf.insert(&"world");
bf.insert(&42);
assert!(bf.contains(&"hello"));
assert!(bf.contains(&"world"));
assert!(bf.contains(&42));
assert!(!bf.contains(&"not_there"));
}
#[test]
fn test_false_positive_rate() {
let mut bf = BloomFilter::new(1000, 0.01);
for i in 0..1000 {
bf.insert(&i);
}
let mut false_positives = 0;
for i in 1000..2000 {
if bf.contains(&i) {
false_positives += 1;
}
}
let actual_fpr = false_positives as f64 / 1000.0;
println!("Target FPR: 0.01, Actual FPR: {}", actual_fpr);
assert!(actual_fpr < 0.02, "FPR too high: {}", actual_fpr);
}
#[test]
fn test_size_calculation() {
let bf = BloomFilter::new(10000, 0.01);
let size = bf.size_bytes();
println!("Bloom filter size for 10k elements: {} bytes", size);
assert!(size > 0);
}
#[test]
fn test_serialization() {
let mut bf = BloomFilter::new(1000, 0.01);
bf.insert(&"key1");
bf.insert(&"key2");
bf.insert(&"key3");
let bytes = bf.to_bytes();
let bf2 = BloomFilter::from_bytes(&bytes).unwrap();
assert!(bf2.contains(&"key1"));
assert!(bf2.contains(&"key2"));
assert!(bf2.contains(&"key3"));
assert!(!bf2.contains(&"key4"));
assert_eq!(bf2.len(), 3);
}
}