use alloc::vec;
use alloc::vec::Vec;
use core::hash::{Hash, Hasher};
use std::hash::DefaultHasher;
use crate::error::{RcfError, RcfResult};
pub const DEFAULT_FALSE_POSITIVE_RATE: f64 = 0.01;
pub const MAX_HASHES: u32 = 64;
pub const MAX_NUM_BITS: usize = 1 << 30;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(try_from = "BloomFilterShadow"))]
pub struct BloomFilter {
bits: Vec<u64>,
num_bits: usize,
num_hashes: u32,
total_added: u64,
}
#[cfg(feature = "serde")]
#[derive(serde::Serialize, serde::Deserialize)]
#[allow(clippy::missing_docs_in_private_items)]
struct BloomFilterShadow {
bits: Vec<u64>,
num_bits: usize,
num_hashes: u32,
total_added: u64,
}
#[cfg(feature = "serde")]
impl TryFrom<BloomFilterShadow> for BloomFilter {
type Error = RcfError;
fn try_from(raw: BloomFilterShadow) -> Result<Self, Self::Error> {
if raw.num_bits == 0 || raw.num_bits > MAX_NUM_BITS {
return Err(RcfError::InvalidConfig(
alloc::format!(
"BloomFilter: num_bits {} out of (0, {MAX_NUM_BITS}]",
raw.num_bits
)
.into(),
));
}
if raw.num_hashes == 0 || raw.num_hashes > MAX_HASHES {
return Err(RcfError::InvalidConfig(
alloc::format!(
"BloomFilter: num_hashes {} out of (0, {MAX_HASHES}]",
raw.num_hashes
)
.into(),
));
}
let expected_words = raw.num_bits.div_ceil(64);
if raw.bits.len() != expected_words {
return Err(RcfError::InvalidConfig(
alloc::format!(
"BloomFilter: bit-bank length {} != expected {expected_words} for num_bits {}",
raw.bits.len(),
raw.num_bits
)
.into(),
));
}
Ok(Self {
bits: raw.bits,
num_bits: raw.num_bits,
num_hashes: raw.num_hashes,
total_added: raw.total_added,
})
}
}
impl BloomFilter {
pub fn new(capacity: usize, fpr: f64) -> RcfResult<Self> {
if capacity == 0 {
return Err(RcfError::InvalidConfig(
alloc::string::ToString::to_string("BloomFilter: capacity must be > 0").into(),
));
}
if !fpr.is_finite() || fpr <= 0.0 || fpr >= 1.0 {
return Err(RcfError::InvalidConfig(
alloc::format!("BloomFilter: fpr {fpr} must be in (0, 1)").into(),
));
}
let ln2 = core::f64::consts::LN_2;
#[allow(clippy::cast_precision_loss)]
let n_f = capacity as f64;
let m_f = (-n_f * fpr.ln() / (ln2 * ln2)).ceil().max(1.0);
let k_f = ((m_f / n_f) * ln2).round().max(1.0);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let num_bits = m_f as usize;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let num_hashes = k_f as u32;
Self::with_params(num_bits, num_hashes)
}
pub fn with_capacity(capacity: usize) -> RcfResult<Self> {
Self::new(capacity, DEFAULT_FALSE_POSITIVE_RATE)
}
pub fn with_params(num_bits: usize, num_hashes: u32) -> RcfResult<Self> {
if num_bits == 0 || num_bits > MAX_NUM_BITS {
return Err(RcfError::InvalidConfig(
alloc::format!("BloomFilter: num_bits {num_bits} out of (0, {MAX_NUM_BITS}]")
.into(),
));
}
if num_hashes == 0 || num_hashes > MAX_HASHES {
return Err(RcfError::InvalidConfig(
alloc::format!("BloomFilter: num_hashes {num_hashes} out of (0, {MAX_HASHES}]")
.into(),
));
}
let words = num_bits.div_ceil(64);
Ok(Self {
bits: vec![0_u64; words],
num_bits,
num_hashes,
total_added: 0,
})
}
#[must_use]
pub fn num_bits(&self) -> usize {
self.num_bits
}
#[must_use]
pub fn num_hashes(&self) -> u32 {
self.num_hashes
}
#[must_use]
pub fn total_added(&self) -> u64 {
self.total_added
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
self.bits.len() * core::mem::size_of::<u64>()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.total_added == 0
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn effective_fpr(&self) -> f64 {
let n = self.total_added as f64;
let m = self.num_bits as f64;
let k = f64::from(self.num_hashes);
(1.0 - (-k * n / m).exp()).powf(k)
}
#[inline]
pub fn insert<T: Hash + ?Sized>(&mut self, value: &T) {
let (h1, h2) = double_hash(value);
self.insert_hash(h1, h2);
}
#[inline]
pub fn insert_bytes(&mut self, key: &[u8]) {
let (h1, h2) = double_hash(key);
self.insert_hash(h1, h2);
}
#[inline]
pub fn insert_hash(&mut self, h1: u64, h2: u64) {
self.total_added = self.total_added.saturating_add(1);
for i in 0..self.num_hashes {
let idx = self.combined_index(h1, h2, i);
self.set_bit(idx);
}
}
#[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
#[inline]
pub fn contains<T: Hash + ?Sized>(&self, value: &T) -> bool {
let (h1, h2) = double_hash(value);
self.contains_hash(h1, h2)
}
#[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
#[inline]
pub fn contains_bytes(&self, key: &[u8]) -> bool {
let (h1, h2) = double_hash(key);
self.contains_hash(h1, h2)
}
#[must_use = "detector output should be checked — dropping it silently usually indicates a logic bug"]
#[inline]
pub fn contains_hash(&self, h1: u64, h2: u64) -> bool {
for i in 0..self.num_hashes {
let idx = self.combined_index(h1, h2, i);
if !self.get_bit(idx) {
return false;
}
}
true
}
pub fn union(&mut self, other: &Self) -> RcfResult<()> {
if self.num_bits != other.num_bits || self.num_hashes != other.num_hashes {
return Err(RcfError::InvalidConfig(
alloc::format!(
"BloomFilter::union: shape mismatch ({}/{} vs {}/{})",
self.num_bits,
self.num_hashes,
other.num_bits,
other.num_hashes,
)
.into(),
));
}
for (a, b) in self.bits.iter_mut().zip(other.bits.iter()) {
*a |= *b;
}
self.total_added = self.total_added.saturating_add(other.total_added);
Ok(())
}
pub fn reset(&mut self) {
for w in &mut self.bits {
*w = 0;
}
self.total_added = 0;
}
#[inline]
fn combined_index(&self, h1: u64, h2: u64, i: u32) -> usize {
let combined = h1.wrapping_add(u64::from(i).wrapping_mul(h2));
#[allow(clippy::cast_possible_truncation)]
let modded = (combined % (self.num_bits as u64)) as usize;
modded
}
#[inline]
fn set_bit(&mut self, idx: usize) {
let (w, b) = (idx >> 6, idx & 63);
self.bits[w] |= 1_u64 << b;
}
#[inline]
fn get_bit(&self, idx: usize) -> bool {
let (w, b) = (idx >> 6, idx & 63);
(self.bits[w] >> b) & 1 == 1
}
}
fn double_hash<T: Hash + ?Sized>(value: &T) -> (u64, u64) {
let mut h = DefaultHasher::new();
value.hash(&mut h);
let full = h.finish();
let h1 = full;
let h2 = full.rotate_left(32).wrapping_mul(0x9E37_79B9_7F4A_7C15);
(h1, h2)
}
#[cfg(test)]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
mod tests {
use super::*;
#[test]
fn new_rejects_invalid_params() {
assert!(BloomFilter::new(0, 0.01).is_err());
assert!(BloomFilter::new(1_000, 0.0).is_err());
assert!(BloomFilter::new(1_000, 1.0).is_err());
assert!(BloomFilter::new(1_000, f64::NAN).is_err());
assert!(BloomFilter::new(1_000, -0.1).is_err());
}
#[test]
fn with_params_rejects_zero_and_oversized_k() {
assert!(BloomFilter::with_params(0, 4).is_err());
assert!(BloomFilter::with_params(1_024, 0).is_err());
assert!(BloomFilter::with_params(1_024, MAX_HASHES + 1).is_err());
}
#[test]
fn with_params_rejects_oversized_num_bits() {
assert!(BloomFilter::with_params(MAX_NUM_BITS + 1, 4).is_err());
assert!(BloomFilter::with_params(usize::MAX, 4).is_err());
}
#[test]
fn sizing_matches_optimal_formulas() {
let bf = BloomFilter::new(10_000, 0.01).unwrap();
assert!((95_000..=96_000).contains(&bf.num_bits()));
assert_eq!(bf.num_hashes(), 7);
}
#[test]
fn no_false_negatives_on_inserted_keys() {
let mut bf = BloomFilter::new(1_000, 0.01).unwrap();
for i in 0..1_000_u32 {
bf.insert_bytes(&i.to_le_bytes());
}
for i in 0..1_000_u32 {
assert!(bf.contains_bytes(&i.to_le_bytes()));
}
}
#[test]
fn contains_before_insert_is_false() {
let bf = BloomFilter::new(1_000, 0.01).unwrap();
assert!(!bf.contains_bytes(b"never-inserted"));
}
#[test]
fn false_positive_rate_within_budget() {
let target = 0.01_f64;
let mut bf = BloomFilter::new(10_000, target).unwrap();
for i in 0..10_000_u32 {
bf.insert_bytes(&i.to_le_bytes());
}
let mut hits = 0_u32;
for i in 10_000_u32..20_000 {
if bf.contains_bytes(&i.to_le_bytes()) {
hits += 1;
}
}
let fpr = f64::from(hits) / 10_000.0;
assert!(fpr < target * 3.0, "fpr={fpr}");
}
#[test]
fn union_matches_either_insert() {
let mut a = BloomFilter::new(1_000, 0.01).unwrap();
let mut b = BloomFilter::new(1_000, 0.01).unwrap();
a.insert_bytes(b"alpha");
b.insert_bytes(b"beta");
a.union(&b).unwrap();
assert!(a.contains_bytes(b"alpha"));
assert!(a.contains_bytes(b"beta"));
}
#[test]
fn union_rejects_shape_mismatch() {
let mut a = BloomFilter::new(1_000, 0.01).unwrap();
let b = BloomFilter::new(2_000, 0.01).unwrap();
assert!(a.union(&b).is_err());
}
#[test]
fn reset_clears_bits_but_keeps_capacity() {
let mut bf = BloomFilter::new(1_000, 0.01).unwrap();
for i in 0..100_u32 {
bf.insert_bytes(&i.to_le_bytes());
}
bf.reset();
assert!(bf.is_empty());
assert!(!bf.contains_bytes(&0_u32.to_le_bytes()));
bf.insert_bytes(b"fresh");
assert!(bf.contains_bytes(b"fresh"));
}
#[test]
fn generic_hash_and_byte_paths_agree() {
let mut a = BloomFilter::new(1_000, 0.01).unwrap();
let mut b = BloomFilter::new(1_000, 0.01).unwrap();
let key = b"same-key";
a.insert(&key.as_slice());
b.insert_bytes(key);
assert_eq!(a.bits, b.bits);
}
#[test]
fn effective_fpr_grows_with_load() {
let mut bf = BloomFilter::new(1_000, 0.01).unwrap();
let empty = bf.effective_fpr();
for i in 0..500_u32 {
bf.insert_bytes(&i.to_le_bytes());
}
let half = bf.effective_fpr();
for i in 500..1_000_u32 {
bf.insert_bytes(&i.to_le_bytes());
}
let full = bf.effective_fpr();
assert!(empty < half && half < full);
assert!(full < 0.015); }
#[cfg(all(feature = "serde", feature = "postcard"))]
#[test]
fn postcard_roundtrip_preserves_membership() {
let mut bf = BloomFilter::new(1_000, 0.01).unwrap();
for i in 0..500_u32 {
bf.insert_bytes(&i.to_le_bytes());
}
let bytes = postcard::to_allocvec(&bf).expect("serde ok");
let back: BloomFilter = postcard::from_bytes(&bytes).expect("serde ok");
for i in 0..500_u32 {
assert!(back.contains_bytes(&i.to_le_bytes()));
}
assert_eq!(bf.num_bits(), back.num_bits());
assert_eq!(bf.num_hashes(), back.num_hashes());
}
#[cfg(all(feature = "serde", feature = "postcard"))]
#[test]
fn deserialize_rejects_oversized_num_hashes() {
let mut bf = BloomFilter::with_params(1_024, 7).unwrap();
bf.insert_bytes(b"x");
let bad = BloomFilterShadow {
bits: bf.bits.clone(),
num_bits: bf.num_bits,
num_hashes: MAX_HASHES + 1,
total_added: bf.total_added,
};
let bytes = postcard::to_allocvec(&bad).unwrap();
let back: Result<BloomFilter, _> = postcard::from_bytes(&bytes);
assert!(back.is_err(), "oversized num_hashes must be rejected");
}
#[cfg(all(feature = "serde", feature = "postcard"))]
#[test]
fn deserialize_rejects_bit_bank_length_mismatch() {
let bf = BloomFilter::with_params(1_024, 4).unwrap();
let bad = BloomFilterShadow {
bits: vec![0_u64; 3], num_bits: bf.num_bits,
num_hashes: bf.num_hashes,
total_added: 0,
};
let bytes = postcard::to_allocvec(&bad).unwrap();
let back: Result<BloomFilter, _> = postcard::from_bytes(&bytes);
assert!(back.is_err(), "bit-bank length mismatch must be rejected");
}
#[cfg(all(feature = "serde", feature = "postcard"))]
#[test]
fn deserialize_rejects_zero_num_bits() {
let bad = BloomFilterShadow {
bits: alloc::vec::Vec::new(),
num_bits: 0,
num_hashes: 4,
total_added: 0,
};
let bytes = postcard::to_allocvec(&bad).unwrap();
let back: Result<BloomFilter, _> = postcard::from_bytes(&bytes);
assert!(back.is_err(), "zero num_bits must be rejected");
}
}