bloomz 0.1.0

A fast, flexible Bloom filter library for Rust with parallel operations support
Documentation
use core::hash::{BuildHasher, Hash};
use core::marker::PhantomData;

use std::fmt;
#[cfg(feature = "serde")] use serde::{Deserialize, Serializer, Deserializer, ser::SerializeStruct};
#[cfg(feature = "rayon")] use rayon::prelude::*;

use crate::{bitset::BitSet, hashing, math};
/// bloom filter with configurable BuildHasher `S`.
///
/// `S` defaults to `std::collections::hash_map::RandomState` which uses SipHash (safe).
#[derive(Clone)]
pub struct BloomFilter<S = std::collections::hash_map::RandomState> {
    bits: BitSet,
    m: usize, //number of bits
    k: u32,   //hash funcs
    items: usize,
    hasher_builder: S,
    _marker: PhantomData<S>,
}

impl<S> fmt::Debug for BloomFilter<S>
where
    S: BuildHasher + Clone,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("BloomFilter")
            .field("m(bits)", &self.m)
            .field("k", &self.k)
            .field("items", &self.items)
            .finish()
    }
}

impl BloomFilter<std::collections::hash_map::RandomState> {
    /// convenience constructor using default hasher builder.
    pub fn new(m: usize, k: u32) -> Self {
        Self::with_hasher(m, k, std::collections::hash_map::RandomState::new())
    }

    /// convenience constructor from capacity and false-positive rate with default hasher.
    pub fn new_for_capacity(n: usize, p: f64) -> Self {
        let m = math::optimal_m(n, p);
        let k = math::optimal_k(m, n);
        Self::with_hasher(m, k, std::collections::hash_map::RandomState::new())
    }
}

impl<S> BloomFilter<S>
where
    S: BuildHasher + Clone,
{
    /// create with explicit hasher builder (eg. ahash::AHasherBuilder or RandomState)
    pub fn with_hasher(m: usize, k: u32, hasher_builder: S) -> Self {
        assert!(m > 0 && k > 0);
        Self {
            bits: BitSet::new(m),
            m,
            k,
            items: 0,
            hasher_builder,
            _marker: PhantomData,
        }
    }

    /// Insert an item into the Bloom filter.
    ///
    /// Computes `k` indices using double hashing based on two base hashes and
    /// sets the corresponding bits. Duplicate inserts still increment the
    /// internal `items` counter (no attempt to de-duplicate inputs is made).
    ///
    /// * `item` - The value to insert (any type implementing `Hash`).
    pub fn insert<T : Hash>(&mut self, item : &T){
        let (h1, h2) = hashing::hash2(&self.hasher_builder, item);
        for i in 0..self.k{
            let combined = h1.wrapping_add((i as u64).wrapping_mul(h2));
            let idx = (combined % (self.m as u64)) as usize;
            self.bits.set(idx);
        }
        self.items = self.items.saturating_add(1);
    }

    /// Parallel batch insert using rayon (requires "rayon" feature).
    ///
    /// Uses a two-phase approach: compute indices in parallel,
    /// then merge results to avoid data races on the bit array.
    #[cfg(feature = "rayon")]
    pub fn insert_batch<T>(&mut self, items: impl IntoParallelIterator<Item = T>)
    where
        T: Hash + Send + Sync,
        S: Send + Sync,
    {
        use std::collections::HashSet;
        use std::sync::Mutex;
        
        // Phase 1: Compute all indices in parallel and collect
        let indices_set = Mutex::new(HashSet::new());
        
        items.into_par_iter().for_each(|item| {
            let (h1, h2) = hashing::hash2(&self.hasher_builder, &item);
            let mut local_indices = Vec::with_capacity(self.k as usize);
            
            for i in 0..self.k {
                let combined = h1.wrapping_add((i as u64).wrapping_mul(h2));
                let idx = (combined % (self.m as u64)) as usize;
                local_indices.push(idx);
            }
            
            let mut set = indices_set.lock().unwrap();
            set.extend(local_indices);
        });

        // Phase 2: Set bits sequentially (no data races)
        let indices = indices_set.into_inner().unwrap();
        let count = indices.len();
        
        for idx in indices {
            self.bits.set(idx);
        }
        
        // Update items count (approximate, since we deduplicated indices)
        self.items = self.items.saturating_add(count);
    }

    /// Test whether an item is *probably* in the set.
    ///
    /// Returns `false` if any of the `k` derived bit positions is clear
    /// (definitely not present). Returns `true` if all are set (item was
    /// likely inserted earlier, with a chance of false positives).
    pub fn contains<T : Hash>(&self, item : &T) -> bool{
        let (h1, h2) = hashing::hash2(&self.hasher_builder, item);
       for i in 0..self.k {
            let combined = h1.wrapping_add((i as u64).wrapping_mul(h2));
            let idx = (combined % (self.m as u64)) as usize;
            if !self.bits.get(idx) {
                return false;
            }
        }
        true
    }

    /// Parallel batch contains check (requires "rayon" feature).
    ///
    /// Returns `true` if ALL items are probably in the set.
    #[cfg(feature = "rayon")]
    pub fn contains_all<T>(&self, items: impl IntoParallelIterator<Item = T>) -> bool
    where
        T: Hash + Send + Sync,
        S: Send + Sync,
    {
        items.into_par_iter().all(|item| self.contains(&item))
    }

    /// Parallel batch contains check returning a Vec of results (requires "rayon" feature).
    ///
    /// Returns a Vec<bool> with same length as input, indicating membership for each item.
    #[cfg(feature = "rayon")]
    pub fn contains_batch<T>(&self, items: impl IntoParallelIterator<Item = T>) -> Vec<bool>
    where
        T: Hash + Send + Sync,
        S: Send + Sync,
    {
        items.into_par_iter().map(|item| self.contains(&item)).collect()
    }

    /// In‑place union (bitwise OR) with another filter.
    ///
    /// Both filters must have identical `m` and `k` parameters.
    pub fn union_inplace(&mut self, other: &Self) {
        assert_eq!(self.m, other.m, "m mismatch for union");
        assert_eq!(self.k, other.k, "k mismatch for union");
        self.bits.or_with(&other.bits);
    }

    /// In‑place intersection (bitwise AND) with another filter.
    ///
    /// Both filters must have identical `m` and `k` parameters.
    pub fn intersect_inplace(&mut self, other: &Self) {
        assert_eq!(self.m, other.m, "m mismatch for intersection");
        assert_eq!(self.k, other.k, "k mismatch for intersection");
        self.bits.and_with(&other.bits);
    }

    /// Clear all bits and reset the item counter to zero.
    pub fn clear(&mut self) {
        self.bits.clear();
        self.items = 0;
    }

    /// Approximate number of times `insert` was called.
    ///
    /// Note: duplicates are counted; this is not a distinct element count.
    pub fn approximate_items(&self) -> usize {
        self.items
    }

    /// Serialize the filter into a byte vector.
    ///
    /// Layout:
    ///   words (u64 little‑endian) + m (u64 LE) + k (u32 LE)
    pub fn to_bytes(&self) -> Vec<u8> {
        let words = self.bits.words_slice();
        let mut out = Vec::with_capacity(words.len() * 8 + 12);
        for w in words {
            out.extend_from_slice(&w.to_le_bytes());
        }
        out.extend_from_slice(&(self.m as u64).to_le_bytes());
        out.extend_from_slice(&self.k.to_le_bytes());
        out
    }

    /// Deserialize from bytes with an explicit hasher builder.
    ///
    /// Returns `None` if the data length or internal layout is invalid.
    pub fn from_bytes_hasher(data: &[u8], hasher_builder: S) -> Option<Self> {
        if data.len() < 12 { return None; }
        let meta_offset = data.len() - 12;
        let mut m_bytes = [0u8; 8];
        m_bytes.copy_from_slice(&data[meta_offset..meta_offset+8]);
        let m = u64::from_le_bytes(m_bytes) as usize;

        let mut k_bytes = [0u8; 4];
        k_bytes.copy_from_slice(&data[meta_offset+8..meta_offset+12]);
        let k = u32::from_le_bytes(k_bytes);

        let words_expected = m.div_ceil(64);
        if meta_offset != words_expected * 8 { return None; }

        let mut words = Vec::with_capacity(words_expected);
        for i in 0..words_expected {
            let start = i * 8;
            let mut wb = [0u8; 8];
            wb.copy_from_slice(&data[start..start+8]);
            words.push(u64::from_le_bytes(wb));
        }

        let mut bitset = BitSet::new(m);
        bitset.words_mut().copy_from_slice(&words);

        Some(Self {
            bits: bitset,
            m,
            k,
            items: 0,
            hasher_builder,
            _marker: PhantomData,
        })
    }

    /// Convenience wrapper that rebuilds using a default `RandomState`-derived builder.
    pub fn from_bytes(data: &[u8]) -> Option<Self>
    where
        std::collections::hash_map::RandomState: Clone,
        S: From<std::collections::hash_map::RandomState>,
    {
        let rs = std::collections::hash_map::RandomState::new();
        let builder: S = rs.into();
        Self::from_bytes_hasher(data, builder)
    }
}

#[cfg(feature = "serde")]
impl<S> serde::Serialize for BloomFilter<S>
where S: BuildHasher + Clone + Default {
    fn serialize<Se: Serializer>(&self, serializer: Se) -> Result<Se::Ok, Se::Error> {
        let mut st = serializer.serialize_struct("BloomFilter", 4)?;
        st.serialize_field("m", &self.m)?;
        st.serialize_field("k", &self.k)?;
        st.serialize_field("items", &self.items)?;
        st.serialize_field("words", self.bits.words_slice())?;
        st.end()
    }
}

#[cfg(feature = "serde")]
impl<'de, S> serde::Deserialize<'de> for BloomFilter<S>
where S: BuildHasher + Clone + Default {
    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
        #[derive(Deserialize)]
        struct BFHelper { m: usize, k: u32, items: usize, words: Vec<u64> }
        let helper = BFHelper::deserialize(deserializer)?;
        let expected = helper.m.div_ceil(64);
        if helper.words.len() != expected {
            return Err(serde::de::Error::custom("words length mismatch"));
        }
        let bitset = BitSet::from_words(helper.m, helper.words);
        Ok(Self { bits: bitset, m: helper.m, k: helper.k, items: helper.items, hasher_builder: S::default(), _marker: PhantomData })
    }
}