use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct BloomFilter {
bits: Vec<u64>,
n_bits: usize,
n_hashes: usize,
n_elements: u64,
}
impl BloomFilter {
pub fn new(expected_elements: usize, false_positive_rate: f64) -> Self {
let n = expected_elements.max(1) as f64;
let p = false_positive_rate.clamp(1e-15, 1.0 - 1e-15);
let ln2 = std::f64::consts::LN_2;
let m = (-(n * p.ln()) / (ln2 * ln2)).ceil() as usize;
let m = m.max(64);
let k = ((m as f64 / n) * ln2).round() as usize;
let k = k.clamp(1, 30);
let n_words = (m + 63) / 64;
Self {
bits: vec![0u64; n_words],
n_bits: n_words * 64, n_hashes: k,
n_elements: 0,
}
}
pub fn insert(&mut self, item: &[u8]) {
let h1 = Self::fnv1a_hash(item);
let h2 = Self::djb2_hash(item);
let m = self.n_bits as u64;
for i in 0..self.n_hashes {
let bit_idx = h1.wrapping_add((i as u64).wrapping_mul(h2)) % m;
let word = (bit_idx / 64) as usize;
let bit = bit_idx % 64;
self.bits[word] |= 1u64 << bit;
}
self.n_elements += 1;
}
pub fn contains(&self, item: &[u8]) -> bool {
let h1 = Self::fnv1a_hash(item);
let h2 = Self::djb2_hash(item);
let m = self.n_bits as u64;
for i in 0..self.n_hashes {
let bit_idx = h1.wrapping_add((i as u64).wrapping_mul(h2)) % m;
let word = (bit_idx / 64) as usize;
let bit = bit_idx % 64;
if self.bits[word] & (1u64 << bit) == 0 {
return false;
}
}
true
}
pub fn false_positive_rate_estimate(&self) -> f64 {
let k = self.n_hashes as f64;
let n = self.n_elements as f64;
let m = self.n_bits as f64;
(1.0 - (-k * n / m).exp()).powf(k)
}
pub fn n_elements(&self) -> u64 {
self.n_elements
}
pub fn n_bits(&self) -> usize {
self.n_bits
}
pub fn n_hashes(&self) -> usize {
self.n_hashes
}
fn fnv1a_hash(data: &[u8]) -> u64 {
const OFFSET_BASIS: u64 = 0xcbf29ce484222325;
const PRIME: u64 = 0x100000001b3;
let mut hash = OFFSET_BASIS;
for &b in data {
hash ^= u64::from(b);
hash = hash.wrapping_mul(PRIME);
}
hash
}
fn djb2_hash(data: &[u8]) -> u64 {
let mut hash: u64 = 5381;
for &b in data {
hash = hash.wrapping_mul(33).wrapping_add(u64::from(b));
}
hash ^ (hash >> 32)
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PredicatePushdownResult {
MightContain,
DefinitelyAbsent,
}
pub struct BloomColumnIndex {
columns: HashMap<String, BloomFilter>,
}
impl BloomColumnIndex {
pub fn new() -> Self {
Self {
columns: HashMap::new(),
}
}
pub fn add_column(
&mut self,
column_name: &str,
expected_elements: usize,
fpr: f64,
) -> &mut BloomFilter {
self.columns
.entry(column_name.to_string())
.or_insert_with(|| BloomFilter::new(expected_elements, fpr))
}
pub fn insert_bytes(&mut self, column_name: &str, value: &[u8]) {
if let Some(filter) = self.columns.get_mut(column_name) {
filter.insert(value);
}
}
pub fn insert_string(&mut self, column_name: &str, value: &str) {
self.insert_bytes(column_name, value.as_bytes());
}
pub fn check(&self, column_name: &str, value: &[u8]) -> PredicatePushdownResult {
match self.columns.get(column_name) {
None => PredicatePushdownResult::MightContain,
Some(filter) => {
if filter.contains(value) {
PredicatePushdownResult::MightContain
} else {
PredicatePushdownResult::DefinitelyAbsent
}
}
}
}
pub fn check_string(&self, column_name: &str, value: &str) -> PredicatePushdownResult {
self.check(column_name, value.as_bytes())
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &BloomFilter)> {
self.columns.iter().map(|(k, v)| (k.as_str(), v))
}
pub fn len(&self) -> usize {
self.columns.len()
}
pub fn is_empty(&self) -> bool {
self.columns.is_empty()
}
}
impl Default for BloomColumnIndex {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bloom_no_false_negatives() {
let mut bloom = BloomFilter::new(500, 0.01);
let items: Vec<String> = (0..500).map(|i| format!("item_{i}")).collect();
for item in &items {
bloom.insert(item.as_bytes());
}
for item in &items {
assert!(
bloom.contains(item.as_bytes()),
"false negative for {:?}",
item
);
}
}
#[test]
fn test_bloom_definitely_absent_for_unseen() {
let mut bloom = BloomFilter::new(10, 0.001); bloom.insert(b"only_this");
let unseen: Vec<String> = (1000..1100).map(|i| format!("never_{i}")).collect();
let absent_count = unseen
.iter()
.filter(|s| !bloom.contains(s.as_bytes()))
.count();
assert!(
absent_count > 50,
"too many false positives: only {absent_count}/100 correctly absent"
);
}
#[test]
fn test_bloom_column_index_check() {
let mut idx = BloomColumnIndex::new();
idx.add_column("city", 100, 0.01);
let cities = ["Berlin", "Paris", "Tokyo", "London"];
for city in &cities {
idx.insert_string("city", city);
}
for city in &cities {
let result = idx.check_string("city", city);
assert_eq!(
result,
PredicatePushdownResult::MightContain,
"inserted city {city:?} should be MightContain"
);
}
let mut found_absent = false;
for i in 0..1000 {
let candidate = format!("NOT_A_CITY_{i}_xyz_never_inserted");
if idx.check_string("city", &candidate)
== PredicatePushdownResult::DefinitelyAbsent
{
found_absent = true;
break;
}
}
assert!(found_absent, "expected at least one definitely-absent result");
}
#[test]
fn test_bloom_column_index_missing_column() {
let idx = BloomColumnIndex::new();
let result = idx.check_string("nonexistent_col", "any_value");
assert_eq!(result, PredicatePushdownResult::MightContain);
}
#[test]
fn test_bloom_fpr_estimate_increases_with_load() {
let mut bloom = BloomFilter::new(100, 0.01);
let fpr0 = bloom.false_positive_rate_estimate();
for i in 0u64..50 {
bloom.insert(&i.to_le_bytes());
}
let fpr50 = bloom.false_positive_rate_estimate();
for i in 50u64..200 {
bloom.insert(&i.to_le_bytes());
}
let fpr200 = bloom.false_positive_rate_estimate();
assert!(fpr0 <= fpr50, "fpr should increase: {fpr0} -> {fpr50}");
assert!(fpr50 <= fpr200, "fpr should increase: {fpr50} -> {fpr200}");
}
#[test]
fn test_bloom_new_params_reasonable() {
let bloom = BloomFilter::new(1000, 0.01);
assert!(bloom.n_bits() >= 1000, "filter too small");
assert!(bloom.n_hashes() >= 1, "need at least 1 hash");
assert!(bloom.n_hashes() <= 30, "too many hashes");
assert_eq!(bloom.n_elements(), 0);
}
#[test]
fn test_bloom_index_default() {
let idx = BloomColumnIndex::default();
assert!(idx.is_empty());
assert_eq!(idx.len(), 0);
}
}