Skip to main content

ix/
bloom.rs

1//! Per-file bloom filters.
2//!
3//! 256 bytes bitset, 5 hashes, FPR ≈ 0.7% for 200 unique trigrams.
4//! Eliminates candidate files before decoding posting lists.
5
6use crate::trigram::Trigram;
7use std::hash::Hasher;
8use std::io::Write;
9use xxhash_rust::xxh64::Xxh64;
10
11/// Per-file bloom filter for fast trigram membership testing.
12///
13/// Uses a 256-byte bitset with 5 independent hashes, yielding
14/// a false-positive rate of ~0.7% for up to 200 trigrams.
15#[derive(Clone)]
16pub struct BloomFilter {
17    /// Size of the bit array in bytes.
18    pub size: u16,
19    /// Number of hash functions (k).
20    pub num_hashes: u8,
21    /// The bit array backing the filter.
22    pub bits: Vec<u8>,
23}
24
25impl BloomFilter {
26    /// Create a new bloom filter with the given size and number of hash functions.
27    #[must_use]
28    pub fn new(size: usize, num_hashes: u8) -> Self {
29        Self {
30            size: u16::try_from(size).unwrap_or(0),
31            num_hashes,
32            bits: vec![0u8; size],
33        }
34    }
35
36    /// Insert a trigram into the bloom filter.
37    pub fn insert(&mut self, trigram: Trigram) {
38        let tri_bytes = trigram.to_le_bytes();
39        let h1 = Self::hash(&tri_bytes, 0);
40        let h2 = Self::hash(&tri_bytes, 1);
41        let num_bits = usize::from(self.size) * 8;
42
43        for i in 0..self.num_hashes {
44            let bit_pos =
45                (h1.wrapping_add(u64::from(i).wrapping_mul(h2))) % u64::try_from(num_bits).unwrap_or(0);
46            let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
47            let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
48            if let Some(byte) = self.bits.get_mut(byte_idx) {
49                *byte |= 1 << bit_idx;
50            }
51        }
52    }
53
54    /// Check if a trigram may be present in the filter.
55    ///
56    /// Returns `true` if the trigram might be present (could be a false positive).
57    /// Returns `false` only if the trigram is definitely not present.
58    #[must_use]
59    pub fn contains(&self, trigram: Trigram) -> bool {
60        let tri_bytes = trigram.to_le_bytes();
61        let h1 = Self::hash(&tri_bytes, 0);
62        let h2 = Self::hash(&tri_bytes, 1);
63        let num_bits = usize::from(self.size) * 8;
64
65        for i in 0..self.num_hashes {
66            let bit_pos =
67                (h1.wrapping_add(u64::from(i).wrapping_mul(h2))) % u64::try_from(num_bits).unwrap_or(0);
68            let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
69            let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
70            if self.bits.get(byte_idx).is_none_or(|&b| b & (1 << bit_idx) == 0) {
71                return false;
72            }
73        }
74        true
75    }
76
77    fn hash(data: &[u8], seed: u64) -> u64 {
78        let mut hasher = Xxh64::new(seed);
79        hasher.write(data);
80        hasher.finish()
81    }
82
83    /// Serialize the bloom filter to a writer.
84    ///
85    /// # Errors
86    ///
87    /// Returns an I/O error if the writer fails.
88    pub fn serialize<W: Write>(&self, mut w: W) -> std::io::Result<()> {
89        w.write_all(&self.size.to_le_bytes())?;
90        w.write_all(&[self.num_hashes, 0x00])?;
91        w.write_all(&self.bits)?;
92        Ok(())
93    }
94
95    /// Load from a slice (borrowed).
96    #[must_use]
97    pub fn from_slice(data: &[u8]) -> Option<(&[u8], usize)> {
98        if data.len() < 4 {
99            return None;
100        }
101        let size = data
102            .get(0..2)?
103            .try_into()
104            .ok()
105            .map_or(0, u16::from_le_bytes);
106        let size = usize::from(size);
107        let num_hashes = *data.get(2)?;
108        let total_size = 4 + size;
109        if data.len() < total_size {
110            return None;
111        }
112        data.get(4..total_size)
113            .map(|bits| (bits, usize::from(num_hashes)))
114    }
115
116    /// Check if a slice (from mmap) contains a trigram.
117    #[must_use]
118    pub fn slice_contains(bits: &[u8], num_hashes: u8, trigram: Trigram) -> bool {
119        let tri_bytes = trigram.to_le_bytes();
120        let mut h1_hasher = Xxh64::new(0);
121        h1_hasher.write(&tri_bytes);
122        let h1 = h1_hasher.finish();
123
124        let mut h2_hasher = Xxh64::new(1);
125        h2_hasher.write(&tri_bytes);
126        let h2 = h2_hasher.finish();
127
128        let num_bits = bits.len() * 8;
129
130        for i in 0..num_hashes {
131            let bit_pos =
132                (h1.wrapping_add(u64::from(i).wrapping_mul(h2))) % u64::try_from(num_bits).unwrap_or(0);
133            let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
134            let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
135            if bits.get(byte_idx).is_none_or(|&b| b & (1 << bit_idx) == 0) {
136                return false;
137            }
138        }
139        true
140    }
141}
142
143#[cfg(test)]
144#[allow(clippy::as_conversions, clippy::unwrap_used, clippy::indexing_slicing)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn basic() {
150        let mut bloom = BloomFilter::new(256, 5);
151        let t1 = 0x0001_0203;
152        let t2 = 0x0004_0506;
153        bloom.insert(t1);
154        assert!(bloom.contains(t1));
155        assert!(!bloom.contains(t2));
156    }
157
158    #[test]
159    fn false_positives() {
160        let mut bloom = BloomFilter::new(256, 5);
161        for i in 0..200 {
162            bloom.insert(i as u32);
163        }
164        let mut fp = 0;
165        for i in 200..1200 {
166            if bloom.contains(i as u32) {
167                fp += 1;
168            }
169        }
170        assert!(fp < 20, "FPR too high: {fp}/1000");
171    }
172}