commonware_cryptography/
bloomfilter.rs

1//! An implementation of a [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter).
2
3use crate::{hash, sha256::Digest};
4use bytes::{Buf, BufMut};
5use commonware_codec::{
6    codec::{Read, Write},
7    config::RangeCfg,
8    error::Error as CodecError,
9    EncodeSize, FixedSize,
10};
11use commonware_utils::BitVec;
12use std::num::{NonZeroU8, NonZeroUsize};
13
14/// The length of a half of a [Digest].
15const HALF_DIGEST_LEN: usize = 16;
16
17/// The length of a full [Digest].
18const FULL_DIGEST_LEN: usize = Digest::SIZE;
19
20/// A [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter).
21///
22/// This implementation uses the Kirsch-Mitzenmacher optimization to derive `k` hash functions
23/// from two hash values, which are in turn derived from a single [Digest]. This provides
24/// efficient hashing for [BloomFilter::insert] and [BloomFilter::contains] operations.
25#[derive(Clone, Debug, PartialEq, Eq)]
26pub struct BloomFilter {
27    hashers: u8,
28    bits: BitVec,
29}
30
31impl BloomFilter {
32    /// Creates a new [BloomFilter] with `hashers` hash functions and `bits` bits.
33    pub fn new(hashers: NonZeroU8, bits: NonZeroUsize) -> Self {
34        Self {
35            hashers: hashers.get(),
36            bits: BitVec::zeroes(bits.get()),
37        }
38    }
39
40    /// Generate `num_hashers` bit indices for a given item.
41    fn indices(&self, item: &[u8], bits: usize) -> impl Iterator<Item = usize> {
42        // Extract two 128-bit hash values from the SHA256 digest of the item
43        let digest = hash(item);
44        let mut h1_bytes = [0u8; HALF_DIGEST_LEN];
45        h1_bytes.copy_from_slice(&digest[0..HALF_DIGEST_LEN]);
46        let h1 = u128::from_be_bytes(h1_bytes);
47        let mut h2_bytes = [0u8; HALF_DIGEST_LEN];
48        h2_bytes.copy_from_slice(&digest[HALF_DIGEST_LEN..FULL_DIGEST_LEN]);
49        let h2 = u128::from_be_bytes(h2_bytes);
50
51        // Generate `hashers` hashes using the Kirsch-Mitzenmacher optimization:
52        //
53        // `h_i(x) = (h1(x) + i * h2(x)) mod m`
54        let hashers = self.hashers as u128;
55        let bits = bits as u128;
56        (0..hashers)
57            .map(move |hasher| h1.wrapping_add(hasher.wrapping_mul(h2)) % bits)
58            .map(|index| index as usize)
59    }
60
61    /// Inserts an item into the [BloomFilter].
62    pub fn insert(&mut self, item: &[u8]) {
63        let indices = self.indices(item, self.bits.len());
64        for index in indices {
65            self.bits.set(index);
66        }
67    }
68
69    /// Checks if an item is possibly in the [BloomFilter].
70    ///
71    /// Returns `true` if the item is probably in the set, and `false` if it is definitely not.
72    pub fn contains(&self, item: &[u8]) -> bool {
73        let indices = self.indices(item, self.bits.len());
74        for index in indices {
75            if !self.bits.get(index).expect("index out of bounds") {
76                return false;
77            }
78        }
79        true
80    }
81}
82
83impl Write for BloomFilter {
84    fn write(&self, buf: &mut impl BufMut) {
85        self.hashers.write(buf);
86        self.bits.write(buf);
87    }
88}
89
90impl Read for BloomFilter {
91    type Cfg = (RangeCfg, RangeCfg);
92
93    fn read_cfg(
94        buf: &mut impl Buf,
95        (hashers_cfg, bits_cfg): &Self::Cfg,
96    ) -> Result<Self, CodecError> {
97        let hashers = u8::read_cfg(buf, &())?;
98        if !hashers_cfg.contains(&(hashers as usize)) {
99            return Err(CodecError::Invalid("BloomFilter", "invalid hashers"));
100        }
101        let bits = BitVec::read_cfg(buf, bits_cfg)?;
102        Ok(Self { hashers, bits })
103    }
104}
105
106impl EncodeSize for BloomFilter {
107    fn encode_size(&self) -> usize {
108        self.hashers.encode_size() + self.bits.encode_size()
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use commonware_codec::{Decode, Encode};
116    use commonware_utils::{NZUsize, NZU8};
117
118    #[test]
119    fn test_insert_and_contains() {
120        let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(1000));
121        let item1 = b"hello";
122        let item2 = b"world";
123        let item3 = b"bloomfilter";
124
125        bf.insert(item1);
126        bf.insert(item2);
127
128        assert!(bf.contains(item1));
129        assert!(bf.contains(item2));
130        assert!(!bf.contains(item3));
131    }
132
133    #[test]
134    fn test_empty() {
135        let bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
136        assert!(!bf.contains(b"anything"));
137    }
138
139    #[test]
140    fn test_false_positives() {
141        let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(100));
142        for i in 0..10usize {
143            bf.insert(&i.to_be_bytes());
144        }
145
146        // Check for inserted items
147        for i in 0..10usize {
148            assert!(bf.contains(&i.to_be_bytes()));
149        }
150
151        // Check for non-inserted items and count false positives
152        let mut false_positives = 0;
153        for i in 100..1100usize {
154            if bf.contains(&i.to_be_bytes()) {
155                false_positives += 1;
156            }
157        }
158
159        // A small bloom filter with many items will have some false positives.
160        // The exact number is probabilistic, but it should not be zero and not all should be FPs.
161        assert!(false_positives > 0);
162        assert!(false_positives < 1000);
163    }
164
165    #[test]
166    fn test_codec_roundtrip() {
167        let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
168        bf.insert(b"test1");
169        bf.insert(b"test2");
170
171        let cfg = ((1..=100).into(), (100..=100).into());
172
173        let encoded = bf.encode();
174        let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap();
175
176        assert_eq!(bf, decoded);
177    }
178
179    #[test]
180    fn test_codec_empty() {
181        let bf = BloomFilter::new(NZU8!(4), NZUsize!(128));
182        let cfg = ((1..=100).into(), (128..=128).into());
183        let encoded = bf.encode();
184        let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap();
185        assert_eq!(bf, decoded);
186    }
187
188    #[test]
189    fn test_codec_with_invalid_hashers() {
190        let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
191        bf.insert(b"test1");
192        let encoded = bf.encode();
193
194        // Too small
195        let cfg = ((10..=10).into(), (100..=100).into());
196        let decoded = BloomFilter::decode_cfg(encoded.clone(), &cfg);
197        assert!(matches!(
198            decoded,
199            Err(CodecError::Invalid("BloomFilter", "invalid hashers"))
200        ));
201
202        // Too large
203        let cfg = ((0..5).into(), (100..=100).into());
204        let decoded = BloomFilter::decode_cfg(encoded, &cfg);
205        assert!(matches!(
206            decoded,
207            Err(CodecError::Invalid("BloomFilter", "invalid hashers"))
208        ));
209    }
210
211    #[test]
212    fn test_codec_with_invalid_bits() {
213        let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
214        bf.insert(b"test1");
215        let encoded = bf.encode();
216
217        // Too small
218        let cfg_small = ((5..=5).into(), (101..=110).into());
219        let result_small = BloomFilter::decode_cfg(encoded.clone(), &cfg_small);
220        assert!(matches!(result_small, Err(CodecError::InvalidLength(100))));
221
222        // Too large
223        let cfg_large = ((5..=5).into(), (0..100).into());
224        let result_large = BloomFilter::decode_cfg(encoded.clone(), &cfg_large);
225        assert!(matches!(result_large, Err(CodecError::InvalidLength(100))));
226    }
227}