use std::fmt;
use serde::{Deserialize, Deserializer, Serialize};
use xxhash_rust::xxh3::xxh3_128_with_seed;
const BLOOM_HASH_SEED: u64 = 0xB100_F1AC_DEAD_CAFE;
const MAX_K: u32 = 32;
const MIN_LEN_BITS: usize = 64;
#[derive(Clone, PartialEq, Eq)]
pub struct BloomFilter {
len_bits: usize,
k: u32,
bits: Vec<u64>,
}
impl fmt::Debug for BloomFilter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let popcount: u32 = self.bits.iter().map(|w| w.count_ones()).sum();
f.debug_struct("BloomFilter")
.field("len_bits", &self.len_bits)
.field("k", &self.k)
.field("popcount", &popcount)
.field(
"fill_ratio",
&(popcount as f64 / self.len_bits.max(1) as f64),
)
.finish()
}
}
fn round_up_to_64(n: usize) -> usize {
n.saturating_add(63) & !63
}
impl BloomFilter {
pub fn new(expected_items: usize, false_positive_rate: f64) -> Self {
let n = expected_items.max(1) as f64;
let p = false_positive_rate.clamp(1e-9, 0.5);
let ln2 = std::f64::consts::LN_2;
let m_float = (-n * p.ln() / (ln2 * ln2)).ceil();
let m_raw = if m_float.is_finite() && m_float > 0.0 {
m_float as usize
} else {
MIN_LEN_BITS
};
let len_bits = round_up_to_64(m_raw).max(MIN_LEN_BITS);
let k_float = ((len_bits as f64 / n) * ln2).round();
let k = (k_float as u32).clamp(1, MAX_K);
Self::with_params(len_bits, k)
}
pub fn with_params(len_bits: usize, k: u32) -> Self {
let len_bits = round_up_to_64(len_bits).max(MIN_LEN_BITS);
let k = k.clamp(1, MAX_K);
let words = len_bits / 64;
Self {
len_bits,
k,
bits: vec![0u64; words],
}
}
pub fn insert(&mut self, key: &[u8]) {
let h128 = xxh3_128_with_seed(key, BLOOM_HASH_SEED);
let h1 = (h128 >> 64) as u64;
let h2 = (h128 as u64) | 1;
let m = self.len_bits as u64;
for i in 0..self.k {
let combined = h1.wrapping_add((i as u64).wrapping_mul(h2));
let idx = (combined % m) as usize;
let word = idx / 64;
let bit = idx % 64;
self.bits[word] |= 1u64 << bit;
}
}
pub fn contains(&self, key: &[u8]) -> bool {
let h128 = xxh3_128_with_seed(key, BLOOM_HASH_SEED);
let h1 = (h128 >> 64) as u64;
let h2 = (h128 as u64) | 1;
let m = self.len_bits as u64;
for i in 0..self.k {
let combined = h1.wrapping_add((i as u64).wrapping_mul(h2));
let idx = (combined % m) as usize;
let word = idx / 64;
let bit = idx % 64;
if self.bits[word] & (1u64 << bit) == 0 {
return false;
}
}
true
}
pub fn len_bits(&self) -> usize {
self.len_bits
}
pub fn hash_count(&self) -> u32 {
self.k
}
pub fn popcount(&self) -> u32 {
self.bits.iter().map(|w| w.count_ones()).sum()
}
pub fn estimated_false_positive_rate(&self) -> f64 {
let fill = self.popcount() as f64 / self.len_bits.max(1) as f64;
fill.powi(self.k as i32)
}
pub fn serialized_bytes(&self) -> usize {
8 + 4 + self.bits.len() * 8
}
}
#[derive(Serialize, Deserialize)]
struct BloomFilterWire {
len_bits: usize,
k: u32,
bits: Vec<u64>,
}
impl Serialize for BloomFilter {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
BloomFilterWire {
len_bits: self.len_bits,
k: self.k,
bits: self.bits.clone(),
}
.serialize(s)
}
}
impl<'de> Deserialize<'de> for BloomFilter {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let wire = BloomFilterWire::deserialize(d)?;
if wire.len_bits != wire.bits.len() * 64 {
return Err(serde::de::Error::custom(format!(
"bloom filter wire mismatch: len_bits={}, bits.len()*64={}",
wire.len_bits,
wire.bits.len() * 64,
)));
}
if wire.len_bits < MIN_LEN_BITS {
return Err(serde::de::Error::custom(format!(
"bloom filter len_bits={} below minimum {}",
wire.len_bits, MIN_LEN_BITS,
)));
}
if wire.k < 1 || wire.k > MAX_K {
return Err(serde::de::Error::custom(format!(
"bloom filter k={} out of range [1, {}]",
wire.k, MAX_K,
)));
}
Ok(Self {
len_bits: wire.len_bits,
k: wire.k,
bits: wire.bits,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn inserted_keys_always_pass_membership() {
let mut bf = BloomFilter::new(1000, 0.01);
for i in 0..1000u32 {
bf.insert(&i.to_le_bytes());
}
for i in 0..1000u32 {
assert!(bf.contains(&i.to_le_bytes()), "false negative on key {i}");
}
}
#[test]
fn empty_filter_rejects_all_keys() {
let bf = BloomFilter::new(100, 0.01);
assert!(!bf.contains(b"foo"));
assert!(!bf.contains(b"bar"));
assert!(!bf.contains(&[0u8; 32]));
assert_eq!(bf.popcount(), 0);
}
#[test]
fn empirical_false_positive_rate_near_target() {
let n = 10_000usize;
let p = 0.01;
let mut bf = BloomFilter::new(n, p);
for i in 0..n {
bf.insert(format!("key-in-{i}").as_bytes());
}
let mut fp = 0;
for i in 0..n {
if bf.contains(format!("probe-out-{i}").as_bytes()) {
fp += 1;
}
}
let observed = fp as f64 / n as f64;
assert!(
(0.005..=0.025).contains(&observed),
"observed FP rate {observed:.4} far from target {p}; \
popcount={}, fill_ratio={:.3}",
bf.popcount(),
bf.popcount() as f64 / bf.len_bits() as f64,
);
}
#[test]
fn ten_thousand_at_one_percent_under_500kb() {
let bf = BloomFilter::new(10_000, 0.01);
assert!(
bf.serialized_bytes() <= 500 * 1024,
"10K @ 1% sizing budget breached: {} bytes (target ≤ 500 KB)",
bf.serialized_bytes(),
);
}
#[test]
fn insert_is_idempotent() {
let mut bf = BloomFilter::new(100, 0.01);
bf.insert(b"hello");
let pop_after_first = bf.popcount();
bf.insert(b"hello");
assert_eq!(
bf.popcount(),
pop_after_first,
"re-insert moved bits — should be a no-op",
);
}
#[test]
fn sizing_degeneracies_clamp_to_safe_defaults() {
let bf0 = BloomFilter::new(0, 0.01);
assert!(bf0.len_bits() >= MIN_LEN_BITS);
assert!(bf0.hash_count() >= 1);
let bf_tiny = BloomFilter::new(100, 1e-12);
assert!(bf_tiny.len_bits() >= MIN_LEN_BITS);
assert!(bf_tiny.hash_count() <= MAX_K);
let bf_loose = BloomFilter::new(100, 0.99);
assert!(bf_loose.len_bits() >= MIN_LEN_BITS);
}
#[test]
fn round_up_to_64_saturates_at_usize_max() {
assert_eq!(round_up_to_64(0), 0);
assert_eq!(round_up_to_64(1), 64);
assert_eq!(round_up_to_64(64), 64);
assert_eq!(round_up_to_64(65), 128);
assert_eq!(round_up_to_64(usize::MAX), usize::MAX & !63);
assert_eq!(round_up_to_64(usize::MAX - 62), usize::MAX & !63);
assert_eq!(round_up_to_64(usize::MAX & !63), usize::MAX & !63,);
}
#[test]
fn serde_round_trip_preserves_membership() {
let mut bf = BloomFilter::new(500, 0.01);
for i in 0..500u32 {
bf.insert(&i.to_le_bytes());
}
let json = serde_json::to_string(&bf).unwrap();
let restored: BloomFilter = serde_json::from_str(&json).unwrap();
assert_eq!(bf, restored);
for i in 0..500u32 {
assert!(restored.contains(&i.to_le_bytes()));
}
}
#[test]
fn serde_rejects_mismatched_len_bits() {
let bad = serde_json::json!({
"len_bits": 128,
"k": 7,
"bits": [0u64], });
let result: Result<BloomFilter, _> = serde_json::from_value(bad);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("len_bits") || err.contains("bits"),
"error message must reference the mismatch: {err}",
);
}
#[test]
fn serde_rejects_zero_or_excessive_k() {
let zero_k = serde_json::json!({
"len_bits": 64,
"k": 0,
"bits": [0u64],
});
assert!(serde_json::from_value::<BloomFilter>(zero_k).is_err());
let huge_k = serde_json::json!({
"len_bits": 64,
"k": 999,
"bits": [0u64],
});
assert!(serde_json::from_value::<BloomFilter>(huge_k).is_err());
}
#[test]
fn debug_output_is_compact() {
let mut bf = BloomFilter::new(10_000, 0.01);
bf.insert(b"some-key");
let s = format!("{bf:?}");
assert!(s.contains("BloomFilter"));
assert!(s.contains("len_bits"));
assert!(s.contains("popcount"));
assert!(s.len() < 200, "debug output too long: {} chars", s.len());
}
}