commonware_cryptography/
bloomfilter.rs1use crate::{
4 sha256::{Digest, Sha256},
5 Hasher,
6};
7use bytes::{Buf, BufMut};
8use commonware_codec::{
9 codec::{Read, Write},
10 config::RangeCfg,
11 error::Error as CodecError,
12 EncodeSize, FixedSize,
13};
14use commonware_utils::BitVec;
15use core::num::{NonZeroU8, NonZeroUsize};
16
17const HALF_DIGEST_LEN: usize = 16;
19
20const FULL_DIGEST_LEN: usize = Digest::SIZE;
22
23#[derive(Clone, Debug, PartialEq, Eq)]
29pub struct BloomFilter {
30 hashers: u8,
31 bits: BitVec,
32}
33
34impl BloomFilter {
35 pub fn new(hashers: NonZeroU8, bits: NonZeroUsize) -> Self {
37 Self {
38 hashers: hashers.get(),
39 bits: BitVec::zeroes(bits.get()),
40 }
41 }
42
43 fn indices(&self, item: &[u8], bits: usize) -> impl Iterator<Item = usize> {
45 let digest = Sha256::hash(item);
47 let mut h1_bytes = [0u8; HALF_DIGEST_LEN];
48 h1_bytes.copy_from_slice(&digest[0..HALF_DIGEST_LEN]);
49 let h1 = u128::from_be_bytes(h1_bytes);
50 let mut h2_bytes = [0u8; HALF_DIGEST_LEN];
51 h2_bytes.copy_from_slice(&digest[HALF_DIGEST_LEN..FULL_DIGEST_LEN]);
52 let h2 = u128::from_be_bytes(h2_bytes);
53
54 let hashers = self.hashers as u128;
58 let bits = bits as u128;
59 (0..hashers)
60 .map(move |hasher| h1.wrapping_add(hasher.wrapping_mul(h2)) % bits)
61 .map(|index| index as usize)
62 }
63
64 pub fn insert(&mut self, item: &[u8]) {
66 let indices = self.indices(item, self.bits.len());
67 for index in indices {
68 self.bits.set(index);
69 }
70 }
71
72 pub fn contains(&self, item: &[u8]) -> bool {
76 let indices = self.indices(item, self.bits.len());
77 for index in indices {
78 if !self.bits.get(index).expect("index out of bounds") {
79 return false;
80 }
81 }
82 true
83 }
84}
85
86impl Write for BloomFilter {
87 fn write(&self, buf: &mut impl BufMut) {
88 self.hashers.write(buf);
89 self.bits.write(buf);
90 }
91}
92
93impl Read for BloomFilter {
94 type Cfg = (RangeCfg, RangeCfg);
95
96 fn read_cfg(
97 buf: &mut impl Buf,
98 (hashers_cfg, bits_cfg): &Self::Cfg,
99 ) -> Result<Self, CodecError> {
100 let hashers = u8::read_cfg(buf, &())?;
101 if !hashers_cfg.contains(&(hashers as usize)) {
102 return Err(CodecError::Invalid("BloomFilter", "invalid hashers"));
103 }
104 let bits = BitVec::read_cfg(buf, bits_cfg)?;
105 Ok(Self { hashers, bits })
106 }
107}
108
109impl EncodeSize for BloomFilter {
110 fn encode_size(&self) -> usize {
111 self.hashers.encode_size() + self.bits.encode_size()
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use commonware_codec::{Decode, Encode};
119 use commonware_utils::{NZUsize, NZU8};
120
121 #[test]
122 fn test_insert_and_contains() {
123 let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(1000));
124 let item1 = b"hello";
125 let item2 = b"world";
126 let item3 = b"bloomfilter";
127
128 bf.insert(item1);
129 bf.insert(item2);
130
131 assert!(bf.contains(item1));
132 assert!(bf.contains(item2));
133 assert!(!bf.contains(item3));
134 }
135
136 #[test]
137 fn test_empty() {
138 let bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
139 assert!(!bf.contains(b"anything"));
140 }
141
142 #[test]
143 fn test_false_positives() {
144 let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(100));
145 for i in 0..10usize {
146 bf.insert(&i.to_be_bytes());
147 }
148
149 for i in 0..10usize {
151 assert!(bf.contains(&i.to_be_bytes()));
152 }
153
154 let mut false_positives = 0;
156 for i in 100..1100usize {
157 if bf.contains(&i.to_be_bytes()) {
158 false_positives += 1;
159 }
160 }
161
162 assert!(false_positives > 0);
165 assert!(false_positives < 1000);
166 }
167
168 #[test]
169 fn test_codec_roundtrip() {
170 let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
171 bf.insert(b"test1");
172 bf.insert(b"test2");
173
174 let cfg = ((1..=100).into(), (100..=100).into());
175
176 let encoded = bf.encode();
177 let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap();
178
179 assert_eq!(bf, decoded);
180 }
181
182 #[test]
183 fn test_codec_empty() {
184 let bf = BloomFilter::new(NZU8!(4), NZUsize!(128));
185 let cfg = ((1..=100).into(), (128..=128).into());
186 let encoded = bf.encode();
187 let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap();
188 assert_eq!(bf, decoded);
189 }
190
191 #[test]
192 fn test_codec_with_invalid_hashers() {
193 let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
194 bf.insert(b"test1");
195 let encoded = bf.encode();
196
197 let cfg = ((10..=10).into(), (100..=100).into());
199 let decoded = BloomFilter::decode_cfg(encoded.clone(), &cfg);
200 assert!(matches!(
201 decoded,
202 Err(CodecError::Invalid("BloomFilter", "invalid hashers"))
203 ));
204
205 let cfg = ((0..5).into(), (100..=100).into());
207 let decoded = BloomFilter::decode_cfg(encoded, &cfg);
208 assert!(matches!(
209 decoded,
210 Err(CodecError::Invalid("BloomFilter", "invalid hashers"))
211 ));
212 }
213
214 #[test]
215 fn test_codec_with_invalid_bits() {
216 let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100));
217 bf.insert(b"test1");
218 let encoded = bf.encode();
219
220 let cfg_small = ((5..=5).into(), (101..=110).into());
222 let result_small = BloomFilter::decode_cfg(encoded.clone(), &cfg_small);
223 assert!(matches!(result_small, Err(CodecError::InvalidLength(100))));
224
225 let cfg_large = ((5..=5).into(), (0..100).into());
227 let result_large = BloomFilter::decode_cfg(encoded.clone(), &cfg_large);
228 assert!(matches!(result_large, Err(CodecError::InvalidLength(100))));
229 }
230}