Skip to main content

ix/
bloom.rs

1//! Per-file bloom filters.
2//!
3//! 256 bytes bitset, 5 hashes, FPR ≈ 0.7% for 200 unique trigrams.
4//! Eliminates candidate files before decoding posting lists.
5
6use crate::trigram::Trigram;
7use std::hash::Hasher;
8use std::io::Write;
9use xxhash_rust::xxh64::Xxh64;
10
11/// Per-file bloom filter for fast trigram membership testing.
12///
13/// Uses a 256-byte bitset with 5 independent hashes, yielding
14/// a false-positive rate of ~0.7% for up to 200 trigrams.
15#[derive(Clone)]
16pub struct BloomFilter {
17    /// Size of the bit array in bytes.
18    pub size: u16,
19    /// Number of hash functions (k).
20    pub num_hashes: u8,
21    /// The bit array backing the filter.
22    pub bits: Vec<u8>,
23}
24
25impl BloomFilter {
26    /// Create a new bloom filter with the given size and number of hash functions.
27    #[must_use]
28    pub fn new(size: usize, num_hashes: u8) -> Self {
29        Self {
30            size: u16::try_from(size).unwrap_or(0),
31            num_hashes,
32            bits: vec![0u8; size],
33        }
34    }
35
36    /// Insert a trigram into the bloom filter.
37    pub fn insert(&mut self, trigram: Trigram) {
38        let tri_bytes = trigram.to_le_bytes();
39        let h1 = Self::hash(&tri_bytes, 0);
40        let h2 = Self::hash(&tri_bytes, 1);
41        let num_bits = usize::from(self.size) * 8;
42
43        for i in 0..self.num_hashes {
44            let bit_pos = (h1.wrapping_add(u64::from(i).wrapping_mul(h2)))
45                % u64::try_from(num_bits).unwrap_or(0);
46            let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
47            let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
48            if let Some(byte) = self.bits.get_mut(byte_idx) {
49                *byte |= 1 << bit_idx;
50            }
51        }
52    }
53
54    /// Check if a trigram may be present in the filter.
55    ///
56    /// Returns `true` if the trigram might be present (could be a false positive).
57    /// Returns `false` only if the trigram is definitely not present.
58    #[must_use]
59    pub fn contains(&self, trigram: Trigram) -> bool {
60        let tri_bytes = trigram.to_le_bytes();
61        let h1 = Self::hash(&tri_bytes, 0);
62        let h2 = Self::hash(&tri_bytes, 1);
63        let num_bits = usize::from(self.size) * 8;
64
65        for i in 0..self.num_hashes {
66            let bit_pos = (h1.wrapping_add(u64::from(i).wrapping_mul(h2)))
67                % u64::try_from(num_bits).unwrap_or(0);
68            let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
69            let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
70            if self
71                .bits
72                .get(byte_idx)
73                .is_none_or(|&b| b & (1 << bit_idx) == 0)
74            {
75                return false;
76            }
77        }
78        true
79    }
80
81    fn hash(data: &[u8], seed: u64) -> u64 {
82        let mut hasher = Xxh64::new(seed);
83        hasher.write(data);
84        hasher.finish()
85    }
86
87    /// Serialize the bloom filter to a writer.
88    ///
89    /// # Errors
90    ///
91    /// Returns an I/O error if the writer fails.
92    pub fn serialize<W: Write>(&self, mut w: W) -> std::io::Result<()> {
93        w.write_all(&self.size.to_le_bytes())?;
94        w.write_all(&[self.num_hashes, 0x00])?;
95        w.write_all(&self.bits)?;
96        Ok(())
97    }
98
99    /// Load from a slice (borrowed).
100    #[must_use]
101    pub fn from_slice(data: &[u8]) -> Option<(&[u8], usize)> {
102        if data.len() < 4 {
103            return None;
104        }
105        let size = data
106            .get(0..2)?
107            .try_into()
108            .ok()
109            .map_or(0, u16::from_le_bytes);
110        let size = usize::from(size);
111        let num_hashes = *data.get(2)?;
112        let total_size = 4 + size;
113        if data.len() < total_size {
114            return None;
115        }
116        data.get(4..total_size)
117            .map(|bits| (bits, usize::from(num_hashes)))
118    }
119
120    /// Check if a slice (from mmap) contains a trigram.
121    #[must_use]
122    pub fn slice_contains(bits: &[u8], num_hashes: u8, trigram: Trigram) -> bool {
123        let tri_bytes = trigram.to_le_bytes();
124        let mut h1_hasher = Xxh64::new(0);
125        h1_hasher.write(&tri_bytes);
126        let h1 = h1_hasher.finish();
127
128        let mut h2_hasher = Xxh64::new(1);
129        h2_hasher.write(&tri_bytes);
130        let h2 = h2_hasher.finish();
131
132        let num_bits = bits.len() * 8;
133
134        for i in 0..num_hashes {
135            let bit_pos = (h1.wrapping_add(u64::from(i).wrapping_mul(h2)))
136                % u64::try_from(num_bits).unwrap_or(0);
137            let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
138            let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
139            if bits.get(byte_idx).is_none_or(|&b| b & (1 << bit_idx) == 0) {
140                return false;
141            }
142        }
143        true
144    }
145}
146
147#[cfg(test)]
148#[allow(clippy::as_conversions, clippy::unwrap_used, clippy::indexing_slicing)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn basic() {
154        let mut bloom = BloomFilter::new(256, 5);
155        let t1 = 0x0001_0203;
156        let t2 = 0x0004_0506;
157        bloom.insert(t1);
158        assert!(bloom.contains(t1));
159        assert!(!bloom.contains(t2));
160    }
161
162    #[test]
163    fn false_positives() {
164        let mut bloom = BloomFilter::new(256, 5);
165        for i in 0..200 {
166            bloom.insert(i as u32);
167        }
168        let mut fp = 0;
169        for i in 200..1200 {
170            if bloom.contains(i as u32) {
171                fp += 1;
172            }
173        }
174        assert!(fp < 20, "FPR too high: {fp}/1000");
175    }
176}