Skip to main content

commonware_cryptography/bloomfilter/
mod.rs

1//! An implementation of a [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter).
2
3#[cfg(all(test, feature = "arbitrary"))]
4mod conformance;
5
6use crate::{sha256::Sha256, Hasher};
7use bytes::{Buf, BufMut};
8use commonware_codec::{
9    codec::{Read, Write},
10    error::Error as CodecError,
11    EncodeSize, FixedSize,
12};
13use commonware_utils::bitmap::BitMap;
14use core::{
15    marker::PhantomData,
16    num::{NonZeroU64, NonZeroU8, NonZeroUsize},
17};
18#[cfg(feature = "std")]
19use {
20    commonware_utils::rational::BigRationalExt,
21    num_rational::BigRational,
22    num_traits::{One, ToPrimitive, Zero},
23};
24
25/// Rational approximation of ln(2) with 6 digits of precision: 14397/20769.
26#[cfg(feature = "std")]
27const LN2: (u64, u64) = (14397, 20769);
28
29/// Rational approximation of 1/ln(2) with 6 digits of precision: 29145/20201.
30#[cfg(feature = "std")]
31const LN2_INV: (u64, u64) = (29145, 20201);
32
33/// A [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter).
34///
35/// This implementation uses the Kirsch-Mitzenmacher optimization to derive `k` hash functions
36/// from two hash values, which are in turn derived from a single hash digest. This provides
37/// efficient hashing for [BloomFilter::insert] and [BloomFilter::contains] operations.
38///
39/// # Hasher Selection
40///
41/// The `H` type parameter specifies the hash function to use. It defaults to [Sha256].
42/// The hasher's digest must be at least 16 bytes (128 bits) long, this is enforced at
43/// compile time.
44///
45/// When choosing a hasher, consider:
46///
47/// - **Security**: If the bloom filter accepts untrusted input, use a cryptographically
48///   secure hash function to prevent attackers from crafting inputs that cause excessive
49///   collisions (degrading the filter to always return `true`).
50///
51/// - **Determinism**: If the bloom filter must produce consistent results across runs
52///   or machines (e.g. for serialization or consensus-critical applications), avoid keyed
53///   or randomized hash functions. Both [Sha256] and [Blake3](crate::blake3::Blake3)
54///   are deterministic.
55///
56/// - **Performance**: Hash function performance varies with the size of items inserted
57///   and queried. [Sha256] is faster for smaller items (up to ~2KB), while
58///   [Blake3](crate::blake3::Blake3) is faster for larger items (4KB+).
59#[derive(Clone, Debug)]
60pub struct BloomFilter<H: Hasher = Sha256> {
61    hashers: u8,
62    bits: BitMap,
63    _marker: PhantomData<H>,
64}
65
66impl<H: Hasher> PartialEq for BloomFilter<H> {
67    fn eq(&self, other: &Self) -> bool {
68        self.hashers == other.hashers && self.bits == other.bits
69    }
70}
71
72impl<H: Hasher> Eq for BloomFilter<H> {}
73
74impl<H: Hasher> BloomFilter<H> {
75    /// Compile-time assertion that the digest is at least 16 bytes.
76    const _ASSERT_DIGEST_AT_LEAST_16_BYTES: () = assert!(
77        <H::Digest as FixedSize>::SIZE >= 16,
78        "digest must be at least 128 bits (16 bytes)"
79    );
80
81    /// Creates a new [BloomFilter] with `hashers` hash functions and `bits` bits.
82    ///
83    /// The number of bits will be rounded up to the next power of 2. If that would
84    /// overflow, the maximum power of 2 for the platform (2^63 on 64-bit) is used.
85    pub fn new(hashers: NonZeroU8, bits: NonZeroUsize) -> Self {
86        let bits = bits
87            .get()
88            .checked_next_power_of_two()
89            .unwrap_or(1 << (usize::BITS - 1));
90        Self {
91            hashers: hashers.get(),
92            bits: BitMap::zeroes(bits as u64),
93            _marker: PhantomData,
94        }
95    }
96
97    /// Creates a new [BloomFilter] with optimal parameters for the expected number
98    /// of items and desired false positive rate.
99    ///
100    /// Uses exact rational arithmetic for full determinism across all platforms.
101    ///
102    /// # Arguments
103    ///
104    /// * `expected_items` - Number of items expected to be inserted
105    /// * `fp_rate` - False positive rate as a rational (e.g., `BigRational::from_frac_u64(1, 100)` for 1%)
106    ///
107    /// # Panics
108    ///
109    /// Panics if `fp_rate` is not in (0, 1).
110    #[cfg(feature = "std")]
111    pub fn with_rate(expected_items: NonZeroUsize, fp_rate: BigRational) -> Self {
112        let bits = Self::optimal_bits(expected_items.get(), &fp_rate);
113        let hashers = Self::optimal_hashers(expected_items.get(), bits);
114        Self {
115            hashers,
116            bits: BitMap::zeroes(bits as u64),
117            _marker: PhantomData,
118        }
119    }
120
121    /// Returns the number of hashers used by the filter.
122    pub const fn hashers(&self) -> NonZeroU8 {
123        NonZeroU8::new(self.hashers).expect("hashers is never zero")
124    }
125
126    /// Returns the number of bits used by the filter.
127    pub const fn bits(&self) -> NonZeroUsize {
128        NonZeroUsize::new(self.bits.len() as usize).expect("bits is never zero")
129    }
130
131    /// Generate `num_hashers` bit indices for a given item.
132    fn indices(&self, item: &[u8]) -> impl Iterator<Item = u64> {
133        #[allow(path_statements)]
134        Self::_ASSERT_DIGEST_AT_LEAST_16_BYTES;
135
136        // Extract two 64-bit hash values from the digest of the item
137        let digest = H::hash(item);
138        let h1 = u64::from_be_bytes(digest[0..8].try_into().unwrap());
139        let mut h2 = u64::from_be_bytes(digest[8..16].try_into().unwrap());
140
141        // Ensure h2 is odd (non-zero). If h2 were 0, all k hash functions would
142        // produce the same index (h1), defeating the purpose of multiple hashers.
143        h2 |= 1;
144
145        // Generate `hashers` hashes using the Kirsch-Mitzenmacher optimization:
146        //
147        // `h_i(x) = (h1(x) + i * h2(x)) mod m`
148        let hashers = self.hashers as u64;
149        let mask = self.bits.len() - 1;
150        (0..hashers).map(move |hasher| h1.wrapping_add(hasher.wrapping_mul(h2)) & mask)
151    }
152
153    /// Inserts an item into the [BloomFilter].
154    pub fn insert(&mut self, item: &[u8]) {
155        let indices = self.indices(item);
156        for index in indices {
157            self.bits.set(index, true);
158        }
159    }
160
161    /// Checks if an item is possibly in the [BloomFilter].
162    ///
163    /// Returns `true` if the item is probably in the set, and `false` if it is definitely not.
164    pub fn contains(&self, item: &[u8]) -> bool {
165        let indices = self.indices(item);
166        for index in indices {
167            if !self.bits.get(index) {
168                return false;
169            }
170        }
171        true
172    }
173
174    /// Estimates the current false positive probability.
175    ///
176    /// This approximates the false positive rate as `f^k` where `f` is the fill ratio
177    /// (proportion of bits set to 1) and `k` is the number of hash functions.
178    ///
179    /// Returns a [`BigRational`] for exact representation and cross-platform determinism.
180    #[cfg(feature = "std")]
181    pub fn estimated_false_positive_rate(&self) -> BigRational {
182        let ones = self.bits.count_ones();
183        let len = self.bits.len();
184        let fill_ratio = BigRational::new(ones.into(), len.into());
185        fill_ratio.pow(self.hashers as i32)
186    }
187
188    /// Estimates the number of items that have been inserted.
189    ///
190    /// Uses the formula `n = -(m/k) * ln(1 - x/m)` where `m` is the number of bits,
191    /// `k` is the number of hash functions, and `x` is the number of bits set to 1.
192    ///
193    /// Returns a [`BigRational`] using `log2_floor` for the logarithm computation.
194    #[cfg(feature = "std")]
195    pub fn estimated_count(&self) -> BigRational {
196        let m = self.bits.len();
197        let x = self.bits.count_ones();
198        let k = self.hashers as u64;
199        if x >= m {
200            return BigRational::from_usize(usize::MAX);
201        }
202
203        // ln(1 - x/m) = log2(1 - x/m) * ln(2)
204        let one_minus_fill = BigRational::new((m - x).into(), m.into());
205        let log2_val = one_minus_fill.log2_floor(16);
206        let ln2 = BigRational::from_frac_u64(LN2.0, LN2.1);
207        let ln_result = &log2_val * &ln2;
208
209        // n = -(m/k) * ln(1 - x/m)
210        let m_over_k = BigRational::new(m.into(), k.into());
211        -m_over_k * ln_result
212    }
213
214    /// Calculates the optimal number of hash functions for a given capacity and bit count.
215    ///
216    /// Uses [`BigRational`] for determinism. The result is clamped to [1, 16] since
217    /// beyond ~10-12 hashes provides negligible improvement while increasing CPU cost.
218    #[cfg(feature = "std")]
219    pub fn optimal_hashers(expected_items: usize, bits: usize) -> u8 {
220        if expected_items == 0 {
221            return 1;
222        }
223
224        // k = (m/n) * ln(2)
225        let ln2 = BigRational::from_frac_u64(LN2.0, LN2.1);
226        let k_ratio = BigRational::from_usize(bits) * ln2 / BigRational::from_usize(expected_items);
227        k_ratio.to_integer().to_u8().unwrap_or(16).clamp(1, 16)
228    }
229
230    /// Calculates the optimal number of bits for a given capacity and false positive rate.
231    ///
232    /// Uses exact rational arithmetic for full determinism across all platforms.
233    /// The result is rounded up to the next power of 2. If that would overflow, the maximum
234    /// power of 2 for the platform (2^63 on 64-bit) is used.
235    ///
236    /// Formula: m = -n * log2(p) / ln(2)
237    ///
238    /// # Panics
239    ///
240    /// Panics if `fp_rate` is not in (0, 1).
241    #[cfg(feature = "std")]
242    pub fn optimal_bits(expected_items: usize, fp_rate: &BigRational) -> usize {
243        assert!(
244            fp_rate > &BigRational::zero() && fp_rate < &BigRational::one(),
245            "false positive rate must be in (0, 1)"
246        );
247
248        // log2(p) is negative for p < 1. Use floor to get a more negative value,
249        // which results in more bits (conservative choice to not exceed target FP rate).
250        let log2_p = fp_rate.log2_floor(16);
251
252        // m = -n * log2(p) / ln(2) = -n * log2(p) * (1/ln(2))
253        // Since log2(p) < 0 for p < 1, -log2(p) > 0
254        let n = BigRational::from_usize(expected_items);
255        let ln2_inv = BigRational::from_frac_u64(LN2_INV.0, LN2_INV.1);
256        let bits_rational = -(&n * &log2_p * &ln2_inv);
257
258        let raw = bits_rational.ceil_to_u128().unwrap_or(1) as usize;
259        raw.max(1)
260            .checked_next_power_of_two()
261            .unwrap_or(1 << (usize::BITS - 1))
262    }
263}
264
265impl<H: Hasher> Write for BloomFilter<H> {
266    fn write(&self, buf: &mut impl BufMut) {
267        self.hashers.write(buf);
268        self.bits.write(buf);
269    }
270}
271
272impl<H: Hasher> Read for BloomFilter<H> {
273    // The number of hashers and the number of bits that the bitmap must have.
274    type Cfg = (NonZeroU8, NonZeroU64);
275
276    fn read_cfg(
277        buf: &mut impl Buf,
278        (hashers_cfg, bits_cfg): &Self::Cfg,
279    ) -> Result<Self, CodecError> {
280        if !bits_cfg.get().is_power_of_two() {
281            return Err(CodecError::Invalid(
282                "BloomFilter",
283                "bits must be a power of 2",
284            ));
285        }
286        let hashers = u8::read_cfg(buf, &())?;
287        if hashers != hashers_cfg.get() {
288            return Err(CodecError::Invalid(
289                "BloomFilter",
290                "hashers doesn't match config",
291            ));
292        }
293        let bits = BitMap::read_cfg(buf, &bits_cfg.get())?;
294        if bits.len() != bits_cfg.get() {
295            return Err(CodecError::Invalid(
296                "BloomFilter",
297                "bitmap length doesn't match config",
298            ));
299        }
300        Ok(Self {
301            hashers,
302            bits,
303            _marker: PhantomData,
304        })
305    }
306}
307
308impl<H: Hasher> EncodeSize for BloomFilter<H> {
309    fn encode_size(&self) -> usize {
310        self.hashers.encode_size() + self.bits.encode_size()
311    }
312}
313
314#[cfg(feature = "arbitrary")]
315impl<H: Hasher> arbitrary::Arbitrary<'_> for BloomFilter<H> {
316    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
317        // Ensure at least 1 hasher
318        let hashers = u8::arbitrary(u)?.max(1);
319        // Generate u64 in u16 range to avoid OOM, then round to power of two
320        let bits_len = u.int_in_range(0..=u16::MAX as u64)?.next_power_of_two();
321        let mut bits = BitMap::with_capacity(bits_len);
322        for _ in 0..bits_len {
323            bits.push(u.arbitrary::<bool>()?);
324        }
325        Ok(Self {
326            hashers,
327            bits,
328            _marker: PhantomData,
329        })
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use commonware_codec::{Decode, Encode};
337    use commonware_utils::{NZUsize, NZU64, NZU8};
338
339    #[test]
340    fn test_insert_and_contains() {
341        let mut bf = BloomFilter::<Sha256>::new(NZU8!(10), NZUsize!(1000));
342        let item1 = b"hello";
343        let item2 = b"world";
344        let item3 = b"bloomfilter";
345
346        bf.insert(item1);
347        bf.insert(item2);
348
349        assert!(bf.contains(item1));
350        assert!(bf.contains(item2));
351        assert!(!bf.contains(item3));
352    }
353
354    #[test]
355    fn test_empty() {
356        let bf = BloomFilter::<Sha256>::new(NZU8!(5), NZUsize!(100));
357        assert!(!bf.contains(b"anything"));
358    }
359
360    #[test]
361    fn test_false_positives() {
362        let mut bf = BloomFilter::<Sha256>::new(NZU8!(10), NZUsize!(100));
363        for i in 0..10usize {
364            bf.insert(&i.to_be_bytes());
365        }
366
367        // Check for inserted items
368        for i in 0..10usize {
369            assert!(bf.contains(&i.to_be_bytes()));
370        }
371
372        // Check for non-inserted items and count false positives
373        let mut false_positives = 0;
374        for i in 100..1100usize {
375            if bf.contains(&i.to_be_bytes()) {
376                false_positives += 1;
377            }
378        }
379
380        // A small bloom filter with many items will have some false positives.
381        // The exact number is probabilistic, but it should not be zero and not all should be FPs.
382        assert!(false_positives > 0);
383        assert!(false_positives < 1000);
384    }
385
386    #[test]
387    fn test_codec_roundtrip() {
388        let mut bf = BloomFilter::<Sha256>::new(NZU8!(5), NZUsize!(128));
389        bf.insert(b"test1");
390        bf.insert(b"test2");
391
392        let cfg = (NZU8!(5), NZU64!(128));
393
394        let encoded = bf.encode();
395        let decoded = BloomFilter::<Sha256>::decode_cfg(encoded, &cfg).unwrap();
396
397        assert_eq!(bf, decoded);
398    }
399
400    #[test]
401    fn test_codec_empty() {
402        let bf = BloomFilter::<Sha256>::new(NZU8!(4), NZUsize!(128));
403        let cfg = (NZU8!(4), NZU64!(128));
404        let encoded = bf.encode();
405        let decoded = BloomFilter::<Sha256>::decode_cfg(encoded, &cfg).unwrap();
406        assert_eq!(bf, decoded);
407    }
408
409    #[test]
410    fn test_codec_with_invalid_hashers() {
411        let mut bf = BloomFilter::<Sha256>::new(NZU8!(5), NZUsize!(128));
412        bf.insert(b"test1");
413        let encoded = bf.encode();
414
415        // Too large
416        let cfg = (NZU8!(10), NZU64!(128));
417        let decoded = BloomFilter::<Sha256>::decode_cfg(encoded.clone(), &cfg);
418        assert!(matches!(
419            decoded,
420            Err(CodecError::Invalid(
421                "BloomFilter",
422                "hashers doesn't match config"
423            ))
424        ));
425
426        // Too small
427        let cfg = (NZU8!(4), NZU64!(128));
428        let decoded = BloomFilter::<Sha256>::decode_cfg(encoded, &cfg);
429        assert!(matches!(
430            decoded,
431            Err(CodecError::Invalid(
432                "BloomFilter",
433                "hashers doesn't match config"
434            ))
435        ));
436    }
437
438    #[test]
439    fn test_codec_with_invalid_bits() {
440        let mut bf = BloomFilter::<Sha256>::new(NZU8!(5), NZUsize!(128));
441        bf.insert(b"test1");
442        let encoded = bf.encode();
443
444        // Wrong bit count
445        let cfg = (NZU8!(5), NZU64!(64));
446        let result = BloomFilter::<Sha256>::decode_cfg(encoded.clone(), &cfg);
447        assert!(matches!(result, Err(CodecError::InvalidLength(128))));
448
449        let cfg = (NZU8!(5), NZU64!(256));
450        let result = BloomFilter::<Sha256>::decode_cfg(encoded.clone(), &cfg);
451        assert!(matches!(
452            result,
453            Err(CodecError::Invalid(
454                "BloomFilter",
455                "bitmap length doesn't match config"
456            ))
457        ));
458
459        // Non-power-of-2 bits
460        let cfg = (NZU8!(5), NZU64!(100));
461        let result = BloomFilter::<Sha256>::decode_cfg(encoded, &cfg);
462        assert!(matches!(
463            result,
464            Err(CodecError::Invalid(
465                "BloomFilter",
466                "bits must be a power of 2"
467            ))
468        ));
469    }
470
471    #[test]
472    fn test_statistics() {
473        let mut bf = BloomFilter::<Sha256>::new(NZU8!(7), NZUsize!(1024));
474
475        // Empty filter should have 0 estimated count and FP rate
476        assert_eq!(bf.estimated_count(), BigRational::zero());
477        assert_eq!(bf.estimated_false_positive_rate(), BigRational::zero());
478
479        // Insert some items
480        for i in 0..100usize {
481            bf.insert(&i.to_be_bytes());
482        }
483
484        // Estimated count should be reasonably close to 100
485        let estimated = bf.estimated_count();
486        let lower = BigRational::from_usize(75);
487        let upper = BigRational::from_usize(125);
488        assert!(estimated > lower && estimated < upper);
489
490        // FP rate should be non-zero after insertions
491        assert!(bf.estimated_false_positive_rate() > BigRational::zero());
492        assert!(bf.estimated_false_positive_rate() < BigRational::one());
493    }
494
495    #[test]
496    fn test_with_rate() {
497        // Create a filter for 1000 items with 1% false positive rate
498        let fp_rate = BigRational::from_frac_u64(1, 100);
499        let mut bf = BloomFilter::<Sha256>::with_rate(NZUsize!(1000), fp_rate.clone());
500
501        // Verify getters return expected values
502        let expected_bits = BloomFilter::<Sha256>::optimal_bits(1000, &fp_rate);
503        let expected_hashers = BloomFilter::<Sha256>::optimal_hashers(1000, expected_bits);
504        assert_eq!(bf.bits().get(), expected_bits);
505        assert_eq!(bf.hashers().get(), expected_hashers);
506
507        // Insert 1000 items
508        for i in 0..1000usize {
509            bf.insert(&i.to_be_bytes());
510        }
511
512        // All inserted items should be found
513        for i in 0..1000usize {
514            assert!(bf.contains(&i.to_be_bytes()));
515        }
516
517        // Count false positives on non-inserted items
518        let mut false_positives = 0;
519        for i in 1000..2000usize {
520            if bf.contains(&i.to_be_bytes()) {
521                false_positives += 1;
522            }
523        }
524
525        // With 1% target FP rate, we expect around 10 false positives out of 1000
526        // Allow some variance (should be well under 2%)
527        assert!(false_positives < 20);
528    }
529
530    #[test]
531    fn test_optimal_hashers() {
532        // For 1000 items in 10000 bits, optimal k = (10000/1000) * ln(2) = 6.93
533        // Integer math truncates to 6
534        let k = BloomFilter::<Sha256>::optimal_hashers(1000, 10000);
535        assert_eq!(k, 6);
536
537        // For 100 items in 1000 bits, optimal k = (1000/100) * ln(2) = 6.93
538        // Integer math truncates to 6
539        let k = BloomFilter::<Sha256>::optimal_hashers(100, 1000);
540        assert_eq!(k, 6);
541
542        // Edge case: very few bits per item, clamped to 1
543        let k = BloomFilter::<Sha256>::optimal_hashers(1000, 100);
544        assert_eq!(k, 1);
545
546        // Edge case: many bits per item, clamped to 16
547        let k = BloomFilter::<Sha256>::optimal_hashers(100, 100000);
548        assert_eq!(k, 16);
549
550        // Edge case: zero items returns 1
551        let k = BloomFilter::<Sha256>::optimal_hashers(0, 1000);
552        assert_eq!(k, 1);
553
554        // Edge case: extreme values that would overflow (n << 16 wraps to 0 for n >= 2^48)
555        // Should not panic, should return clamped value
556        let k = BloomFilter::<Sha256>::optimal_hashers(1 << 48, 1000);
557        assert_eq!(k, 1);
558        let k = BloomFilter::<Sha256>::optimal_hashers(usize::MAX, usize::MAX);
559        assert!((1..=16).contains(&k));
560    }
561
562    #[test]
563    fn test_optimal_bits() {
564        // For 1000 items with 1% FP rate
565        // Formula: m = -n * ln(p) / (ln(2))^2 = -1000 * ln(0.01) / 0.4804 = 9585
566        // Rounded to next power of 2 = 16384
567        let fp_1pct = BigRational::from_frac_u64(1, 100);
568        let bits = BloomFilter::<Sha256>::optimal_bits(1000, &fp_1pct);
569        assert_eq!(bits, 16384);
570        assert!(bits.is_power_of_two());
571
572        // For 10000 items with 0.001% FP rate (need significantly more bits)
573        // Formula: m = -10000 * ln(0.00001) / 0.4804 = 239627
574        // Rounded to next power of 2 = 262144
575        let fp_001pct = BigRational::from_frac_u64(1, 100_000);
576        let bits_lower_fp = BloomFilter::<Sha256>::optimal_bits(10000, &fp_001pct);
577        assert_eq!(bits_lower_fp, 262144);
578        assert!(bits_lower_fp.is_power_of_two());
579    }
580
581    #[test]
582    fn test_bits_extreme_values() {
583        let fp_001pct = BigRational::from_frac_u64(1, 10_000);
584        let fp_1pct = BigRational::from_frac_u64(1, 100);
585
586        // Very large expected_items
587        let bits = BloomFilter::<Sha256>::optimal_bits(usize::MAX / 2, &fp_001pct);
588        assert!(bits.is_power_of_two());
589        assert!(bits > 0);
590
591        // Large but reasonable values
592        let bits = BloomFilter::<Sha256>::optimal_bits(1_000_000_000, &fp_001pct);
593        assert!(bits.is_power_of_two());
594
595        // Zero items
596        let bits = BloomFilter::<Sha256>::optimal_bits(0, &fp_1pct);
597        assert!(bits.is_power_of_two());
598        assert_eq!(bits, 1); // 0 * bpe rounds up to 1
599    }
600
601    #[test]
602    fn test_with_rate_deterministic() {
603        let fp_rate = BigRational::from_frac_u64(1, 100);
604        let bf1 = BloomFilter::<Sha256>::with_rate(NZUsize!(1000), fp_rate.clone());
605        let bf2 = BloomFilter::<Sha256>::with_rate(NZUsize!(1000), fp_rate);
606        assert_eq!(bf1.bits(), bf2.bits());
607        assert_eq!(bf1.hashers(), bf2.hashers());
608    }
609
610    #[test]
611    fn test_optimal_bits_matches_formula() {
612        // For 1000 items at 1% FP rate
613        // m = -1000 * log2(0.01) / ln(2) = 9585
614        // Rounded to power of 2 = 16384
615        let fp_rate = BigRational::from_frac_u64(1, 100);
616        let bits = BloomFilter::<Sha256>::optimal_bits(1000, &fp_rate);
617        assert_eq!(bits, 16384);
618    }
619}