commonware_cryptography/
bloomfilter.rs1use crate::{hash, sha256::Digest};
4use bytes::{Buf, BufMut};
5use commonware_codec::{
6 codec::{Read, Write},
7 config::RangeCfg,
8 error::Error as CodecError,
9 EncodeSize, FixedSize,
10};
11use commonware_utils::BitVec;
12use std::num::{NonZeroU8, NonZeroUsize};
13
14const HALF_DIGEST_LEN: usize = 16;
16
17const FULL_DIGEST_LEN: usize = Digest::SIZE;
19
20#[derive(Clone, Debug, PartialEq, Eq)]
26pub struct BloomFilter {
27 hashers: u8,
28 bits: BitVec,
29}
30
31impl BloomFilter {
32 pub fn new(hashers: NonZeroU8, bits: NonZeroUsize) -> Self {
34 Self {
35 hashers: hashers.get(),
36 bits: BitVec::zeroes(bits.get()),
37 }
38 }
39
40 fn indices(&self, item: &[u8], bits: usize) -> impl Iterator<Item = usize> {
42 let digest = hash(item);
44 let mut h1_bytes = [0u8; HALF_DIGEST_LEN];
45 h1_bytes.copy_from_slice(&digest[0..HALF_DIGEST_LEN]);
46 let h1 = u128::from_be_bytes(h1_bytes);
47 let mut h2_bytes = [0u8; HALF_DIGEST_LEN];
48 h2_bytes.copy_from_slice(&digest[HALF_DIGEST_LEN..FULL_DIGEST_LEN]);
49 let h2 = u128::from_be_bytes(h2_bytes);
50
51 let hashers = self.hashers as u128;
55 let bits = bits as u128;
56 (0..hashers)
57 .map(move |hasher| h1.wrapping_add(hasher.wrapping_mul(h2)) % bits)
58 .map(|index| index as usize)
59 }
60
61 pub fn insert(&mut self, item: &[u8]) {
63 let indices = self.indices(item, self.bits.len());
64 for index in indices {
65 self.bits.set(index);
66 }
67 }
68
69 pub fn contains(&self, item: &[u8]) -> bool {
73 let indices = self.indices(item, self.bits.len());
74 for index in indices {
75 if !self.bits.get(index).expect("index out of bounds") {
76 return false;
77 }
78 }
79 true
80 }
81}
82
83impl Write for BloomFilter {
84 fn write(&self, buf: &mut impl BufMut) {
85 self.hashers.write(buf);
86 self.bits.write(buf);
87 }
88}
89
90impl Read for BloomFilter {
91 type Cfg = (RangeCfg, RangeCfg);
92
93 fn read_cfg(
94 buf: &mut impl Buf,
95 (hashers_cfg, bits_cfg): &Self::Cfg,
96 ) -> Result<Self, CodecError> {
97 let hashers = u8::read_cfg(buf, &())?;
98 if !hashers_cfg.contains(&(hashers as usize)) {
99 return Err(CodecError::Invalid("BloomFilter", "invalid hashers"));
100 }
101 let bits = BitVec::read_cfg(buf, bits_cfg)?;
102 Ok(Self { hashers, bits })
103 }
104}
105
106impl EncodeSize for BloomFilter {
107 fn encode_size(&self) -> usize {
108 self.hashers.encode_size() + self.bits.encode_size()
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use commonware_codec::{Decode, Encode};
116 use commonware_utils::{NZUsize, NZU8};
117
118 #[test]
119 fn test_insert_and_contains() {
120 let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(1000));
121 let item1 = b"hello";
122 let item2 = b"world";
123 let item3 = b"bloomfilter";
124
125 bf.insert(item1);
126 bf.insert(item2);
127
128 assert!(bf.contains(item1));
129 assert!(bf.contains(item2));
130 assert!(!bf.contains(item3));
131 }
132
133 #[test]
134 fn test_empty() {
135 let bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
136 assert!(!bf.contains(b"anything"));
137 }
138
139 #[test]
140 fn test_false_positives() {
141 let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(100));
142 for i in 0..10usize {
143 bf.insert(&i.to_be_bytes());
144 }
145
146 for i in 0..10usize {
148 assert!(bf.contains(&i.to_be_bytes()));
149 }
150
151 let mut false_positives = 0;
153 for i in 100..1100usize {
154 if bf.contains(&i.to_be_bytes()) {
155 false_positives += 1;
156 }
157 }
158
159 assert!(false_positives > 0);
162 assert!(false_positives < 1000);
163 }
164
165 #[test]
166 fn test_codec_roundtrip() {
167 let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
168 bf.insert(b"test1");
169 bf.insert(b"test2");
170
171 let cfg = ((1..=100).into(), (100..=100).into());
172
173 let encoded = bf.encode();
174 let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap();
175
176 assert_eq!(bf, decoded);
177 }
178
179 #[test]
180 fn test_codec_empty() {
181 let bf = BloomFilter::new(NZU8!(4), NZUsize!(128));
182 let cfg = ((1..=100).into(), (128..=128).into());
183 let encoded = bf.encode();
184 let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap();
185 assert_eq!(bf, decoded);
186 }
187
188 #[test]
189 fn test_codec_with_invalid_hashers() {
190 let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
191 bf.insert(b"test1");
192 let encoded = bf.encode();
193
194 let cfg = ((10..=10).into(), (100..=100).into());
196 let decoded = BloomFilter::decode_cfg(encoded.clone(), &cfg);
197 assert!(matches!(
198 decoded,
199 Err(CodecError::Invalid("BloomFilter", "invalid hashers"))
200 ));
201
202 let cfg = ((0..5).into(), (100..=100).into());
204 let decoded = BloomFilter::decode_cfg(encoded, &cfg);
205 assert!(matches!(
206 decoded,
207 Err(CodecError::Invalid("BloomFilter", "invalid hashers"))
208 ));
209 }
210
211 #[test]
212 fn test_codec_with_invalid_bits() {
213 let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
214 bf.insert(b"test1");
215 let encoded = bf.encode();
216
217 let cfg_small = ((5..=5).into(), (101..=110).into());
219 let result_small = BloomFilter::decode_cfg(encoded.clone(), &cfg_small);
220 assert!(matches!(result_small, Err(CodecError::InvalidLength(100))));
221
222 let cfg_large = ((5..=5).into(), (0..100).into());
224 let result_large = BloomFilter::decode_cfg(encoded.clone(), &cfg_large);
225 assert!(matches!(result_large, Err(CodecError::InvalidLength(100))));
226 }
227}