commonware_cryptography/
bloomfilter.rs

1//! An implementation of a [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter).
2
3use crate::{
4    sha256::{Digest, Sha256},
5    Hasher,
6};
7use bytes::{Buf, BufMut};
8use commonware_codec::{
9    codec::{Read, Write},
10    error::Error as CodecError,
11    EncodeSize, FixedSize,
12};
13use commonware_utils::bitmap::BitMap;
14use core::num::{NonZeroU64, NonZeroU8, NonZeroUsize};
15
16/// The length of a half of a [Digest].
17const HALF_DIGEST_LEN: usize = 16;
18
19/// The length of a full [Digest].
20const FULL_DIGEST_LEN: usize = Digest::SIZE;
21
22/// A [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter).
23///
24/// This implementation uses the Kirsch-Mitzenmacher optimization to derive `k` hash functions
25/// from two hash values, which are in turn derived from a single [Digest]. This provides
26/// efficient hashing for [BloomFilter::insert] and [BloomFilter::contains] operations.
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub struct BloomFilter {
29    hashers: u8,
30    bits: BitMap,
31}
32
33impl BloomFilter {
34    /// Creates a new [BloomFilter] with `hashers` hash functions and `bits` bits.
35    pub fn new(hashers: NonZeroU8, bits: NonZeroUsize) -> Self {
36        Self {
37            hashers: hashers.get(),
38            bits: BitMap::zeroes(bits.get() as u64),
39        }
40    }
41
42    /// Generate `num_hashers` bit indices for a given item.
43    fn indices(&self, item: &[u8], bits: u64) -> impl Iterator<Item = u64> {
44        // Extract two 128-bit hash values from the SHA256 digest of the item
45        let digest = Sha256::hash(item);
46        let mut h1_bytes = [0u8; HALF_DIGEST_LEN];
47        h1_bytes.copy_from_slice(&digest[0..HALF_DIGEST_LEN]);
48        let h1 = u128::from_be_bytes(h1_bytes);
49        let mut h2_bytes = [0u8; HALF_DIGEST_LEN];
50        h2_bytes.copy_from_slice(&digest[HALF_DIGEST_LEN..FULL_DIGEST_LEN]);
51        let h2 = u128::from_be_bytes(h2_bytes);
52
53        // Generate `hashers` hashes using the Kirsch-Mitzenmacher optimization:
54        //
55        // `h_i(x) = (h1(x) + i * h2(x)) mod m`
56        let hashers = self.hashers as u128;
57        let bits = bits as u128;
58        (0..hashers)
59            .map(move |hasher| h1.wrapping_add(hasher.wrapping_mul(h2)) % bits)
60            .map(|index| index as u64)
61    }
62
63    /// Inserts an item into the [BloomFilter].
64    pub fn insert(&mut self, item: &[u8]) {
65        let indices = self.indices(item, self.bits.len());
66        for index in indices {
67            self.bits.set(index, true);
68        }
69    }
70
71    /// Checks if an item is possibly in the [BloomFilter].
72    ///
73    /// Returns `true` if the item is probably in the set, and `false` if it is definitely not.
74    pub fn contains(&self, item: &[u8]) -> bool {
75        let indices = self.indices(item, self.bits.len());
76        for index in indices {
77            if !self.bits.get(index) {
78                return false;
79            }
80        }
81        true
82    }
83}
84
85impl Write for BloomFilter {
86    fn write(&self, buf: &mut impl BufMut) {
87        self.hashers.write(buf);
88        self.bits.write(buf);
89    }
90}
91
92impl Read for BloomFilter {
93    // The number of hashers and the number of bits that the bitmap must have.
94    type Cfg = (NonZeroU8, NonZeroU64);
95
96    fn read_cfg(
97        buf: &mut impl Buf,
98        (hashers_cfg, bits_cfg): &Self::Cfg,
99    ) -> Result<Self, CodecError> {
100        let hashers = u8::read_cfg(buf, &())?;
101        if hashers != hashers_cfg.get() {
102            return Err(CodecError::Invalid(
103                "BloomFilter",
104                "hashers doesn't match config",
105            ));
106        }
107        let bits = BitMap::read_cfg(buf, &bits_cfg.get())?;
108        if bits.len() != bits_cfg.get() {
109            return Err(CodecError::Invalid(
110                "BloomFilter",
111                "bitmap length doesn't match config",
112            ));
113        }
114        Ok(Self { hashers, bits })
115    }
116}
117
118impl EncodeSize for BloomFilter {
119    fn encode_size(&self) -> usize {
120        self.hashers.encode_size() + self.bits.encode_size()
121    }
122}
123
124#[cfg(feature = "arbitrary")]
125impl arbitrary::Arbitrary<'_> for BloomFilter {
126    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
127        let hashers = u8::arbitrary(u)?;
128        // Ensure at least 1 bit to avoid empty bitmap
129        let bits_len = u.arbitrary_len::<u64>()?.max(1);
130        let mut bits = BitMap::with_capacity(bits_len as u64);
131        for _ in 0..bits_len {
132            bits.push(u.arbitrary::<bool>()?);
133        }
134        Ok(Self { hashers, bits })
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use commonware_codec::{Decode, Encode};
142    use commonware_utils::{NZUsize, NZU64, NZU8};
143
144    #[test]
145    fn test_insert_and_contains() {
146        let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(1000));
147        let item1 = b"hello";
148        let item2 = b"world";
149        let item3 = b"bloomfilter";
150
151        bf.insert(item1);
152        bf.insert(item2);
153
154        assert!(bf.contains(item1));
155        assert!(bf.contains(item2));
156        assert!(!bf.contains(item3));
157    }
158
159    #[test]
160    fn test_empty() {
161        let bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
162        assert!(!bf.contains(b"anything"));
163    }
164
165    #[test]
166    fn test_false_positives() {
167        let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(100));
168        for i in 0..10usize {
169            bf.insert(&i.to_be_bytes());
170        }
171
172        // Check for inserted items
173        for i in 0..10usize {
174            assert!(bf.contains(&i.to_be_bytes()));
175        }
176
177        // Check for non-inserted items and count false positives
178        let mut false_positives = 0;
179        for i in 100..1100usize {
180            if bf.contains(&i.to_be_bytes()) {
181                false_positives += 1;
182            }
183        }
184
185        // A small bloom filter with many items will have some false positives.
186        // The exact number is probabilistic, but it should not be zero and not all should be FPs.
187        assert!(false_positives > 0);
188        assert!(false_positives < 1000);
189    }
190
191    #[test]
192    fn test_codec_roundtrip() {
193        let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
194        bf.insert(b"test1");
195        bf.insert(b"test2");
196
197        let cfg = (NZU8!(5), NZU64!(100));
198
199        let encoded = bf.encode();
200        let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap();
201
202        assert_eq!(bf, decoded);
203    }
204
205    #[test]
206    fn test_codec_empty() {
207        let bf = BloomFilter::new(NZU8!(4), NZUsize!(128));
208        let cfg = (NZU8!(4), NZU64!(128));
209        let encoded = bf.encode();
210        let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap();
211        assert_eq!(bf, decoded);
212    }
213
214    #[test]
215    fn test_codec_with_invalid_hashers() {
216        let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
217        bf.insert(b"test1");
218        let encoded = bf.encode();
219
220        // Too large
221        let cfg = (NZU8!(10), NZU64!(100));
222        let decoded = BloomFilter::decode_cfg(encoded.clone(), &cfg);
223        assert!(matches!(
224            decoded,
225            Err(CodecError::Invalid(
226                "BloomFilter",
227                "hashers doesn't match config"
228            ))
229        ));
230
231        // Too small
232        let cfg = (NZU8!(4), NZU64!(100));
233        let decoded = BloomFilter::decode_cfg(encoded, &cfg);
234        assert!(matches!(
235            decoded,
236            Err(CodecError::Invalid(
237                "BloomFilter",
238                "hashers doesn't match config"
239            ))
240        ));
241    }
242
243    #[test]
244    fn test_codec_with_invalid_bits() {
245        let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
246        bf.insert(b"test1");
247        let encoded = bf.encode();
248
249        // Wrong bit count
250        let cfg = (NZU8!(5), NZU64!(99));
251        let result = BloomFilter::decode_cfg(encoded.clone(), &cfg);
252        assert!(matches!(result, Err(CodecError::InvalidLength(100))));
253
254        let cfg = (NZU8!(5), NZU64!(101));
255        let result = BloomFilter::decode_cfg(encoded, &cfg);
256        assert!(matches!(
257            result,
258            Err(CodecError::Invalid(
259                "BloomFilter",
260                "bitmap length doesn't match config"
261            ))
262        ));
263    }
264
265    #[cfg(feature = "arbitrary")]
266    mod conformance {
267        use super::*;
268        use commonware_codec::conformance::CodecConformance;
269
270        commonware_conformance::conformance_tests! {
271            CodecConformance<BloomFilter>,
272        }
273    }
274}