use std::hash::{Hash, Hasher};
pub type FPR = f64;
pub const DEFAULT_FPR: FPR = 0.01;
pub trait FilterPolicy: Send + Sync + std::fmt::Debug {
fn name(&self) -> &str;
fn create_builder(&self, num_keys: usize) -> Box<dyn FilterBuilder>;
fn create_filter(&self, data: Vec<u8>) -> Box<dyn Filter>;
fn bits_per_key(&self) -> f64;
fn target_fpr(&self) -> FPR;
}
pub trait FilterBuilder: Send {
fn add_key(&mut self, key: &[u8]);
fn finish(&mut self) -> Vec<u8>;
fn num_keys(&self) -> usize;
}
pub trait Filter: Send + Sync {
fn may_contain(&self, key: &[u8]) -> bool;
fn size_bytes(&self) -> usize;
}
#[derive(Debug)]
pub struct FilterReader {
data: Vec<u8>,
num_hashes: usize,
}
impl FilterReader {
pub fn from_bytes(data: &[u8]) -> Option<Self> {
if data.len() < 2 {
return None;
}
let num_hashes = data[data.len() - 1] as usize;
if num_hashes == 0 || num_hashes > 30 {
return None;
}
Some(Self {
data: data.to_vec(),
num_hashes,
})
}
pub fn may_contain(&self, key: &[u8]) -> bool {
if self.data.len() < 2 {
return true; }
let bits_len = (self.data.len() - 1) * 8;
if bits_len == 0 {
return true;
}
let mut h1 = 0u64;
let mut h2 = 0u64;
for (i, &b) in key.iter().enumerate() {
h1 = h1.wrapping_mul(31).wrapping_add(b as u64);
h2 = h2
.wrapping_mul(37)
.wrapping_add(b as u64)
.wrapping_add(i as u64);
}
for i in 0..self.num_hashes {
let bit_pos = h1.wrapping_add(h2.wrapping_mul(i as u64)) % (bits_len as u64);
let byte_idx = (bit_pos / 8) as usize;
let bit_idx = (bit_pos % 8) as u8;
if byte_idx < self.data.len() - 1 {
if self.data[byte_idx] & (1 << bit_idx) == 0 {
return false;
}
}
}
true
}
pub fn size_bytes(&self) -> usize {
self.data.len()
}
}
#[derive(Debug)]
pub struct BloomFilterPolicy {
fpr: FPR,
bits_per_key: f64,
num_hashes: usize,
}
impl BloomFilterPolicy {
pub fn new(fpr: FPR) -> Self {
let bits_per_key = -fpr.log2() / 2.0_f64.ln();
let num_hashes = (bits_per_key * 2.0_f64.ln()).round() as usize;
Self {
fpr,
bits_per_key,
num_hashes: num_hashes.max(1).min(30), }
}
pub fn with_bits_per_key(bits_per_key: f64) -> Self {
let num_hashes = (bits_per_key * 2.0_f64.ln()).round() as usize;
let fpr = 0.6185_f64.powf(bits_per_key);
Self {
fpr,
bits_per_key,
num_hashes: num_hashes.max(1).min(30),
}
}
}
impl FilterPolicy for BloomFilterPolicy {
fn name(&self) -> &str {
"bloom"
}
fn create_builder(&self, num_keys: usize) -> Box<dyn FilterBuilder> {
Box::new(BloomFilterBuilder::new(
num_keys,
self.bits_per_key,
self.num_hashes,
))
}
fn create_filter(&self, data: Vec<u8>) -> Box<dyn Filter> {
Box::new(BloomFilter::new(data, self.num_hashes))
}
fn bits_per_key(&self) -> f64 {
self.bits_per_key
}
fn target_fpr(&self) -> FPR {
self.fpr
}
}
struct BloomFilterBuilder {
bits: Vec<u64>,
num_bits: usize,
num_hashes: usize,
num_keys: usize,
}
impl BloomFilterBuilder {
fn new(expected_keys: usize, bits_per_key: f64, num_hashes: usize) -> Self {
let num_bits = ((expected_keys as f64 * bits_per_key).ceil() as usize).max(64);
let num_words = num_bits.div_ceil(64);
Self {
bits: vec![0; num_words],
num_bits,
num_hashes,
num_keys: 0,
}
}
}
impl FilterBuilder for BloomFilterBuilder {
fn add_key(&mut self, key: &[u8]) {
let h1 = hash1(key);
let h2 = hash2(key);
for i in 0..self.num_hashes {
let h = h1.wrapping_add((i as u64).wrapping_mul(h2));
let bit_idx = (h as usize) % self.num_bits;
let word_idx = bit_idx / 64;
let bit_pos = bit_idx % 64;
self.bits[word_idx] |= 1u64 << bit_pos;
}
self.num_keys += 1;
}
fn finish(&mut self) -> Vec<u8> {
use byteorder::{LittleEndian, WriteBytesExt};
let mut result = Vec::with_capacity(self.bits.len() * 8 + 8);
for &word in &self.bits {
result.write_u64::<LittleEndian>(word).unwrap();
}
result
.write_u32::<LittleEndian>(self.num_hashes as u32)
.unwrap();
result
.write_u32::<LittleEndian>(self.num_bits as u32)
.unwrap();
result
}
fn num_keys(&self) -> usize {
self.num_keys
}
}
struct BloomFilter {
bits: Vec<u64>,
num_bits: usize,
num_hashes: usize,
}
impl BloomFilter {
fn new(data: Vec<u8>, default_num_hashes: usize) -> Self {
use byteorder::{LittleEndian, ReadBytesExt};
use std::io::Cursor;
if data.len() < 8 {
return Self {
bits: Vec::new(),
num_bits: 0,
num_hashes: default_num_hashes,
};
}
let mut cursor = Cursor::new(&data[data.len() - 8..]);
let num_hashes = cursor
.read_u32::<LittleEndian>()
.unwrap_or(default_num_hashes as u32) as usize;
let num_bits = cursor.read_u32::<LittleEndian>().unwrap_or(0) as usize;
let bits_data = &data[..data.len() - 8];
let mut bits = Vec::with_capacity(bits_data.len() / 8);
let mut cursor = Cursor::new(bits_data);
while let Ok(word) = cursor.read_u64::<LittleEndian>() {
bits.push(word);
}
Self {
bits,
num_bits,
num_hashes,
}
}
}
impl Filter for BloomFilter {
fn may_contain(&self, key: &[u8]) -> bool {
if self.bits.is_empty() || self.num_bits == 0 {
return true; }
let h1 = hash1(key);
let h2 = hash2(key);
for i in 0..self.num_hashes {
let h = h1.wrapping_add((i as u64).wrapping_mul(h2));
let bit_idx = (h as usize) % self.num_bits;
let word_idx = bit_idx / 64;
let bit_pos = bit_idx % 64;
if word_idx >= self.bits.len() || self.bits[word_idx] & (1u64 << bit_pos) == 0 {
return false;
}
}
true
}
fn size_bytes(&self) -> usize {
self.bits.len() * 8 + 8
}
}
#[derive(Debug)]
pub struct RibbonFilterPolicy {
fpr: FPR,
bits_per_key: f64,
}
impl RibbonFilterPolicy {
pub fn new(fpr: FPR) -> Self {
let bits_per_key = -fpr.log2();
Self { fpr, bits_per_key }
}
}
impl FilterPolicy for RibbonFilterPolicy {
fn name(&self) -> &str {
"ribbon"
}
fn create_builder(&self, num_keys: usize) -> Box<dyn FilterBuilder> {
Box::new(RibbonFilterBuilder::new(num_keys, self.bits_per_key))
}
fn create_filter(&self, data: Vec<u8>) -> Box<dyn Filter> {
Box::new(RibbonFilter::new(data))
}
fn bits_per_key(&self) -> f64 {
self.bits_per_key
}
fn target_fpr(&self) -> FPR {
self.fpr
}
}
struct RibbonFilterBuilder {
fingerprints: Vec<u64>,
fingerprint_bits: usize,
num_slots: usize,
}
impl RibbonFilterBuilder {
fn new(expected_keys: usize, bits_per_key: f64) -> Self {
let fingerprint_bits = bits_per_key.ceil() as usize;
let num_slots = ((expected_keys as f64) * 1.05).ceil() as usize;
Self {
fingerprints: Vec::with_capacity(expected_keys),
fingerprint_bits,
num_slots,
}
}
}
impl FilterBuilder for RibbonFilterBuilder {
fn add_key(&mut self, key: &[u8]) {
let h = hash1(key);
let fingerprint = h & ((1u64 << self.fingerprint_bits) - 1);
self.fingerprints.push(fingerprint);
}
fn finish(&mut self) -> Vec<u8> {
use byteorder::{LittleEndian, WriteBytesExt};
let mut result = Vec::new();
result
.write_u32::<LittleEndian>(self.fingerprints.len() as u32)
.unwrap();
result
.write_u32::<LittleEndian>(self.fingerprint_bits as u32)
.unwrap();
for &fp in &self.fingerprints {
result.write_u64::<LittleEndian>(fp).unwrap();
}
result
}
fn num_keys(&self) -> usize {
self.fingerprints.len()
}
}
struct RibbonFilter {
fingerprints: Vec<u64>,
fingerprint_bits: usize,
}
impl RibbonFilter {
fn new(data: Vec<u8>) -> Self {
use byteorder::{LittleEndian, ReadBytesExt};
use std::io::Cursor;
if data.len() < 8 {
return Self {
fingerprints: Vec::new(),
fingerprint_bits: 0,
};
}
let mut cursor = Cursor::new(&data[..8]);
let num_keys = cursor.read_u32::<LittleEndian>().unwrap_or(0) as usize;
let fingerprint_bits = cursor.read_u32::<LittleEndian>().unwrap_or(0) as usize;
let mut fingerprints = Vec::with_capacity(num_keys);
let mut cursor = Cursor::new(&data[8..]);
while let Ok(fp) = cursor.read_u64::<LittleEndian>() {
fingerprints.push(fp);
}
Self {
fingerprints,
fingerprint_bits,
}
}
}
impl Filter for RibbonFilter {
fn may_contain(&self, key: &[u8]) -> bool {
if self.fingerprints.is_empty() {
return true;
}
let h = hash1(key);
let fingerprint = h & ((1u64 << self.fingerprint_bits) - 1);
self.fingerprints.iter().any(|&fp| fp == fingerprint)
}
fn size_bytes(&self) -> usize {
8 + self.fingerprints.len() * 8
}
}
#[derive(Debug)]
pub struct XorFilterPolicy {
fpr: FPR,
bits_per_key: f64,
}
impl XorFilterPolicy {
pub fn new(fpr: FPR) -> Self {
let bits_per_key = -fpr.log2() * 1.23;
Self { fpr, bits_per_key }
}
}
impl FilterPolicy for XorFilterPolicy {
fn name(&self) -> &str {
"xor"
}
fn create_builder(&self, num_keys: usize) -> Box<dyn FilterBuilder> {
Box::new(XorFilterBuilder::new(num_keys, self.bits_per_key))
}
fn create_filter(&self, data: Vec<u8>) -> Box<dyn Filter> {
Box::new(XorFilter::new(data))
}
fn bits_per_key(&self) -> f64 {
self.bits_per_key
}
fn target_fpr(&self) -> FPR {
self.fpr
}
}
struct XorFilterBuilder {
keys: Vec<Vec<u8>>,
fingerprint_bits: usize,
}
impl XorFilterBuilder {
fn new(_expected_keys: usize, bits_per_key: f64) -> Self {
Self {
keys: Vec::new(),
fingerprint_bits: bits_per_key.ceil() as usize,
}
}
}
impl FilterBuilder for XorFilterBuilder {
fn add_key(&mut self, key: &[u8]) {
self.keys.push(key.to_vec());
}
fn finish(&mut self) -> Vec<u8> {
use byteorder::{LittleEndian, WriteBytesExt};
let mut result = Vec::new();
result
.write_u32::<LittleEndian>(self.keys.len() as u32)
.unwrap();
result
.write_u32::<LittleEndian>(self.fingerprint_bits as u32)
.unwrap();
for key in &self.keys {
let h = hash1(key);
let fp = h & ((1u64 << self.fingerprint_bits) - 1);
result.write_u64::<LittleEndian>(fp).unwrap();
}
result
}
fn num_keys(&self) -> usize {
self.keys.len()
}
}
struct XorFilter {
fingerprints: Vec<u64>,
fingerprint_bits: usize,
}
impl XorFilter {
fn new(data: Vec<u8>) -> Self {
use byteorder::{LittleEndian, ReadBytesExt};
use std::io::Cursor;
if data.len() < 8 {
return Self {
fingerprints: Vec::new(),
fingerprint_bits: 0,
};
}
let mut cursor = Cursor::new(&data[..8]);
let num_keys = cursor.read_u32::<LittleEndian>().unwrap_or(0) as usize;
let fingerprint_bits = cursor.read_u32::<LittleEndian>().unwrap_or(0) as usize;
let mut fingerprints = Vec::with_capacity(num_keys);
let mut cursor = Cursor::new(&data[8..]);
while let Ok(fp) = cursor.read_u64::<LittleEndian>() {
fingerprints.push(fp);
}
Self {
fingerprints,
fingerprint_bits,
}
}
}
impl Filter for XorFilter {
fn may_contain(&self, key: &[u8]) -> bool {
if self.fingerprints.is_empty() {
return true;
}
let h = hash1(key);
let fingerprint = h & ((1u64 << self.fingerprint_bits) - 1);
self.fingerprints.iter().any(|&fp| fp == fingerprint)
}
fn size_bytes(&self) -> usize {
8 + self.fingerprints.len() * 8
}
}
fn hash1(key: &[u8]) -> u64 {
twox_hash::xxh3::hash64(key)
}
fn hash2(key: &[u8]) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
pub struct FilterCascade {
filters: Vec<Box<dyn Filter>>,
combined_fpr: FPR,
}
impl FilterCascade {
pub fn new() -> Self {
Self {
filters: Vec::new(),
combined_fpr: 1.0,
}
}
pub fn add_level(&mut self, filter: Box<dyn Filter>, fpr: FPR) {
self.filters.push(filter);
self.combined_fpr *= fpr;
}
pub fn may_contain(&self, key: &[u8]) -> bool {
self.filters.iter().all(|f| f.may_contain(key))
}
pub fn combined_fpr(&self) -> FPR {
self.combined_fpr
}
pub fn num_levels(&self) -> usize {
self.filters.len()
}
}
impl Default for FilterCascade {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bloom_filter_basic() {
let policy = BloomFilterPolicy::new(0.01);
let mut builder = policy.create_builder(1000);
for i in 0..1000 {
let key = format!("key{}", i);
builder.add_key(key.as_bytes());
}
let data = builder.finish();
let filter = policy.create_filter(data);
for i in 0..1000 {
let key = format!("key{}", i);
assert!(filter.may_contain(key.as_bytes()));
}
let mut false_positives = 0;
for i in 1000..2000 {
let key = format!("key{}", i);
if filter.may_contain(key.as_bytes()) {
false_positives += 1;
}
}
assert!(
false_positives < 50,
"Too many false positives: {}",
false_positives
);
}
#[test]
fn test_bloom_bits_per_key() {
let policy = BloomFilterPolicy::with_bits_per_key(10.0);
assert!((policy.bits_per_key() - 10.0).abs() < 0.01);
}
#[test]
fn test_filter_cascade() {
let policy1 = BloomFilterPolicy::new(0.1);
let policy2 = BloomFilterPolicy::new(0.1);
let mut builder1 = policy1.create_builder(100);
let mut builder2 = policy2.create_builder(100);
for i in 0..100 {
let key = format!("key{}", i);
builder1.add_key(key.as_bytes());
builder2.add_key(key.as_bytes());
}
let filter1 = policy1.create_filter(builder1.finish());
let filter2 = policy2.create_filter(builder2.finish());
let mut cascade = FilterCascade::new();
cascade.add_level(filter1, 0.1);
cascade.add_level(filter2, 0.1);
assert!((cascade.combined_fpr() - 0.01).abs() < 0.001);
for i in 0..100 {
let key = format!("key{}", i);
assert!(cascade.may_contain(key.as_bytes()));
}
}
#[test]
fn test_empty_filter() {
let policy = BloomFilterPolicy::new(0.01);
let filter = policy.create_filter(Vec::new());
assert!(filter.may_contain(b"any_key"));
}
}