commonware_cryptography/
bloomfilter.rs

1//! An implementation of a [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter).
2
3use crate::{
4    sha256::{Digest, Sha256},
5    Hasher,
6};
7use bytes::{Buf, BufMut};
8use commonware_codec::{
9    codec::{Read, Write},
10    config::RangeCfg,
11    error::Error as CodecError,
12    EncodeSize, FixedSize,
13};
14use commonware_utils::BitVec;
15use core::num::{NonZeroU8, NonZeroUsize};
16
17/// The length of a half of a [Digest].
18const HALF_DIGEST_LEN: usize = 16;
19
20/// The length of a full [Digest].
21const FULL_DIGEST_LEN: usize = Digest::SIZE;
22
23/// A [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter).
24///
25/// This implementation uses the Kirsch-Mitzenmacher optimization to derive `k` hash functions
26/// from two hash values, which are in turn derived from a single [Digest]. This provides
27/// efficient hashing for [BloomFilter::insert] and [BloomFilter::contains] operations.
28#[derive(Clone, Debug, PartialEq, Eq)]
29pub struct BloomFilter {
30    hashers: u8,
31    bits: BitVec,
32}
33
34impl BloomFilter {
35    /// Creates a new [BloomFilter] with `hashers` hash functions and `bits` bits.
36    pub fn new(hashers: NonZeroU8, bits: NonZeroUsize) -> Self {
37        Self {
38            hashers: hashers.get(),
39            bits: BitVec::zeroes(bits.get()),
40        }
41    }
42
43    /// Generate `num_hashers` bit indices for a given item.
44    fn indices(&self, item: &[u8], bits: usize) -> impl Iterator<Item = usize> {
45        // Extract two 128-bit hash values from the SHA256 digest of the item
46        let digest = Sha256::hash(item);
47        let mut h1_bytes = [0u8; HALF_DIGEST_LEN];
48        h1_bytes.copy_from_slice(&digest[0..HALF_DIGEST_LEN]);
49        let h1 = u128::from_be_bytes(h1_bytes);
50        let mut h2_bytes = [0u8; HALF_DIGEST_LEN];
51        h2_bytes.copy_from_slice(&digest[HALF_DIGEST_LEN..FULL_DIGEST_LEN]);
52        let h2 = u128::from_be_bytes(h2_bytes);
53
54        // Generate `hashers` hashes using the Kirsch-Mitzenmacher optimization:
55        //
56        // `h_i(x) = (h1(x) + i * h2(x)) mod m`
57        let hashers = self.hashers as u128;
58        let bits = bits as u128;
59        (0..hashers)
60            .map(move |hasher| h1.wrapping_add(hasher.wrapping_mul(h2)) % bits)
61            .map(|index| index as usize)
62    }
63
64    /// Inserts an item into the [BloomFilter].
65    pub fn insert(&mut self, item: &[u8]) {
66        let indices = self.indices(item, self.bits.len());
67        for index in indices {
68            self.bits.set(index);
69        }
70    }
71
72    /// Checks if an item is possibly in the [BloomFilter].
73    ///
74    /// Returns `true` if the item is probably in the set, and `false` if it is definitely not.
75    pub fn contains(&self, item: &[u8]) -> bool {
76        let indices = self.indices(item, self.bits.len());
77        for index in indices {
78            if !self.bits.get(index).expect("index out of bounds") {
79                return false;
80            }
81        }
82        true
83    }
84}
85
86impl Write for BloomFilter {
87    fn write(&self, buf: &mut impl BufMut) {
88        self.hashers.write(buf);
89        self.bits.write(buf);
90    }
91}
92
93impl Read for BloomFilter {
94    type Cfg = (RangeCfg, RangeCfg);
95
96    fn read_cfg(
97        buf: &mut impl Buf,
98        (hashers_cfg, bits_cfg): &Self::Cfg,
99    ) -> Result<Self, CodecError> {
100        let hashers = u8::read_cfg(buf, &())?;
101        if !hashers_cfg.contains(&(hashers as usize)) {
102            return Err(CodecError::Invalid("BloomFilter", "invalid hashers"));
103        }
104        let bits = BitVec::read_cfg(buf, bits_cfg)?;
105        Ok(Self { hashers, bits })
106    }
107}
108
109impl EncodeSize for BloomFilter {
110    fn encode_size(&self) -> usize {
111        self.hashers.encode_size() + self.bits.encode_size()
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use commonware_codec::{Decode, Encode};
119    use commonware_utils::{NZUsize, NZU8};
120
121    #[test]
122    fn test_insert_and_contains() {
123        let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(1000));
124        let item1 = b"hello";
125        let item2 = b"world";
126        let item3 = b"bloomfilter";
127
128        bf.insert(item1);
129        bf.insert(item2);
130
131        assert!(bf.contains(item1));
132        assert!(bf.contains(item2));
133        assert!(!bf.contains(item3));
134    }
135
136    #[test]
137    fn test_empty() {
138        let bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
139        assert!(!bf.contains(b"anything"));
140    }
141
142    #[test]
143    fn test_false_positives() {
144        let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(100));
145        for i in 0..10usize {
146            bf.insert(&i.to_be_bytes());
147        }
148
149        // Check for inserted items
150        for i in 0..10usize {
151            assert!(bf.contains(&i.to_be_bytes()));
152        }
153
154        // Check for non-inserted items and count false positives
155        let mut false_positives = 0;
156        for i in 100..1100usize {
157            if bf.contains(&i.to_be_bytes()) {
158                false_positives += 1;
159            }
160        }
161
162        // A small bloom filter with many items will have some false positives.
163        // The exact number is probabilistic, but it should not be zero and not all should be FPs.
164        assert!(false_positives > 0);
165        assert!(false_positives < 1000);
166    }
167
168    #[test]
169    fn test_codec_roundtrip() {
170        let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
171        bf.insert(b"test1");
172        bf.insert(b"test2");
173
174        let cfg = ((1..=100).into(), (100..=100).into());
175
176        let encoded = bf.encode();
177        let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap();
178
179        assert_eq!(bf, decoded);
180    }
181
182    #[test]
183    fn test_codec_empty() {
184        let bf = BloomFilter::new(NZU8!(4), NZUsize!(128));
185        let cfg = ((1..=100).into(), (128..=128).into());
186        let encoded = bf.encode();
187        let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap();
188        assert_eq!(bf, decoded);
189    }
190
191    #[test]
192    fn test_codec_with_invalid_hashers() {
193        let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
194        bf.insert(b"test1");
195        let encoded = bf.encode();
196
197        // Too small
198        let cfg = ((10..=10).into(), (100..=100).into());
199        let decoded = BloomFilter::decode_cfg(encoded.clone(), &cfg);
200        assert!(matches!(
201            decoded,
202            Err(CodecError::Invalid("BloomFilter", "invalid hashers"))
203        ));
204
205        // Too large
206        let cfg = ((0..5).into(), (100..=100).into());
207        let decoded = BloomFilter::decode_cfg(encoded, &cfg);
208        assert!(matches!(
209            decoded,
210            Err(CodecError::Invalid("BloomFilter", "invalid hashers"))
211        ));
212    }
213
214    #[test]
215    fn test_codec_with_invalid_bits() {
216        let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
217        bf.insert(b"test1");
218        let encoded = bf.encode();
219
220        // Too small
221        let cfg_small = ((5..=5).into(), (101..=110).into());
222        let result_small = BloomFilter::decode_cfg(encoded.clone(), &cfg_small);
223        assert!(matches!(result_small, Err(CodecError::InvalidLength(100))));
224
225        // Too large
226        let cfg_large = ((5..=5).into(), (0..100).into());
227        let result_large = BloomFilter::decode_cfg(encoded.clone(), &cfg_large);
228        assert!(matches!(result_large, Err(CodecError::InvalidLength(100))));
229    }
230}