Skip to main content

ailake_query/
bloom.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Compact Bloom filter for per-file BM25 term pruning (Phase F).
3//!
4//! Uses k=4 independent hash probes derived from two FNV-64 seeds (double-hashing
5//! trick: h_i(x) = h1(x) + i*h2(x) mod m). No external hash dep — FNV-64 is
6//! trivially inlined.
7//!
8//! Serialization: 8-byte little-endian num_bits header + bit words (u64, LE).
9
10const K: usize = 4;
11
12fn fnv64a(data: &[u8], seed: u64) -> u64 {
13    const PRIME: u64 = 0x00000100000001B3;
14    let mut h = 0xcbf29ce484222325u64 ^ seed;
15    for &b in data {
16        h ^= b as u64;
17        h = h.wrapping_mul(PRIME);
18    }
19    h
20}
21
22/// Probabilistic set membership test for string terms.
23///
24/// Optimized for the BM25 file-pruning use case: insert all tokenized terms from
25/// a data file at write time; at search time, check whether any query term
26/// *may* be present. False positives keep the file (safe); false negatives
27/// are impossible.
28#[derive(Clone)]
29pub struct BloomFilter {
30    bits: Vec<u64>,
31    num_bits: usize,
32}
33
34impl BloomFilter {
35    /// Construct a Bloom filter sized for `capacity` items at `fpr` false-positive rate.
36    ///
37    /// Formula: `m = -n * ln(p) / ln(2)²`, rounded up to the next 64-bit boundary.
38    /// A `capacity=10_000, fpr=0.01` filter uses ~12 KB of bit storage.
39    pub fn with_capacity(capacity: usize, fpr: f64) -> Self {
40        let n = capacity.max(1) as f64;
41        let bits_f = -n * fpr.ln() / std::f64::consts::LN_2.powi(2);
42        let num_bits = (bits_f.ceil() as usize).max(64);
43        // Round up to next multiple of 64 so word-aligned ops are always valid.
44        let num_bits = (num_bits + 63) & !63;
45        Self {
46            bits: vec![0u64; num_bits / 64],
47            num_bits,
48        }
49    }
50
51    fn probes(&self, term: &[u8]) -> [usize; K] {
52        let h1 = fnv64a(term, 0);
53        let h2 = fnv64a(term, 0xcbf29ce484222325u64);
54        let m = self.num_bits as u64;
55        let mut out = [0usize; K];
56        for (i, slot) in out.iter_mut().enumerate().take(K) {
57            *slot = (h1.wrapping_add((i as u64).wrapping_mul(h2)) % m) as usize;
58        }
59        out
60    }
61
62    pub fn insert(&mut self, term: &str) {
63        for pos in self.probes(term.as_bytes()) {
64            self.bits[pos / 64] |= 1u64 << (pos % 64);
65        }
66    }
67
68    /// Returns `true` if `term` *might* be in the set (false positives possible).
69    /// Returns `false` only when the term is *definitely absent*.
70    pub fn may_contain(&self, term: &str) -> bool {
71        for pos in self.probes(term.as_bytes()) {
72            if self.bits[pos / 64] & (1u64 << (pos % 64)) == 0 {
73                return false;
74            }
75        }
76        true
77    }
78
79    /// Serialize to bytes: `u64_le(num_bits) || u64_le(word_0) || ... || u64_le(word_N)`.
80    pub fn to_bytes(&self) -> Vec<u8> {
81        let mut out = Vec::with_capacity(8 + self.bits.len() * 8);
82        out.extend_from_slice(&(self.num_bits as u64).to_le_bytes());
83        for &w in &self.bits {
84            out.extend_from_slice(&w.to_le_bytes());
85        }
86        out
87    }
88
89    /// Deserialize from bytes produced by `to_bytes`. Returns `None` on malformed input.
90    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
91        if bytes.len() < 8 {
92            return None;
93        }
94        let num_bits = u64::from_le_bytes(bytes[0..8].try_into().ok()?) as usize;
95        if num_bits == 0 || !num_bits.is_multiple_of(64) {
96            return None;
97        }
98        let word_count = num_bits / 64;
99        if bytes.len() < 8 + word_count * 8 {
100            return None;
101        }
102        let bits: Vec<u64> = (0..word_count)
103            .map(|i| u64::from_le_bytes(bytes[8 + i * 8..16 + i * 8].try_into().unwrap()))
104            .collect();
105        Some(Self { bits, num_bits })
106    }
107
108    pub fn num_bits(&self) -> usize {
109        self.num_bits
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn inserted_terms_always_found() {
119        let mut bf = BloomFilter::with_capacity(100, 0.01);
120        let terms = ["rust", "iceberg", "puffin", "vector", "bloom"];
121        for t in &terms {
122            bf.insert(t);
123        }
124        for t in &terms {
125            assert!(
126                bf.may_contain(t),
127                "term '{t}' should be found after insertion"
128            );
129        }
130    }
131
132    #[test]
133    fn absent_terms_mostly_absent() {
134        let mut bf = BloomFilter::with_capacity(1000, 0.01);
135        for i in 0..500u32 {
136            bf.insert(&format!("term_{i}"));
137        }
138        // Check 100 absent terms — expect ~1% false positives max with fpr=0.01
139        let fp: usize = (500u32..600)
140            .filter(|i| bf.may_contain(&format!("absent_{i}")))
141            .count();
142        // Allow up to 5% in practice (probabilistic)
143        assert!(fp <= 5, "too many false positives: {fp}/100");
144    }
145
146    #[test]
147    fn roundtrip_serialization() {
148        let mut bf = BloomFilter::with_capacity(50, 0.01);
149        bf.insert("hello");
150        bf.insert("world");
151        let bytes = bf.to_bytes();
152        let restored = BloomFilter::from_bytes(&bytes).expect("deserialization should succeed");
153        assert!(restored.may_contain("hello"));
154        assert!(restored.may_contain("world"));
155        // num_bits preserved
156        assert_eq!(restored.num_bits(), bf.num_bits());
157    }
158
159    #[test]
160    fn from_bytes_returns_none_on_short_input() {
161        assert!(BloomFilter::from_bytes(&[]).is_none());
162        assert!(BloomFilter::from_bytes(&[0u8; 7]).is_none());
163    }
164
165    #[test]
166    fn definitely_absent_term_returns_false() {
167        let bf = BloomFilter::with_capacity(64, 0.001);
168        // Never inserted anything — every term must be absent.
169        assert!(
170            !bf.may_contain("anything"),
171            "empty filter must return false for all terms"
172        );
173    }
174}