use std::hash::Hash;
use std::hash::Hasher;
use crate::codec::SketchBytes;
use crate::codec::SketchSlice;
use crate::error::Error;
use crate::hash::XxHash64;
const PREAMBLE_LONGS_EMPTY: u8 = 3;
const PREAMBLE_LONGS_STANDARD: u8 = 4;
const FAMILY_ID: u8 = 21; const SERIAL_VERSION: u8 = 1;
const EMPTY_FLAG_MASK: u8 = 1 << 2;
#[derive(Debug, Clone, PartialEq)]
pub struct BloomFilter {
pub(super) seed: u64,
pub(super) num_hashes: u16,
pub(super) capacity_bits: u64,
pub(super) num_bits_set: u64,
pub(super) bit_array: Vec<u64>,
}
impl BloomFilter {
pub fn contains<T: Hash>(&self, item: &T) -> bool {
if self.is_empty() {
return false;
}
let (h0, h1) = self.compute_hash(item);
self.check_bits(h0, h1)
}
pub fn contains_and_insert<T: Hash>(&mut self, item: &T) -> bool {
let (h0, h1) = self.compute_hash(item);
let was_present = self.check_bits(h0, h1);
self.set_bits(h0, h1);
was_present
}
pub fn insert<T: Hash>(&mut self, item: T) {
let (h0, h1) = self.compute_hash(&item);
self.set_bits(h0, h1);
}
pub fn reset(&mut self) {
self.bit_array.fill(0);
self.num_bits_set = 0
}
pub fn union(&mut self, other: &BloomFilter) {
assert!(
self.is_compatible(other),
"Cannot union incompatible Bloom filters"
);
let mut num_bits_set = 0;
for (word, other_word) in self.bit_array.iter_mut().zip(&other.bit_array) {
*word |= *other_word;
num_bits_set += word.count_ones() as u64;
}
self.num_bits_set = num_bits_set;
}
pub fn intersect(&mut self, other: &BloomFilter) {
assert!(
self.is_compatible(other),
"Cannot intersect incompatible Bloom filters"
);
let mut num_bits_set = 0;
for (word, other_word) in self.bit_array.iter_mut().zip(&other.bit_array) {
*word &= *other_word;
num_bits_set += word.count_ones() as u64;
}
self.num_bits_set = num_bits_set;
}
pub fn invert(&mut self) {
let mut num_bits_set = 0;
for word in &mut self.bit_array {
*word = !*word;
num_bits_set += word.count_ones() as u64;
}
let excess_bits = self.capacity_bits % 64;
if excess_bits != 0 {
let last_idx = self.bit_array.len() - 1;
let old_count = self.bit_array[last_idx].count_ones() as u64;
let mask = (1u64 << excess_bits) - 1;
self.bit_array[last_idx] &= mask;
let new_count = self.bit_array[last_idx].count_ones() as u64;
num_bits_set = num_bits_set - old_count + new_count;
}
self.num_bits_set = num_bits_set;
}
pub fn is_empty(&self) -> bool {
self.num_bits_set == 0
}
pub fn bits_used(&self) -> u64 {
self.num_bits_set
}
pub fn capacity(&self) -> u64 {
self.capacity_bits
}
pub fn num_hashes(&self) -> u16 {
self.num_hashes
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn load_factor(&self) -> f64 {
self.num_bits_set as f64 / self.capacity_bits as f64
}
pub fn estimated_fpp(&self) -> f64 {
let k = self.num_hashes as f64;
let load = self.load_factor();
load.powf(k)
}
pub fn is_compatible(&self, other: &BloomFilter) -> bool {
self.capacity_bits == other.capacity_bits
&& self.num_hashes == other.num_hashes
&& self.seed == other.seed
}
pub fn serialize(&self) -> Vec<u8> {
let is_empty = self.is_empty();
let preamble_longs = if is_empty {
PREAMBLE_LONGS_EMPTY
} else {
PREAMBLE_LONGS_STANDARD
};
let capacity = 8 * preamble_longs as usize
+ if is_empty {
0
} else {
self.bit_array.len() * 8
};
let mut bytes = SketchBytes::with_capacity(capacity);
bytes.write_u8(preamble_longs); bytes.write_u8(SERIAL_VERSION); bytes.write_u8(FAMILY_ID); bytes.write_u8(if is_empty { EMPTY_FLAG_MASK } else { 0 }); bytes.write_u16_le(self.num_hashes); bytes.write_u16_le(0);
bytes.write_u64_le(self.seed);
let num_longs = (self.capacity_bits / 64) as i32;
bytes.write_i32_le(num_longs);
bytes.write_u32_le(0);
if !is_empty {
bytes.write_u64_le(self.num_bits_set);
for &word in &self.bit_array {
bytes.write_u64_le(word);
}
}
bytes.into_bytes()
}
pub fn deserialize(bytes: &[u8]) -> Result<Self, Error> {
let mut cursor = SketchSlice::new(bytes);
let preamble_longs = cursor
.read_u8()
.map_err(|_| Error::insufficient_data("preamble_longs"))?;
let serial_version = cursor
.read_u8()
.map_err(|_| Error::insufficient_data("serial_version"))?;
let family_id = cursor
.read_u8()
.map_err(|_| Error::insufficient_data("family_id"))?;
let flags = cursor
.read_u8()
.map_err(|_| Error::insufficient_data("flags"))?;
if family_id != FAMILY_ID {
return Err(Error::invalid_family(FAMILY_ID, family_id, "BloomFilter"));
}
if serial_version != SERIAL_VERSION {
return Err(Error::unsupported_serial_version(
SERIAL_VERSION,
serial_version,
));
}
if preamble_longs != PREAMBLE_LONGS_EMPTY && preamble_longs != PREAMBLE_LONGS_STANDARD {
return Err(Error::invalid_preamble_longs(
PREAMBLE_LONGS_STANDARD,
preamble_longs,
));
}
let is_empty = (flags & EMPTY_FLAG_MASK) != 0;
let num_hashes = cursor
.read_u16_le()
.map_err(|_| Error::insufficient_data("num_hashes"))?;
let _unused = cursor
.read_u16_le()
.map_err(|_| Error::insufficient_data("unused_header"))?;
let seed = cursor
.read_u64_le()
.map_err(|_| Error::insufficient_data("seed"))?;
let num_longs = cursor
.read_i32_le()
.map_err(|_| Error::insufficient_data("num_longs"))? as u64;
let _unused = cursor
.read_u32_le()
.map_err(|_| Error::insufficient_data("unused"))?;
let capacity_bits = num_longs * 64; let num_words = num_longs as usize;
let mut bit_array = vec![0u64; num_words];
let num_bits_set;
if is_empty {
num_bits_set = 0;
} else {
let raw_num_bits_set = cursor
.read_u64_le()
.map_err(|_| Error::insufficient_data("num_bits_set"))?;
for word in &mut bit_array {
*word = cursor
.read_u64_le()
.map_err(|_| Error::insufficient_data("bit_array"))?;
}
const DIRTY_BITS_VALUE: u64 = 0xFFFFFFFFFFFFFFFF;
if raw_num_bits_set == DIRTY_BITS_VALUE {
num_bits_set = bit_array.iter().map(|w| w.count_ones() as u64).sum();
} else {
num_bits_set = raw_num_bits_set;
}
}
Ok(BloomFilter {
seed,
num_hashes,
capacity_bits,
num_bits_set,
bit_array,
})
}
fn compute_hash<T: Hash>(&self, item: &T) -> (u64, u64) {
let mut hasher = XxHash64::with_seed(self.seed);
item.hash(&mut hasher);
let h0 = hasher.finish();
let mut hasher = XxHash64::with_seed(h0);
item.hash(&mut hasher);
let h1 = hasher.finish();
(h0, h1)
}
fn check_bits(&self, h0: u64, h1: u64) -> bool {
for i in 1..=self.num_hashes {
let bit_index = self.compute_bit_index(h0, h1, i);
if !self.get_bit(bit_index) {
return false;
}
}
true
}
fn set_bits(&mut self, h0: u64, h1: u64) {
for i in 1..=self.num_hashes {
let bit_index = self.compute_bit_index(h0, h1, i);
self.set_bit(bit_index);
}
}
fn compute_bit_index(&self, h0: u64, h1: u64, i: u16) -> u64 {
let hash = h0.wrapping_add(u64::from(i).wrapping_mul(h1));
(hash >> 1) % self.capacity_bits
}
fn get_bit(&self, bit_index: u64) -> bool {
let word_index = (bit_index >> 6) as usize; let bit_offset = bit_index & 63; let mask = 1u64 << bit_offset;
(self.bit_array[word_index] & mask) != 0
}
fn set_bit(&mut self, bit_index: u64) {
let word_index = (bit_index >> 6) as usize; let bit_offset = bit_index & 63; let mask = 1u64 << bit_offset;
if (self.bit_array[word_index] & mask) == 0 {
self.bit_array[word_index] |= mask;
self.num_bits_set += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::BloomFilter;
use crate::bloom::BloomFilterBuilder;
#[test]
fn test_builder_with_accuracy() {
let filter = BloomFilterBuilder::with_accuracy(1000, 0.01).build();
assert!(filter.capacity() >= 9000);
assert_eq!(filter.num_hashes(), 7);
assert!(filter.is_empty());
}
#[test]
fn test_builder_with_size() {
let filter = BloomFilterBuilder::with_size(1024, 5).build();
assert_eq!(filter.capacity(), 1024);
assert_eq!(filter.num_hashes(), 5);
}
#[test]
fn test_insert_and_contains() {
let mut filter = BloomFilterBuilder::with_accuracy(100, 0.01).build();
assert!(!filter.contains(&"apple"));
filter.insert("apple");
assert!(filter.contains(&"apple"));
assert!(!filter.is_empty());
}
#[test]
fn test_contains_and_insert() {
let mut filter = BloomFilterBuilder::with_accuracy(100, 0.01).build();
let was_present = filter.contains_and_insert(&42_u64);
assert!(!was_present);
let was_present = filter.contains_and_insert(&42_u64);
assert!(was_present);
}
#[test]
fn test_reset() {
let mut filter = BloomFilterBuilder::with_accuracy(100, 0.01).build();
filter.insert("test");
assert!(!filter.is_empty());
filter.reset();
assert!(filter.is_empty());
assert!(!filter.contains(&"test"));
}
#[test]
fn test_union() {
let mut f1 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
let mut f2 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
f1.insert("a");
f2.insert("b");
f1.union(&f2);
assert!(f1.contains(&"a"));
assert!(f1.contains(&"b"));
}
#[test]
fn test_intersect() {
let mut f1 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
let mut f2 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
f1.insert("a");
f1.insert("b");
f2.insert("b");
f2.insert("c");
f1.intersect(&f2);
assert!(f1.contains(&"b"));
}
#[test]
fn test_serialize_deserialize_empty() {
let filter = BloomFilterBuilder::with_accuracy(100, 0.01).build();
let bytes = filter.serialize();
let restored = BloomFilter::deserialize(&bytes).unwrap();
assert_eq!(filter, restored);
}
#[test]
fn test_serialize_deserialize_with_data() {
let mut filter = BloomFilterBuilder::with_accuracy(100, 0.01).build();
filter.insert("test");
filter.insert(42_u64);
let bytes = filter.serialize();
let restored = BloomFilter::deserialize(&bytes).unwrap();
assert_eq!(filter, restored);
assert!(restored.contains(&"test"));
assert!(restored.contains(&42_u64));
}
#[test]
fn test_statistics() {
let mut filter = BloomFilterBuilder::with_size(1000, 5).build();
assert_eq!(filter.bits_used(), 0);
assert_eq!(filter.load_factor(), 0.0);
filter.insert("test");
assert!(filter.bits_used() > 0);
assert!(filter.load_factor() > 0.0);
assert!(filter.estimated_fpp() > 0.0);
}
#[test]
fn test_is_compatible() {
let f1 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
let f2 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
let f3 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(456)
.build();
assert!(f1.is_compatible(&f2));
assert!(!f1.is_compatible(&f3));
}
#[test]
#[should_panic(expected = "max_items must be greater than 0")]
fn test_invalid_max_items() {
BloomFilterBuilder::with_accuracy(0, 0.01);
}
#[test]
#[should_panic(expected = "fpp must be between")]
fn test_invalid_fpp() {
BloomFilterBuilder::with_accuracy(100, 1.5);
}
}