light_bloom_filter/
lib.rs

1use std::f64::consts::LN_2;
2
3use thiserror::Error;
4
5#[derive(Debug, Error, PartialEq)]
6pub enum BloomFilterError {
7    #[error("Bloom filter is full")]
8    Full,
9    #[error("Invalid store capacity")]
10    InvalidStoreCapacity,
11}
12
13impl From<BloomFilterError> for u32 {
14    fn from(e: BloomFilterError) -> u32 {
15        match e {
16            BloomFilterError::Full => 14201,
17            BloomFilterError::InvalidStoreCapacity => 14202,
18        }
19    }
20}
21
22#[cfg(feature = "solana")]
23impl From<BloomFilterError> for solana_program_error::ProgramError {
24    fn from(e: BloomFilterError) -> Self {
25        solana_program_error::ProgramError::Custom(e.into())
26    }
27}
28
29#[cfg(feature = "pinocchio")]
30impl From<BloomFilterError> for pinocchio::program_error::ProgramError {
31    fn from(e: BloomFilterError) -> Self {
32        pinocchio::program_error::ProgramError::Custom(e.into())
33    }
34}
35
36#[derive(Debug)]
37pub struct BloomFilter<'a> {
38    pub num_iters: usize,
39    pub capacity: u64,
40    pub store: &'a mut [u8],
41}
42
43impl<'a> BloomFilter<'a> {
44    // TODO: find source for this
45    pub fn calculate_bloom_filter_size(n: usize, p: f64) -> usize {
46        let m = -((n as f64) * p.ln()) / (LN_2 * LN_2);
47        m.ceil() as usize
48    }
49
50    pub fn calculate_optimal_hash_functions(n: usize, m: usize) -> usize {
51        let k = (m as f64 / n as f64) * LN_2;
52        k.ceil() as usize
53    }
54
55    pub fn new(
56        num_iters: usize,
57        capacity: u64,
58        store: &'a mut [u8],
59    ) -> Result<Self, BloomFilterError> {
60        // Capacity is in bits while store is in bytes.
61        if store.len() * 8 != capacity as usize {
62            return Err(BloomFilterError::InvalidStoreCapacity);
63        }
64        Ok(Self {
65            num_iters,
66            capacity,
67            store,
68        })
69    }
70
71    pub fn probe_index_keccak(value_bytes: &[u8; 32], iteration: usize, capacity: &u64) -> usize {
72        let iter_bytes: [u8; 8] = iteration.to_le_bytes();
73        let mut combined_bytes = [0u8; 40];
74        combined_bytes[..32].copy_from_slice(value_bytes);
75        combined_bytes[32..].copy_from_slice(&iter_bytes);
76
77        let hash = solana_nostd_keccak::hash(&combined_bytes);
78
79        let mut index = 0u64;
80        for chunk in hash.chunks(8) {
81            let value = u64::from_le_bytes(chunk.try_into().unwrap());
82            index = value.wrapping_add(index) % *capacity;
83        }
84        index as usize
85    }
86
87    pub fn insert(&mut self, value: &[u8; 32]) -> Result<(), BloomFilterError> {
88        if self._insert(value, true) {
89            Ok(())
90        } else {
91            Err(BloomFilterError::Full)
92        }
93    }
94
95    // TODO: reconsider &mut self
96    pub fn contains(&mut self, value: &[u8; 32]) -> bool {
97        !self._insert(value, false)
98    }
99
100    fn _insert(&mut self, value: &[u8; 32], insert: bool) -> bool {
101        let mut all_bits_set = true;
102        use bitvec::prelude::*;
103
104        let bits = BitSlice::<u8, Msb0>::from_slice_mut(self.store);
105        for i in 0..self.num_iters {
106            let probe_index = Self::probe_index_keccak(value, i, &(self.capacity));
107            if bits[probe_index] {
108                continue;
109            } else if insert {
110                all_bits_set = false;
111                bits.set(probe_index, true);
112            } else if !bits[probe_index] && !insert {
113                return true;
114            }
115        }
116        !all_bits_set
117    }
118}
119
120#[cfg(test)]
121mod test {
122    use light_hasher::bigint::bigint_to_be_bytes_array;
123    use num_bigint::{RandBigInt, ToBigUint};
124    use rand::thread_rng;
125
126    use super::*;
127
128    #[test]
129    fn test_insert_and_contains() -> Result<(), BloomFilterError> {
130        let capacity = 128_000 * 8;
131        let mut store = [0u8; 128_000];
132        let mut bf = BloomFilter {
133            num_iters: 3,
134            capacity,
135            store: &mut store,
136        };
137
138        let value1 = [1u8; 32];
139        let value2 = [2u8; 32];
140
141        bf.insert(&value1)?;
142        assert!(bf.contains(&value1));
143        assert!(!bf.contains(&value2));
144
145        Ok(())
146    }
147
148    #[test]
149    fn short_rnd_test() {
150        let capacity = 500;
151        let bloom_filter_capacity = 20_000 * 8;
152        let optimal_hash_functions = 3;
153        rnd_test(
154            1000,
155            capacity,
156            bloom_filter_capacity,
157            optimal_hash_functions,
158            false,
159        );
160    }
161
162    /// Bench results:
163    /// - 15310 CU for 10 insertions with 3 hash functions
164    /// - capacity 5000 0.000_000_000_1 with 15 hash functions seems to not
165    ///   produce any collisions
166    #[ignore = "bench"]
167    #[test]
168    fn bench_bloom_filter() {
169        let capacity = 5000;
170        let bloom_filter_capacity =
171            BloomFilter::calculate_bloom_filter_size(capacity, 0.000_000_000_1);
172        let optimal_hash_functions = 15;
173        let iterations = 1_000_000;
174        rnd_test(
175            iterations,
176            capacity,
177            bloom_filter_capacity,
178            optimal_hash_functions,
179            true,
180        );
181    }
182
183    fn rnd_test(
184        num_iters: usize,
185        capacity: usize,
186        bloom_filter_capacity: usize,
187        optimal_hash_functions: usize,
188        bench: bool,
189    ) {
190        println!("Optimal hash functions: {}", optimal_hash_functions);
191        println!(
192            "Bloom filter capacity (kb): {}",
193            bloom_filter_capacity / 8 / 1_000
194        );
195        let mut num_total_txs = 0;
196        let mut rng = thread_rng();
197        let mut failed_vec = Vec::new();
198        for j in 0..num_iters {
199            let mut inserted_values = Vec::new();
200            let mut store = vec![0; bloom_filter_capacity];
201            let mut bf = BloomFilter {
202                num_iters: optimal_hash_functions,
203                capacity: bloom_filter_capacity as u64,
204                store: &mut store,
205            };
206            if j == 0 {
207                println!("Bloom filter capacity: {}", bf.capacity);
208                println!("Bloom filter size: {}", bf.store.len());
209                println!("Bloom filter size (kb): {}", bf.store.len() / 8 / 1_000);
210                println!("num iters: {}", bf.num_iters);
211            }
212            for i in 0..capacity {
213                num_total_txs += 1;
214                let value = {
215                    let mut _value = 0u64.to_biguint().unwrap();
216                    while inserted_values.contains(&_value.clone()) {
217                        _value = rng.gen_biguint(254);
218                    }
219                    inserted_values.push(_value.clone());
220
221                    _value
222                };
223                let value: [u8; 32] = bigint_to_be_bytes_array(&value).unwrap();
224                match bf.insert(&value) {
225                    Ok(_) => {
226                        assert!(bf.contains(&value));
227                    }
228                    Err(_) => {
229                        println!("Failed to insert iter: {}", i);
230                        println!("total iter {}", j);
231                        println!("num_total_txs {}", num_total_txs);
232                        failed_vec.push(i);
233                    }
234                };
235                assert!(bf.contains(&value));
236                assert!(bf.insert(&value).is_err());
237            }
238        }
239        if bench {
240            println!("total num tx {}", num_total_txs);
241            let average = failed_vec.iter().sum::<usize>() as f64 / failed_vec.len() as f64;
242            println!("average failed insertions: {}", average);
243            println!(
244                "max failed insertions: {}",
245                failed_vec.iter().max().unwrap()
246            );
247            println!(
248                "min failed insertions: {}",
249                failed_vec.iter().min().unwrap()
250            );
251
252            let num_chunks = 10;
253            let chunk_size = num_iters / num_chunks;
254            failed_vec.sort();
255            for (i, chunk) in failed_vec.chunks(chunk_size).enumerate() {
256                let average = chunk.iter().sum::<usize>() as f64 / chunk.len() as f64;
257                println!("chunk: {} average failed insertions: {}", i, average);
258                println!(
259                    "chunk: {} max failed insertions: {}",
260                    i,
261                    chunk.iter().max().unwrap()
262                );
263                println!(
264                    "chunk: {} min failed insertions: {}",
265                    i,
266                    chunk.iter().min().unwrap()
267                );
268            }
269        }
270    }
271}