use core::{hash::BuildHasher, marker::PhantomData};
use crate::{
bit_set::BitSet,
hash::{reduce, DefaultHashBuilder, HashPair},
Error,
};
const LN_2: f64 = core::f64::consts::LN_2;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BloomFilter<T: ?Sized, S = DefaultHashBuilder> {
bits: BitSet,
num_hashes: u32,
#[cfg_attr(feature = "serde", serde(skip))]
hasher: S,
#[cfg_attr(feature = "serde", serde(skip))]
_marker: PhantomData<fn(&T)>,
}
impl<T: ?Sized> BloomFilter<T, DefaultHashBuilder> {
pub fn new(capacity: usize, rate: f64) -> Result<Self, Error> {
Self::with_hasher(capacity, rate, DefaultHashBuilder)
}
pub fn with_dimensions(num_bits: u64, num_hashes: u32) -> Result<Self, Error> {
Self::with_dimensions_and_hasher(num_bits, num_hashes, DefaultHashBuilder)
}
}
impl<T: ?Sized, S: BuildHasher> BloomFilter<T, S> {
pub fn with_hasher(capacity: usize, rate: f64, hasher: S) -> Result<Self, Error> {
if capacity == 0 {
return Err(Error::InvalidParameter {
param: "capacity",
reason: "must be greater than zero",
});
}
if !(rate.is_finite() && rate > 0.0 && rate < 1.0) {
return Err(Error::InvalidParameter {
param: "rate",
reason: "must be a finite value in the open interval (0.0, 1.0)",
});
}
let num_bits = optimal_num_bits(capacity, rate);
let num_hashes = optimal_num_hashes(num_bits, capacity);
Self::with_dimensions_and_hasher(num_bits, num_hashes, hasher)
}
pub fn with_dimensions_and_hasher(
num_bits: u64,
num_hashes: u32,
hasher: S,
) -> Result<Self, Error> {
if num_bits == 0 {
return Err(Error::InvalidParameter {
param: "num_bits",
reason: "must be greater than zero",
});
}
if num_hashes == 0 {
return Err(Error::InvalidParameter {
param: "num_hashes",
reason: "must be greater than zero",
});
}
Ok(Self {
bits: BitSet::new(num_bits),
num_hashes,
hasher,
_marker: PhantomData,
})
}
pub fn insert(&mut self, item: &T) -> bool
where
T: core::hash::Hash,
{
let pair = HashPair::new(item, &self.hasher);
let num_bits = self.bits.len();
let mut newly_added = false;
for i in 0..u64::from(self.num_hashes) {
let index = reduce(pair.nth(i), num_bits);
if !self.bits.set(index) {
newly_added = true;
}
}
newly_added
}
#[must_use]
pub fn contains(&self, item: &T) -> bool
where
T: core::hash::Hash,
{
let pair = HashPair::new(item, &self.hasher);
let num_bits = self.bits.len();
(0..u64::from(self.num_hashes)).all(|i| self.bits.get(reduce(pair.nth(i), num_bits)))
}
pub fn clear(&mut self) {
self.bits.clear();
}
#[inline]
#[must_use]
pub fn num_bits(&self) -> u64 {
self.bits.len()
}
#[inline]
#[must_use]
pub fn num_hashes(&self) -> u32 {
self.num_hashes
}
#[inline]
#[must_use]
pub fn count_ones(&self) -> u64 {
self.bits.count_ones()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.bits.count_ones() == 0
}
#[must_use]
pub fn estimated_len(&self) -> u64 {
let m = self.bits.len() as f64;
let k = f64::from(self.num_hashes);
let x = self.bits.count_ones() as f64;
if x == 0.0 {
return 0;
}
if x >= m {
return self.bits.len();
}
let estimate = -(m / k) * libm::log(1.0 - x / m);
libm::round(estimate) as u64
}
#[must_use]
pub fn estimated_false_positive_rate(&self) -> f64 {
let m = self.bits.len() as f64;
let k = f64::from(self.num_hashes);
let fill = self.bits.count_ones() as f64 / m;
libm::pow(fill, k)
}
pub fn merge(&mut self, other: &Self) -> Result<(), Error> {
if self.num_hashes != other.num_hashes || !self.bits.is_compatible(&other.bits) {
return Err(Error::IncompatibleParameters);
}
self.bits.union_with(&other.bits);
Ok(())
}
}
fn optimal_num_bits(capacity: usize, rate: f64) -> u64 {
let n = capacity as f64;
let m = -(n * libm::log(rate)) / (LN_2 * LN_2);
let rounded = libm::ceil(m);
if rounded < 1.0 {
1
} else {
rounded as u64
}
}
fn optimal_num_hashes(num_bits: u64, capacity: usize) -> u32 {
let k = (num_bits as f64 / capacity as f64) * LN_2;
let rounded = libm::round(k);
if rounded < 1.0 {
1
} else {
rounded as u32
}
}
#[cfg(test)]
mod tests {
#![allow(unused_results)]
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_new_rejects_zero_capacity() {
let err = BloomFilter::<&str>::new(0, 0.01).unwrap_err();
assert_eq!(
err,
Error::InvalidParameter {
param: "capacity",
reason: "must be greater than zero"
}
);
}
#[test]
fn test_new_rejects_out_of_range_rate() {
assert!(matches!(
BloomFilter::<&str>::new(10, 0.0),
Err(Error::InvalidParameter { .. })
));
assert!(matches!(
BloomFilter::<&str>::new(10, 1.0),
Err(Error::InvalidParameter { .. })
));
assert!(matches!(
BloomFilter::<&str>::new(10, f64::NAN),
Err(Error::InvalidParameter { .. })
));
}
#[test]
fn test_with_dimensions_rejects_zeros() {
assert!(matches!(
BloomFilter::<u8>::with_dimensions(0, 3),
Err(Error::InvalidParameter { .. })
));
assert!(matches!(
BloomFilter::<u8>::with_dimensions(64, 0),
Err(Error::InvalidParameter { .. })
));
}
#[test]
fn test_no_false_negatives() {
let mut filter = BloomFilter::new(1_000, 0.01).unwrap();
for i in 0..1_000u32 {
filter.insert(&i);
}
for i in 0..1_000u32 {
assert!(filter.contains(&i), "inserted item {i} reported absent");
}
}
#[test]
fn test_insert_reports_novelty() {
let mut filter = BloomFilter::new(100, 0.01).unwrap();
assert!(filter.insert("alpha"));
assert!(!filter.insert("alpha"));
}
#[test]
fn test_false_positive_rate_is_near_target() {
let capacity = 10_000;
let target = 0.01;
let mut filter = BloomFilter::new(capacity, target).unwrap();
for i in 0..capacity as u64 {
filter.insert(&i);
}
let trials = 100_000u64;
let mut hits = 0u64;
for i in capacity as u64..capacity as u64 + trials {
if filter.contains(&i) {
hits += 1;
}
}
let observed = hits as f64 / trials as f64;
assert!(
observed < target * 3.0,
"observed FP rate {observed} far exceeds target {target}"
);
}
#[test]
fn test_clear_empties_filter() {
let mut filter = BloomFilter::new(100, 0.01).unwrap();
filter.insert("x");
assert!(!filter.is_empty());
filter.clear();
assert!(filter.is_empty());
assert!(!filter.contains("x"));
}
#[test]
fn test_merge_unions_membership() {
let mut a = BloomFilter::new(1_000, 0.01).unwrap();
let mut b = BloomFilter::new(1_000, 0.01).unwrap();
a.insert("a");
b.insert("b");
a.merge(&b).unwrap();
assert!(a.contains("a"));
assert!(a.contains("b"));
}
#[test]
fn test_merge_rejects_incompatible() {
let mut a = BloomFilter::<u32>::with_dimensions(1_024, 3).unwrap();
let b = BloomFilter::<u32>::with_dimensions(2_048, 3).unwrap();
assert_eq!(a.merge(&b), Err(Error::IncompatibleParameters));
let c = BloomFilter::<u32>::with_dimensions(1_024, 4).unwrap();
assert_eq!(a.merge(&c), Err(Error::IncompatibleParameters));
}
#[test]
fn test_estimated_len_is_reasonable() {
let mut filter = BloomFilter::new(10_000, 0.01).unwrap();
for i in 0..1_000u32 {
filter.insert(&i);
}
let estimate = filter.estimated_len();
assert!(
(900..=1_100).contains(&estimate),
"estimate {estimate} not within 10% of 1000"
);
}
#[test]
fn test_estimated_len_empty_is_zero() {
let filter = BloomFilter::<u32>::new(1_000, 0.01).unwrap();
assert_eq!(filter.estimated_len(), 0);
}
#[test]
fn test_sizing_formulas() {
let bits = optimal_num_bits(10_000, 0.01);
assert!(
(95_000..=96_500).contains(&bits),
"unexpected bit count {bits}"
);
let k = optimal_num_hashes(bits, 10_000);
assert_eq!(k, 7);
}
}