lsm_tree/bloom/
mod.rs

1// Copyright (c) 2024-present, fjall-rs
2// This source code is licensed under both the Apache 2.0 and MIT License
3// (found in the LICENSE-* files in the repository)
4
5mod bit_array;
6
7use crate::{
8    coding::{Decode, DecodeError, Encode, EncodeError},
9    file::MAGIC_BYTES,
10};
11use bit_array::BitArray;
12use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
13use std::io::{Read, Write};
14
15/// Two hashes that are used for double hashing
16pub type CompositeHash = (u64, u64);
17
18/// A standard bloom filter
19///
20/// Allows buffering the key hashes before actual filter construction
21/// which is needed to properly calculate the filter size, as the amount of items
22/// are unknown during segment construction.
23///
24/// The filter uses double hashing instead of `k` hash functions, see:
25/// <https://fjall-rs.github.io/post/bloom-filter-hash-sharing>
26#[derive(Debug, Eq, PartialEq)]
27#[allow(clippy::module_name_repetitions)]
28pub struct BloomFilter {
29    /// Raw bytes exposed as bit array
30    inner: BitArray,
31
32    /// Bit count
33    m: usize,
34
35    /// Number of hash functions
36    k: usize,
37}
38
39impl Encode for BloomFilter {
40    fn encode_into<W: Write>(&self, writer: &mut W) -> Result<(), EncodeError> {
41        // Write header
42        writer.write_all(&MAGIC_BYTES)?;
43
44        // NOTE: Filter type (unused)
45        writer.write_u8(0)?;
46
47        // NOTE: Hash type (unused)
48        writer.write_u8(0)?;
49
50        writer.write_u64::<BigEndian>(self.m as u64)?;
51        writer.write_u64::<BigEndian>(self.k as u64)?;
52        writer.write_all(self.inner.bytes())?;
53
54        Ok(())
55    }
56}
57
58impl Decode for BloomFilter {
59    fn decode_from<R: Read>(reader: &mut R) -> Result<Self, DecodeError> {
60        // Check header
61        let mut magic = [0u8; MAGIC_BYTES.len()];
62        reader.read_exact(&mut magic)?;
63
64        if magic != MAGIC_BYTES {
65            return Err(DecodeError::InvalidHeader("BloomFilter"));
66        }
67
68        // NOTE: Filter type (unused)
69        let filter_type = reader.read_u8()?;
70        assert_eq!(0, filter_type, "Invalid filter type");
71
72        // NOTE: Hash type (unused)
73        let hash_type = reader.read_u8()?;
74        assert_eq!(0, hash_type, "Invalid bloom hash type");
75
76        let m = reader.read_u64::<BigEndian>()? as usize;
77        let k = reader.read_u64::<BigEndian>()? as usize;
78
79        let mut bytes = vec![0; m / 8];
80        reader.read_exact(&mut bytes)?;
81
82        Ok(Self::from_raw(m, k, bytes.into_boxed_slice()))
83    }
84}
85
86#[allow(clippy::len_without_is_empty)]
87impl BloomFilter {
88    /// Size of bloom filter in bytes.
89    #[must_use]
90    pub fn len(&self) -> usize {
91        self.inner.bytes().len()
92    }
93
94    fn from_raw(m: usize, k: usize, bytes: Box<[u8]>) -> Self {
95        Self {
96            inner: BitArray::from_bytes(bytes),
97            m,
98            k,
99        }
100    }
101
102    /// Constructs a bloom filter that can hold `n` items
103    /// while maintaining a certain false positive rate `fpr`.
104    #[must_use]
105    pub fn with_fp_rate(n: usize, fpr: f32) -> Self {
106        use std::f32::consts::LN_2;
107
108        assert!(n > 0);
109
110        // NOTE: Some sensible minimum
111        let fpr = fpr.max(0.000_001);
112
113        let m = Self::calculate_m(n, fpr);
114        let bpk = m / n;
115        let k = (((bpk as f32) * LN_2) as usize).max(1);
116
117        Self {
118            inner: BitArray::with_capacity(m / 8),
119            m,
120            k,
121        }
122    }
123
124    /// Constructs a bloom filter that can hold `n` items
125    /// with `bpk` bits per key.
126    ///
127    /// 10 bits per key is a sensible default.
128    #[must_use]
129    pub fn with_bpk(n: usize, bpk: u8) -> Self {
130        use std::f32::consts::LN_2;
131
132        assert!(bpk > 0);
133        assert!(n > 0);
134
135        let bpk = bpk as usize;
136
137        let m = n * bpk;
138        let k = (((bpk as f32) * LN_2) as usize).max(1);
139
140        // NOTE: Round up so we don't get too little bits
141        let bytes = (m as f32 / 8.0).ceil() as usize;
142
143        Self {
144            inner: BitArray::with_capacity(bytes),
145            m: bytes * 8,
146            k,
147        }
148    }
149
150    fn calculate_m(n: usize, fp_rate: f32) -> usize {
151        use std::f32::consts::LN_2;
152
153        let n = n as f32;
154        let ln2_squared = LN_2.powi(2);
155
156        let numerator = n * fp_rate.ln();
157        let m = -(numerator / ln2_squared);
158
159        // Round up to next byte
160        ((m / 8.0).ceil() * 8.0) as usize
161    }
162
163    /// Returns `true` if the hash may be contained.
164    ///
165    /// Will never have a false negative.
166    #[must_use]
167    pub fn contains_hash(&self, hash: CompositeHash) -> bool {
168        let (mut h1, mut h2) = hash;
169
170        for i in 0..(self.k as u64) {
171            let idx = h1 % (self.m as u64);
172
173            // NOTE: should be in bounds because of modulo
174            #[allow(clippy::expect_used)]
175            if !self.has_bit(idx as usize) {
176                return false;
177            }
178
179            h1 = h1.wrapping_add(h2);
180            h2 = h2.wrapping_add(i);
181        }
182
183        true
184    }
185
186    /// Returns `true` if the item may be contained.
187    ///
188    /// Will never have a false negative.
189    #[must_use]
190    pub fn contains(&self, key: &[u8]) -> bool {
191        self.contains_hash(Self::get_hash(key))
192    }
193
194    /// Adds the key to the filter.
195    pub fn set_with_hash(&mut self, (mut h1, mut h2): CompositeHash) {
196        for i in 0..(self.k as u64) {
197            let idx = h1 % (self.m as u64);
198
199            self.enable_bit(idx as usize);
200
201            h1 = h1.wrapping_add(h2);
202            h2 = h2.wrapping_add(i);
203        }
204    }
205
206    /// Returns `true` if the bit at `idx` is `1`.
207    fn has_bit(&self, idx: usize) -> bool {
208        self.inner.get(idx)
209    }
210
211    /// Sets the bit at the given index to `true`.
212    fn enable_bit(&mut self, idx: usize) {
213        self.inner.set(idx, true);
214    }
215
216    /// Gets the hash of a key.
217    #[must_use]
218    pub fn get_hash(key: &[u8]) -> CompositeHash {
219        let h0 = xxhash_rust::xxh3::xxh3_128(key);
220        let h1 = (h0 >> 64) as u64;
221        let h2 = h0 as u64;
222        (h1, h2)
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use std::fs::File;
230    use test_log::test;
231
232    #[test]
233    fn bloom_serde_round_trip() -> crate::Result<()> {
234        let dir = tempfile::tempdir()?;
235
236        let path = dir.path().join("bf");
237        let mut file = File::create(&path)?;
238
239        let mut filter = BloomFilter::with_fp_rate(10, 0.0001);
240
241        let keys = &[
242            b"item0", b"item1", b"item2", b"item3", b"item4", b"item5", b"item6", b"item7",
243            b"item8", b"item9",
244        ];
245
246        for key in keys {
247            filter.set_with_hash(BloomFilter::get_hash(*key));
248        }
249
250        for key in keys {
251            assert!(filter.contains(&**key));
252        }
253        assert!(!filter.contains(b"asdasads"));
254        assert!(!filter.contains(b"item10"));
255        assert!(!filter.contains(b"cxycxycxy"));
256
257        filter.encode_into(&mut file)?;
258        file.sync_all()?;
259        drop(file);
260
261        let mut file = File::open(&path)?;
262        let filter_copy = BloomFilter::decode_from(&mut file)?;
263
264        assert_eq!(filter, filter_copy);
265
266        for key in keys {
267            assert!(filter.contains(&**key));
268        }
269        assert!(!filter_copy.contains(b"asdasads"));
270        assert!(!filter_copy.contains(b"item10"));
271        assert!(!filter_copy.contains(b"cxycxycxy"));
272
273        Ok(())
274    }
275
276    #[test]
277    fn bloom_calculate_m() {
278        assert_eq!(9_592, BloomFilter::calculate_m(1_000, 0.01));
279        assert_eq!(4_800, BloomFilter::calculate_m(1_000, 0.1));
280        assert_eq!(4_792_536, BloomFilter::calculate_m(1_000_000, 0.1));
281    }
282
283    #[test]
284    fn bloom_basic() {
285        let mut filter = BloomFilter::with_fp_rate(10, 0.0001);
286
287        for key in [
288            b"item0", b"item1", b"item2", b"item3", b"item4", b"item5", b"item6", b"item7",
289            b"item8", b"item9",
290        ] {
291            assert!(!filter.contains(key));
292            filter.set_with_hash(BloomFilter::get_hash(key));
293            assert!(filter.contains(key));
294
295            assert!(!filter.contains(b"asdasdasdasdasdasdasd"));
296        }
297    }
298
299    #[test]
300    fn bloom_bpk() {
301        let item_count = 1_000;
302        let bpk = 5;
303
304        let mut filter = BloomFilter::with_bpk(item_count, bpk);
305
306        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
307            let key = key.as_bytes();
308
309            filter.set_with_hash(BloomFilter::get_hash(key));
310            assert!(filter.contains(key));
311        }
312
313        let mut false_positives = 0;
314
315        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
316            let key = key.as_bytes();
317
318            if filter.contains(key) {
319                false_positives += 1;
320            }
321        }
322
323        #[allow(clippy::cast_precision_loss)]
324        let fpr = false_positives as f32 / item_count as f32;
325        assert!(fpr < 0.13);
326    }
327
328    #[test]
329    fn bloom_fpr() {
330        let item_count = 100_000;
331        let wanted_fpr = 0.1;
332
333        let mut filter = BloomFilter::with_fp_rate(item_count, wanted_fpr);
334
335        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
336            let key = key.as_bytes();
337
338            filter.set_with_hash(BloomFilter::get_hash(key));
339            assert!(filter.contains(key));
340        }
341
342        let mut false_positives = 0;
343
344        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
345            let key = key.as_bytes();
346
347            if filter.contains(key) {
348                false_positives += 1;
349            }
350        }
351
352        #[allow(clippy::cast_precision_loss)]
353        let fpr = false_positives as f32 / item_count as f32;
354        assert!(fpr > 0.05);
355        assert!(fpr < 0.13);
356    }
357
358    #[test]
359    fn bloom_fpr_2() {
360        let item_count = 100_000;
361        let wanted_fpr = 0.5;
362
363        let mut filter = BloomFilter::with_fp_rate(item_count, wanted_fpr);
364
365        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
366            let key = key.as_bytes();
367
368            filter.set_with_hash(BloomFilter::get_hash(key));
369            assert!(filter.contains(key));
370        }
371
372        let mut false_positives = 0;
373
374        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
375            let key = key.as_bytes();
376
377            if filter.contains(key) {
378                false_positives += 1;
379            }
380        }
381
382        #[allow(clippy::cast_precision_loss)]
383        let fpr = false_positives as f32 / item_count as f32;
384        assert!(fpr > 0.45);
385        assert!(fpr < 0.55);
386    }
387}