use crate::util::{Error, Result, Serializable, var_int};
use hex;
use murmur3::murmur3_32;
use rand::Rng;
use std::fmt;
use std::io;
use std::io::{Cursor, Read, Write};
use std::num::Wrapping;
pub const BLOOM_FILTER_MAX_FILTER_SIZE: usize = 36_000;
pub const BLOOM_FILTER_MAX_HASH_FUNCS: usize = 50;
#[derive(Default, PartialEq, Eq, Hash, Clone)]
pub struct BloomFilter {
pub filter: Vec<u8>,
pub num_hash_funcs: usize,
pub tweak: u32,
}
impl BloomFilter {
#[must_use]
pub fn new(insert: f64, pr_false_pos: f64) -> Result<BloomFilter> {
if insert <= 0.0 || !insert.is_normal() {
return Err(Error::BadArgument("Invalid insert value".to_string()));
}
if !(0.0 < pr_false_pos && pr_false_pos < 1.0) || !pr_false_pos.is_normal() {
return Err(Error::BadArgument("Invalid pr_false_pos value".to_string()));
}
let ln2 = 2.0_f64.ln();
let size = (-1.0 / ln2.powi(2) * insert * pr_false_pos.ln()) / 8.0;
let size = size.min(BLOOM_FILTER_MAX_FILTER_SIZE as f64).ceil() as usize;
let num_hash_funcs = ((size as f64 * 8.0 / insert * ln2)
.min(BLOOM_FILTER_MAX_HASH_FUNCS as f64))
.ceil() as usize;
let mut rng = rand::rng();
let tweak = rng.random::<u32>();
Ok(BloomFilter {
filter: vec![0; size],
num_hash_funcs,
tweak,
})
}
pub fn add(&mut self, data: &[u8]) -> Result<()> {
if data.len() > 520 {
return Err(Error::BadArgument(
"Data too large for bloom add".to_string(),
));
}
let bit_size = (self.filter.len() * 8) as u32;
for i in 0..self.num_hash_funcs {
let seed = Wrapping(i as u32) * Wrapping(0xFBA4C795) + Wrapping(self.tweak);
let c = murmur3_32(&mut Cursor::new(data), seed.0).unwrap() % bit_size;
self.filter[(c / 8) as usize] |= 1u8 << (c % 8);
}
Ok(())
}
#[must_use]
pub fn contains(&self, data: &[u8]) -> bool {
let bit_size = (self.filter.len() * 8) as u32;
for i in 0..self.num_hash_funcs {
let seed = Wrapping(i as u32) * Wrapping(0xFBA4C795) + Wrapping(self.tweak);
let c = murmur3_32(&mut Cursor::new(data), seed.0).unwrap() % bit_size;
if self.filter[(c / 8) as usize] & (1u8 << (c % 8)) == 0 {
return false;
}
}
true
}
pub fn validate(&self) -> Result<()> {
if self.filter.len() > BLOOM_FILTER_MAX_FILTER_SIZE {
return Err(Error::BadData("Filter too long".to_string()));
}
if self.num_hash_funcs > BLOOM_FILTER_MAX_HASH_FUNCS {
return Err(Error::BadData("Too many hash funcs".to_string()));
}
Ok(())
}
}
impl Serializable<BloomFilter> for BloomFilter {
fn read(reader: &mut dyn Read) -> Result<BloomFilter> {
let filter_len = var_int::read(reader)? as usize;
if filter_len > BLOOM_FILTER_MAX_FILTER_SIZE {
return Err(Error::BadData("Filter too long".to_string()));
}
let mut filter = vec![0; filter_len];
reader
.read_exact(&mut filter)
.map_err(|e| Error::IOError(e))?;
let mut num_hash_funcs = [0u8; 4];
reader
.read_exact(&mut num_hash_funcs)
.map_err(|e| Error::IOError(e))?;
let num_hash_funcs = u32::from_le_bytes(num_hash_funcs) as usize;
if num_hash_funcs > BLOOM_FILTER_MAX_HASH_FUNCS {
return Err(Error::BadData("Too many hash funcs".to_string()));
}
let mut tweak = [0u8; 4];
reader
.read_exact(&mut tweak)
.map_err(|e| Error::IOError(e))?;
let tweak = u32::from_le_bytes(tweak);
Ok(BloomFilter {
filter,
num_hash_funcs,
tweak,
})
}
fn write(&self, writer: &mut dyn Write) -> io::Result<()> {
var_int::write(self.filter.len() as u64, writer)?;
writer.write_all(&self.filter)?;
writer.write_all(&(self.num_hash_funcs as u32).to_le_bytes())?;
writer.write_all(&self.tweak.to_le_bytes())?;
Ok(())
}
}
impl fmt::Debug for BloomFilter {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BloomFilter")
.field("filter", &hex::encode(&self.filter))
.field("num_hash_funcs", &self.num_hash_funcs)
.field("tweak", &self.tweak)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::io::Cursor;
#[test]
fn write_read() {
let mut bf = BloomFilter::new(20000.0, 0.001).unwrap();
for i in 0..5 {
bf.add(&vec![i as u8; 32]).unwrap();
}
let mut v = Vec::new();
bf.write(&mut v).unwrap();
assert_eq!(BloomFilter::read(&mut Cursor::new(&v)).unwrap(), bf);
}
#[test]
fn contains() {
let mut bf = BloomFilter::new(20000.0, 0.001).unwrap();
bf.add(&vec![5u8; 32]).unwrap();
assert!(bf.contains(&vec![5u8; 32]));
assert!(!bf.contains(&vec![6u8; 32]));
}
#[test]
fn invalid() {
let mut bf = BloomFilter::new(15.0, 0.0001).unwrap();
let data = vec![0u8; 521];
let err = bf.add(&data).unwrap_err();
assert_eq!(
err.to_string(),
"Bad argument: Data too large for bloom add"
);
assert_eq!(
BloomFilter::new(0.0, 0.5).unwrap_err().to_string(),
"Bad argument: Invalid insert value"
);
assert_eq!(
BloomFilter::new(1.0, 0.0).unwrap_err().to_string(),
"Bad argument: Invalid pr_false_pos value"
);
assert_eq!(
BloomFilter::new(1.0, -1.0).unwrap_err().to_string(),
"Bad argument: Invalid pr_false_pos value"
);
assert!(BloomFilter::new(1.0, f64::NAN).is_err());
assert!(BloomFilter::new(f64::NAN, 0.5).is_err());
}
#[test]
fn validate() {
let bf = BloomFilter {
filter: vec![0, 1, 2, 3, 4, 5],
num_hash_funcs: 30,
tweak: 100,
};
assert!(bf.validate().is_ok());
let mut bf_clone = bf.clone();
bf_clone.filter = vec![0; BLOOM_FILTER_MAX_FILTER_SIZE + 1];
assert_eq!(
bf_clone.validate().unwrap_err().to_string(),
"Bad data: Filter too long"
);
let mut bf_clone = bf.clone();
bf_clone.num_hash_funcs = BLOOM_FILTER_MAX_HASH_FUNCS + 1;
assert_eq!(
bf_clone.validate().unwrap_err().to_string(),
"Bad data: Too many hash funcs"
);
}
#[test]
fn add_too_large() {
let mut bf = BloomFilter::new(20000.0, 0.001).unwrap();
let invalid_data = vec![0u8; 521];
let err = bf.add(&invalid_data).unwrap_err();
assert_eq!(
err.to_string(),
"Bad argument: Data too large for bloom add"
);
}
}