use std::hash::Hash;
use std::hash::Hasher;
use crate::codec::SketchBytes;
use crate::codec::SketchSlice;
use crate::codec::assert::ensure_preamble_longs_in_range;
use crate::codec::assert::ensure_serial_version_is;
use crate::codec::assert::insufficient_data;
use crate::codec::family::Family;
use crate::error::Error;
use crate::hash::XxHash64;
const SERIAL_VERSION: u8 = 1;
const EMPTY_FLAG_MASK: u8 = 1 << 2;
#[derive(Debug, Clone, PartialEq)]
pub struct BloomFilter {
pub(super) seed: u64,
pub(super) num_hashes: u16,
pub(super) num_bits_set: u64,
pub(super) bit_array: Box<[u64]>,
}
impl BloomFilter {
pub fn contains<T: Hash>(&self, item: &T) -> bool {
if self.is_empty() {
return false;
}
let (h0, h1) = self.compute_hash(item);
self.check_bits(h0, h1)
}
pub fn contains_and_insert<T: Hash>(&mut self, item: &T) -> bool {
let (h0, h1) = self.compute_hash(item);
let was_present = self.check_bits(h0, h1);
self.set_bits(h0, h1);
was_present
}
pub fn insert<T: Hash>(&mut self, item: T) {
let (h0, h1) = self.compute_hash(&item);
self.set_bits(h0, h1);
}
pub fn reset(&mut self) {
self.bit_array.fill(0);
self.num_bits_set = 0
}
pub fn union(&mut self, other: &BloomFilter) {
assert!(
self.is_compatible(other),
"Cannot union incompatible Bloom filters"
);
let mut num_bits_set = 0;
for (word, other_word) in self.bit_array.iter_mut().zip(&other.bit_array) {
*word |= *other_word;
num_bits_set += word.count_ones() as u64;
}
self.num_bits_set = num_bits_set;
}
pub fn intersect(&mut self, other: &BloomFilter) {
assert!(
self.is_compatible(other),
"Cannot intersect incompatible Bloom filters"
);
let mut num_bits_set = 0;
for (word, other_word) in self.bit_array.iter_mut().zip(&other.bit_array) {
*word &= *other_word;
num_bits_set += word.count_ones() as u64;
}
self.num_bits_set = num_bits_set;
}
pub fn invert(&mut self) {
for word in &mut self.bit_array {
*word = !*word;
}
self.num_bits_set = self.capacity() as u64 - self.num_bits_set;
}
pub fn is_empty(&self) -> bool {
self.num_bits_set == 0
}
pub fn bits_used(&self) -> u64 {
self.num_bits_set
}
pub fn capacity(&self) -> usize {
self.bit_array.len() * 64
}
pub fn num_hashes(&self) -> u16 {
self.num_hashes
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn load_factor(&self) -> f64 {
self.num_bits_set as f64 / self.capacity() as f64
}
pub fn estimated_fpp(&self) -> f64 {
let k = self.num_hashes as f64;
let load = self.load_factor();
load.powf(k)
}
pub fn is_compatible(&self, other: &Self) -> bool {
self.bit_array.len() == other.bit_array.len()
&& self.num_hashes == other.num_hashes
&& self.seed == other.seed
}
pub fn serialize(&self) -> Vec<u8> {
let is_empty = self.is_empty();
let preamble_longs = if is_empty {
Family::BLOOMFILTER.min_pre_longs
} else {
Family::BLOOMFILTER.max_pre_longs
};
let capacity = 8 * preamble_longs as usize
+ if is_empty {
0
} else {
self.bit_array.len() * 8
};
let mut bytes = SketchBytes::with_capacity(capacity);
bytes.write_u8(preamble_longs); bytes.write_u8(SERIAL_VERSION); bytes.write_u8(Family::BLOOMFILTER.id); bytes.write_u8(if is_empty { EMPTY_FLAG_MASK } else { 0 }); bytes.write_u16_le(self.num_hashes); bytes.write_u16_le(0);
bytes.write_u64_le(self.seed);
let num_longs = self.bit_array.len() as i32;
bytes.write_i32_le(num_longs);
bytes.write_u32_le(0);
if !is_empty {
bytes.write_u64_le(self.num_bits_set);
for &word in &self.bit_array {
bytes.write_u64_le(word);
}
}
bytes.into_bytes()
}
pub fn deserialize(bytes: &[u8]) -> Result<Self, Error> {
let mut cursor = SketchSlice::new(bytes);
let preamble_longs = cursor
.read_u8()
.map_err(insufficient_data("preamble_longs"))?;
let serial_version = cursor
.read_u8()
.map_err(insufficient_data("serial_version"))?;
let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?;
let flags = cursor.read_u8().map_err(insufficient_data("flags"))?;
Family::BLOOMFILTER.validate_id(family_id)?;
ensure_serial_version_is(SERIAL_VERSION, serial_version)?;
ensure_preamble_longs_in_range(
Family::BLOOMFILTER.min_pre_longs..=Family::BLOOMFILTER.max_pre_longs,
preamble_longs,
)?;
let is_empty = (flags & EMPTY_FLAG_MASK) != 0;
let num_hashes = cursor
.read_u16_le()
.map_err(insufficient_data("num_hashes"))?;
if num_hashes == 0 || num_hashes > i16::MAX as u16 {
return Err(Error::deserial(format!(
"invalid num_hashes: expected [1, {}], got {}",
i16::MAX,
num_hashes
)));
}
let _unused = cursor
.read_u16_le()
.map_err(insufficient_data("unused_header"))?;
let seed = cursor.read_u64_le().map_err(insufficient_data("seed"))?;
let num_longs = cursor
.read_i32_le()
.map_err(insufficient_data("num_longs"))?;
let _unused = cursor.read_u32_le().map_err(insufficient_data("unused"))?;
if num_longs <= 0 {
return Err(Error::deserial(format!(
"invalid num_longs: expected at least 1, got {}",
num_longs
)));
}
let num_words = num_longs as usize;
let mut bit_array = vec![0u64; num_words].into_boxed_slice();
let num_bits_set;
if is_empty {
num_bits_set = 0;
} else {
let raw_num_bits_set = cursor
.read_u64_le()
.map_err(insufficient_data("num_bits_set"))?;
for word in &mut bit_array {
*word = cursor
.read_u64_le()
.map_err(insufficient_data("bit_array"))?;
}
const DIRTY_BITS_VALUE: u64 = 0xFFFFFFFFFFFFFFFF;
if raw_num_bits_set == DIRTY_BITS_VALUE {
num_bits_set = bit_array.iter().map(|w| w.count_ones() as u64).sum();
} else {
let raw_num_words_set = raw_num_bits_set.div_ceil(64) as usize;
if raw_num_words_set > num_words {
return Err(Error::deserial(format!(
"invalid num_bits_set: expected <= {}, got {}",
num_words * 64,
raw_num_bits_set
)));
}
num_bits_set = raw_num_bits_set;
}
}
Ok(BloomFilter {
seed,
num_hashes,
num_bits_set,
bit_array,
})
}
fn compute_hash<T: Hash>(&self, item: &T) -> (u64, u64) {
let mut hasher = XxHash64::with_seed(self.seed);
item.hash(&mut hasher);
let h0 = hasher.finish();
let mut hasher = XxHash64::with_seed(h0);
item.hash(&mut hasher);
let h1 = hasher.finish();
(h0, h1)
}
fn check_bits(&self, h0: u64, h1: u64) -> bool {
for i in 1..=self.num_hashes {
let bit_index = self.compute_bit_index(h0, h1, i);
if !self.get_bit(bit_index) {
return false;
}
}
true
}
fn set_bits(&mut self, h0: u64, h1: u64) {
for i in 1..=self.num_hashes {
let bit_index = self.compute_bit_index(h0, h1, i);
self.set_bit(bit_index);
}
}
fn compute_bit_index(&self, h0: u64, h1: u64, i: u16) -> usize {
let hash = h0.wrapping_add(u64::from(i).wrapping_mul(h1)) as usize;
(hash >> 1) % self.capacity()
}
fn get_bit(&self, bit_index: usize) -> bool {
let word_index = bit_index >> 6; let bit_offset = bit_index & 63; let mask = 1u64 << bit_offset;
(self.bit_array[word_index] & mask) != 0
}
fn set_bit(&mut self, bit_index: usize) {
let word_index = bit_index >> 6; let bit_offset = bit_index & 63; let mask = 1u64 << bit_offset;
if (self.bit_array[word_index] & mask) == 0 {
self.bit_array[word_index] |= mask;
self.num_bits_set += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::BloomFilter;
use crate::bloom::BloomFilterBuilder;
#[test]
fn test_builder_with_accuracy() {
let filter = BloomFilterBuilder::with_accuracy(1000, 0.01).build();
assert!(filter.capacity() >= 9000);
assert_eq!(filter.num_hashes(), 7);
assert!(filter.is_empty());
}
#[test]
fn test_builder_with_size() {
let filter = BloomFilterBuilder::with_size(1024, 5).build();
assert_eq!(filter.capacity(), 1024);
assert_eq!(filter.num_hashes(), 5);
}
#[test]
fn test_builder_with_size_rounds_to_word_boundary() {
let filter = BloomFilterBuilder::with_size(1, 3).build();
assert_eq!(filter.capacity(), 64);
assert_eq!(filter.num_hashes(), 3);
}
#[test]
fn test_insert_and_contains() {
let mut filter = BloomFilterBuilder::with_accuracy(100, 0.01).build();
assert!(!filter.contains(&"apple"));
filter.insert("apple");
assert!(filter.contains(&"apple"));
assert!(!filter.is_empty());
}
#[test]
fn test_contains_and_insert() {
let mut filter = BloomFilterBuilder::with_accuracy(100, 0.01).build();
let was_present = filter.contains_and_insert(&42_u64);
assert!(!was_present);
let was_present = filter.contains_and_insert(&42_u64);
assert!(was_present);
}
#[test]
fn test_reset() {
let mut filter = BloomFilterBuilder::with_accuracy(100, 0.01).build();
filter.insert("test");
assert!(!filter.is_empty());
filter.reset();
assert!(filter.is_empty());
assert!(!filter.contains(&"test"));
}
#[test]
fn test_union() {
let mut f1 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
let mut f2 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
f1.insert("a");
f2.insert("b");
f1.union(&f2);
assert!(f1.contains(&"a"));
assert!(f1.contains(&"b"));
}
#[test]
fn test_intersect() {
let mut f1 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
let mut f2 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
f1.insert("a");
f1.insert("b");
f2.insert("b");
f2.insert("c");
f1.intersect(&f2);
assert!(f1.contains(&"b"));
}
#[test]
fn test_serialize_deserialize_empty() {
let filter = BloomFilterBuilder::with_accuracy(100, 0.01).build();
let bytes = filter.serialize();
let restored = BloomFilter::deserialize(&bytes).unwrap();
assert_eq!(filter, restored);
}
#[test]
fn test_serialize_deserialize_with_data() {
let mut filter = BloomFilterBuilder::with_accuracy(100, 0.01).build();
filter.insert("test");
filter.insert(42_u64);
let bytes = filter.serialize();
let restored = BloomFilter::deserialize(&bytes).unwrap();
assert_eq!(filter, restored);
assert!(restored.contains(&"test"));
assert!(restored.contains(&42_u64));
}
#[test]
fn test_statistics() {
let mut filter = BloomFilterBuilder::with_size(1000, 5).build();
assert_eq!(filter.bits_used(), 0);
assert_eq!(filter.load_factor(), 0.0);
filter.insert("test");
assert!(filter.bits_used() > 0);
assert!(filter.load_factor() > 0.0);
assert!(filter.estimated_fpp() > 0.0);
}
#[test]
fn test_is_compatible() {
let f1 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
let f2 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(123)
.build();
let f3 = BloomFilterBuilder::with_accuracy(100, 0.01)
.seed(456)
.build();
assert!(f1.is_compatible(&f2));
assert!(!f1.is_compatible(&f3));
}
#[test]
#[should_panic(expected = "max_items must be greater than 0")]
fn test_invalid_max_items() {
BloomFilterBuilder::with_accuracy(0, 0.01);
}
#[test]
#[should_panic(expected = "fpp must be between")]
fn test_invalid_fpp() {
BloomFilterBuilder::with_accuracy(100, 1.5);
}
}