lsm_tree/bloom/
mod.rs

1// Copyright (c) 2024-present, fjall-rs
2// This source code is licensed under both the Apache 2.0 and MIT License
3// (found in the LICENSE-* files in the repository)
4
5mod bit_array;
6
7use crate::{
8    coding::{Decode, DecodeError, Encode, EncodeError},
9    file::MAGIC_BYTES,
10};
11use bit_array::BitArray;
12use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
13use std::io::{Read, Write};
14
15/// Two hashes that are used for double hashing
16pub type CompositeHash = (u64, u64);
17
18/// A standard bloom filter
19///
20/// Allows buffering the key hashes before actual filter construction
21/// which is needed to properly calculate the filter size, as the amount of items
22/// are unknown during segment construction.
23///
24/// The filter uses double hashing instead of `k` hash functions, see:
25/// <https://fjall-rs.github.io/post/bloom-filter-hash-sharing>
26#[derive(Debug, Eq, PartialEq)]
27#[allow(clippy::module_name_repetitions)]
28pub struct BloomFilter {
29    /// Raw bytes exposed as bit array
30    inner: BitArray,
31
32    /// Bit count
33    m: usize,
34
35    /// Number of hash functions
36    k: usize,
37}
38
39impl Encode for BloomFilter {
40    fn encode_into<W: Write>(&self, writer: &mut W) -> Result<(), EncodeError> {
41        // Write header
42        writer.write_all(&MAGIC_BYTES)?;
43
44        // NOTE: Filter type
45        writer.write_u8(0)?;
46
47        // NOTE: Hash type (unused)
48        writer.write_u8(0)?;
49
50        writer.write_u64::<BigEndian>(self.m as u64)?;
51        writer.write_u64::<BigEndian>(self.k as u64)?;
52        writer.write_all(self.inner.bytes())?;
53
54        Ok(())
55    }
56}
57
58impl Decode for BloomFilter {
59    fn decode_from<R: Read>(reader: &mut R) -> Result<Self, DecodeError> {
60        // Check header
61        let mut magic = [0u8; MAGIC_BYTES.len()];
62        reader.read_exact(&mut magic)?;
63
64        if magic != MAGIC_BYTES {
65            return Err(DecodeError::InvalidHeader("BloomFilter"));
66        }
67
68        // NOTE: Filter type
69        let filter_type = reader.read_u8()?;
70        assert_eq!(0, filter_type, "Invalid filter type");
71
72        // NOTE: Hash type (unused)
73        let hash_type = reader.read_u8()?;
74        assert_eq!(0, hash_type, "Invalid bloom hash type");
75
76        let m = reader.read_u64::<BigEndian>()? as usize;
77        let k = reader.read_u64::<BigEndian>()? as usize;
78
79        let mut bytes = vec![0; m / 8];
80        reader.read_exact(&mut bytes)?;
81
82        Ok(Self::from_raw(m, k, bytes.into_boxed_slice()))
83    }
84}
85
86#[allow(clippy::len_without_is_empty)]
87impl BloomFilter {
88    /// Returns the size of the bloom filter in bytes.
89    #[must_use]
90    pub fn len(&self) -> usize {
91        self.inner.bytes().len()
92    }
93
94    /// Returns the amount of hashes used per lookup.
95    #[must_use]
96    pub fn hash_fn_count(&self) -> usize {
97        self.k
98    }
99
100    fn from_raw(m: usize, k: usize, bytes: Box<[u8]>) -> Self {
101        Self {
102            inner: BitArray::from_bytes(bytes),
103            m,
104            k,
105        }
106    }
107
108    /// Constructs a bloom filter that can hold `n` items
109    /// while maintaining a certain false positive rate `fpr`.
110    #[must_use]
111    pub fn with_fp_rate(n: usize, fpr: f32) -> Self {
112        use std::f32::consts::LN_2;
113
114        assert!(n > 0);
115
116        // NOTE: Some sensible minimum
117        let fpr = fpr.max(0.000_001);
118
119        let m = Self::calculate_m(n, fpr);
120        let bpk = m / n;
121        let k = (((bpk as f32) * LN_2) as usize).max(1);
122
123        Self {
124            inner: BitArray::with_capacity(m / 8),
125            m,
126            k,
127        }
128    }
129
130    /// Constructs a bloom filter that can hold `n` items
131    /// with `bpk` bits per key.
132    ///
133    /// 10 bits per key is a sensible default.
134    #[must_use]
135    pub fn with_bpk(n: usize, bpk: u8) -> Self {
136        use std::f32::consts::LN_2;
137
138        assert!(bpk > 0);
139        assert!(n > 0);
140
141        let bpk = bpk as usize;
142
143        let m = n * bpk;
144        let k = (((bpk as f32) * LN_2) as usize).max(1);
145
146        // NOTE: Round up so we don't get too little bits
147        let bytes = (m as f32 / 8.0).ceil() as usize;
148
149        Self {
150            inner: BitArray::with_capacity(bytes),
151            m: bytes * 8,
152            k,
153        }
154    }
155
156    fn calculate_m(n: usize, fp_rate: f32) -> usize {
157        use std::f32::consts::LN_2;
158
159        let n = n as f32;
160        let ln2_squared = LN_2.powi(2);
161
162        let numerator = n * fp_rate.ln();
163        let m = -(numerator / ln2_squared);
164
165        // Round up to next byte
166        ((m / 8.0).ceil() * 8.0) as usize
167    }
168
169    /// Returns `true` if the hash may be contained.
170    ///
171    /// Will never have a false negative.
172    #[must_use]
173    pub fn contains_hash(&self, (mut h1, mut h2): CompositeHash) -> bool {
174        for i in 0..(self.k as u64) {
175            let idx = h1 % (self.m as u64);
176
177            // NOTE: should be in bounds because of modulo
178            #[allow(clippy::expect_used)]
179            if !self.has_bit(idx as usize) {
180                return false;
181            }
182
183            h1 = h1.wrapping_add(h2);
184            h2 = h2.wrapping_add(i);
185        }
186
187        true
188    }
189
190    /// Returns `true` if the item may be contained.
191    ///
192    /// Will never have a false negative.
193    #[must_use]
194    pub fn contains(&self, key: &[u8]) -> bool {
195        self.contains_hash(Self::get_hash(key))
196    }
197
198    /// Adds the key to the filter.
199    pub fn set_with_hash(&mut self, (mut h1, mut h2): CompositeHash) {
200        for i in 0..(self.k as u64) {
201            let idx = h1 % (self.m as u64);
202
203            self.enable_bit(idx as usize);
204
205            h1 = h1.wrapping_add(h2);
206            h2 = h2.wrapping_add(i);
207        }
208    }
209
210    /// Returns `true` if the bit at `idx` is `1`.
211    fn has_bit(&self, idx: usize) -> bool {
212        self.inner.get(idx)
213    }
214
215    /// Sets the bit at the given index to `true`.
216    fn enable_bit(&mut self, idx: usize) {
217        self.inner.enable(idx);
218    }
219
220    /// Gets the hash of a key.
221    #[must_use]
222    pub fn get_hash(key: &[u8]) -> CompositeHash {
223        let h0 = xxhash_rust::xxh3::xxh3_128(key);
224        let h1 = (h0 >> 64) as u64;
225        let h2 = h0 as u64;
226        (h1, h2)
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use std::fs::File;
234    use test_log::test;
235
236    #[test]
237    fn bloom_serde_round_trip() -> crate::Result<()> {
238        let dir = tempfile::tempdir()?;
239
240        let path = dir.path().join("bf");
241        let mut file = File::create(&path)?;
242
243        let mut filter = BloomFilter::with_fp_rate(10, 0.0001);
244
245        let keys = &[
246            b"item0", b"item1", b"item2", b"item3", b"item4", b"item5", b"item6", b"item7",
247            b"item8", b"item9",
248        ];
249
250        for key in keys {
251            filter.set_with_hash(BloomFilter::get_hash(*key));
252        }
253
254        for key in keys {
255            assert!(filter.contains(&**key));
256        }
257        assert!(!filter.contains(b"asdasads"));
258        assert!(!filter.contains(b"item10"));
259        assert!(!filter.contains(b"cxycxycxy"));
260
261        filter.encode_into(&mut file)?;
262        file.sync_all()?;
263        drop(file);
264
265        let mut file = File::open(&path)?;
266        let filter_copy = BloomFilter::decode_from(&mut file)?;
267
268        assert_eq!(filter, filter_copy);
269
270        for key in keys {
271            assert!(filter.contains(&**key));
272        }
273        assert!(!filter_copy.contains(b"asdasads"));
274        assert!(!filter_copy.contains(b"item10"));
275        assert!(!filter_copy.contains(b"cxycxycxy"));
276
277        Ok(())
278    }
279
280    #[test]
281    fn bloom_calculate_m() {
282        assert_eq!(9_592, BloomFilter::calculate_m(1_000, 0.01));
283        assert_eq!(4_800, BloomFilter::calculate_m(1_000, 0.1));
284        assert_eq!(4_792_536, BloomFilter::calculate_m(1_000_000, 0.1));
285    }
286
287    #[test]
288    fn bloom_basic() {
289        let mut filter = BloomFilter::with_fp_rate(10, 0.0001);
290
291        for key in [
292            b"item0", b"item1", b"item2", b"item3", b"item4", b"item5", b"item6", b"item7",
293            b"item8", b"item9",
294        ] {
295            assert!(!filter.contains(key));
296            filter.set_with_hash(BloomFilter::get_hash(key));
297            assert!(filter.contains(key));
298
299            assert!(!filter.contains(b"asdasdasdasdasdasdasd"));
300        }
301    }
302
303    #[test]
304    fn bloom_bpk() {
305        let item_count = 1_000;
306        let bpk = 5;
307
308        let mut filter = BloomFilter::with_bpk(item_count, bpk);
309
310        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
311            let key = key.as_bytes();
312
313            filter.set_with_hash(BloomFilter::get_hash(key));
314            assert!(filter.contains(key));
315        }
316
317        let mut false_positives = 0;
318
319        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
320            let key = key.as_bytes();
321
322            if filter.contains(key) {
323                false_positives += 1;
324            }
325        }
326
327        #[allow(clippy::cast_precision_loss)]
328        let fpr = false_positives as f32 / item_count as f32;
329        assert!(fpr < 0.13);
330    }
331
332    #[test]
333    fn bloom_fpr() {
334        let item_count = 100_000;
335        let wanted_fpr = 0.1;
336
337        let mut filter = BloomFilter::with_fp_rate(item_count, wanted_fpr);
338
339        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
340            let key = key.as_bytes();
341
342            filter.set_with_hash(BloomFilter::get_hash(key));
343            assert!(filter.contains(key));
344        }
345
346        let mut false_positives = 0;
347
348        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
349            let key = key.as_bytes();
350
351            if filter.contains(key) {
352                false_positives += 1;
353            }
354        }
355
356        #[allow(clippy::cast_precision_loss)]
357        let fpr = false_positives as f32 / item_count as f32;
358        assert!(fpr > 0.05);
359        assert!(fpr < 0.13);
360    }
361
362    #[test]
363    fn bloom_fpr_2() {
364        let item_count = 100_000;
365        let wanted_fpr = 0.5;
366
367        let mut filter = BloomFilter::with_fp_rate(item_count, wanted_fpr);
368
369        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
370            let key = key.as_bytes();
371
372            filter.set_with_hash(BloomFilter::get_hash(key));
373            assert!(filter.contains(key));
374        }
375
376        let mut false_positives = 0;
377
378        for key in (0..item_count).map(|_| nanoid::nanoid!()) {
379            let key = key.as_bytes();
380
381            if filter.contains(key) {
382                false_positives += 1;
383            }
384        }
385
386        #[allow(clippy::cast_precision_loss)]
387        let fpr = false_positives as f32 / item_count as f32;
388        assert!(fpr > 0.45);
389        assert!(fpr < 0.55);
390    }
391}