use std::io::Cursor;
use std::{f64::consts::E, sync::atomic::AtomicU8};
use anyhow::{bail, Result};
use bitvec::prelude::*;
use murmur3::murmur3_x64_128 as murmur3hash;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone)]
pub struct BloomFilter {
fp_prob: f64,
size: u128,
hash_count: u32,
bitvec: BitBox<AtomicU8, Msb0>,
}
impl BloomFilter {
pub fn new(
fp_prob: f64,
size: u64,
hash_count: u32,
bitvec: BitBox<AtomicU8, Msb0>,
) -> Result<Self> {
Ok(Self {
fp_prob,
hash_count,
bitvec,
size: size as u128,
})
}
pub fn get_fp_prob(&self) -> f64 {
self.fp_prob
}
pub fn get_size(&self) -> u128 {
self.size
}
pub fn get_hash_count(&self) -> u32 {
self.hash_count
}
pub fn get_bitvec(&self) -> &BitBox<AtomicU8, Msb0> {
&self.bitvec
}
pub fn new_by_item_count_and_fp_prob(items_count: u64, fp_prob: f64) -> Result<Self> {
let size = Self::calc_size(items_count, fp_prob);
let hash_count = Self::calc_hash_count(size, items_count)?;
let bitvec = bitvec!(AtomicU8, Msb0; 0; size as usize);
Self::new(fp_prob, size, hash_count, bitvec.into_boxed_bitslice())
}
pub fn new_by_size_and_fp_prob(size: u64, fp_prob: f64) -> Result<Self> {
let rounded_size = size + 8 - (size % 8);
let (_, hash_count) = Self::calc_item_size_and_hash_count(rounded_size, fp_prob);
let bitvec = bitvec!(AtomicU8, Msb0; 0; rounded_size as usize);
Self::new(
fp_prob,
rounded_size,
hash_count,
bitvec.into_boxed_bitslice(),
)
}
fn calc_item_position(&self, item: &str, seed: u32) -> Result<usize> {
Ok((murmur3hash(&mut Cursor::new(item), seed)? % self.size) as usize)
}
pub fn add(&mut self, item: &str) -> Result<()> {
for i in 0..self.hash_count {
let digest = self.calc_item_position(item, i)?;
self.bitvec.set(digest, true)
}
Ok(())
}
pub fn add_aliased(&self, item: &str) -> Result<()> {
for i in 0..self.hash_count {
let digest = self.calc_item_position(item, i)?;
self.bitvec.set_aliased(digest, true)
}
Ok(())
}
pub fn contains(&self, item: &str) -> Result<bool> {
for i in 0..self.hash_count {
let digest = self.calc_item_position(item, i)?;
if !self.bitvec[digest] {
return Ok(false);
}
}
Ok(true)
}
pub fn calc_size(n: u64, p: f64) -> u64 {
let mut m = (-(n as f64 * p.log(E)) / (2.0_f64.log(E).powi(2))) as u64;
m += 8 - (m % 8); m
}
pub fn calc_hash_count(m: u64, n: u64) -> Result<u32> {
let k = ((m as f64) / (n as f64)) * 2.0_f64.log(E);
if k > u32::MAX as f64 {
bail!("Hash count is too large");
}
Ok(k as u32)
}
pub fn calc_item_size_and_hash_count(size: u64, fp_prob: f64) -> (u64, u32) {
let size_f = size as f64;
let mut item_size: u64 = 0;
for i in 1..=u32::MAX {
let i_f = i as f64;
let temp_item_size =
(size_f / (-i_f / (1_f64 - (fp_prob.ln() / i_f).exp()).ln())).ceil() as u64;
if item_size > temp_item_size {
return (item_size, i - 1);
} else {
item_size = temp_item_size;
}
}
(item_size, u32::MAX)
}
#[cfg(feature = "hdf5")]
pub fn load_hdf5(path: &std::path::PathBuf) -> Result<Self> {
let file = hdf5::File::open(path)?;
let size = file.dataset("size")?.read_scalar::<u64>()?;
let hash_count = file.dataset("hash_count")?.read_scalar::<u32>()?;
let fp_prob = file.dataset("fp_prob")?.read_scalar::<f64>()?;
let bytes = match Self::decode_hex(
file.dataset("bit_array")?
.read_scalar::<hdf5::types::VarLenAscii>()?
.as_str(),
) {
Ok(bytes) => bytes,
Err(err) => bail!(format!("Error while decoding hex: {}", err)),
};
Self::new(
fp_prob,
size,
hash_count,
BitVec::<AtomicU8, Msb0>::from_slice(&bytes).into_boxed_bitslice(),
)
}
#[cfg(feature = "hdf5")]
pub fn save_hdf5(&self, path: &std::path::PathBuf) -> Result<()> {
let file = hdf5::File::create(path)?;
file.new_dataset::<u64>()
.create("size")?
.write_scalar(&(self.size as u64))?;
file.new_dataset::<u32>()
.create("hash_count")?
.write_scalar(&self.hash_count)?;
file.new_dataset::<f64>()
.create("fp_prob")?
.write_scalar(&self.fp_prob)?;
let s_ascii = Self::encode_hex(&self.bitvec)?
.iter()
.map(|b| format!("{:02X}", b))
.collect::<String>();
file.new_dataset::<hdf5::types::VarLenAscii>()
.create("bit_array")?
.write_scalar(&hdf5::types::VarLenAscii::from_ascii(&s_ascii)?)?;
Ok(())
}
#[cfg(feature = "hdf5")]
pub fn decode_hex(s: &str) -> Result<Vec<AtomicU8>, core::num::ParseIntError> {
(0..s.len())
.step_by(2)
.map(|i| match u8::from_str_radix(&s[i..i + 2], 16) {
Ok(b) => Ok(AtomicU8::new(b)),
Err(err) => Err(err),
})
.collect()
}
#[cfg(feature = "hdf5")]
pub fn encode_hex(bit_array: &BitBox<AtomicU8, Msb0>) -> Result<Vec<u8>> {
let mut bytes: Vec<u8> = Vec::with_capacity(bit_array.len() / 8);
for start in (0..bit_array.len()).step_by(8) {
bytes.push(bit_array[start..(start + 8)].load::<u8>());
}
Ok(bytes)
}
}
#[cfg(test)]
mod tests {
use std::fs::read_to_string;
use std::path::PathBuf;
use super::*;
#[test]
fn test_inserting_and_finding() {
let some_strings: Vec<String> =
read_to_string(PathBuf::from("test_data/10000_random_strings.txt"))
.unwrap()
.lines()
.map(String::from)
.collect();
let mut bloom_filter =
BloomFilter::new_by_item_count_and_fp_prob(some_strings.len() as u64, 0.01).unwrap();
let some_strings_split = some_strings.split_at(some_strings.len() / 2);
for a_string in some_strings_split.0.iter() {
bloom_filter.add(a_string).unwrap();
}
for a_string in some_strings_split.1.iter() {
bloom_filter.add_aliased(a_string).unwrap();
}
for a_string in some_strings.iter() {
assert!(bloom_filter.contains(a_string).unwrap());
}
}
#[cfg(feature = "hdf5")]
#[test]
fn test_save_and_load() {
let some_strings: Vec<String> =
read_to_string(PathBuf::from("test_data/10000_random_strings.txt"))
.unwrap()
.lines()
.map(String::from)
.collect();
let mut bloom_filter =
BloomFilter::new_by_item_count_and_fp_prob(some_strings.len() as u64, 0.01).unwrap();
for a_string in some_strings.iter() {
bloom_filter.add(a_string).unwrap();
}
let temp_file = std::env::temp_dir().join("bloom_filter.h5");
if temp_file.is_file() {
std::fs::remove_file(&temp_file).unwrap();
}
bloom_filter.save_hdf5(&temp_file).unwrap();
let read_bloom_filter = BloomFilter::load_hdf5(&temp_file).unwrap();
assert!(bloom_filter.size == read_bloom_filter.size);
assert!(bloom_filter.hash_count == read_bloom_filter.hash_count);
assert!(bloom_filter.fp_prob == read_bloom_filter.fp_prob);
assert!(bloom_filter.bitvec == read_bloom_filter.bitvec);
for a_string in some_strings.iter() {
assert!(read_bloom_filter.contains(a_string).unwrap());
}
if temp_file.is_file() {
std::fs::remove_file(&temp_file).unwrap();
}
}
#[cfg(feature = "serde")]
#[test]
fn test_serde() {
use rmp_serde::{Deserializer, Serializer};
use serde::{Deserialize, Serialize};
let some_strings: Vec<String> =
read_to_string(PathBuf::from("test_data/10000_random_strings.txt"))
.unwrap()
.lines()
.map(String::from)
.collect();
let mut bloom_filter =
BloomFilter::new_by_item_count_and_fp_prob(some_strings.len() as u64, 0.01).unwrap();
for a_string in some_strings.iter() {
bloom_filter.add(a_string).unwrap();
}
let temp_file_path = std::env::temp_dir().join("bloom_filter.messagepack");
if temp_file_path.is_file() {
std::fs::remove_file(&temp_file_path).unwrap();
}
let mut temp_file = std::fs::File::create(&temp_file_path).unwrap();
let mut byte_writer = std::io::BufWriter::new(&mut temp_file);
bloom_filter
.serialize(&mut Serializer::new(&mut byte_writer))
.unwrap();
drop(byte_writer);
let mut temp_file = std::fs::File::open(&temp_file_path).unwrap();
let mut byte_reader = std::io::BufReader::new(&mut temp_file);
let read_bloom_filter =
BloomFilter::deserialize(&mut Deserializer::new(&mut byte_reader)).unwrap();
assert!(bloom_filter.size == read_bloom_filter.size);
assert!(bloom_filter.hash_count == read_bloom_filter.hash_count);
assert!(bloom_filter.fp_prob == read_bloom_filter.fp_prob);
assert!(bloom_filter.bitvec == read_bloom_filter.bitvec);
for a_string in some_strings.iter() {
assert!(read_bloom_filter.contains(a_string).unwrap());
}
if temp_file_path.is_file() {
std::fs::remove_file(&temp_file_path).unwrap();
}
}
}