#[cfg(all(test, feature = "arbitrary"))]
mod conformance;
use crate::{sha256::Sha256, Hasher};
use bytes::{Buf, BufMut};
use commonware_codec::{
codec::{Read, Write},
error::Error as CodecError,
EncodeSize, FixedSize,
};
use commonware_utils::bitmap::BitMap;
use core::{
marker::PhantomData,
num::{NonZeroU64, NonZeroU8, NonZeroUsize},
};
#[cfg(feature = "std")]
use {
commonware_utils::rational::BigRationalExt,
num_rational::BigRational,
num_traits::{One, ToPrimitive, Zero},
};
#[cfg(feature = "std")]
const LN2: (u64, u64) = (14397, 20769);
#[cfg(feature = "std")]
const LN2_INV: (u64, u64) = (29145, 20201);
#[derive(Clone, Debug)]
pub struct BloomFilter<H: Hasher = Sha256> {
hashers: u8,
bits: BitMap,
_marker: PhantomData<H>,
}
impl<H: Hasher> PartialEq for BloomFilter<H> {
fn eq(&self, other: &Self) -> bool {
self.hashers == other.hashers && self.bits == other.bits
}
}
impl<H: Hasher> Eq for BloomFilter<H> {}
impl<H: Hasher> BloomFilter<H> {
const _ASSERT_DIGEST_AT_LEAST_16_BYTES: () = assert!(
<H::Digest as FixedSize>::SIZE >= 16,
"digest must be at least 128 bits (16 bytes)"
);
pub fn new(hashers: NonZeroU8, bits: NonZeroUsize) -> Self {
let bits = bits
.get()
.checked_next_power_of_two()
.unwrap_or(1 << (usize::BITS - 1));
Self {
hashers: hashers.get(),
bits: BitMap::zeroes(bits as u64),
_marker: PhantomData,
}
}
#[cfg(feature = "std")]
pub fn with_rate(expected_items: NonZeroUsize, fp_rate: BigRational) -> Self {
let bits = Self::optimal_bits(expected_items.get(), &fp_rate);
let hashers = Self::optimal_hashers(expected_items.get(), bits);
Self {
hashers,
bits: BitMap::zeroes(bits as u64),
_marker: PhantomData,
}
}
pub const fn hashers(&self) -> NonZeroU8 {
NonZeroU8::new(self.hashers).expect("hashers is never zero")
}
pub const fn bits(&self) -> NonZeroUsize {
NonZeroUsize::new(self.bits.len() as usize).expect("bits is never zero")
}
fn indices(&self, item: &[u8]) -> impl Iterator<Item = u64> {
#[allow(path_statements)]
Self::_ASSERT_DIGEST_AT_LEAST_16_BYTES;
let digest = H::hash(item);
let h1 = u64::from_be_bytes(digest[0..8].try_into().unwrap());
let mut h2 = u64::from_be_bytes(digest[8..16].try_into().unwrap());
h2 |= 1;
let hashers = self.hashers as u64;
let mask = self.bits.len() - 1;
(0..hashers).map(move |hasher| h1.wrapping_add(hasher.wrapping_mul(h2)) & mask)
}
pub fn insert(&mut self, item: &[u8]) {
let indices = self.indices(item);
for index in indices {
self.bits.set(index, true);
}
}
pub fn contains(&self, item: &[u8]) -> bool {
let indices = self.indices(item);
for index in indices {
if !self.bits.get(index) {
return false;
}
}
true
}
#[cfg(feature = "std")]
pub fn estimated_false_positive_rate(&self) -> BigRational {
let ones = self.bits.count_ones();
let len = self.bits.len();
let fill_ratio = BigRational::new(ones.into(), len.into());
fill_ratio.pow(self.hashers as i32)
}
#[cfg(feature = "std")]
pub fn estimated_count(&self) -> BigRational {
let m = self.bits.len();
let x = self.bits.count_ones();
let k = self.hashers as u64;
if x >= m {
return BigRational::from_usize(usize::MAX);
}
let one_minus_fill = BigRational::new((m - x).into(), m.into());
let log2_val = one_minus_fill.log2_floor(16);
let ln2 = BigRational::from_frac_u64(LN2.0, LN2.1);
let ln_result = &log2_val * &ln2;
let m_over_k = BigRational::new(m.into(), k.into());
-m_over_k * ln_result
}
#[cfg(feature = "std")]
pub fn optimal_hashers(expected_items: usize, bits: usize) -> u8 {
if expected_items == 0 {
return 1;
}
let ln2 = BigRational::from_frac_u64(LN2.0, LN2.1);
let k_ratio = BigRational::from_usize(bits) * ln2 / BigRational::from_usize(expected_items);
k_ratio.to_integer().to_u8().unwrap_or(16).clamp(1, 16)
}
#[cfg(feature = "std")]
pub fn optimal_bits(expected_items: usize, fp_rate: &BigRational) -> usize {
assert!(
fp_rate > &BigRational::zero() && fp_rate < &BigRational::one(),
"false positive rate must be in (0, 1)"
);
let log2_p = fp_rate.log2_floor(16);
let n = BigRational::from_usize(expected_items);
let ln2_inv = BigRational::from_frac_u64(LN2_INV.0, LN2_INV.1);
let bits_rational = -(&n * &log2_p * &ln2_inv);
let raw = bits_rational.ceil_to_u128().unwrap_or(1) as usize;
raw.max(1)
.checked_next_power_of_two()
.unwrap_or(1 << (usize::BITS - 1))
}
}
impl<H: Hasher> Write for BloomFilter<H> {
fn write(&self, buf: &mut impl BufMut) {
self.hashers.write(buf);
self.bits.write(buf);
}
}
impl<H: Hasher> Read for BloomFilter<H> {
type Cfg = (NonZeroU8, NonZeroU64);
fn read_cfg(
buf: &mut impl Buf,
(hashers_cfg, bits_cfg): &Self::Cfg,
) -> Result<Self, CodecError> {
if !bits_cfg.get().is_power_of_two() {
return Err(CodecError::Invalid(
"BloomFilter",
"bits must be a power of 2",
));
}
let hashers = u8::read_cfg(buf, &())?;
if hashers != hashers_cfg.get() {
return Err(CodecError::Invalid(
"BloomFilter",
"hashers doesn't match config",
));
}
let bits = BitMap::read_cfg(buf, &bits_cfg.get())?;
if bits.len() != bits_cfg.get() {
return Err(CodecError::Invalid(
"BloomFilter",
"bitmap length doesn't match config",
));
}
Ok(Self {
hashers,
bits,
_marker: PhantomData,
})
}
}
impl<H: Hasher> EncodeSize for BloomFilter<H> {
fn encode_size(&self) -> usize {
self.hashers.encode_size() + self.bits.encode_size()
}
}
#[cfg(feature = "arbitrary")]
impl<H: Hasher> arbitrary::Arbitrary<'_> for BloomFilter<H> {
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
let hashers = u8::arbitrary(u)?.max(1);
let bits_len = u.int_in_range(0..=u16::MAX as u64)?.next_power_of_two();
let mut bits = BitMap::with_capacity(bits_len);
for _ in 0..bits_len {
bits.push(u.arbitrary::<bool>()?);
}
Ok(Self {
hashers,
bits,
_marker: PhantomData,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_codec::{Decode, Encode};
use commonware_utils::{NZUsize, NZU64, NZU8};
#[test]
fn test_insert_and_contains() {
let mut bf = BloomFilter::<Sha256>::new(NZU8!(10), NZUsize!(1000));
let item1 = b"hello";
let item2 = b"world";
let item3 = b"bloomfilter";
bf.insert(item1);
bf.insert(item2);
assert!(bf.contains(item1));
assert!(bf.contains(item2));
assert!(!bf.contains(item3));
}
#[test]
fn test_empty() {
let bf = BloomFilter::<Sha256>::new(NZU8!(5), NZUsize!(100));
assert!(!bf.contains(b"anything"));
}
#[test]
fn test_false_positives() {
let mut bf = BloomFilter::<Sha256>::new(NZU8!(10), NZUsize!(100));
for i in 0..10usize {
bf.insert(&i.to_be_bytes());
}
for i in 0..10usize {
assert!(bf.contains(&i.to_be_bytes()));
}
let mut false_positives = 0;
for i in 100..1100usize {
if bf.contains(&i.to_be_bytes()) {
false_positives += 1;
}
}
assert!(false_positives > 0);
assert!(false_positives < 1000);
}
#[test]
fn test_codec_roundtrip() {
let mut bf = BloomFilter::<Sha256>::new(NZU8!(5), NZUsize!(128));
bf.insert(b"test1");
bf.insert(b"test2");
let cfg = (NZU8!(5), NZU64!(128));
let encoded = bf.encode();
let decoded = BloomFilter::<Sha256>::decode_cfg(encoded, &cfg).unwrap();
assert_eq!(bf, decoded);
}
#[test]
fn test_codec_empty() {
let bf = BloomFilter::<Sha256>::new(NZU8!(4), NZUsize!(128));
let cfg = (NZU8!(4), NZU64!(128));
let encoded = bf.encode();
let decoded = BloomFilter::<Sha256>::decode_cfg(encoded, &cfg).unwrap();
assert_eq!(bf, decoded);
}
#[test]
fn test_codec_with_invalid_hashers() {
let mut bf = BloomFilter::<Sha256>::new(NZU8!(5), NZUsize!(128));
bf.insert(b"test1");
let encoded = bf.encode();
let cfg = (NZU8!(10), NZU64!(128));
let decoded = BloomFilter::<Sha256>::decode_cfg(encoded.clone(), &cfg);
assert!(matches!(
decoded,
Err(CodecError::Invalid(
"BloomFilter",
"hashers doesn't match config"
))
));
let cfg = (NZU8!(4), NZU64!(128));
let decoded = BloomFilter::<Sha256>::decode_cfg(encoded, &cfg);
assert!(matches!(
decoded,
Err(CodecError::Invalid(
"BloomFilter",
"hashers doesn't match config"
))
));
}
#[test]
fn test_codec_with_invalid_bits() {
let mut bf = BloomFilter::<Sha256>::new(NZU8!(5), NZUsize!(128));
bf.insert(b"test1");
let encoded = bf.encode();
let cfg = (NZU8!(5), NZU64!(64));
let result = BloomFilter::<Sha256>::decode_cfg(encoded.clone(), &cfg);
assert!(matches!(result, Err(CodecError::InvalidLength(128))));
let cfg = (NZU8!(5), NZU64!(256));
let result = BloomFilter::<Sha256>::decode_cfg(encoded.clone(), &cfg);
assert!(matches!(
result,
Err(CodecError::Invalid(
"BloomFilter",
"bitmap length doesn't match config"
))
));
let cfg = (NZU8!(5), NZU64!(100));
let result = BloomFilter::<Sha256>::decode_cfg(encoded, &cfg);
assert!(matches!(
result,
Err(CodecError::Invalid(
"BloomFilter",
"bits must be a power of 2"
))
));
}
#[test]
fn test_statistics() {
let mut bf = BloomFilter::<Sha256>::new(NZU8!(7), NZUsize!(1024));
assert_eq!(bf.estimated_count(), BigRational::zero());
assert_eq!(bf.estimated_false_positive_rate(), BigRational::zero());
for i in 0..100usize {
bf.insert(&i.to_be_bytes());
}
let estimated = bf.estimated_count();
let lower = BigRational::from_usize(75);
let upper = BigRational::from_usize(125);
assert!(estimated > lower && estimated < upper);
assert!(bf.estimated_false_positive_rate() > BigRational::zero());
assert!(bf.estimated_false_positive_rate() < BigRational::one());
}
#[test]
fn test_with_rate() {
let fp_rate = BigRational::from_frac_u64(1, 100);
let mut bf = BloomFilter::<Sha256>::with_rate(NZUsize!(1000), fp_rate.clone());
let expected_bits = BloomFilter::<Sha256>::optimal_bits(1000, &fp_rate);
let expected_hashers = BloomFilter::<Sha256>::optimal_hashers(1000, expected_bits);
assert_eq!(bf.bits().get(), expected_bits);
assert_eq!(bf.hashers().get(), expected_hashers);
for i in 0..1000usize {
bf.insert(&i.to_be_bytes());
}
for i in 0..1000usize {
assert!(bf.contains(&i.to_be_bytes()));
}
let mut false_positives = 0;
for i in 1000..2000usize {
if bf.contains(&i.to_be_bytes()) {
false_positives += 1;
}
}
assert!(false_positives < 20);
}
#[test]
fn test_optimal_hashers() {
let k = BloomFilter::<Sha256>::optimal_hashers(1000, 10000);
assert_eq!(k, 6);
let k = BloomFilter::<Sha256>::optimal_hashers(100, 1000);
assert_eq!(k, 6);
let k = BloomFilter::<Sha256>::optimal_hashers(1000, 100);
assert_eq!(k, 1);
let k = BloomFilter::<Sha256>::optimal_hashers(100, 100000);
assert_eq!(k, 16);
let k = BloomFilter::<Sha256>::optimal_hashers(0, 1000);
assert_eq!(k, 1);
let k = BloomFilter::<Sha256>::optimal_hashers(1 << 48, 1000);
assert_eq!(k, 1);
let k = BloomFilter::<Sha256>::optimal_hashers(usize::MAX, usize::MAX);
assert!((1..=16).contains(&k));
}
#[test]
fn test_optimal_bits() {
let fp_1pct = BigRational::from_frac_u64(1, 100);
let bits = BloomFilter::<Sha256>::optimal_bits(1000, &fp_1pct);
assert_eq!(bits, 16384);
assert!(bits.is_power_of_two());
let fp_001pct = BigRational::from_frac_u64(1, 100_000);
let bits_lower_fp = BloomFilter::<Sha256>::optimal_bits(10000, &fp_001pct);
assert_eq!(bits_lower_fp, 262144);
assert!(bits_lower_fp.is_power_of_two());
}
#[test]
fn test_bits_extreme_values() {
let fp_001pct = BigRational::from_frac_u64(1, 10_000);
let fp_1pct = BigRational::from_frac_u64(1, 100);
let bits = BloomFilter::<Sha256>::optimal_bits(usize::MAX / 2, &fp_001pct);
assert!(bits.is_power_of_two());
assert!(bits > 0);
let bits = BloomFilter::<Sha256>::optimal_bits(1_000_000_000, &fp_001pct);
assert!(bits.is_power_of_two());
let bits = BloomFilter::<Sha256>::optimal_bits(0, &fp_1pct);
assert!(bits.is_power_of_two());
assert_eq!(bits, 1); }
#[test]
fn test_with_rate_deterministic() {
let fp_rate = BigRational::from_frac_u64(1, 100);
let bf1 = BloomFilter::<Sha256>::with_rate(NZUsize!(1000), fp_rate.clone());
let bf2 = BloomFilter::<Sha256>::with_rate(NZUsize!(1000), fp_rate);
assert_eq!(bf1.bits(), bf2.bits());
assert_eq!(bf1.hashers(), bf2.hashers());
}
#[test]
fn test_optimal_bits_matches_formula() {
let fp_rate = BigRational::from_frac_u64(1, 100);
let bits = BloomFilter::<Sha256>::optimal_bits(1000, &fp_rate);
assert_eq!(bits, 16384);
}
}