use super::varint::{decode_varint, encode_varint};
use crate::error::{InterpretError, TauqError};
use std::hash::Hasher;
#[derive(Debug, Clone)]
pub struct BloomFilter {
bits: Vec<u8>,
hash_functions: u8,
num_items: u32,
seed: u64,
}
impl BloomFilter {
pub fn new(num_items: u32, false_positive_rate: f32) -> Self {
let ln_p = false_positive_rate.ln();
let ln2_sq = std::f32::consts::LN_2 * std::f32::consts::LN_2;
let m = (-(num_items as f32) * ln_p) / ln2_sq;
let num_bytes = (m as usize).max(64).div_ceil(8); let num_bits = num_bytes * 8;
let k = (num_bits as f32 / (num_items as f32)) * std::f32::consts::LN_2;
let hash_functions = (k.round() as u8).clamp(1, 4);
let seed = {
use std::time::SystemTime;
let t = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0x517cc1b727220a95);
t ^ (num_bytes as u64).wrapping_mul(0x9e3779b97f4a7c15)
};
Self {
bits: vec![0; num_bytes],
hash_functions,
num_items: 0,
seed,
}
}
pub fn from_bytes(bits: Vec<u8>, hash_functions: u8, num_items: u32, seed: u64) -> Self {
Self {
bits,
hash_functions,
num_items,
seed,
}
}
pub fn insert(&mut self, value: &str) {
for i in 0..self.hash_functions {
let hash = self.hash(value, i as u64);
let bit_pos = (hash % ((self.bits.len() as u64) * 8)) as usize;
let byte_idx = bit_pos / 8;
let bit_idx = bit_pos % 8;
self.bits[byte_idx] |= 1 << bit_idx;
}
self.num_items = self.num_items.saturating_add(1);
}
pub fn might_contain(&self, value: &str) -> bool {
for i in 0..self.hash_functions {
let hash = self.hash(value, i as u64);
let bit_pos = (hash % ((self.bits.len() as u64) * 8)) as usize;
let byte_idx = bit_pos / 8;
let bit_idx = bit_pos % 8;
if (self.bits[byte_idx] >> bit_idx) & 1 == 0 {
return false;
}
}
true
}
pub fn encode(&self) -> Vec<u8> {
let mut buffer = Vec::new();
buffer.push(self.hash_functions);
encode_varint(self.num_items as u64, &mut buffer);
buffer.extend_from_slice(&self.seed.to_le_bytes());
encode_varint(self.bits.len() as u64, &mut buffer);
buffer.extend_from_slice(&self.bits);
buffer
}
pub fn decode(bytes: &[u8]) -> Result<(Self, usize), TauqError> {
if bytes.is_empty() {
return Err(TauqError::Interpret(InterpretError::new(
"Cannot decode bloom filter: empty buffer",
)));
}
let mut offset = 0;
let hash_functions = bytes[offset];
offset += 1;
let (num_items, size) = decode_varint(&bytes[offset..])?;
offset += size;
if bytes.len() < offset + 8 {
return Err(TauqError::Interpret(InterpretError::new(
"Not enough bytes to decode bloom filter seed",
)));
}
let seed = u64::from_le_bytes(bytes[offset..offset + 8].try_into().unwrap());
offset += 8;
let (bits_len, size) = decode_varint(&bytes[offset..])?;
offset += size;
let bits_len = bits_len as usize;
if bytes.len() < offset + bits_len {
return Err(TauqError::Interpret(InterpretError::new(
"Not enough bytes to decode bloom filter",
)));
}
let bits = bytes[offset..offset + bits_len].to_vec();
Ok((
Self {
bits,
hash_functions,
num_items: num_items as u32,
seed,
},
offset + bits_len,
))
}
pub fn num_items(&self) -> u32 {
self.num_items
}
fn hash(&self, value: &str, hash_fn_index: u64) -> u64 {
let mut hasher = ahash::AHasher::default();
hasher.write_u64(self.seed.wrapping_add(hash_fn_index));
hasher.write(value.as_bytes());
hasher.finish()
}
}
pub struct BloomFilterBuilder {
items: Vec<String>,
target_fpr: f32,
}
impl BloomFilterBuilder {
pub fn new() -> Self {
Self {
items: Vec::new(),
target_fpr: 0.01, }
}
pub fn with_fpr(mut self, fpr: f32) -> Self {
self.target_fpr = fpr;
self
}
pub fn add_item(&mut self, item: impl Into<String>) {
self.items.push(item.into());
}
pub fn build(self) -> BloomFilter {
let mut filter = BloomFilter::new(self.items.len() as u32, self.target_fpr);
for item in self.items {
filter.insert(&item);
}
filter
}
}
impl Default for BloomFilterBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bloom_filter_insert_and_check() {
let mut filter = BloomFilter::new(100, 0.01);
filter.insert("alice");
filter.insert("bob");
filter.insert("charlie");
assert!(filter.might_contain("alice"));
assert!(filter.might_contain("bob"));
assert!(filter.might_contain("charlie"));
assert!(!filter.might_contain("eve")); }
#[test]
fn test_bloom_filter_false_negatives() {
let mut filter = BloomFilter::new(100, 0.01);
filter.insert("value1");
filter.insert("value2");
filter.insert("value3");
assert!(filter.might_contain("value1"));
assert!(filter.might_contain("value2"));
assert!(filter.might_contain("value3"));
}
#[test]
fn test_bloom_filter_encode_decode() {
let mut filter = BloomFilter::new(100, 0.01);
filter.insert("alice");
filter.insert("bob");
filter.insert("charlie");
let encoded = filter.encode();
let (decoded, _) = BloomFilter::decode(&encoded).unwrap();
assert!(decoded.might_contain("alice"));
assert!(decoded.might_contain("bob"));
assert!(decoded.might_contain("charlie"));
assert_eq!(decoded.num_items, filter.num_items);
}
#[test]
fn test_bloom_filter_builder() {
let mut builder = BloomFilterBuilder::new().with_fpr(0.01);
builder.add_item("alice");
builder.add_item("bob");
builder.add_item("charlie");
let filter = builder.build();
assert!(filter.might_contain("alice"));
assert!(filter.might_contain("bob"));
assert!(filter.might_contain("charlie"));
}
#[test]
fn test_bloom_filter_cardinality() {
let mut filter = BloomFilter::new(1000, 0.01);
for i in 0..100 {
filter.insert(&format!("item{}", i));
}
assert_eq!(filter.num_items, 100);
for i in 0..100 {
assert!(filter.might_contain(&format!("item{}", i)));
}
}
#[test]
fn test_bloom_filter_definitely_not_present() {
let mut filter = BloomFilter::new(10, 0.01);
filter.insert("engineer");
filter.insert("sales");
filter.insert("support");
let negative_test_count = 100;
let mut definitely_absent = 0;
for i in 0..negative_test_count {
if !filter.might_contain(&format!("not_present_{}", i)) {
definitely_absent += 1;
}
}
assert!(definitely_absent > negative_test_count / 2);
}
}