use serde::{Deserialize, Serialize};
pub(crate) const CURRENT_HASH_VERSION: u8 = 2;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct BloomFilter {
bits: Vec<u64>,
num_hashes: u8,
#[serde(default)]
hash_version: u8,
}
impl BloomFilter {
pub fn with_bits_per_key(expected_keys: usize, bits_per_key: usize) -> Self {
let num_hashes = bits_per_key.saturating_mul(69).div_ceil(100).min(30) as u8;
let num_hashes = num_hashes.max(2);
let num_bits = (expected_keys * bits_per_key).max(64);
let num_words = num_bits.div_ceil(64);
Self {
bits: vec![0u64; num_words],
num_hashes,
hash_version: CURRENT_HASH_VERSION,
}
}
pub fn from_keys_with_bits(keys: &[Vec<u8>], bits_per_key: usize) -> Self {
let mut bf = Self::with_bits_per_key(keys.len(), bits_per_key);
for key in keys {
bf.insert(key);
}
bf
}
pub fn hash_version(&self) -> u8 {
self.hash_version
}
#[allow(dead_code)]
pub fn mark_current(&mut self) {
self.hash_version = CURRENT_HASH_VERSION;
}
fn hash_pair(data: &[u8]) -> (u64, u64) {
const SEED_A: u64 = 0x5365_7244_424c_4f4f;
const SEED_B: u64 = 0x5744_425f_464c_4f57;
let hash1 = fxhash64(data, SEED_A ^ SEED_B);
let hash2 = fxhash64(data, !SEED_A ^ SEED_B);
(hash1, hash2)
}
pub fn insert(&mut self, key: &[u8]) {
let (h1, h2) = Self::hash_pair(key);
let num_bits = self.bits.len() as u64 * 64;
for i in 0..self.num_hashes {
let hash = h1.wrapping_add((i as u64).wrapping_mul(h2));
let bit_pos = hash % num_bits;
let word_idx = (bit_pos / 64) as usize;
let bit_idx = bit_pos % 64;
self.bits[word_idx] |= 1u64 << bit_idx;
}
}
pub fn may_contain(&self, key: &[u8]) -> bool {
if self.hash_version != CURRENT_HASH_VERSION {
return true;
}
let (h1, h2) = Self::hash_pair(key);
let num_bits = self.bits.len() as u64 * 64;
for i in 0..self.num_hashes {
let hash = h1.wrapping_add((i as u64).wrapping_mul(h2));
let bit_pos = hash % num_bits;
let word_idx = (bit_pos / 64) as usize;
let bit_idx = bit_pos % 64;
if self.bits[word_idx] & (1u64 << bit_idx) == 0 {
return false;
}
}
true
}
}
fn fxhash64(data: &[u8], seed: u64) -> u64 {
const MULT: u64 = 0x517c_c1b7_2722_0a95;
let mut hash = seed;
let mut i = 0;
while i + 8 <= data.len() {
let chunk: [u8; 8] = data[i..i + 8].try_into().unwrap();
hash = hash.rotate_left(5) ^ u64::from_le_bytes(chunk);
hash = hash.wrapping_mul(MULT);
i += 8;
}
if i < data.len() {
let mut buf = [0u8; 8];
buf[..data.len() - i].copy_from_slice(&data[i..]);
hash = hash.rotate_left(5) ^ u64::from_le_bytes(buf);
hash = hash.wrapping_mul(MULT);
}
hash
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bloom_no_false_negatives() {
let keys: Vec<Vec<u8>> = (0..1000)
.map(|i| format!("key_{:06}", i).into_bytes())
.collect();
let bf = BloomFilter::from_keys_with_bits(&keys, 10);
for key in &keys {
assert!(bf.may_contain(key), "false negative");
}
}
#[test]
fn test_bloom_serialization() {
let keys: Vec<Vec<u8>> = (0..100).map(|i| format!("k{}", i).into_bytes()).collect();
let bf = BloomFilter::from_keys_with_bits(&keys, 10);
let json = serde_json::to_string(&bf).unwrap();
let bf2: BloomFilter = serde_json::from_str(&json).unwrap();
for key in &keys {
assert!(bf2.may_contain(key));
}
}
#[test]
fn test_bloom_version_is_current_after_build() {
let bf = BloomFilter::from_keys_with_bits(&[b"k1".to_vec()], 10);
assert_eq!(bf.hash_version(), CURRENT_HASH_VERSION);
}
#[test]
fn test_bloom_legacy_version_loads_as_zero() {
let json = r#"{"bits":[18446744073709551615],"num_hashes":2}"#;
let bf: BloomFilter = serde_json::from_str(json).unwrap();
assert_eq!(bf.hash_version(), 0);
}
#[test]
fn test_bloom_stale_version_returns_true_safely() {
let mut bf = BloomFilter::with_bits_per_key(10, 10);
bf.hash_version = 0;
assert!(bf.may_contain(b"never-inserted"));
}
#[test]
fn test_bloom_mark_current_after_rebuild() {
let keys: Vec<Vec<u8>> = (0..50).map(|i| format!("k{}", i).into_bytes()).collect();
let mut bf = BloomFilter::from_keys_with_bits(&keys, 10);
bf.hash_version = 0;
bf.mark_current();
assert_eq!(bf.hash_version(), CURRENT_HASH_VERSION);
for k in &keys {
assert!(bf.may_contain(k));
}
}
#[test]
fn test_bloom_false_positive_rate_is_bounded() {
let keys: Vec<Vec<u8>> = (0..10000)
.map(|i| format!("key_{:06}", i).into_bytes())
.collect();
let bf = BloomFilter::from_keys_with_bits(&keys, 10);
let mut fp = 0usize;
let test_count = 10000;
for i in 0..test_count {
let key = format!("other_{:06}", i).into_bytes();
if bf.may_contain(&key) {
fp += 1;
}
}
let fp_rate = fp as f64 / test_count as f64;
assert!(fp_rate < 0.05, "false positive rate too high: {}", fp_rate);
assert!(
fp_rate < 0.03,
"expected improved FPR with optimal num_hashes, got {}",
fp_rate
);
}
}