use crate::proto;
#[derive(Debug, Clone)]
pub struct BloomFilter {
num_hash_functions: u32,
bitset: Vec<u64>,
}
impl BloomFilter {
pub fn try_from_proto(proto: &proto::BloomFilter) -> Option<Self> {
assert!(
proto.bitset.is_empty() || proto.utf8bitset.is_none(),
"Bloom filter proto has both bitset and utf8bitset populated"
);
let num_hash_functions = proto.num_hash_functions();
if proto.bitset.is_empty() && proto.utf8bitset.is_none() {
return None;
}
let bitset = if !proto.bitset.is_empty() {
proto.bitset.clone()
} else {
proto
.utf8bitset
.as_ref()
.map(|bytes| {
bytes
.chunks(8)
.map(|chunk| {
let mut padded = [0u8; 8];
for (idx, value) in chunk.iter().enumerate() {
padded[idx] = *value;
}
u64::from_le_bytes(padded)
})
.collect::<Vec<_>>()
})
.unwrap_or_default()
};
Some(Self {
num_hash_functions: if num_hash_functions == 0 {
3
} else {
num_hash_functions
},
bitset,
})
}
#[cfg(test)]
pub fn from_parts(num_hash_functions: u32, bitset: Vec<u64>) -> Self {
Self {
num_hash_functions: num_hash_functions.max(1),
bitset,
}
}
pub fn add_hash(&mut self, hash64: u64) {
let bit_count = self.bitset.len() * 64;
if bit_count == 0 {
return;
}
let hash1 = hash64 as u32 as i32;
let hash2 = (hash64 >> 32) as u32 as i32;
for i in 1..=self.num_hash_functions {
let mut combined = hash1.wrapping_add((i as i32).wrapping_mul(hash2));
if combined < 0 {
combined = !combined;
}
let bit_idx = ((combined as u32 as u64) % (bit_count as u64)) as usize;
self.bitset[bit_idx / 64] |= 1u64 << (bit_idx % 64);
}
}
pub fn test_hash(&self, hash64: u64) -> bool {
let bit_count = self.bitset.len() * 64;
if bit_count == 0 {
return true;
}
let hash1 = hash64 as u32 as i32;
let hash2 = (hash64 >> 32) as u32 as i32;
for i in 1..=self.num_hash_functions {
let mut combined = hash1.wrapping_add((i as i32).wrapping_mul(hash2));
if combined < 0 {
combined = !combined;
}
let bit_idx = ((combined as u32 as u64) % (bit_count as u64)) as usize;
let word = bit_idx / 64;
let bit = bit_idx % 64;
let mask = 1u64 << bit;
if self
.bitset
.get(word)
.map_or(true, |bits| (bits & mask) == 0)
{
return false;
}
}
true
}
pub(crate) fn hash_bytes(value: &[u8]) -> u64 {
murmur3_64_orc(value)
}
pub(crate) fn hash_long(value: i64) -> u64 {
let mut key = value;
key = (!key).wrapping_add(key.wrapping_shl(21));
key ^= key >> 24;
key = key
.wrapping_add(key.wrapping_shl(3))
.wrapping_add(key.wrapping_shl(8));
key ^= key >> 14;
key = key
.wrapping_add(key.wrapping_shl(2))
.wrapping_add(key.wrapping_shl(4));
key ^= key >> 28;
key = key.wrapping_add(key.wrapping_shl(31));
key as u64
}
pub fn num_hash_functions(&self) -> u32 {
self.num_hash_functions
}
pub fn word_count(&self) -> usize {
self.bitset.len()
}
pub fn bit_count(&self) -> usize {
self.bitset.len() * 64
}
pub fn might_contain(&self, value: &[u8]) -> bool {
self.test_hash(Self::hash_bytes(value))
}
}
fn murmur3_64_orc(bytes: &[u8]) -> u64 {
const C1: u64 = 0x87c3_7b91_1142_53d5;
const C2: u64 = 0x4cf5_ad43_2745_937f;
const R1: u32 = 31;
const R2: u32 = 27;
const M: u64 = 5;
const N1: u64 = 1_390_208_809;
const SEED: u64 = 104_729;
let mut h1 = SEED;
let nblocks = bytes.len() / 8;
for i in 0..nblocks {
let start = i * 8;
let mut k1 =
u64::from_le_bytes(bytes[start..start + 8].try_into().unwrap()).wrapping_mul(C1);
k1 = k1.rotate_left(R1);
k1 = k1.wrapping_mul(C2);
h1 ^= k1;
h1 = h1.rotate_left(R2);
h1 = h1.wrapping_mul(M).wrapping_add(N1);
}
let mut k1 = 0u64;
let tail = &bytes[nblocks * 8..];
if tail.len() >= 7 {
k1 ^= (tail[6] as u64) << 48;
}
if tail.len() >= 6 {
k1 ^= (tail[5] as u64) << 40;
}
if tail.len() >= 5 {
k1 ^= (tail[4] as u64) << 32;
}
if tail.len() >= 4 {
k1 ^= (tail[3] as u64) << 24;
}
if tail.len() >= 3 {
k1 ^= (tail[2] as u64) << 16;
}
if tail.len() >= 2 {
k1 ^= (tail[1] as u64) << 8;
}
if !tail.is_empty() {
k1 ^= tail[0] as u64;
}
if !tail.is_empty() {
k1 = k1.wrapping_mul(C1);
k1 = k1.rotate_left(R1);
k1 = k1.wrapping_mul(C2);
h1 ^= k1;
}
h1 ^= bytes.len() as u64;
fmix64(h1)
}
fn fmix64(mut k: u64) -> u64 {
k ^= k >> 33;
k = k.wrapping_mul(0xff51_afd7_ed55_8ccd);
k ^= k >> 33;
k = k.wrapping_mul(0xc4ce_b9fe_1a85_ec53);
k ^= k >> 33;
k
}
#[cfg(test)]
mod tests {
use super::*;
fn build_filter(values: &[&[u8]], bitset_words: usize, hash_funcs: u32) -> BloomFilter {
let mut filter = BloomFilter::from_parts(hash_funcs, vec![0u64; bitset_words]);
for value in values {
let hash64 = BloomFilter::hash_bytes(value);
filter.add_hash(hash64);
}
filter
}
#[test]
fn test_bloom_filter_hit_and_miss() {
let filter = build_filter(&[b"abc", b"def"], 2, 3);
let abc = BloomFilter::hash_bytes(b"abc");
let xyz = BloomFilter::hash_bytes(b"xyz");
assert!(filter.test_hash(abc));
assert!(!filter.test_hash(xyz));
}
#[test]
fn test_try_from_proto_utf8_bitset() {
let filter = build_filter(&[b"foo"], 1, 2);
let proto = proto::BloomFilter {
num_hash_functions: Some(filter.num_hash_functions),
bitset: vec![],
utf8bitset: Some(filter.bitset.iter().flat_map(|w| w.to_le_bytes()).collect()),
};
let decoded = BloomFilter::try_from_proto(&proto).unwrap();
let foo = BloomFilter::hash_bytes(b"foo");
let bar = BloomFilter::hash_bytes(b"bar");
assert!(decoded.test_hash(foo));
assert!(!decoded.test_hash(bar));
}
#[test]
fn test_might_contain_hash64() {
let value = 42i64;
let hash64 = BloomFilter::hash_long(value);
let num_hash_functions = 3;
let mut filter = BloomFilter::from_parts(num_hash_functions, vec![0u64; 2]);
filter.add_hash(hash64);
assert!(filter.test_hash(hash64));
assert!(!filter.test_hash(BloomFilter::hash_long(value + 1)));
}
}