use crate::error::ShardexError;
use crate::identifiers::DocumentId;
use bytemuck::{Pod, Zeroable};
use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BloomFilter {
bit_array: Vec<u64>,
hash_functions: usize,
capacity: usize,
inserted_count: usize,
false_positive_rate: f64,
bit_array_size: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
#[allow(dead_code)]
pub struct BloomFilterHeader {
pub hash_functions: u32,
pub capacity: u32,
pub inserted_count: u32,
pub false_positive_rate_micros: u32,
pub bit_array_size: u32,
pub bit_array_bytes: u32,
pub bit_array_offset: u64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BloomFilterStats {
pub hash_functions: usize,
pub capacity: usize,
pub inserted_count: usize,
pub false_positive_rate: f64,
pub load_factor: f64,
pub actual_false_positive_rate: f64,
pub memory_usage: usize,
pub bits_set: usize,
pub bit_utilization: f64,
}
unsafe impl Pod for BloomFilterHeader {}
unsafe impl Zeroable for BloomFilterHeader {}
impl BloomFilter {
pub fn new(capacity: usize, false_positive_rate: f64) -> Result<Self, ShardexError> {
if capacity == 0 {
return Err(ShardexError::Config(
"Bloom filter capacity must be greater than 0".to_string(),
));
}
if false_positive_rate <= 0.0 || false_positive_rate >= 1.0 {
return Err(ShardexError::Config(
"False positive rate must be between 0.0 and 1.0 (exclusive)".to_string(),
));
}
let (bit_array_size, hash_functions) = Self::calculate_parameters(capacity, false_positive_rate);
let bit_array = vec![0u64; (bit_array_size + 63) / 64];
Ok(Self {
bit_array,
hash_functions,
capacity,
inserted_count: 0,
false_positive_rate,
bit_array_size,
})
}
fn calculate_parameters(capacity: usize, false_positive_rate: f64) -> (usize, usize) {
let capacity_f = capacity as f64;
let ln2 = std::f64::consts::LN_2;
let bit_array_size = (-capacity_f * false_positive_rate.ln() / (ln2 * ln2)).ceil() as usize;
let hash_functions = ((bit_array_size as f64 / capacity_f) * ln2).ceil() as usize;
let hash_functions = hash_functions.clamp(1, 10);
(bit_array_size, hash_functions)
}
pub fn insert(&mut self, document_id: DocumentId) {
let hashes = self.hash_document_id(document_id);
for hash_value in hashes {
let bit_index = (hash_value as usize) % self.bit_array_size;
let array_index = bit_index / 64;
let bit_position = bit_index % 64;
self.bit_array[array_index] |= 1u64 << bit_position;
}
self.inserted_count += 1;
}
pub fn contains(&self, document_id: DocumentId) -> bool {
let hashes = self.hash_document_id(document_id);
for hash_value in hashes {
let bit_index = (hash_value as usize) % self.bit_array_size;
let array_index = bit_index / 64;
let bit_position = bit_index % 64;
if (self.bit_array[array_index] & (1u64 << bit_position)) == 0 {
return false; }
}
true }
#[allow(dead_code)]
pub fn merge(&mut self, other: &BloomFilter) -> Result<(), ShardexError> {
if self.bit_array_size != other.bit_array_size {
return Err(ShardexError::Config(format!(
"Cannot merge bloom filters with different bit array sizes: {} vs {}",
self.bit_array_size, other.bit_array_size
)));
}
if self.hash_functions != other.hash_functions {
return Err(ShardexError::Config(format!(
"Cannot merge bloom filters with different hash function counts: {} vs {}",
self.hash_functions, other.hash_functions
)));
}
for (i, other_bits) in other.bit_array.iter().enumerate() {
self.bit_array[i] |= other_bits;
}
self.inserted_count += other.inserted_count;
Ok(())
}
pub fn clear(&mut self) {
for bits in &mut self.bit_array {
*bits = 0;
}
self.inserted_count = 0;
}
pub fn stats(&self) -> BloomFilterStats {
let load_factor = if self.capacity > 0 {
self.inserted_count as f64 / self.capacity as f64
} else {
0.0
};
let actual_false_positive_rate = if self.inserted_count > 0 {
let k = self.hash_functions as f64;
let m = self.bit_array_size as f64;
let n = self.inserted_count as f64;
(1.0 - (-k * n / m).exp()).powf(k)
} else {
0.0
};
let bits_set = self.count_set_bits();
let bit_utilization = if self.bit_array_size > 0 {
bits_set as f64 / self.bit_array_size as f64
} else {
0.0
};
BloomFilterStats {
hash_functions: self.hash_functions,
capacity: self.capacity,
inserted_count: self.inserted_count,
false_positive_rate: self.false_positive_rate,
load_factor,
actual_false_positive_rate,
memory_usage: self.bit_array.len() * 8, bits_set,
bit_utilization,
}
}
pub fn to_header(&self, bit_array_offset: u64) -> BloomFilterHeader {
BloomFilterHeader {
hash_functions: self.hash_functions as u32,
capacity: self.capacity as u32,
inserted_count: self.inserted_count as u32,
false_positive_rate_micros: (self.false_positive_rate * 1_000_000.0) as u32,
bit_array_size: self.bit_array_size as u32,
bit_array_bytes: (self.bit_array.len() * 8) as u32,
bit_array_offset,
}
}
pub fn is_at_capacity(&self) -> bool {
self.inserted_count >= self.capacity
}
pub fn is_overloaded(&self) -> bool {
self.inserted_count > self.capacity
}
pub fn load_factor(&self) -> f64 {
if self.capacity > 0 {
self.inserted_count as f64 / self.capacity as f64
} else {
0.0
}
}
pub fn inserted_count(&self) -> usize {
self.inserted_count
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn false_positive_rate(&self) -> f64 {
self.false_positive_rate
}
pub fn hash_functions(&self) -> usize {
self.hash_functions
}
pub fn bit_array_size(&self) -> usize {
self.bit_array_size
}
fn hash_document_id(&self, document_id: DocumentId) -> Vec<u64> {
let mut hashes = Vec::with_capacity(self.hash_functions);
let bytes = document_id.to_bytes();
for i in 0..self.hash_functions {
let mut hasher = DefaultHasher::new();
bytes.hash(&mut hasher);
i.hash(&mut hasher);
hashes.push(hasher.finish());
}
hashes
}
#[allow(dead_code)]
fn count_set_bits(&self) -> usize {
self.bit_array
.iter()
.map(|&bits| bits.count_ones() as usize)
.sum()
}
}
impl BloomFilterHeader {
#[allow(dead_code)]
pub fn new_zero() -> Self {
Self::zeroed()
}
#[allow(dead_code)]
pub fn is_valid(&self) -> bool {
self.hash_functions > 0 && self.capacity > 0 && self.bit_array_size > 0 && self.bit_array_bytes > 0
}
#[allow(dead_code)]
pub fn false_positive_rate(&self) -> f64 {
self.false_positive_rate_micros as f64 / 1_000_000.0
}
}
#[cfg(test)]
#[derive(Debug, Clone)]
pub struct BloomFilterBuilder {
capacity: usize,
false_positive_rate: f64,
}
#[cfg(test)]
impl BloomFilterBuilder {
pub fn new() -> Self {
Self {
capacity: 1000,
false_positive_rate: 0.01,
}
}
pub fn capacity(mut self, capacity: usize) -> Self {
self.capacity = capacity;
self
}
pub fn false_positive_rate(mut self, rate: f64) -> Self {
self.false_positive_rate = rate;
self
}
pub fn build(self) -> Result<BloomFilter, ShardexError> {
BloomFilter::new(self.capacity, self.false_positive_rate)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn test_bloom_filter_creation() {
let filter = BloomFilter::new(1000, 0.01).unwrap();
assert_eq!(filter.capacity(), 1000);
assert_eq!(filter.inserted_count(), 0);
assert_eq!(filter.false_positive_rate(), 0.01);
assert!(filter.hash_functions() >= 1);
assert!(filter.bit_array_size() > 0);
}
#[test]
fn test_invalid_parameters() {
let result = BloomFilter::new(0, 0.01);
assert!(result.is_err());
assert!(matches!(result, Err(ShardexError::Config(_))));
let result = BloomFilter::new(1000, 0.0);
assert!(result.is_err());
let result = BloomFilter::new(1000, 1.0);
assert!(result.is_err());
let result = BloomFilter::new(1000, 1.5);
assert!(result.is_err());
}
#[test]
fn test_parameter_calculation() {
let (bits, hash_funcs) = BloomFilter::calculate_parameters(1000, 0.01);
assert!((9000..=10000).contains(&bits), "Expected ~9600 bits, got {}", bits);
assert!(
(3..=10).contains(&hash_funcs),
"Expected 3-10 hash functions, got {}",
hash_funcs
);
}
#[test]
fn test_insert_and_contains() {
let mut filter = BloomFilter::new(100, 0.01).unwrap();
let doc_id = DocumentId::new();
assert!(!filter.contains(doc_id));
filter.insert(doc_id);
assert!(filter.contains(doc_id));
assert_eq!(filter.inserted_count(), 1);
}
#[test]
fn test_no_false_negatives() {
let mut filter = BloomFilter::new(1000, 0.05).unwrap();
let mut inserted_ids = Vec::new();
for _ in 0..500 {
let doc_id = DocumentId::new();
inserted_ids.push(doc_id);
filter.insert(doc_id);
}
for doc_id in inserted_ids {
assert!(filter.contains(doc_id), "False negative detected for document ID");
}
}
#[test]
fn test_false_positive_rate() {
let mut filter = BloomFilter::new(1000, 0.05).unwrap();
let mut inserted_ids = HashSet::new();
for _ in 0..1000 {
let doc_id = DocumentId::new();
inserted_ids.insert(doc_id);
filter.insert(doc_id);
}
let test_count = 10000;
let mut false_positives = 0;
for _ in 0..test_count {
let test_id = DocumentId::new();
if !inserted_ids.contains(&test_id) && filter.contains(test_id) {
false_positives += 1;
}
}
let actual_fp_rate = false_positives as f64 / test_count as f64;
assert!(
actual_fp_rate <= 0.10,
"False positive rate too high: {}",
actual_fp_rate
);
}
#[test]
fn test_merge_compatible_filters() {
let mut filter1 = BloomFilter::new(100, 0.01).unwrap();
let mut filter2 = BloomFilter::new(100, 0.01).unwrap();
let doc1 = DocumentId::new();
let doc2 = DocumentId::new();
filter1.insert(doc1);
filter2.insert(doc2);
filter1.merge(&filter2).unwrap();
assert!(filter1.contains(doc1));
assert!(filter1.contains(doc2));
assert_eq!(filter1.inserted_count(), 2);
}
#[test]
fn test_merge_incompatible_filters() {
let mut filter1 = BloomFilter::new(100, 0.01).unwrap();
let filter2 = BloomFilter::new(200, 0.01).unwrap();
let result = filter1.merge(&filter2);
assert!(result.is_err());
assert!(matches!(result, Err(ShardexError::Config(_))));
}
#[test]
fn test_clear() {
let mut filter = BloomFilter::new(100, 0.01).unwrap();
let doc_id = DocumentId::new();
filter.insert(doc_id);
assert!(filter.contains(doc_id));
assert_eq!(filter.inserted_count(), 1);
filter.clear();
assert!(!filter.contains(doc_id));
assert_eq!(filter.inserted_count(), 0);
}
#[test]
fn test_load_factor_tracking() {
let mut filter = BloomFilter::new(100, 0.01).unwrap();
assert_eq!(filter.load_factor(), 0.0);
assert!(!filter.is_at_capacity());
assert!(!filter.is_overloaded());
for _ in 0..50 {
filter.insert(DocumentId::new());
}
assert_eq!(filter.load_factor(), 0.5);
assert!(!filter.is_at_capacity());
for _ in 0..50 {
filter.insert(DocumentId::new());
}
assert_eq!(filter.load_factor(), 1.0);
assert!(filter.is_at_capacity());
filter.insert(DocumentId::new());
assert!(filter.load_factor() > 1.0);
assert!(filter.is_overloaded());
}
#[test]
fn test_statistics() {
let mut filter = BloomFilter::new(1000, 0.01).unwrap();
for _ in 0..500 {
filter.insert(DocumentId::new());
}
let stats = filter.stats();
assert_eq!(stats.capacity, 1000);
assert_eq!(stats.inserted_count, 500);
assert_eq!(stats.load_factor, 0.5);
assert_eq!(stats.false_positive_rate, 0.01);
assert!(stats.actual_false_positive_rate > 0.0);
assert!(stats.memory_usage > 0);
assert!(stats.bits_set > 0);
assert!(stats.bit_utilization > 0.0 && stats.bit_utilization <= 1.0);
}
#[test]
fn test_builder_pattern() {
let filter = BloomFilterBuilder::new()
.capacity(5000)
.false_positive_rate(0.005)
.build()
.unwrap();
assert_eq!(filter.capacity(), 5000);
assert_eq!(filter.false_positive_rate(), 0.005);
}
#[test]
fn test_header_creation() {
let filter = BloomFilter::new(1000, 0.01).unwrap();
let header = filter.to_header(2048);
assert_eq!(header.capacity, 1000);
assert_eq!(header.hash_functions, filter.hash_functions() as u32);
assert_eq!(header.bit_array_offset, 2048);
assert!(header.is_valid());
assert_eq!(header.false_positive_rate(), 0.01);
}
#[test]
fn test_header_bytemuck() {
let header = BloomFilterHeader {
hash_functions: 5,
capacity: 1000,
inserted_count: 500,
false_positive_rate_micros: 10_000, bit_array_size: 9600,
bit_array_bytes: 1200,
bit_array_offset: 4096,
};
let bytes: &[u8] = bytemuck::bytes_of(&header);
assert!(!bytes.is_empty());
let header_restored: BloomFilterHeader = bytemuck::pod_read_unaligned(bytes);
assert_eq!(header, header_restored);
}
#[test]
fn test_zero_initialized_header() {
let zero_header = BloomFilterHeader::new_zero();
assert!(!zero_header.is_valid()); assert_eq!(zero_header.false_positive_rate(), 0.0);
let zero_header_bytemuck: BloomFilterHeader = BloomFilterHeader::zeroed();
assert_eq!(zero_header_bytemuck.hash_functions, 0);
assert_eq!(zero_header_bytemuck.capacity, 0);
assert_eq!(zero_header_bytemuck.bit_array_size, 0);
}
#[test]
fn test_hash_function_independence() {
let filter = BloomFilter::new(100, 0.01).unwrap();
let doc_id = DocumentId::new();
let hashes = filter.hash_document_id(doc_id);
assert_eq!(hashes.len(), filter.hash_functions());
for i in 0..hashes.len() {
for j in (i + 1)..hashes.len() {
assert_ne!(hashes[i], hashes[j], "Hash functions should produce different values");
}
}
}
#[test]
fn test_consistent_hashing() {
let filter1 = BloomFilter::new(100, 0.01).unwrap();
let filter2 = BloomFilter::new(100, 0.01).unwrap();
let doc_id = DocumentId::new();
let hashes1 = filter1.hash_document_id(doc_id);
let hashes2 = filter2.hash_document_id(doc_id);
assert_eq!(hashes1, hashes2);
}
#[test]
fn test_bit_operations() {
let mut filter = BloomFilter::new(100, 0.01).unwrap();
let initial_bits_set = filter.count_set_bits();
assert_eq!(initial_bits_set, 0);
filter.insert(DocumentId::new());
let after_insert = filter.count_set_bits();
assert!(after_insert > 0);
assert!(after_insert <= filter.hash_functions()); }
#[test]
fn test_serialization() {
let mut filter = BloomFilter::new(1000, 0.01).unwrap();
for _ in 0..100 {
filter.insert(DocumentId::new());
}
let json = serde_json::to_string(&filter).unwrap();
let filter_restored: BloomFilter = serde_json::from_str(&json).unwrap();
assert_eq!(filter, filter_restored);
let doc_id = DocumentId::new();
filter.insert(doc_id);
filter_restored.contains(doc_id); }
#[test]
fn test_stats_serialization() {
let mut filter = BloomFilter::new(1000, 0.01).unwrap();
for _ in 0..100 {
filter.insert(DocumentId::new());
}
let stats = filter.stats();
let json = serde_json::to_string(&stats).unwrap();
let stats_restored: BloomFilterStats = serde_json::from_str(&json).unwrap();
assert_eq!(stats, stats_restored);
}
#[test]
fn test_edge_cases() {
let mut small_filter = BloomFilter::new(1, 0.01).unwrap();
let doc_id = DocumentId::new();
small_filter.insert(doc_id);
assert!(small_filter.contains(doc_id));
let mut high_fp_filter = BloomFilter::new(100, 0.99).unwrap();
high_fp_filter.insert(doc_id);
assert!(high_fp_filter.contains(doc_id));
let mut low_fp_filter = BloomFilter::new(100, 0.001).unwrap();
low_fp_filter.insert(doc_id);
assert!(low_fp_filter.contains(doc_id));
}
#[test]
fn test_memory_layout() {
use std::mem;
let header_size = mem::size_of::<BloomFilterHeader>();
assert!(header_size >= 32); assert!(header_size % mem::align_of::<BloomFilterHeader>() == 0);
assert!(mem::align_of::<BloomFilterHeader>() >= 8); }
}