1use serde::{Deserialize, Serialize};
2
3use crate::entry::Hash;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BloomFilter {
16 bits: Vec<u64>,
17 num_bits: usize,
18 num_hashes: u32,
19 count: usize,
20}
21
22impl BloomFilter {
23 pub fn validate(&self) -> Result<(), String> {
26 if self.num_bits == 0 {
27 return Err("bloom filter num_bits must be > 0".into());
28 }
29 if self.bits.len() * 64 < self.num_bits {
30 return Err(format!(
31 "bloom filter bits array too small: {} words for {} bits",
32 self.bits.len(),
33 self.num_bits
34 ));
35 }
36 if self.num_hashes == 0 || self.num_hashes > 32 {
37 return Err(format!(
38 "bloom filter num_hashes {} out of range [1, 32]",
39 self.num_hashes
40 ));
41 }
42 Ok(())
43 }
44
45 pub fn new(expected_items: usize, fp_rate: f64) -> Self {
52 assert!(expected_items > 0, "expected_items must be > 0");
53 assert!((0.0..1.0).contains(&fp_rate), "fp_rate must be in (0, 1)");
54
55 let n = expected_items as f64;
56 let ln2 = std::f64::consts::LN_2;
57 let ln2_sq = ln2 * ln2;
58
59 let num_bits = ((-n * fp_rate.ln()) / ln2_sq).ceil() as usize;
60 let num_bits = num_bits.max(64); let num_hashes = ((num_bits as f64 / n) * ln2).ceil() as u32;
62 let num_hashes = num_hashes.max(1);
63
64 let words = num_bits.div_ceil(64);
65 Self {
66 bits: vec![0u64; words],
67 num_bits,
68 num_hashes,
69 count: 0,
70 }
71 }
72
73 pub fn insert(&mut self, hash: &Hash) {
75 for idx in self.indices(hash) {
76 let word = idx / 64;
77 let bit = idx % 64;
78 self.bits[word] |= 1u64 << bit;
79 }
80 self.count += 1;
81 }
82
83 pub fn contains(&self, hash: &Hash) -> bool {
88 for idx in self.indices(hash) {
89 let word = idx / 64;
90 let bit = idx % 64;
91 if self.bits[word] & (1u64 << bit) == 0 {
92 return false;
93 }
94 }
95 true
96 }
97
98 pub fn count(&self) -> usize {
100 self.count
101 }
102
103 pub fn merge(&mut self, other: &BloomFilter) {
107 assert_eq!(self.num_bits, other.num_bits, "bloom filter size mismatch");
108 assert_eq!(
109 self.num_hashes, other.num_hashes,
110 "bloom filter hash count mismatch"
111 );
112 for (a, b) in self.bits.iter_mut().zip(other.bits.iter()) {
113 *a |= *b;
114 }
115 self.count += other.count;
116 }
117
118 pub fn to_bytes(&self) -> Vec<u8> {
120 rmp_serde::to_vec(self).expect("bloom filter serialization should not fail")
121 }
122
123 pub fn from_bytes(bytes: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
125 rmp_serde::from_slice(bytes)
126 }
127
128 fn indices(&self, hash: &Hash) -> Vec<usize> {
133 let h1 = u64::from_le_bytes(hash[0..8].try_into().unwrap());
135 let h2 = u64::from_le_bytes(hash[8..16].try_into().unwrap());
136 let m = self.num_bits as u64;
137
138 (0..self.num_hashes)
139 .map(|i| {
140 let i = i as u64;
141 let idx = h1
143 .wrapping_add(i.wrapping_mul(h2))
144 .wrapping_add(i.wrapping_mul(i));
145 (idx % m) as usize
146 })
147 .collect()
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 fn make_hash(seed: u8) -> Hash {
156 let mut h = [0u8; 32];
157 h[0] = seed;
158 *blake3::hash(&h).as_bytes()
160 }
161
162 #[test]
163 fn bloom_insert_and_check() {
164 let mut bloom = BloomFilter::new(100, 0.01);
165 let h1 = make_hash(1);
166 let h2 = make_hash(2);
167 let h3 = make_hash(3);
168
169 bloom.insert(&h1);
170 bloom.insert(&h2);
171
172 assert!(bloom.contains(&h1));
173 assert!(bloom.contains(&h2));
174 assert!(!bloom.contains(&h3));
178 }
179
180 #[test]
181 fn bloom_empty_contains_nothing() {
182 let bloom = BloomFilter::new(100, 0.01);
183 for i in 0..=255 {
184 assert!(!bloom.contains(&make_hash(i)));
185 }
186 }
187
188 #[test]
189 fn bloom_false_positive_rate() {
190 let n = 1000;
193 let mut bloom = BloomFilter::new(n, 0.01);
194
195 for i in 0..n {
196 let h = *blake3::hash(&(i as u64).to_le_bytes()).as_bytes();
197 bloom.insert(&h);
198 }
199
200 let test_count = 10_000;
201 let mut false_positives = 0;
202 for i in n..(n + test_count) {
203 let h = *blake3::hash(&(i as u64).to_le_bytes()).as_bytes();
204 if bloom.contains(&h) {
205 false_positives += 1;
206 }
207 }
208
209 let fpr = false_positives as f64 / test_count as f64;
210 assert!(
211 fpr < 0.02,
212 "false positive rate {fpr:.4} exceeds 2% threshold"
213 );
214 }
215
216 #[test]
217 fn bloom_merge_union() {
218 let mut bloom_a = BloomFilter::new(100, 0.01);
219 let mut bloom_b = BloomFilter::new(100, 0.01);
220
221 let h1 = make_hash(1);
222 let h2 = make_hash(2);
223 let h3 = make_hash(3);
224
225 bloom_a.insert(&h1);
226 bloom_a.insert(&h2);
227 bloom_b.insert(&h2);
228 bloom_b.insert(&h3);
229
230 bloom_a.merge(&bloom_b);
231
232 assert!(bloom_a.contains(&h1));
234 assert!(bloom_a.contains(&h2));
235 assert!(bloom_a.contains(&h3));
236 }
237
238 #[test]
239 fn bloom_serialization_roundtrip() {
240 let mut bloom = BloomFilter::new(100, 0.01);
241 let h1 = make_hash(1);
242 let h2 = make_hash(2);
243 bloom.insert(&h1);
244 bloom.insert(&h2);
245
246 let bytes = bloom.to_bytes();
247 let restored = BloomFilter::from_bytes(&bytes).unwrap();
248
249 assert!(restored.contains(&h1));
250 assert!(restored.contains(&h2));
251 assert!(!restored.contains(&make_hash(3)));
252 assert_eq!(restored.count(), 2);
253 assert_eq!(restored.num_bits, bloom.num_bits);
254 assert_eq!(restored.num_hashes, bloom.num_hashes);
255 }
256
257 #[test]
258 fn bloom_count_tracks_inserts() {
259 let mut bloom = BloomFilter::new(100, 0.01);
260 assert_eq!(bloom.count(), 0);
261 bloom.insert(&make_hash(1));
262 assert_eq!(bloom.count(), 1);
263 bloom.insert(&make_hash(2));
264 assert_eq!(bloom.count(), 2);
265 }
266}