use crate::error::{CoreError, CoreResult, ErrorContext};
use crate::error_context;
use super::DoubleHasher;
#[derive(Clone)]
pub struct BloomFilter {
bits: Vec<u64>,
num_bits: usize,
num_hashes: u32,
count: usize,
hasher: DoubleHasher,
}
impl BloomFilter {
pub fn new(num_bits: usize, num_hashes: u32) -> CoreResult<Self> {
if num_bits == 0 {
return Err(CoreError::InvalidArgument(
error_context!("num_bits must be > 0"),
));
}
if num_hashes == 0 {
return Err(CoreError::InvalidArgument(
error_context!("num_hashes must be > 0"),
));
}
let n_words = (num_bits + 63) / 64;
Ok(Self {
bits: vec![0u64; n_words],
num_bits,
num_hashes,
count: 0,
hasher: DoubleHasher::new(),
})
}
pub fn with_rate(expected_items: usize, fpr: f64) -> CoreResult<Self> {
if expected_items == 0 {
return Err(CoreError::InvalidArgument(
error_context!("expected_items must be > 0"),
));
}
if fpr <= 0.0 || fpr >= 1.0 {
return Err(CoreError::InvalidArgument(
error_context!("fpr must be in (0, 1)"),
));
}
let (m, k) = optimal_params(expected_items, fpr);
Self::new(m, k)
}
pub fn insert(&mut self, item: &[u8]) {
let (h1, h2) = self.hasher.hash_pair(item);
for i in 0..self.num_hashes {
let pos = DoubleHasher::position(h1, h2, i, self.num_bits);
self.set_bit(pos);
}
self.count += 1;
}
pub fn contains(&self, item: &[u8]) -> bool {
let (h1, h2) = self.hasher.hash_pair(item);
for i in 0..self.num_hashes {
let pos = DoubleHasher::position(h1, h2, i, self.num_bits);
if !self.get_bit(pos) {
return false;
}
}
true
}
pub fn union(&self, other: &BloomFilter) -> CoreResult<BloomFilter> {
if self.num_bits != other.num_bits || self.num_hashes != other.num_hashes {
return Err(CoreError::DimensionError(
error_context!("Bloom filters must have the same num_bits and num_hashes for union"),
));
}
let bits: Vec<u64> = self
.bits
.iter()
.zip(other.bits.iter())
.map(|(a, b)| a | b)
.collect();
Ok(BloomFilter {
bits,
num_bits: self.num_bits,
num_hashes: self.num_hashes,
count: self.count + other.count, hasher: self.hasher.clone(),
})
}
pub fn intersection_estimate(&self, other: &BloomFilter) -> CoreResult<f64> {
if self.num_bits != other.num_bits || self.num_hashes != other.num_hashes {
return Err(CoreError::DimensionError(
error_context!(
"Bloom filters must have the same num_bits and num_hashes for intersection estimate"
),
));
}
let m = self.num_bits as f64;
let k = self.num_hashes as f64;
let bits_a = self.count_set_bits() as f64;
let bits_b = other.count_set_bits() as f64;
let bits_union: usize = self
.bits
.iter()
.zip(other.bits.iter())
.map(|(a, b)| (a | b).count_ones() as usize)
.sum();
let bits_ab = bits_union as f64;
let est = |x: f64| -> f64 {
let ratio = x / m;
if ratio >= 1.0 {
return m; }
-m / k * (1.0 - ratio).ln()
};
let est_a = est(bits_a);
let est_b = est(bits_b);
let est_ab = est(bits_ab);
let intersection = est_a + est_b - est_ab;
Ok(intersection.max(0.0))
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn num_bits(&self) -> usize {
self.num_bits
}
pub fn num_hashes(&self) -> u32 {
self.num_hashes
}
pub fn estimated_fpr(&self) -> f64 {
let set_bits = self.count_set_bits() as f64;
let fill_ratio = set_bits / self.num_bits as f64;
fill_ratio.powi(self.num_hashes as i32)
}
pub fn empty_clone(&self) -> Self {
let n_words = (self.num_bits + 63) / 64;
Self {
bits: vec![0u64; n_words],
num_bits: self.num_bits,
num_hashes: self.num_hashes,
count: 0,
hasher: self.hasher.clone(),
}
}
pub fn clear(&mut self) {
for word in &mut self.bits {
*word = 0;
}
self.count = 0;
}
#[inline]
fn set_bit(&mut self, pos: usize) {
let word = pos / 64;
let bit = pos % 64;
self.bits[word] |= 1u64 << bit;
}
#[inline]
fn get_bit(&self, pos: usize) -> bool {
let word = pos / 64;
let bit = pos % 64;
(self.bits[word] >> bit) & 1 == 1
}
fn count_set_bits(&self) -> usize {
self.bits.iter().map(|w| w.count_ones() as usize).sum()
}
}
impl std::fmt::Debug for BloomFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BloomFilter")
.field("num_bits", &self.num_bits)
.field("num_hashes", &self.num_hashes)
.field("count", &self.count)
.field("set_bits", &self.count_set_bits())
.finish()
}
}
#[derive(Clone)]
pub struct CountingBloomFilter {
counters: Vec<u8>,
num_buckets: usize,
num_hashes: u32,
count: usize,
hasher: DoubleHasher,
}
impl CountingBloomFilter {
pub fn new(num_buckets: usize, num_hashes: u32) -> CoreResult<Self> {
if num_buckets == 0 {
return Err(CoreError::InvalidArgument(
error_context!("num_buckets must be > 0"),
));
}
if num_hashes == 0 {
return Err(CoreError::InvalidArgument(
error_context!("num_hashes must be > 0"),
));
}
let n_bytes = (num_buckets + 1) / 2;
Ok(Self {
counters: vec![0u8; n_bytes],
num_buckets,
num_hashes,
count: 0,
hasher: DoubleHasher::new(),
})
}
pub fn with_rate(expected_items: usize, fpr: f64) -> CoreResult<Self> {
if expected_items == 0 {
return Err(CoreError::InvalidArgument(
error_context!("expected_items must be > 0"),
));
}
if fpr <= 0.0 || fpr >= 1.0 {
return Err(CoreError::InvalidArgument(
error_context!("fpr must be in (0, 1)"),
));
}
let (m, k) = optimal_params(expected_items, fpr);
Self::new(m, k)
}
pub fn insert(&mut self, item: &[u8]) {
let (h1, h2) = self.hasher.hash_pair(item);
for i in 0..self.num_hashes {
let pos = DoubleHasher::position(h1, h2, i, self.num_buckets);
self.increment_counter(pos);
}
self.count += 1;
}
pub fn remove(&mut self, item: &[u8]) {
let (h1, h2) = self.hasher.hash_pair(item);
for i in 0..self.num_hashes {
let pos = DoubleHasher::position(h1, h2, i, self.num_buckets);
self.decrement_counter(pos);
}
self.count = self.count.saturating_sub(1);
}
pub fn contains(&self, item: &[u8]) -> bool {
let (h1, h2) = self.hasher.hash_pair(item);
for i in 0..self.num_hashes {
let pos = DoubleHasher::position(h1, h2, i, self.num_buckets);
if self.get_counter(pos) == 0 {
return false;
}
}
true
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn num_buckets(&self) -> usize {
self.num_buckets
}
pub fn num_hashes(&self) -> u32 {
self.num_hashes
}
pub fn clear(&mut self) {
for b in &mut self.counters {
*b = 0;
}
self.count = 0;
}
#[inline]
fn get_counter(&self, pos: usize) -> u8 {
let byte_idx = pos / 2;
if pos % 2 == 0 {
self.counters[byte_idx] & 0x0F
} else {
(self.counters[byte_idx] >> 4) & 0x0F
}
}
#[inline]
fn increment_counter(&mut self, pos: usize) {
let byte_idx = pos / 2;
let current = self.get_counter(pos);
if current < 15 {
if pos % 2 == 0 {
self.counters[byte_idx] = (self.counters[byte_idx] & 0xF0) | (current + 1);
} else {
self.counters[byte_idx] =
(self.counters[byte_idx] & 0x0F) | ((current + 1) << 4);
}
}
}
#[inline]
fn decrement_counter(&mut self, pos: usize) {
let byte_idx = pos / 2;
let current = self.get_counter(pos);
if current > 0 {
if pos % 2 == 0 {
self.counters[byte_idx] = (self.counters[byte_idx] & 0xF0) | (current - 1);
} else {
self.counters[byte_idx] =
(self.counters[byte_idx] & 0x0F) | ((current - 1) << 4);
}
}
}
}
impl std::fmt::Debug for CountingBloomFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CountingBloomFilter")
.field("num_buckets", &self.num_buckets)
.field("num_hashes", &self.num_hashes)
.field("count", &self.count)
.finish()
}
}
#[derive(Clone)]
pub struct ScalableBloomFilter {
slices: Vec<BloomFilter>,
target_fpr: f64,
slice_capacity: usize,
ratio: f64,
total_count: usize,
}
impl ScalableBloomFilter {
pub fn new(target_fpr: f64, initial_capacity: usize, ratio: f64) -> CoreResult<Self> {
if target_fpr <= 0.0 || target_fpr >= 1.0 {
return Err(CoreError::InvalidArgument(
error_context!("target_fpr must be in (0, 1)"),
));
}
if initial_capacity == 0 {
return Err(CoreError::InvalidArgument(
error_context!("initial_capacity must be > 0"),
));
}
if ratio <= 0.0 || ratio >= 1.0 {
return Err(CoreError::InvalidArgument(
error_context!("ratio must be in (0, 1)"),
));
}
let fpr0 = target_fpr * (1.0 - ratio);
let first_slice = BloomFilter::with_rate(initial_capacity, fpr0)?;
Ok(Self {
slices: vec![first_slice],
target_fpr,
slice_capacity: initial_capacity,
ratio,
total_count: 0,
})
}
pub fn insert(&mut self, item: &[u8]) {
let last_idx = self.slices.len() - 1;
if self.slices[last_idx].len() >= self.slice_capacity {
let slice_idx = self.slices.len();
let slice_fpr =
self.target_fpr * (1.0 - self.ratio) * self.ratio.powi(slice_idx as i32);
let clamped_fpr = slice_fpr.max(1e-15);
if let Ok(new_slice) = BloomFilter::with_rate(self.slice_capacity, clamped_fpr) {
self.slices.push(new_slice);
}
}
if let Some(last) = self.slices.last_mut() {
last.insert(item);
}
self.total_count += 1;
}
pub fn contains(&self, item: &[u8]) -> bool {
self.slices.iter().any(|s| s.contains(item))
}
pub fn len(&self) -> usize {
self.total_count
}
pub fn is_empty(&self) -> bool {
self.total_count == 0
}
pub fn num_slices(&self) -> usize {
self.slices.len()
}
pub fn clear(&mut self) {
for s in &mut self.slices {
s.clear();
}
self.total_count = 0;
}
}
impl std::fmt::Debug for ScalableBloomFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScalableBloomFilter")
.field("num_slices", &self.slices.len())
.field("total_count", &self.total_count)
.field("target_fpr", &self.target_fpr)
.finish()
}
}
fn optimal_params(n: usize, p: f64) -> (usize, u32) {
let n_f = n as f64;
let ln2 = std::f64::consts::LN_2;
let m = (-n_f * p.ln() / (ln2 * ln2)).ceil() as usize;
let m = m.max(1);
let k = ((m as f64 / n_f) * ln2).ceil() as u32;
let k = k.max(1);
(m, k)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bloom_no_false_negatives() {
let mut bf = BloomFilter::with_rate(1000, 0.01).expect("valid params");
let items: Vec<Vec<u8>> = (0..500u64).map(|i| i.to_le_bytes().to_vec()).collect();
for item in &items {
bf.insert(item);
}
for item in &items {
assert!(bf.contains(item), "False negative detected for {:?}", item);
}
}
#[test]
fn test_bloom_fpr_within_bounds() {
let n = 10_000usize;
let target_fpr = 0.05;
let mut bf = BloomFilter::with_rate(n, target_fpr).expect("valid params");
for i in 0..n as u64 {
bf.insert(&i.to_le_bytes());
}
let test_count = 10_000usize;
let mut false_positives = 0usize;
for i in (n as u64)..(n as u64 + test_count as u64) {
if bf.contains(&i.to_le_bytes()) {
false_positives += 1;
}
}
let observed_fpr = false_positives as f64 / test_count as f64;
assert!(
observed_fpr < target_fpr * 2.0,
"FPR too high: {observed_fpr} (target: {target_fpr})"
);
}
#[test]
fn test_bloom_union() {
let mut bf1 = BloomFilter::new(1000, 5).expect("valid");
let mut bf2 = bf1.empty_clone();
bf1.insert(b"alpha");
bf2.insert(b"beta");
let combined = bf1.union(&bf2).expect("compatible");
assert!(combined.contains(b"alpha"));
assert!(combined.contains(b"beta"));
}
#[test]
fn test_bloom_intersection_estimate() {
let mut bf1 = BloomFilter::with_rate(1000, 0.01).expect("valid");
let mut bf2 = bf1.empty_clone();
for i in 0..200u64 {
bf1.insert(&i.to_le_bytes());
}
for i in 100..300u64 {
bf2.insert(&i.to_le_bytes());
}
let est = bf1.intersection_estimate(&bf2).expect("compatible");
assert!(est > 20.0, "Intersection estimate too low: {est}");
assert!(est < 250.0, "Intersection estimate too high: {est}");
}
#[test]
fn test_bloom_empty() {
let bf = BloomFilter::with_rate(100, 0.01).expect("valid");
assert!(bf.is_empty());
assert_eq!(bf.len(), 0);
assert!(!bf.contains(b"anything"));
}
#[test]
fn test_bloom_clear() {
let mut bf = BloomFilter::with_rate(100, 0.01).expect("valid");
bf.insert(b"hello");
assert!(bf.contains(b"hello"));
bf.clear();
assert!(bf.is_empty());
assert!(!bf.contains(b"hello"));
}
#[test]
fn test_bloom_invalid_params() {
assert!(BloomFilter::new(0, 5).is_err());
assert!(BloomFilter::new(100, 0).is_err());
assert!(BloomFilter::with_rate(0, 0.01).is_err());
assert!(BloomFilter::with_rate(100, 0.0).is_err());
assert!(BloomFilter::with_rate(100, 1.0).is_err());
assert!(BloomFilter::with_rate(100, -0.5).is_err());
}
#[test]
fn test_counting_bloom_insert_remove_roundtrip() {
let mut cbf = CountingBloomFilter::with_rate(1000, 0.01).expect("valid");
cbf.insert(b"hello");
assert!(cbf.contains(b"hello"));
cbf.remove(b"hello");
assert!(!cbf.contains(b"hello"));
}
#[test]
fn test_counting_bloom_no_false_negatives() {
let mut cbf = CountingBloomFilter::with_rate(1000, 0.01).expect("valid");
for i in 0..500u64 {
cbf.insert(&i.to_le_bytes());
}
for i in 0..500u64 {
assert!(cbf.contains(&i.to_le_bytes()));
}
}
#[test]
fn test_counting_bloom_counter_overflow() {
let mut cbf = CountingBloomFilter::new(100, 3).expect("valid");
for _ in 0..20 {
cbf.insert(b"overflow_test");
}
assert!(cbf.contains(b"overflow_test"));
for _ in 0..20 {
cbf.remove(b"overflow_test");
}
assert!(!cbf.contains(b"overflow_test"));
}
#[test]
fn test_counting_bloom_multiple_items() {
let mut cbf = CountingBloomFilter::with_rate(1000, 0.01).expect("valid");
cbf.insert(b"apple");
cbf.insert(b"banana");
cbf.insert(b"cherry");
assert!(cbf.contains(b"apple"));
assert!(cbf.contains(b"banana"));
assert!(cbf.contains(b"cherry"));
cbf.remove(b"banana");
assert!(cbf.contains(b"apple"));
assert!(!cbf.contains(b"banana"));
assert!(cbf.contains(b"cherry"));
}
#[test]
fn test_counting_bloom_clear() {
let mut cbf = CountingBloomFilter::new(100, 3).expect("valid");
cbf.insert(b"data");
cbf.clear();
assert!(cbf.is_empty());
assert!(!cbf.contains(b"data"));
}
#[test]
fn test_scalable_bloom_no_false_negatives() {
let mut sbf = ScalableBloomFilter::new(0.01, 500, 0.5).expect("valid");
for i in 0..2000u64 {
sbf.insert(&i.to_le_bytes());
}
for i in 0..2000u64 {
assert!(
sbf.contains(&i.to_le_bytes()),
"False negative at {i}"
);
}
}
#[test]
fn test_scalable_bloom_grows() {
let mut sbf = ScalableBloomFilter::new(0.01, 100, 0.5).expect("valid");
assert_eq!(sbf.num_slices(), 1);
for i in 0..500u64 {
sbf.insert(&i.to_le_bytes());
}
assert!(sbf.num_slices() > 1, "Expected growth, got {} slices", sbf.num_slices());
}
#[test]
fn test_scalable_bloom_fpr_reasonable() {
let mut sbf = ScalableBloomFilter::new(0.05, 1000, 0.5).expect("valid");
for i in 0..1000u64 {
sbf.insert(&i.to_le_bytes());
}
let test_count = 10_000usize;
let mut fp = 0usize;
for i in 1000u64..(1000 + test_count as u64) {
if sbf.contains(&i.to_le_bytes()) {
fp += 1;
}
}
let observed = fp as f64 / test_count as f64;
assert!(
observed < 0.15,
"Scalable bloom FPR too high: {observed}"
);
}
#[test]
fn test_scalable_bloom_invalid_params() {
assert!(ScalableBloomFilter::new(0.0, 100, 0.5).is_err());
assert!(ScalableBloomFilter::new(0.01, 0, 0.5).is_err());
assert!(ScalableBloomFilter::new(0.01, 100, 0.0).is_err());
assert!(ScalableBloomFilter::new(0.01, 100, 1.0).is_err());
}
#[test]
fn test_scalable_bloom_single_element() {
let mut sbf = ScalableBloomFilter::new(0.01, 100, 0.5).expect("valid");
sbf.insert(b"only_one");
assert!(sbf.contains(b"only_one"));
assert_eq!(sbf.len(), 1);
}
}