use rustc_hash::FxHasher;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone)]
#[cfg_attr(
any(feature = "serialization", feature = "serde"),
derive(serde::Serialize, serde::Deserialize)
)]
pub struct BloomFilter {
bits: Vec<u64>, bit_count: usize,
hash_count: usize,
}
impl BloomFilter {
pub fn new(expected_elements: usize) -> Self {
let bit_count = expected_elements.saturating_mul(10).max(64);
let chunk_count = (bit_count + 63) / 64;
BloomFilter {
bits: vec![0u64; chunk_count],
bit_count: chunk_count * 64,
hash_count: 3,
}
}
pub fn with_params(bit_count: usize, hash_count: usize) -> Self {
let bit_count = bit_count.max(64);
let chunk_count = (bit_count + 63) / 64;
BloomFilter {
bits: vec![0u64; chunk_count],
bit_count: chunk_count * 64,
hash_count: hash_count.max(1),
}
}
#[inline]
pub fn insert(&mut self, term: &str) {
self.insert_bytes(term.as_bytes());
}
#[inline]
pub fn insert_bytes(&mut self, bytes: &[u8]) {
for i in 0..self.hash_count {
let hash = self.hash_with_seed(bytes, i as u64);
let bit_index = (hash % self.bit_count as u64) as usize;
let chunk_index = bit_index / 64;
let bit_offset = bit_index % 64;
self.bits[chunk_index] |= 1u64 << bit_offset;
}
}
#[inline]
pub fn might_contain(&self, term: &str) -> bool {
self.might_contain_bytes(term.as_bytes())
}
#[inline]
pub fn might_contain_bytes(&self, bytes: &[u8]) -> bool {
for i in 0..self.hash_count {
let hash = self.hash_with_seed(bytes, i as u64);
let bit_index = (hash % self.bit_count as u64) as usize;
let chunk_index = bit_index / 64;
let bit_offset = bit_index % 64;
if (self.bits[chunk_index] & (1u64 << bit_offset)) == 0 {
return false; }
}
true }
pub fn clear(&mut self) {
self.bits.fill(0);
}
pub fn capacity(&self) -> usize {
self.bit_count
}
pub fn hash_count(&self) -> usize {
self.hash_count
}
#[inline]
fn hash_with_seed(&self, bytes: &[u8], seed: u64) -> u64 {
let mut hasher = FxHasher::default();
seed.hash(&mut hasher);
bytes.hash(&mut hasher);
hasher.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bloom_filter_basic() {
let mut bloom = BloomFilter::new(100);
bloom.insert("hello");
bloom.insert("world");
bloom.insert("test");
assert!(bloom.might_contain("hello"));
assert!(bloom.might_contain("world"));
assert!(bloom.might_contain("test"));
}
#[test]
fn test_bloom_filter_no_false_negatives() {
let mut bloom = BloomFilter::new(1000);
let terms: Vec<String> = (0..100).map(|i| format!("term{}", i)).collect();
for term in &terms {
bloom.insert(term);
}
for term in &terms {
assert!(bloom.might_contain(term), "False negative for: {}", term);
}
}
#[test]
fn test_bloom_filter_clear() {
let mut bloom = BloomFilter::new(100);
bloom.insert("hello");
assert!(bloom.might_contain("hello"));
bloom.clear();
let all_zeros = bloom.bits.iter().all(|&chunk| chunk == 0);
assert!(all_zeros, "Bloom filter not fully cleared");
}
#[test]
fn test_bloom_filter_bytes() {
let mut bloom = BloomFilter::new(100);
bloom.insert_bytes(&[0x10, 0x20, 0x30]);
assert!(bloom.might_contain_bytes(&[0x10, 0x20, 0x30]));
}
#[test]
fn test_bloom_filter_custom_params() {
let bloom = BloomFilter::with_params(256, 5);
assert_eq!(bloom.capacity(), 256);
assert_eq!(bloom.hash_count(), 5);
}
}