Skip to main content

ix/
bloom.rs

1//! Per-file bloom filters.
2//!
3//! 256 bytes bitset, 5 hashes, FPR ≈ 0.7% for 200 unique trigrams.
4//! Eliminates candidate files before decoding posting lists.
5
6use crate::trigram::Trigram;
7use std::hash::Hasher;
8use std::io::Write;
9use xxhash_rust::xxh64::Xxh64;
10
11#[derive(Clone)]
12pub struct BloomFilter {
13    pub size: u16,
14    pub num_hashes: u8,
15    pub bits: Vec<u8>,
16}
17
18impl BloomFilter {
19    pub fn new(size: usize, num_hashes: u8) -> Self {
20        Self {
21            size: size as u16,
22            num_hashes,
23            bits: vec![0u8; size],
24        }
25    }
26
27    pub fn insert(&mut self, trigram: Trigram) {
28        let tri_bytes = trigram.to_le_bytes();
29        let h1 = self.hash(&tri_bytes, 0);
30        let h2 = self.hash(&tri_bytes, 1);
31        let num_bits = (self.size as usize) * 8;
32
33        for i in 0..self.num_hashes {
34            let bit_pos = (h1.wrapping_add((i as u64).wrapping_mul(h2))) % (num_bits as u64);
35            let byte_idx = (bit_pos / 8) as usize;
36            let bit_idx = (bit_pos % 8) as u8;
37            self.bits[byte_idx] |= 1 << bit_idx;
38        }
39    }
40
41    pub fn contains(&self, trigram: Trigram) -> bool {
42        let tri_bytes = trigram.to_le_bytes();
43        let h1 = self.hash(&tri_bytes, 0);
44        let h2 = self.hash(&tri_bytes, 1);
45        let num_bits = (self.size as usize) * 8;
46
47        for i in 0..self.num_hashes {
48            let bit_pos = (h1.wrapping_add((i as u64).wrapping_mul(h2))) % (num_bits as u64);
49            let byte_idx = (bit_pos / 8) as usize;
50            let bit_idx = (bit_pos % 8) as u8;
51            if self.bits[byte_idx] & (1 << bit_idx) == 0 {
52                return false;
53            }
54        }
55        true
56    }
57
58    fn hash(&self, data: &[u8], seed: u64) -> u64 {
59        let mut hasher = Xxh64::new(seed);
60        hasher.write(data);
61        hasher.finish()
62    }
63
64    pub fn serialize<W: Write>(&self, mut w: W) -> std::io::Result<()> {
65        w.write_all(&self.size.to_le_bytes())?;
66        w.write_all(&[self.num_hashes, 0x00])?; // padding
67        w.write_all(&self.bits)?;
68        Ok(())
69    }
70
71    /// Load from a slice (borrowed).
72    pub fn from_slice(data: &[u8]) -> Option<(&[u8], usize)> {
73        if data.len() < 4 {
74            return None;
75        }
76        let size = data[0..2]
77            .try_into()
78            .ok()
79            .map(u16::from_le_bytes)
80            .unwrap_or(0) as usize;
81        let num_hashes = data[2];
82        let total_size = 4 + size;
83        if data.len() < total_size {
84            return None;
85        }
86        Some((&data[4..total_size], num_hashes as usize))
87    }
88
89    /// Check if a slice (from mmap) contains a trigram.
90    pub fn slice_contains(bits: &[u8], num_hashes: u8, trigram: Trigram) -> bool {
91        let tri_bytes = trigram.to_le_bytes();
92        let mut h1_hasher = Xxh64::new(0);
93        h1_hasher.write(&tri_bytes);
94        let h1 = h1_hasher.finish();
95
96        let mut h2_hasher = Xxh64::new(1);
97        h2_hasher.write(&tri_bytes);
98        let h2 = h2_hasher.finish();
99
100        let num_bits = bits.len() * 8;
101
102        for i in 0..num_hashes {
103            let bit_pos = (h1.wrapping_add((i as u64).wrapping_mul(h2))) % (num_bits as u64);
104            let byte_idx = (bit_pos / 8) as usize;
105            let bit_idx = (bit_pos % 8) as u8;
106            if bits[byte_idx] & (1 << bit_idx) == 0 {
107                return false;
108            }
109        }
110        true
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn basic() {
120        let mut bloom = BloomFilter::new(256, 5);
121        let t1 = 0x010203;
122        let t2 = 0x040506;
123        bloom.insert(t1);
124        assert!(bloom.contains(t1));
125        assert!(!bloom.contains(t2));
126    }
127
128    #[test]
129    fn false_positives() {
130        let mut bloom = BloomFilter::new(256, 5);
131        for i in 0..200 {
132            bloom.insert(i as u32);
133        }
134        let mut fp = 0;
135        for i in 200..1200 {
136            if bloom.contains(i as u32) {
137                fp += 1;
138            }
139        }
140        // Expect FPR < 1%
141        assert!(fp < 20, "FPR too high: {}/1000", fp);
142    }
143}