use std::collections::HashMap;
use parking_lot::RwLock;
use super::segment::SegmentId;
use crate::storage::primitives::bloom::BloomFilter;
const DEFAULT_EXPECTED_ENTITIES: usize = 100_000;
const DEFAULT_FP_RATE: f64 = 0.01;
pub struct SegmentBloom {
pub filter: BloomFilter,
pub key_count: usize,
pub frozen: bool,
}
impl SegmentBloom {
pub fn new(expected_elements: usize, fp_rate: f64) -> Self {
Self {
filter: BloomFilter::with_capacity(expected_elements, fp_rate),
key_count: 0,
frozen: false,
}
}
pub fn add(&mut self, key: &[u8]) {
if !self.frozen {
self.filter.insert(key);
self.key_count += 1;
}
}
pub fn might_contain(&self, key: &[u8]) -> bool {
self.filter.contains(key)
}
pub fn freeze(&mut self) {
self.frozen = true;
}
pub fn estimated_fp_rate(&self) -> f64 {
self.filter.estimate_fp_rate(self.key_count)
}
pub fn memory_bytes(&self) -> usize {
self.filter.byte_size() + std::mem::size_of::<Self>()
}
}
pub struct BloomFilterRegistry {
blooms: RwLock<HashMap<(String, SegmentId), SegmentBloom>>,
expected_elements: usize,
fp_rate: f64,
}
impl BloomFilterRegistry {
pub fn new() -> Self {
Self {
blooms: RwLock::new(HashMap::new()),
expected_elements: DEFAULT_EXPECTED_ENTITIES,
fp_rate: DEFAULT_FP_RATE,
}
}
pub fn with_config(expected_elements: usize, fp_rate: f64) -> Self {
Self {
blooms: RwLock::new(HashMap::new()),
expected_elements,
fp_rate,
}
}
pub fn register_segment(&self, collection: &str, segment_id: SegmentId) {
let bloom = SegmentBloom::new(self.expected_elements, self.fp_rate);
let mut blooms = self.blooms.write();
blooms.insert((collection.to_string(), segment_id), bloom);
}
pub fn add_key(&self, collection: &str, segment_id: SegmentId, key: &[u8]) {
let mut blooms = self.blooms.write();
if let Some(bloom) = blooms.get_mut(&(collection.to_string(), segment_id)) {
bloom.add(key);
}
}
pub fn candidate_segments(&self, collection: &str, key: &[u8]) -> Vec<SegmentId> {
let blooms = self.blooms.read();
blooms
.iter()
.filter_map(|((coll, seg_id), bloom)| {
if coll == collection && bloom.might_contain(key) {
Some(*seg_id)
} else {
None
}
})
.collect()
}
pub fn might_contain(&self, collection: &str, segment_id: SegmentId, key: &[u8]) -> bool {
let blooms = self.blooms.read();
match blooms.get(&(collection.to_string(), segment_id)) {
Some(bloom) => bloom.might_contain(key),
None => true, }
}
pub fn freeze_segment(&self, collection: &str, segment_id: SegmentId) {
let mut blooms = self.blooms.write();
if let Some(bloom) = blooms.get_mut(&(collection.to_string(), segment_id)) {
bloom.freeze();
}
}
pub fn remove_segment(&self, collection: &str, segment_id: SegmentId) {
let mut blooms = self.blooms.write();
blooms.remove(&(collection.to_string(), segment_id));
}
pub fn merge_segments(
&self,
collection: &str,
seg_a: SegmentId,
seg_b: SegmentId,
new_seg_id: SegmentId,
) -> bool {
let blooms = self.blooms.read();
let key_a = (collection.to_string(), seg_a);
let key_b = (collection.to_string(), seg_b);
let merged = match (blooms.get(&key_a), blooms.get(&key_b)) {
(Some(a), Some(b)) => a.filter.merge(&b.filter),
_ => return false,
};
drop(blooms);
if let Some(merged_filter) = merged {
let key_count = {
let blooms = self.blooms.read();
let a_count = blooms.get(&key_a).map_or(0, |b| b.key_count);
let b_count = blooms.get(&key_b).map_or(0, |b| b.key_count);
a_count + b_count
};
let bloom = SegmentBloom {
filter: merged_filter,
key_count,
frozen: true,
};
let mut blooms = self.blooms.write();
blooms.insert((collection.to_string(), new_seg_id), bloom);
true
} else {
false
}
}
pub fn stats(&self) -> BloomRegistryStats {
let blooms = self.blooms.read();
let mut total_memory = 0;
let mut total_keys = 0;
let mut segment_count = 0;
for bloom in blooms.values() {
total_memory += bloom.memory_bytes();
total_keys += bloom.key_count;
segment_count += 1;
}
BloomRegistryStats {
segment_count,
total_keys,
total_memory_bytes: total_memory,
}
}
}
impl Default for BloomFilterRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct BloomRegistryStats {
pub segment_count: usize,
pub total_keys: usize,
pub total_memory_bytes: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_basic() {
let registry = BloomFilterRegistry::new();
registry.register_segment("users", 1);
registry.add_key("users", 1, b"alice");
registry.add_key("users", 1, b"bob");
assert!(registry.might_contain("users", 1, b"alice"));
assert!(registry.might_contain("users", 1, b"bob"));
assert!(!registry.might_contain("users", 1, b"charlie"));
}
#[test]
fn test_candidate_segments() {
let registry = BloomFilterRegistry::with_config(100, 0.01);
registry.register_segment("users", 1);
registry.register_segment("users", 2);
registry.register_segment("users", 3);
registry.add_key("users", 1, b"alice");
registry.add_key("users", 2, b"bob");
registry.add_key("users", 3, b"charlie");
let candidates = registry.candidate_segments("users", b"alice");
assert!(candidates.contains(&1));
}
#[test]
fn test_freeze_segment() {
let registry = BloomFilterRegistry::new();
registry.register_segment("data", 1);
registry.add_key("data", 1, b"before_freeze");
registry.freeze_segment("data", 1);
registry.add_key("data", 1, b"after_freeze");
assert!(registry.might_contain("data", 1, b"before_freeze"));
}
#[test]
fn test_merge_segments() {
let registry = BloomFilterRegistry::with_config(100, 0.01);
registry.register_segment("data", 1);
registry.register_segment("data", 2);
registry.add_key("data", 1, b"from_seg1");
registry.add_key("data", 2, b"from_seg2");
assert!(registry.merge_segments("data", 1, 2, 3));
assert!(registry.might_contain("data", 3, b"from_seg1"));
assert!(registry.might_contain("data", 3, b"from_seg2"));
}
#[test]
fn test_remove_segment() {
let registry = BloomFilterRegistry::new();
registry.register_segment("data", 1);
registry.add_key("data", 1, b"key");
assert!(registry.might_contain("data", 1, b"key"));
registry.remove_segment("data", 1);
assert!(registry.might_contain("data", 1, b"key"));
}
#[test]
fn test_stats() {
let registry = BloomFilterRegistry::with_config(100, 0.01);
registry.register_segment("a", 1);
registry.register_segment("b", 2);
registry.add_key("a", 1, b"x");
registry.add_key("a", 1, b"y");
registry.add_key("b", 2, b"z");
let stats = registry.stats();
assert_eq!(stats.segment_count, 2);
assert_eq!(stats.total_keys, 3);
assert!(stats.total_memory_bytes > 0);
}
#[test]
fn test_registry_recovers_from_poisoned_lock() {
let registry = BloomFilterRegistry::new();
registry.register_segment("users", 1);
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = registry.blooms.write();
panic!("poison bloom registry");
}));
registry.add_key("users", 1, b"alice");
assert!(registry.might_contain("users", 1, b"alice"));
}
}