gemachain_runtime/
bloom.rs

1//! Simple Bloom Filter
2use bv::BitVec;
3use fnv::FnvHasher;
4use rand::{self, Rng};
5use serde::{Deserialize, Serialize};
6use gemachain_sdk::sanitize::{Sanitize, SanitizeError};
7use std::fmt;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::{cmp, hash::Hasher, marker::PhantomData};
10
11/// Generate a stable hash of `self` for each `hash_index`
12/// Best effort can be made for uniqueness of each hash.
13pub trait BloomHashIndex {
14    fn hash_at_index(&self, hash_index: u64) -> u64;
15}
16
17#[derive(Serialize, Deserialize, Default, Clone, PartialEq, AbiExample)]
18pub struct Bloom<T: BloomHashIndex> {
19    pub keys: Vec<u64>,
20    pub bits: BitVec<u64>,
21    num_bits_set: u64,
22    _phantom: PhantomData<T>,
23}
24
25impl<T: BloomHashIndex> fmt::Debug for Bloom<T> {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        write!(
28            f,
29            "Bloom {{ keys.len: {} bits.len: {} num_set: {} bits: ",
30            self.keys.len(),
31            self.bits.len(),
32            self.num_bits_set
33        )?;
34        const MAX_PRINT_BITS: u64 = 10;
35        for i in 0..std::cmp::min(MAX_PRINT_BITS, self.bits.len()) {
36            if self.bits.get(i) {
37                write!(f, "1")?;
38            } else {
39                write!(f, "0")?;
40            }
41        }
42        if self.bits.len() > MAX_PRINT_BITS {
43            write!(f, "..")?;
44        }
45        write!(f, " }}")
46    }
47}
48
49impl<T: BloomHashIndex> Sanitize for Bloom<T> {
50    fn sanitize(&self) -> Result<(), SanitizeError> {
51        // Avoid division by zero in self.pos(...).
52        if self.bits.is_empty() {
53            Err(SanitizeError::InvalidValue)
54        } else {
55            Ok(())
56        }
57    }
58}
59
60impl<T: BloomHashIndex> Bloom<T> {
61    pub fn new(num_bits: usize, keys: Vec<u64>) -> Self {
62        let bits = BitVec::new_fill(false, num_bits as u64);
63        Bloom {
64            keys,
65            bits,
66            num_bits_set: 0,
67            _phantom: PhantomData::default(),
68        }
69    }
70    /// Create filter optimal for num size given the `FALSE_RATE`.
71    ///
72    /// The keys are randomized for picking data out of a collision resistant hash of size
73    /// `keysize` bytes.
74    ///
75    /// See <https://hur.st/bloomfilter/>.
76    pub fn random(num_items: usize, false_rate: f64, max_bits: usize) -> Self {
77        let m = Self::num_bits(num_items as f64, false_rate);
78        let num_bits = cmp::max(1, cmp::min(m as usize, max_bits));
79        let num_keys = Self::num_keys(num_bits as f64, num_items as f64) as usize;
80        let keys: Vec<u64> = (0..num_keys).map(|_| rand::thread_rng().gen()).collect();
81        Self::new(num_bits, keys)
82    }
83    fn num_bits(num_items: f64, false_rate: f64) -> f64 {
84        let n = num_items;
85        let p = false_rate;
86        ((n * p.ln()) / (1f64 / 2f64.powf(2f64.ln())).ln()).ceil()
87    }
88    fn num_keys(num_bits: f64, num_items: f64) -> f64 {
89        let n = num_items;
90        let m = num_bits;
91        // infinity as usize is zero in rust 1.43 but 2^64-1 in rust 1.45; ensure it's zero here
92        if n == 0.0 {
93            0.0
94        } else {
95            1f64.max(((m / n) * 2f64.ln()).round())
96        }
97    }
98    fn pos(&self, key: &T, k: u64) -> u64 {
99        key.hash_at_index(k) % self.bits.len()
100    }
101    pub fn clear(&mut self) {
102        self.bits = BitVec::new_fill(false, self.bits.len());
103        self.num_bits_set = 0;
104    }
105    pub fn add(&mut self, key: &T) {
106        for k in &self.keys {
107            let pos = self.pos(key, *k);
108            if !self.bits.get(pos) {
109                self.num_bits_set += 1;
110                self.bits.set(pos, true);
111            }
112        }
113    }
114    pub fn contains(&self, key: &T) -> bool {
115        for k in &self.keys {
116            let pos = self.pos(key, *k);
117            if !self.bits.get(pos) {
118                return false;
119            }
120        }
121        true
122    }
123}
124
125fn slice_hash(slice: &[u8], hash_index: u64) -> u64 {
126    let mut hasher = FnvHasher::with_key(hash_index);
127    hasher.write(slice);
128    hasher.finish()
129}
130
131impl<T: AsRef<[u8]>> BloomHashIndex for T {
132    fn hash_at_index(&self, hash_index: u64) -> u64 {
133        slice_hash(self.as_ref(), hash_index)
134    }
135}
136
137pub struct AtomicBloom<T> {
138    num_bits: u64,
139    keys: Vec<u64>,
140    bits: Vec<AtomicU64>,
141    _phantom: PhantomData<T>,
142}
143
144impl<T: BloomHashIndex> From<Bloom<T>> for AtomicBloom<T> {
145    fn from(bloom: Bloom<T>) -> Self {
146        AtomicBloom {
147            num_bits: bloom.bits.len(),
148            keys: bloom.keys,
149            bits: bloom
150                .bits
151                .into_boxed_slice()
152                .iter()
153                .map(|&x| AtomicU64::new(x))
154                .collect(),
155            _phantom: PhantomData::default(),
156        }
157    }
158}
159
160impl<T: BloomHashIndex> AtomicBloom<T> {
161    fn pos(&self, key: &T, hash_index: u64) -> (usize, u64) {
162        let pos = key.hash_at_index(hash_index) % self.num_bits;
163        // Divide by 64 to figure out which of the
164        // AtomicU64 bit chunks we need to modify.
165        let index = pos >> 6;
166        // (pos & 63) is equivalent to mod 64 so that we can find
167        // the index of the bit within the AtomicU64 to modify.
168        let mask = 1u64 << (pos & 63);
169        (index as usize, mask)
170    }
171
172    pub fn add(&self, key: &T) {
173        for k in &self.keys {
174            let (index, mask) = self.pos(key, *k);
175            self.bits[index].fetch_or(mask, Ordering::Relaxed);
176        }
177    }
178
179    pub fn contains(&self, key: &T) -> bool {
180        self.keys.iter().all(|k| {
181            let (index, mask) = self.pos(key, *k);
182            let bit = self.bits[index].load(Ordering::Relaxed) & mask;
183            bit != 0u64
184        })
185    }
186
187    // Only for tests and simulations.
188    pub fn mock_clone(&self) -> Self {
189        Self {
190            keys: self.keys.clone(),
191            bits: self
192                .bits
193                .iter()
194                .map(|v| AtomicU64::new(v.load(Ordering::Relaxed)))
195                .collect(),
196            ..*self
197        }
198    }
199}
200
201impl<T: BloomHashIndex> From<AtomicBloom<T>> for Bloom<T> {
202    fn from(atomic_bloom: AtomicBloom<T>) -> Self {
203        let bits: Vec<_> = atomic_bloom
204            .bits
205            .into_iter()
206            .map(AtomicU64::into_inner)
207            .collect();
208        let num_bits_set = bits.iter().map(|x| x.count_ones() as u64).sum();
209        let mut bits: BitVec<u64> = bits.into();
210        bits.truncate(atomic_bloom.num_bits);
211        Bloom {
212            keys: atomic_bloom.keys,
213            bits,
214            num_bits_set,
215            _phantom: PhantomData::default(),
216        }
217    }
218}
219
220#[cfg(test)]
221mod test {
222    use super::*;
223    use rayon::prelude::*;
224    use gemachain_sdk::hash::{hash, Hash};
225
226    #[test]
227    fn test_bloom_filter() {
228        //empty
229        let bloom: Bloom<Hash> = Bloom::random(0, 0.1, 100);
230        assert_eq!(bloom.keys.len(), 0);
231        assert_eq!(bloom.bits.len(), 1);
232
233        //normal
234        let bloom: Bloom<Hash> = Bloom::random(10, 0.1, 100);
235        assert_eq!(bloom.keys.len(), 3);
236        assert_eq!(bloom.bits.len(), 48);
237
238        //saturated
239        let bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
240        assert_eq!(bloom.keys.len(), 1);
241        assert_eq!(bloom.bits.len(), 100);
242    }
243    #[test]
244    fn test_add_contains() {
245        let mut bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
246        //known keys to avoid false positives in the test
247        bloom.keys = vec![0, 1, 2, 3];
248
249        let key = hash(b"hello");
250        assert!(!bloom.contains(&key));
251        bloom.add(&key);
252        assert!(bloom.contains(&key));
253
254        let key = hash(b"world");
255        assert!(!bloom.contains(&key));
256        bloom.add(&key);
257        assert!(bloom.contains(&key));
258    }
259    #[test]
260    fn test_random() {
261        let mut b1: Bloom<Hash> = Bloom::random(10, 0.1, 100);
262        let mut b2: Bloom<Hash> = Bloom::random(10, 0.1, 100);
263        b1.keys.sort_unstable();
264        b2.keys.sort_unstable();
265        assert_ne!(b1.keys, b2.keys);
266    }
267    // Bloom filter math in python
268    // n number of items
269    // p false rate
270    // m number of bits
271    // k number of keys
272    //
273    // n = ceil(m / (-k / log(1 - exp(log(p) / k))))
274    // p = pow(1 - exp(-k / (m / n)), k)
275    // m = ceil((n * log(p)) / log(1 / pow(2, log(2))));
276    // k = round((m / n) * log(2));
277    #[test]
278    fn test_filter_math() {
279        assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.1f64) as u64, 480u64);
280        assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.01f64) as u64, 959u64);
281        assert_eq!(Bloom::<Hash>::num_keys(1000f64, 50f64) as u64, 14u64);
282        assert_eq!(Bloom::<Hash>::num_keys(2000f64, 50f64) as u64, 28u64);
283        assert_eq!(Bloom::<Hash>::num_keys(2000f64, 25f64) as u64, 55u64);
284        //ensure min keys is 1
285        assert_eq!(Bloom::<Hash>::num_keys(20f64, 1000f64) as u64, 1u64);
286    }
287
288    #[test]
289    fn test_debug() {
290        let mut b: Bloom<Hash> = Bloom::new(3, vec![100]);
291        b.add(&Hash::default());
292        assert_eq!(
293            format!("{:?}", b),
294            "Bloom { keys.len: 1 bits.len: 3 num_set: 1 bits: 001 }"
295        );
296
297        let mut b: Bloom<Hash> = Bloom::new(1000, vec![100]);
298        b.add(&Hash::default());
299        b.add(&hash(&[1, 2]));
300        assert_eq!(
301            format!("{:?}", b),
302            "Bloom { keys.len: 1 bits.len: 1000 num_set: 2 bits: 0000000000.. }"
303        );
304    }
305
306    #[test]
307    fn test_atomic_bloom() {
308        let mut rng = rand::thread_rng();
309        let hash_values: Vec<_> = std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
310            .take(1200)
311            .collect();
312        let bloom: AtomicBloom<_> = Bloom::<Hash>::random(1287, 0.1, 7424).into();
313        assert_eq!(bloom.keys.len(), 3);
314        assert_eq!(bloom.num_bits, 6168);
315        assert_eq!(bloom.bits.len(), 97);
316        hash_values.par_iter().for_each(|v| bloom.add(v));
317        let bloom: Bloom<Hash> = bloom.into();
318        assert_eq!(bloom.keys.len(), 3);
319        assert_eq!(bloom.bits.len(), 6168);
320        assert!(bloom.num_bits_set > 2000);
321        for hash_value in hash_values {
322            assert!(bloom.contains(&hash_value));
323        }
324        let false_positive = std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
325            .take(10_000)
326            .filter(|hash_value| bloom.contains(hash_value))
327            .count();
328        assert!(false_positive < 2_000, "false_positive: {}", false_positive);
329    }
330
331    #[test]
332    fn test_atomic_bloom_round_trip() {
333        let mut rng = rand::thread_rng();
334        let keys: Vec<_> = std::iter::repeat_with(|| rng.gen()).take(5).collect();
335        let mut bloom = Bloom::<Hash>::new(9731, keys.clone());
336        let hash_values: Vec<_> = std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
337            .take(1000)
338            .collect();
339        for hash_value in &hash_values {
340            bloom.add(hash_value);
341        }
342        let num_bits_set = bloom.num_bits_set;
343        assert!(num_bits_set > 2000, "# bits set: {}", num_bits_set);
344        // Round-trip with no inserts.
345        let bloom: AtomicBloom<_> = bloom.into();
346        assert_eq!(bloom.num_bits, 9731);
347        assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
348        for hash_value in &hash_values {
349            assert!(bloom.contains(hash_value));
350        }
351        let bloom: Bloom<_> = bloom.into();
352        assert_eq!(bloom.num_bits_set, num_bits_set);
353        for hash_value in &hash_values {
354            assert!(bloom.contains(hash_value));
355        }
356        // Round trip, re-inserting the same hash values.
357        let bloom: AtomicBloom<_> = bloom.into();
358        hash_values.par_iter().for_each(|v| bloom.add(v));
359        for hash_value in &hash_values {
360            assert!(bloom.contains(hash_value));
361        }
362        let bloom: Bloom<_> = bloom.into();
363        assert_eq!(bloom.num_bits_set, num_bits_set);
364        assert_eq!(bloom.bits.len(), 9731);
365        for hash_value in &hash_values {
366            assert!(bloom.contains(hash_value));
367        }
368        // Round trip, inserting new hash values.
369        let more_hash_values: Vec<_> =
370            std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
371                .take(1000)
372                .collect();
373        let bloom: AtomicBloom<_> = bloom.into();
374        assert_eq!(bloom.num_bits, 9731);
375        assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
376        more_hash_values.par_iter().for_each(|v| bloom.add(v));
377        for hash_value in &hash_values {
378            assert!(bloom.contains(hash_value));
379        }
380        for hash_value in &more_hash_values {
381            assert!(bloom.contains(hash_value));
382        }
383        let false_positive = std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
384            .take(10_000)
385            .filter(|hash_value| bloom.contains(hash_value))
386            .count();
387        assert!(false_positive < 2000, "false_positive: {}", false_positive);
388        let bloom: Bloom<_> = bloom.into();
389        assert_eq!(bloom.bits.len(), 9731);
390        assert!(bloom.num_bits_set > num_bits_set);
391        assert!(
392            bloom.num_bits_set > 4000,
393            "# bits set: {}",
394            bloom.num_bits_set
395        );
396        for hash_value in &hash_values {
397            assert!(bloom.contains(hash_value));
398        }
399        for hash_value in &more_hash_values {
400            assert!(bloom.contains(hash_value));
401        }
402        let false_positive = std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
403            .take(10_000)
404            .filter(|hash_value| bloom.contains(hash_value))
405            .count();
406        assert!(false_positive < 2000, "false_positive: {}", false_positive);
407        // Assert that the bits vector precisely match if no atomic ops were
408        // used.
409        let bits = bloom.bits;
410        let mut bloom = Bloom::<Hash>::new(9731, keys);
411        for hash_value in &hash_values {
412            bloom.add(hash_value);
413        }
414        for hash_value in &more_hash_values {
415            bloom.add(hash_value);
416        }
417        assert_eq!(bits, bloom.bits);
418    }
419}