use std::hash::Hash;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::{Arc, RwLock};
use crate::hashing::HashGenerator;
#[derive(Clone)]
pub struct ConcurrentBloomFilter<F> {
inner: Arc<RwLock<F>>,
}
impl<F> ConcurrentBloomFilter<F> {
pub fn new(filter: F) -> Self {
Self {
inner: Arc::new(RwLock::new(filter)),
}
}
pub fn ref_count(&self) -> usize {
Arc::strong_count(&self.inner)
}
pub fn read<R, OP: FnOnce(&F) -> R>(&self, op: OP) -> R {
let guard = self.inner.read().expect("RwLock poisoned");
op(&*guard)
}
pub fn write<R, OP: FnOnce(&mut F) -> R>(&self, op: OP) -> R {
let mut guard = self.inner.write().expect("RwLock poisoned");
op(&mut *guard)
}
pub fn into_inner(self) -> Option<F> {
Arc::try_unwrap(self.inner)
.ok()
.map(|rw| rw.into_inner().unwrap())
}
}
pub struct AtomicBloomFilter {
bytes: Vec<AtomicU8>,
size: usize,
hashes: usize,
}
impl AtomicBloomFilter {
pub fn new(size: usize, hashes: usize) -> Self {
assert!(size > 0, "AtomicBloomFilter size must be > 0");
assert!(hashes > 0, "Number of hash functions must be > 0");
let byte_count = size.div_ceil(8);
let bytes = (0..byte_count).map(|_| AtomicU8::new(0)).collect();
Self {
bytes,
size,
hashes,
}
}
pub fn from_config(config: &crate::BloomConfig) -> Self {
let (size, hashes) = config.parameters();
Self::new(size, hashes)
}
#[inline]
pub fn insert<T: Hash>(&self, item: &T) {
let gen = HashGenerator::new(item);
for i in 0..self.hashes {
let bit_idx = (gen.nth(i) % self.size as u64) as usize;
let byte_idx = bit_idx / 8;
let bit_mask = 1u8 << (bit_idx % 8);
self.bytes[byte_idx].fetch_or(bit_mask, Ordering::Relaxed);
}
}
#[inline]
pub fn contains<T: Hash>(&self, item: &T) -> bool {
let gen = HashGenerator::new(item);
for i in 0..self.hashes {
let bit_idx = (gen.nth(i) % self.size as u64) as usize;
let byte_idx = bit_idx / 8;
let bit_mask = 1u8 << (bit_idx % 8);
if self.bytes[byte_idx].load(Ordering::Relaxed) & bit_mask == 0 {
return false;
}
}
true
}
pub fn fill_ratio(&self) -> f64 {
let set: u64 = self
.bytes
.iter()
.map(|b| b.load(Ordering::Relaxed).count_ones() as u64)
.sum();
let total_bits = self.bytes.len() * 8;
set as f64 / total_bits as f64
}
pub fn size(&self) -> usize {
self.size
}
pub fn hashes(&self) -> usize {
self.hashes
}
}