1mod bit_array;
6
7use crate::{
8 coding::{Decode, DecodeError, Encode, EncodeError},
9 file::MAGIC_BYTES,
10};
11use bit_array::BitArray;
12use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
13use std::io::{Read, Write};
14
15pub type CompositeHash = (u64, u64);
17
18#[derive(Debug, Eq, PartialEq)]
27#[allow(clippy::module_name_repetitions)]
28pub struct BloomFilter {
29 inner: BitArray,
31
32 m: usize,
34
35 k: usize,
37}
38
39impl Encode for BloomFilter {
40 fn encode_into<W: Write>(&self, writer: &mut W) -> Result<(), EncodeError> {
41 writer.write_all(&MAGIC_BYTES)?;
43
44 writer.write_u8(0)?;
46
47 writer.write_u8(0)?;
49
50 writer.write_u64::<BigEndian>(self.m as u64)?;
51 writer.write_u64::<BigEndian>(self.k as u64)?;
52 writer.write_all(self.inner.bytes())?;
53
54 Ok(())
55 }
56}
57
58impl Decode for BloomFilter {
59 fn decode_from<R: Read>(reader: &mut R) -> Result<Self, DecodeError> {
60 let mut magic = [0u8; MAGIC_BYTES.len()];
62 reader.read_exact(&mut magic)?;
63
64 if magic != MAGIC_BYTES {
65 return Err(DecodeError::InvalidHeader("BloomFilter"));
66 }
67
68 let filter_type = reader.read_u8()?;
70 assert_eq!(0, filter_type, "Invalid filter type");
71
72 let hash_type = reader.read_u8()?;
74 assert_eq!(0, hash_type, "Invalid bloom hash type");
75
76 let m = reader.read_u64::<BigEndian>()? as usize;
77 let k = reader.read_u64::<BigEndian>()? as usize;
78
79 let mut bytes = vec![0; m / 8];
80 reader.read_exact(&mut bytes)?;
81
82 Ok(Self::from_raw(m, k, bytes.into_boxed_slice()))
83 }
84}
85
86#[allow(clippy::len_without_is_empty)]
87impl BloomFilter {
88 #[must_use]
90 pub fn len(&self) -> usize {
91 self.inner.bytes().len()
92 }
93
94 fn from_raw(m: usize, k: usize, bytes: Box<[u8]>) -> Self {
95 Self {
96 inner: BitArray::from_bytes(bytes),
97 m,
98 k,
99 }
100 }
101
102 #[must_use]
105 pub fn with_fp_rate(n: usize, fpr: f32) -> Self {
106 use std::f32::consts::LN_2;
107
108 assert!(n > 0);
109
110 let fpr = fpr.max(0.000_001);
112
113 let m = Self::calculate_m(n, fpr);
114 let bpk = m / n;
115 let k = (((bpk as f32) * LN_2) as usize).max(1);
116
117 Self {
118 inner: BitArray::with_capacity(m / 8),
119 m,
120 k,
121 }
122 }
123
124 #[must_use]
129 pub fn with_bpk(n: usize, bpk: u8) -> Self {
130 use std::f32::consts::LN_2;
131
132 assert!(bpk > 0);
133 assert!(n > 0);
134
135 let bpk = bpk as usize;
136
137 let m = n * bpk;
138 let k = (((bpk as f32) * LN_2) as usize).max(1);
139
140 let bytes = (m as f32 / 8.0).ceil() as usize;
142
143 Self {
144 inner: BitArray::with_capacity(bytes),
145 m: bytes * 8,
146 k,
147 }
148 }
149
150 fn calculate_m(n: usize, fp_rate: f32) -> usize {
151 use std::f32::consts::LN_2;
152
153 let n = n as f32;
154 let ln2_squared = LN_2.powi(2);
155
156 let numerator = n * fp_rate.ln();
157 let m = -(numerator / ln2_squared);
158
159 ((m / 8.0).ceil() * 8.0) as usize
161 }
162
163 #[must_use]
167 pub fn contains_hash(&self, hash: CompositeHash) -> bool {
168 let (mut h1, mut h2) = hash;
169
170 for i in 0..(self.k as u64) {
171 let idx = h1 % (self.m as u64);
172
173 #[allow(clippy::expect_used)]
175 if !self.has_bit(idx as usize) {
176 return false;
177 }
178
179 h1 = h1.wrapping_add(h2);
180 h2 = h2.wrapping_add(i);
181 }
182
183 true
184 }
185
186 #[must_use]
190 pub fn contains(&self, key: &[u8]) -> bool {
191 self.contains_hash(Self::get_hash(key))
192 }
193
194 pub fn set_with_hash(&mut self, (mut h1, mut h2): CompositeHash) {
196 for i in 0..(self.k as u64) {
197 let idx = h1 % (self.m as u64);
198
199 self.enable_bit(idx as usize);
200
201 h1 = h1.wrapping_add(h2);
202 h2 = h2.wrapping_add(i);
203 }
204 }
205
206 fn has_bit(&self, idx: usize) -> bool {
208 self.inner.get(idx)
209 }
210
211 fn enable_bit(&mut self, idx: usize) {
213 self.inner.set(idx, true);
214 }
215
216 #[must_use]
218 pub fn get_hash(key: &[u8]) -> CompositeHash {
219 let h0 = xxhash_rust::xxh3::xxh3_128(key);
220 let h1 = (h0 >> 64) as u64;
221 let h2 = h0 as u64;
222 (h1, h2)
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use std::fs::File;
230 use test_log::test;
231
232 #[test]
233 fn bloom_serde_round_trip() -> crate::Result<()> {
234 let dir = tempfile::tempdir()?;
235
236 let path = dir.path().join("bf");
237 let mut file = File::create(&path)?;
238
239 let mut filter = BloomFilter::with_fp_rate(10, 0.0001);
240
241 let keys = &[
242 b"item0", b"item1", b"item2", b"item3", b"item4", b"item5", b"item6", b"item7",
243 b"item8", b"item9",
244 ];
245
246 for key in keys {
247 filter.set_with_hash(BloomFilter::get_hash(*key));
248 }
249
250 for key in keys {
251 assert!(filter.contains(&**key));
252 }
253 assert!(!filter.contains(b"asdasads"));
254 assert!(!filter.contains(b"item10"));
255 assert!(!filter.contains(b"cxycxycxy"));
256
257 filter.encode_into(&mut file)?;
258 file.sync_all()?;
259 drop(file);
260
261 let mut file = File::open(&path)?;
262 let filter_copy = BloomFilter::decode_from(&mut file)?;
263
264 assert_eq!(filter, filter_copy);
265
266 for key in keys {
267 assert!(filter.contains(&**key));
268 }
269 assert!(!filter_copy.contains(b"asdasads"));
270 assert!(!filter_copy.contains(b"item10"));
271 assert!(!filter_copy.contains(b"cxycxycxy"));
272
273 Ok(())
274 }
275
276 #[test]
277 fn bloom_calculate_m() {
278 assert_eq!(9_592, BloomFilter::calculate_m(1_000, 0.01));
279 assert_eq!(4_800, BloomFilter::calculate_m(1_000, 0.1));
280 assert_eq!(4_792_536, BloomFilter::calculate_m(1_000_000, 0.1));
281 }
282
283 #[test]
284 fn bloom_basic() {
285 let mut filter = BloomFilter::with_fp_rate(10, 0.0001);
286
287 for key in [
288 b"item0", b"item1", b"item2", b"item3", b"item4", b"item5", b"item6", b"item7",
289 b"item8", b"item9",
290 ] {
291 assert!(!filter.contains(key));
292 filter.set_with_hash(BloomFilter::get_hash(key));
293 assert!(filter.contains(key));
294
295 assert!(!filter.contains(b"asdasdasdasdasdasdasd"));
296 }
297 }
298
299 #[test]
300 fn bloom_bpk() {
301 let item_count = 1_000;
302 let bpk = 5;
303
304 let mut filter = BloomFilter::with_bpk(item_count, bpk);
305
306 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
307 let key = key.as_bytes();
308
309 filter.set_with_hash(BloomFilter::get_hash(key));
310 assert!(filter.contains(key));
311 }
312
313 let mut false_positives = 0;
314
315 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
316 let key = key.as_bytes();
317
318 if filter.contains(key) {
319 false_positives += 1;
320 }
321 }
322
323 #[allow(clippy::cast_precision_loss)]
324 let fpr = false_positives as f32 / item_count as f32;
325 assert!(fpr < 0.13);
326 }
327
328 #[test]
329 fn bloom_fpr() {
330 let item_count = 100_000;
331 let wanted_fpr = 0.1;
332
333 let mut filter = BloomFilter::with_fp_rate(item_count, wanted_fpr);
334
335 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
336 let key = key.as_bytes();
337
338 filter.set_with_hash(BloomFilter::get_hash(key));
339 assert!(filter.contains(key));
340 }
341
342 let mut false_positives = 0;
343
344 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
345 let key = key.as_bytes();
346
347 if filter.contains(key) {
348 false_positives += 1;
349 }
350 }
351
352 #[allow(clippy::cast_precision_loss)]
353 let fpr = false_positives as f32 / item_count as f32;
354 assert!(fpr > 0.05);
355 assert!(fpr < 0.13);
356 }
357
358 #[test]
359 fn bloom_fpr_2() {
360 let item_count = 100_000;
361 let wanted_fpr = 0.5;
362
363 let mut filter = BloomFilter::with_fp_rate(item_count, wanted_fpr);
364
365 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
366 let key = key.as_bytes();
367
368 filter.set_with_hash(BloomFilter::get_hash(key));
369 assert!(filter.contains(key));
370 }
371
372 let mut false_positives = 0;
373
374 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
375 let key = key.as_bytes();
376
377 if filter.contains(key) {
378 false_positives += 1;
379 }
380 }
381
382 #[allow(clippy::cast_precision_loss)]
383 let fpr = false_positives as f32 / item_count as f32;
384 assert!(fpr > 0.45);
385 assert!(fpr < 0.55);
386 }
387}