Skip to main content

silk/
bloom.rs

1use serde::{Deserialize, Serialize};
2
3use crate::entry::Hash;
4
5/// A simple Bloom filter for sync negotiation.
6///
7/// Used during sync to quickly determine which entries a peer likely has.
8/// False positives are expected (~1% with default parameters); false negatives
9/// never occur. Subsequent sync rounds resolve false positives via explicit
10/// `need` lists.
11///
12/// Parameters: `num_bits` total bits in the bitvec, `num_hashes` hash functions
13/// (derived from BLAKE3 by slicing the 32-byte hash into segments).
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BloomFilter {
16    bits: Vec<u64>,
17    num_bits: usize,
18    num_hashes: u32,
19    count: usize,
20}
21
22impl BloomFilter {
23    /// S-05: Validate bloom filter dimensions after deserialization.
24    /// Prevents panics from malformed sync offers (division by zero, out of bounds).
25    pub fn validate(&self) -> Result<(), String> {
26        if self.num_bits == 0 {
27            return Err("bloom filter num_bits must be > 0".into());
28        }
29        if self.bits.len() * 64 < self.num_bits {
30            return Err(format!(
31                "bloom filter bits array too small: {} words for {} bits",
32                self.bits.len(),
33                self.num_bits
34            ));
35        }
36        if self.num_hashes == 0 || self.num_hashes > 32 {
37            return Err(format!(
38                "bloom filter num_hashes {} out of range [1, 32]",
39                self.num_hashes
40            ));
41        }
42        Ok(())
43    }
44
45    /// Create a new Bloom filter sized for `expected_items` with the given
46    /// false positive rate.
47    ///
48    /// Computes optimal bit count and hash count from the standard formulas:
49    /// - m = -n * ln(p) / (ln(2)^2)
50    /// - k = (m/n) * ln(2)
51    pub fn new(expected_items: usize, fp_rate: f64) -> Self {
52        assert!(expected_items > 0, "expected_items must be > 0");
53        assert!((0.0..1.0).contains(&fp_rate), "fp_rate must be in (0, 1)");
54
55        let n = expected_items as f64;
56        let ln2 = std::f64::consts::LN_2;
57        let ln2_sq = ln2 * ln2;
58
59        let num_bits = ((-n * fp_rate.ln()) / ln2_sq).ceil() as usize;
60        let num_bits = num_bits.max(64); // minimum 64 bits
61        let num_hashes = ((num_bits as f64 / n) * ln2).ceil() as u32;
62        let num_hashes = num_hashes.max(1);
63
64        let words = num_bits.div_ceil(64);
65        Self {
66            bits: vec![0u64; words],
67            num_bits,
68            num_hashes,
69            count: 0,
70        }
71    }
72
73    /// Insert a hash into the filter.
74    pub fn insert(&mut self, hash: &Hash) {
75        for idx in self.indices(hash) {
76            let word = idx / 64;
77            let bit = idx % 64;
78            self.bits[word] |= 1u64 << bit;
79        }
80        self.count += 1;
81    }
82
83    /// Check if a hash might be in the filter.
84    ///
85    /// Returns `true` if the item is probably present (with false positive rate),
86    /// `false` if the item is definitely NOT present.
87    pub fn contains(&self, hash: &Hash) -> bool {
88        for idx in self.indices(hash) {
89            let word = idx / 64;
90            let bit = idx % 64;
91            if self.bits[word] & (1u64 << bit) == 0 {
92                return false;
93            }
94        }
95        true
96    }
97
98    /// Number of items inserted.
99    pub fn count(&self) -> usize {
100        self.count
101    }
102
103    /// Merge another filter into this one (union).
104    ///
105    /// Both filters must have the same dimensions (num_bits, num_hashes).
106    pub fn merge(&mut self, other: &BloomFilter) {
107        assert_eq!(self.num_bits, other.num_bits, "bloom filter size mismatch");
108        assert_eq!(
109            self.num_hashes, other.num_hashes,
110            "bloom filter hash count mismatch"
111        );
112        for (a, b) in self.bits.iter_mut().zip(other.bits.iter()) {
113            *a |= *b;
114        }
115        self.count += other.count;
116    }
117
118    /// Serialize to MessagePack bytes.
119    pub fn to_bytes(&self) -> Vec<u8> {
120        rmp_serde::to_vec(self).expect("bloom filter serialization should not fail")
121    }
122
123    /// Deserialize from MessagePack bytes.
124    pub fn from_bytes(bytes: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
125        rmp_serde::from_slice(bytes)
126    }
127
128    /// Compute the bit indices for a given hash.
129    ///
130    /// Uses enhanced double hashing: h_i = h1 + i*h2 + i^2 (mod num_bits)
131    /// where h1 and h2 are derived from the first 16 bytes of the BLAKE3 hash.
132    fn indices(&self, hash: &Hash) -> Vec<usize> {
133        // Split the 32-byte hash into two 8-byte values for double hashing.
134        let h1 = u64::from_le_bytes(hash[0..8].try_into().unwrap());
135        let h2 = u64::from_le_bytes(hash[8..16].try_into().unwrap());
136        let m = self.num_bits as u64;
137
138        (0..self.num_hashes)
139            .map(|i| {
140                let i = i as u64;
141                // Enhanced double hashing with quadratic probing
142                let idx = h1
143                    .wrapping_add(i.wrapping_mul(h2))
144                    .wrapping_add(i.wrapping_mul(i));
145                (idx % m) as usize
146            })
147            .collect()
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    fn make_hash(seed: u8) -> Hash {
156        let mut h = [0u8; 32];
157        h[0] = seed;
158        // Use BLAKE3 to get a proper distribution
159        *blake3::hash(&h).as_bytes()
160    }
161
162    #[test]
163    fn bloom_insert_and_check() {
164        let mut bloom = BloomFilter::new(100, 0.01);
165        let h1 = make_hash(1);
166        let h2 = make_hash(2);
167        let h3 = make_hash(3);
168
169        bloom.insert(&h1);
170        bloom.insert(&h2);
171
172        assert!(bloom.contains(&h1));
173        assert!(bloom.contains(&h2));
174        // h3 was not inserted — should (almost certainly) not be found
175        // Note: this could theoretically fail with a false positive,
176        // but with FPR 0.01 and only 2 items in a 100-item filter, it's vanishingly unlikely.
177        assert!(!bloom.contains(&h3));
178    }
179
180    #[test]
181    fn bloom_empty_contains_nothing() {
182        let bloom = BloomFilter::new(100, 0.01);
183        for i in 0..=255 {
184            assert!(!bloom.contains(&make_hash(i)));
185        }
186    }
187
188    #[test]
189    fn bloom_false_positive_rate() {
190        // Insert 1000 items, check 10000 non-inserted items.
191        // FPR target: 1%. Allow up to 2% for statistical variance.
192        let n = 1000;
193        let mut bloom = BloomFilter::new(n, 0.01);
194
195        for i in 0..n {
196            let h = *blake3::hash(&(i as u64).to_le_bytes()).as_bytes();
197            bloom.insert(&h);
198        }
199
200        let test_count = 10_000;
201        let mut false_positives = 0;
202        for i in n..(n + test_count) {
203            let h = *blake3::hash(&(i as u64).to_le_bytes()).as_bytes();
204            if bloom.contains(&h) {
205                false_positives += 1;
206            }
207        }
208
209        let fpr = false_positives as f64 / test_count as f64;
210        assert!(
211            fpr < 0.02,
212            "false positive rate {fpr:.4} exceeds 2% threshold"
213        );
214    }
215
216    #[test]
217    fn bloom_merge_union() {
218        let mut bloom_a = BloomFilter::new(100, 0.01);
219        let mut bloom_b = BloomFilter::new(100, 0.01);
220
221        let h1 = make_hash(1);
222        let h2 = make_hash(2);
223        let h3 = make_hash(3);
224
225        bloom_a.insert(&h1);
226        bloom_a.insert(&h2);
227        bloom_b.insert(&h2);
228        bloom_b.insert(&h3);
229
230        bloom_a.merge(&bloom_b);
231
232        // Merged filter should contain all three
233        assert!(bloom_a.contains(&h1));
234        assert!(bloom_a.contains(&h2));
235        assert!(bloom_a.contains(&h3));
236    }
237
238    #[test]
239    fn bloom_serialization_roundtrip() {
240        let mut bloom = BloomFilter::new(100, 0.01);
241        let h1 = make_hash(1);
242        let h2 = make_hash(2);
243        bloom.insert(&h1);
244        bloom.insert(&h2);
245
246        let bytes = bloom.to_bytes();
247        let restored = BloomFilter::from_bytes(&bytes).unwrap();
248
249        assert!(restored.contains(&h1));
250        assert!(restored.contains(&h2));
251        assert!(!restored.contains(&make_hash(3)));
252        assert_eq!(restored.count(), 2);
253        assert_eq!(restored.num_bits, bloom.num_bits);
254        assert_eq!(restored.num_hashes, bloom.num_hashes);
255    }
256
257    #[test]
258    fn bloom_count_tracks_inserts() {
259        let mut bloom = BloomFilter::new(100, 0.01);
260        assert_eq!(bloom.count(), 0);
261        bloom.insert(&make_hash(1));
262        assert_eq!(bloom.count(), 1);
263        bloom.insert(&make_hash(2));
264        assert_eq!(bloom.count(), 2);
265    }
266}