use core::hash::{BuildHasher, Hash};
use core::marker::PhantomData;
use std::fmt;
#[cfg(feature = "serde")] use serde::{Deserialize, Serializer, Deserializer, ser::SerializeStruct};
#[cfg(feature = "rayon")] use rayon::prelude::*;
use crate::{bitset::BitSet, hashing, math};
#[derive(Clone)]
pub struct BloomFilter<S = std::collections::hash_map::RandomState> {
bits: BitSet,
m: usize, k: u32, items: usize,
hasher_builder: S,
_marker: PhantomData<S>,
}
impl<S> fmt::Debug for BloomFilter<S>
where
S: BuildHasher + Clone,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BloomFilter")
.field("m(bits)", &self.m)
.field("k", &self.k)
.field("items", &self.items)
.finish()
}
}
impl BloomFilter<std::collections::hash_map::RandomState> {
pub fn new(m: usize, k: u32) -> Self {
Self::with_hasher(m, k, std::collections::hash_map::RandomState::new())
}
pub fn new_for_capacity(n: usize, p: f64) -> Self {
let m = math::optimal_m(n, p);
let k = math::optimal_k(m, n);
Self::with_hasher(m, k, std::collections::hash_map::RandomState::new())
}
}
impl<S> BloomFilter<S>
where
S: BuildHasher + Clone,
{
pub fn with_hasher(m: usize, k: u32, hasher_builder: S) -> Self {
assert!(m > 0 && k > 0);
Self {
bits: BitSet::new(m),
m,
k,
items: 0,
hasher_builder,
_marker: PhantomData,
}
}
pub fn insert<T : Hash>(&mut self, item : &T){
let (h1, h2) = hashing::hash2(&self.hasher_builder, item);
for i in 0..self.k{
let combined = h1.wrapping_add((i as u64).wrapping_mul(h2));
let idx = (combined % (self.m as u64)) as usize;
self.bits.set(idx);
}
self.items = self.items.saturating_add(1);
}
#[cfg(feature = "rayon")]
pub fn insert_batch<T>(&mut self, items: impl IntoParallelIterator<Item = T>)
where
T: Hash + Send + Sync,
S: Send + Sync,
{
use std::collections::HashSet;
use std::sync::Mutex;
let indices_set = Mutex::new(HashSet::new());
items.into_par_iter().for_each(|item| {
let (h1, h2) = hashing::hash2(&self.hasher_builder, &item);
let mut local_indices = Vec::with_capacity(self.k as usize);
for i in 0..self.k {
let combined = h1.wrapping_add((i as u64).wrapping_mul(h2));
let idx = (combined % (self.m as u64)) as usize;
local_indices.push(idx);
}
let mut set = indices_set.lock().unwrap();
set.extend(local_indices);
});
let indices = indices_set.into_inner().unwrap();
let count = indices.len();
for idx in indices {
self.bits.set(idx);
}
self.items = self.items.saturating_add(count);
}
pub fn contains<T : Hash>(&self, item : &T) -> bool{
let (h1, h2) = hashing::hash2(&self.hasher_builder, item);
for i in 0..self.k {
let combined = h1.wrapping_add((i as u64).wrapping_mul(h2));
let idx = (combined % (self.m as u64)) as usize;
if !self.bits.get(idx) {
return false;
}
}
true
}
#[cfg(feature = "rayon")]
pub fn contains_all<T>(&self, items: impl IntoParallelIterator<Item = T>) -> bool
where
T: Hash + Send + Sync,
S: Send + Sync,
{
items.into_par_iter().all(|item| self.contains(&item))
}
#[cfg(feature = "rayon")]
pub fn contains_batch<T>(&self, items: impl IntoParallelIterator<Item = T>) -> Vec<bool>
where
T: Hash + Send + Sync,
S: Send + Sync,
{
items.into_par_iter().map(|item| self.contains(&item)).collect()
}
pub fn union_inplace(&mut self, other: &Self) {
assert_eq!(self.m, other.m, "m mismatch for union");
assert_eq!(self.k, other.k, "k mismatch for union");
self.bits.or_with(&other.bits);
}
pub fn intersect_inplace(&mut self, other: &Self) {
assert_eq!(self.m, other.m, "m mismatch for intersection");
assert_eq!(self.k, other.k, "k mismatch for intersection");
self.bits.and_with(&other.bits);
}
pub fn clear(&mut self) {
self.bits.clear();
self.items = 0;
}
pub fn approximate_items(&self) -> usize {
self.items
}
pub fn to_bytes(&self) -> Vec<u8> {
let words = self.bits.words_slice();
let mut out = Vec::with_capacity(words.len() * 8 + 12);
for w in words {
out.extend_from_slice(&w.to_le_bytes());
}
out.extend_from_slice(&(self.m as u64).to_le_bytes());
out.extend_from_slice(&self.k.to_le_bytes());
out
}
pub fn from_bytes_hasher(data: &[u8], hasher_builder: S) -> Option<Self> {
if data.len() < 12 { return None; }
let meta_offset = data.len() - 12;
let mut m_bytes = [0u8; 8];
m_bytes.copy_from_slice(&data[meta_offset..meta_offset+8]);
let m = u64::from_le_bytes(m_bytes) as usize;
let mut k_bytes = [0u8; 4];
k_bytes.copy_from_slice(&data[meta_offset+8..meta_offset+12]);
let k = u32::from_le_bytes(k_bytes);
let words_expected = m.div_ceil(64);
if meta_offset != words_expected * 8 { return None; }
let mut words = Vec::with_capacity(words_expected);
for i in 0..words_expected {
let start = i * 8;
let mut wb = [0u8; 8];
wb.copy_from_slice(&data[start..start+8]);
words.push(u64::from_le_bytes(wb));
}
let mut bitset = BitSet::new(m);
bitset.words_mut().copy_from_slice(&words);
Some(Self {
bits: bitset,
m,
k,
items: 0,
hasher_builder,
_marker: PhantomData,
})
}
pub fn from_bytes(data: &[u8]) -> Option<Self>
where
std::collections::hash_map::RandomState: Clone,
S: From<std::collections::hash_map::RandomState>,
{
let rs = std::collections::hash_map::RandomState::new();
let builder: S = rs.into();
Self::from_bytes_hasher(data, builder)
}
}
#[cfg(feature = "serde")]
impl<S> serde::Serialize for BloomFilter<S>
where S: BuildHasher + Clone + Default {
fn serialize<Se: Serializer>(&self, serializer: Se) -> Result<Se::Ok, Se::Error> {
let mut st = serializer.serialize_struct("BloomFilter", 4)?;
st.serialize_field("m", &self.m)?;
st.serialize_field("k", &self.k)?;
st.serialize_field("items", &self.items)?;
st.serialize_field("words", self.bits.words_slice())?;
st.end()
}
}
#[cfg(feature = "serde")]
impl<'de, S> serde::Deserialize<'de> for BloomFilter<S>
where S: BuildHasher + Clone + Default {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
#[derive(Deserialize)]
struct BFHelper { m: usize, k: u32, items: usize, words: Vec<u64> }
let helper = BFHelper::deserialize(deserializer)?;
let expected = helper.m.div_ceil(64);
if helper.words.len() != expected {
return Err(serde::de::Error::custom("words length mismatch"));
}
let bitset = BitSet::from_words(helper.m, helper.words);
Ok(Self { bits: bitset, m: helper.m, k: helper.k, items: helper.items, hasher_builder: S::default(), _marker: PhantomData })
}
}