light_bloom_filter/
lib.rs1use std::f64::consts::LN_2;
2
3use thiserror::Error;
4
5#[derive(Debug, Error, PartialEq)]
6pub enum BloomFilterError {
7 #[error("Bloom filter is full")]
8 Full,
9 #[error("Invalid store capacity")]
10 InvalidStoreCapacity,
11}
12
13impl From<BloomFilterError> for u32 {
14 fn from(e: BloomFilterError) -> u32 {
15 match e {
16 BloomFilterError::Full => 14201,
17 BloomFilterError::InvalidStoreCapacity => 14202,
18 }
19 }
20}
21
22#[cfg(feature = "solana")]
23impl From<BloomFilterError> for solana_program_error::ProgramError {
24 fn from(e: BloomFilterError) -> Self {
25 solana_program_error::ProgramError::Custom(e.into())
26 }
27}
28
29#[cfg(feature = "pinocchio")]
30impl From<BloomFilterError> for pinocchio::program_error::ProgramError {
31 fn from(e: BloomFilterError) -> Self {
32 pinocchio::program_error::ProgramError::Custom(e.into())
33 }
34}
35
36#[derive(Debug)]
37pub struct BloomFilter<'a> {
38 pub num_iters: usize,
39 pub capacity: u64,
40 pub store: &'a mut [u8],
41}
42
43impl<'a> BloomFilter<'a> {
44 pub fn calculate_bloom_filter_size(n: usize, p: f64) -> usize {
46 let m = -((n as f64) * p.ln()) / (LN_2 * LN_2);
47 m.ceil() as usize
48 }
49
50 pub fn calculate_optimal_hash_functions(n: usize, m: usize) -> usize {
51 let k = (m as f64 / n as f64) * LN_2;
52 k.ceil() as usize
53 }
54
55 pub fn new(
56 num_iters: usize,
57 capacity: u64,
58 store: &'a mut [u8],
59 ) -> Result<Self, BloomFilterError> {
60 if store.len() * 8 != capacity as usize {
62 return Err(BloomFilterError::InvalidStoreCapacity);
63 }
64 Ok(Self {
65 num_iters,
66 capacity,
67 store,
68 })
69 }
70
71 pub fn probe_index_keccak(value_bytes: &[u8; 32], iteration: usize, capacity: &u64) -> usize {
72 let iter_bytes: [u8; 8] = iteration.to_le_bytes();
73 let mut combined_bytes = [0u8; 40];
74 combined_bytes[..32].copy_from_slice(value_bytes);
75 combined_bytes[32..].copy_from_slice(&iter_bytes);
76
77 let hash = solana_nostd_keccak::hash(&combined_bytes);
78
79 let mut index = 0u64;
80 for chunk in hash.chunks(8) {
81 let value = u64::from_le_bytes(chunk.try_into().unwrap());
82 index = value.wrapping_add(index) % *capacity;
83 }
84 index as usize
85 }
86
87 pub fn insert(&mut self, value: &[u8; 32]) -> Result<(), BloomFilterError> {
88 if self._insert(value, true) {
89 Ok(())
90 } else {
91 Err(BloomFilterError::Full)
92 }
93 }
94
95 pub fn contains(&mut self, value: &[u8; 32]) -> bool {
97 !self._insert(value, false)
98 }
99
100 fn _insert(&mut self, value: &[u8; 32], insert: bool) -> bool {
101 let mut all_bits_set = true;
102 use bitvec::prelude::*;
103
104 let bits = BitSlice::<u8, Msb0>::from_slice_mut(self.store);
105 for i in 0..self.num_iters {
106 let probe_index = Self::probe_index_keccak(value, i, &(self.capacity));
107 if bits[probe_index] {
108 continue;
109 } else if insert {
110 all_bits_set = false;
111 bits.set(probe_index, true);
112 } else if !bits[probe_index] && !insert {
113 return true;
114 }
115 }
116 !all_bits_set
117 }
118}
119
120#[cfg(test)]
121mod test {
122 use light_hasher::bigint::bigint_to_be_bytes_array;
123 use num_bigint::{RandBigInt, ToBigUint};
124 use rand::thread_rng;
125
126 use super::*;
127
128 #[test]
129 fn test_insert_and_contains() -> Result<(), BloomFilterError> {
130 let capacity = 128_000 * 8;
131 let mut store = [0u8; 128_000];
132 let mut bf = BloomFilter {
133 num_iters: 3,
134 capacity,
135 store: &mut store,
136 };
137
138 let value1 = [1u8; 32];
139 let value2 = [2u8; 32];
140
141 bf.insert(&value1)?;
142 assert!(bf.contains(&value1));
143 assert!(!bf.contains(&value2));
144
145 Ok(())
146 }
147
148 #[test]
149 fn short_rnd_test() {
150 let capacity = 500;
151 let bloom_filter_capacity = 20_000 * 8;
152 let optimal_hash_functions = 3;
153 rnd_test(
154 1000,
155 capacity,
156 bloom_filter_capacity,
157 optimal_hash_functions,
158 false,
159 );
160 }
161
162 #[ignore = "bench"]
167 #[test]
168 fn bench_bloom_filter() {
169 let capacity = 5000;
170 let bloom_filter_capacity =
171 BloomFilter::calculate_bloom_filter_size(capacity, 0.000_000_000_1);
172 let optimal_hash_functions = 15;
173 let iterations = 1_000_000;
174 rnd_test(
175 iterations,
176 capacity,
177 bloom_filter_capacity,
178 optimal_hash_functions,
179 true,
180 );
181 }
182
183 fn rnd_test(
184 num_iters: usize,
185 capacity: usize,
186 bloom_filter_capacity: usize,
187 optimal_hash_functions: usize,
188 bench: bool,
189 ) {
190 println!("Optimal hash functions: {}", optimal_hash_functions);
191 println!(
192 "Bloom filter capacity (kb): {}",
193 bloom_filter_capacity / 8 / 1_000
194 );
195 let mut num_total_txs = 0;
196 let mut rng = thread_rng();
197 let mut failed_vec = Vec::new();
198 for j in 0..num_iters {
199 let mut inserted_values = Vec::new();
200 let mut store = vec![0; bloom_filter_capacity];
201 let mut bf = BloomFilter {
202 num_iters: optimal_hash_functions,
203 capacity: bloom_filter_capacity as u64,
204 store: &mut store,
205 };
206 if j == 0 {
207 println!("Bloom filter capacity: {}", bf.capacity);
208 println!("Bloom filter size: {}", bf.store.len());
209 println!("Bloom filter size (kb): {}", bf.store.len() / 8 / 1_000);
210 println!("num iters: {}", bf.num_iters);
211 }
212 for i in 0..capacity {
213 num_total_txs += 1;
214 let value = {
215 let mut _value = 0u64.to_biguint().unwrap();
216 while inserted_values.contains(&_value.clone()) {
217 _value = rng.gen_biguint(254);
218 }
219 inserted_values.push(_value.clone());
220
221 _value
222 };
223 let value: [u8; 32] = bigint_to_be_bytes_array(&value).unwrap();
224 match bf.insert(&value) {
225 Ok(_) => {
226 assert!(bf.contains(&value));
227 }
228 Err(_) => {
229 println!("Failed to insert iter: {}", i);
230 println!("total iter {}", j);
231 println!("num_total_txs {}", num_total_txs);
232 failed_vec.push(i);
233 }
234 };
235 assert!(bf.contains(&value));
236 assert!(bf.insert(&value).is_err());
237 }
238 }
239 if bench {
240 println!("total num tx {}", num_total_txs);
241 let average = failed_vec.iter().sum::<usize>() as f64 / failed_vec.len() as f64;
242 println!("average failed insertions: {}", average);
243 println!(
244 "max failed insertions: {}",
245 failed_vec.iter().max().unwrap()
246 );
247 println!(
248 "min failed insertions: {}",
249 failed_vec.iter().min().unwrap()
250 );
251
252 let num_chunks = 10;
253 let chunk_size = num_iters / num_chunks;
254 failed_vec.sort();
255 for (i, chunk) in failed_vec.chunks(chunk_size).enumerate() {
256 let average = chunk.iter().sum::<usize>() as f64 / chunk.len() as f64;
257 println!("chunk: {} average failed insertions: {}", i, average);
258 println!(
259 "chunk: {} max failed insertions: {}",
260 i,
261 chunk.iter().max().unwrap()
262 );
263 println!(
264 "chunk: {} min failed insertions: {}",
265 i,
266 chunk.iter().min().unwrap()
267 );
268 }
269 }
270 }
271}