1mod bit_array;
6
7use crate::{
8 coding::{Decode, DecodeError, Encode, EncodeError},
9 file::MAGIC_BYTES,
10};
11use bit_array::BitArray;
12use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
13use std::io::{Read, Write};
14
15pub type CompositeHash = (u64, u64);
17
18#[derive(Debug, Eq, PartialEq)]
27#[allow(clippy::module_name_repetitions)]
28pub struct BloomFilter {
29 inner: BitArray,
31
32 m: usize,
34
35 k: usize,
37}
38
39impl Encode for BloomFilter {
40 fn encode_into<W: Write>(&self, writer: &mut W) -> Result<(), EncodeError> {
41 writer.write_all(&MAGIC_BYTES)?;
43
44 writer.write_u8(0)?;
46
47 writer.write_u8(0)?;
49
50 writer.write_u64::<BigEndian>(self.m as u64)?;
51 writer.write_u64::<BigEndian>(self.k as u64)?;
52 writer.write_all(self.inner.bytes())?;
53
54 Ok(())
55 }
56}
57
58impl Decode for BloomFilter {
59 fn decode_from<R: Read>(reader: &mut R) -> Result<Self, DecodeError> {
60 let mut magic = [0u8; MAGIC_BYTES.len()];
62 reader.read_exact(&mut magic)?;
63
64 if magic != MAGIC_BYTES {
65 return Err(DecodeError::InvalidHeader("BloomFilter"));
66 }
67
68 let filter_type = reader.read_u8()?;
70 assert_eq!(0, filter_type, "Invalid filter type");
71
72 let hash_type = reader.read_u8()?;
74 assert_eq!(0, hash_type, "Invalid bloom hash type");
75
76 let m = reader.read_u64::<BigEndian>()? as usize;
77 let k = reader.read_u64::<BigEndian>()? as usize;
78
79 let mut bytes = vec![0; m / 8];
80 reader.read_exact(&mut bytes)?;
81
82 Ok(Self::from_raw(m, k, bytes.into_boxed_slice()))
83 }
84}
85
86#[allow(clippy::len_without_is_empty)]
87impl BloomFilter {
88 #[must_use]
90 pub fn len(&self) -> usize {
91 self.inner.bytes().len()
92 }
93
94 #[must_use]
96 pub fn hash_fn_count(&self) -> usize {
97 self.k
98 }
99
100 fn from_raw(m: usize, k: usize, bytes: Box<[u8]>) -> Self {
101 Self {
102 inner: BitArray::from_bytes(bytes),
103 m,
104 k,
105 }
106 }
107
108 #[must_use]
111 pub fn with_fp_rate(n: usize, fpr: f32) -> Self {
112 use std::f32::consts::LN_2;
113
114 assert!(n > 0);
115
116 let fpr = fpr.max(0.000_001);
118
119 let m = Self::calculate_m(n, fpr);
120 let bpk = m / n;
121 let k = (((bpk as f32) * LN_2) as usize).max(1);
122
123 Self {
124 inner: BitArray::with_capacity(m / 8),
125 m,
126 k,
127 }
128 }
129
130 #[must_use]
135 pub fn with_bpk(n: usize, bpk: u8) -> Self {
136 use std::f32::consts::LN_2;
137
138 assert!(bpk > 0);
139 assert!(n > 0);
140
141 let bpk = bpk as usize;
142
143 let m = n * bpk;
144 let k = (((bpk as f32) * LN_2) as usize).max(1);
145
146 let bytes = (m as f32 / 8.0).ceil() as usize;
148
149 Self {
150 inner: BitArray::with_capacity(bytes),
151 m: bytes * 8,
152 k,
153 }
154 }
155
156 fn calculate_m(n: usize, fp_rate: f32) -> usize {
157 use std::f32::consts::LN_2;
158
159 let n = n as f32;
160 let ln2_squared = LN_2.powi(2);
161
162 let numerator = n * fp_rate.ln();
163 let m = -(numerator / ln2_squared);
164
165 ((m / 8.0).ceil() * 8.0) as usize
167 }
168
169 #[must_use]
173 pub fn contains_hash(&self, (mut h1, mut h2): CompositeHash) -> bool {
174 for i in 0..(self.k as u64) {
175 let idx = h1 % (self.m as u64);
176
177 #[allow(clippy::expect_used)]
179 if !self.has_bit(idx as usize) {
180 return false;
181 }
182
183 h1 = h1.wrapping_add(h2);
184 h2 = h2.wrapping_add(i);
185 }
186
187 true
188 }
189
190 #[must_use]
194 pub fn contains(&self, key: &[u8]) -> bool {
195 self.contains_hash(Self::get_hash(key))
196 }
197
198 pub fn set_with_hash(&mut self, (mut h1, mut h2): CompositeHash) {
200 for i in 0..(self.k as u64) {
201 let idx = h1 % (self.m as u64);
202
203 self.enable_bit(idx as usize);
204
205 h1 = h1.wrapping_add(h2);
206 h2 = h2.wrapping_add(i);
207 }
208 }
209
210 fn has_bit(&self, idx: usize) -> bool {
212 self.inner.get(idx)
213 }
214
215 fn enable_bit(&mut self, idx: usize) {
217 self.inner.enable(idx);
218 }
219
220 #[must_use]
222 pub fn get_hash(key: &[u8]) -> CompositeHash {
223 let h0 = xxhash_rust::xxh3::xxh3_128(key);
224 let h1 = (h0 >> 64) as u64;
225 let h2 = h0 as u64;
226 (h1, h2)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use std::fs::File;
234 use test_log::test;
235
236 #[test]
237 fn bloom_serde_round_trip() -> crate::Result<()> {
238 let dir = tempfile::tempdir()?;
239
240 let path = dir.path().join("bf");
241 let mut file = File::create(&path)?;
242
243 let mut filter = BloomFilter::with_fp_rate(10, 0.0001);
244
245 let keys = &[
246 b"item0", b"item1", b"item2", b"item3", b"item4", b"item5", b"item6", b"item7",
247 b"item8", b"item9",
248 ];
249
250 for key in keys {
251 filter.set_with_hash(BloomFilter::get_hash(*key));
252 }
253
254 for key in keys {
255 assert!(filter.contains(&**key));
256 }
257 assert!(!filter.contains(b"asdasads"));
258 assert!(!filter.contains(b"item10"));
259 assert!(!filter.contains(b"cxycxycxy"));
260
261 filter.encode_into(&mut file)?;
262 file.sync_all()?;
263 drop(file);
264
265 let mut file = File::open(&path)?;
266 let filter_copy = BloomFilter::decode_from(&mut file)?;
267
268 assert_eq!(filter, filter_copy);
269
270 for key in keys {
271 assert!(filter.contains(&**key));
272 }
273 assert!(!filter_copy.contains(b"asdasads"));
274 assert!(!filter_copy.contains(b"item10"));
275 assert!(!filter_copy.contains(b"cxycxycxy"));
276
277 Ok(())
278 }
279
280 #[test]
281 fn bloom_calculate_m() {
282 assert_eq!(9_592, BloomFilter::calculate_m(1_000, 0.01));
283 assert_eq!(4_800, BloomFilter::calculate_m(1_000, 0.1));
284 assert_eq!(4_792_536, BloomFilter::calculate_m(1_000_000, 0.1));
285 }
286
287 #[test]
288 fn bloom_basic() {
289 let mut filter = BloomFilter::with_fp_rate(10, 0.0001);
290
291 for key in [
292 b"item0", b"item1", b"item2", b"item3", b"item4", b"item5", b"item6", b"item7",
293 b"item8", b"item9",
294 ] {
295 assert!(!filter.contains(key));
296 filter.set_with_hash(BloomFilter::get_hash(key));
297 assert!(filter.contains(key));
298
299 assert!(!filter.contains(b"asdasdasdasdasdasdasd"));
300 }
301 }
302
303 #[test]
304 fn bloom_bpk() {
305 let item_count = 1_000;
306 let bpk = 5;
307
308 let mut filter = BloomFilter::with_bpk(item_count, bpk);
309
310 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
311 let key = key.as_bytes();
312
313 filter.set_with_hash(BloomFilter::get_hash(key));
314 assert!(filter.contains(key));
315 }
316
317 let mut false_positives = 0;
318
319 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
320 let key = key.as_bytes();
321
322 if filter.contains(key) {
323 false_positives += 1;
324 }
325 }
326
327 #[allow(clippy::cast_precision_loss)]
328 let fpr = false_positives as f32 / item_count as f32;
329 assert!(fpr < 0.13);
330 }
331
332 #[test]
333 fn bloom_fpr() {
334 let item_count = 100_000;
335 let wanted_fpr = 0.1;
336
337 let mut filter = BloomFilter::with_fp_rate(item_count, wanted_fpr);
338
339 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
340 let key = key.as_bytes();
341
342 filter.set_with_hash(BloomFilter::get_hash(key));
343 assert!(filter.contains(key));
344 }
345
346 let mut false_positives = 0;
347
348 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
349 let key = key.as_bytes();
350
351 if filter.contains(key) {
352 false_positives += 1;
353 }
354 }
355
356 #[allow(clippy::cast_precision_loss)]
357 let fpr = false_positives as f32 / item_count as f32;
358 assert!(fpr > 0.05);
359 assert!(fpr < 0.13);
360 }
361
362 #[test]
363 fn bloom_fpr_2() {
364 let item_count = 100_000;
365 let wanted_fpr = 0.5;
366
367 let mut filter = BloomFilter::with_fp_rate(item_count, wanted_fpr);
368
369 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
370 let key = key.as_bytes();
371
372 filter.set_with_hash(BloomFilter::get_hash(key));
373 assert!(filter.contains(key));
374 }
375
376 let mut false_positives = 0;
377
378 for key in (0..item_count).map(|_| nanoid::nanoid!()) {
379 let key = key.as_bytes();
380
381 if filter.contains(key) {
382 false_positives += 1;
383 }
384 }
385
386 #[allow(clippy::cast_precision_loss)]
387 let fpr = false_positives as f32 / item_count as f32;
388 assert!(fpr > 0.45);
389 assert!(fpr < 0.55);
390 }
391}