1use bv::BitVec;
3use fnv::FnvHasher;
4use rand::{self, Rng};
5use serde::{Deserialize, Serialize};
6use gemachain_sdk::sanitize::{Sanitize, SanitizeError};
7use std::fmt;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::{cmp, hash::Hasher, marker::PhantomData};
10
11pub trait BloomHashIndex {
14 fn hash_at_index(&self, hash_index: u64) -> u64;
15}
16
17#[derive(Serialize, Deserialize, Default, Clone, PartialEq, AbiExample)]
18pub struct Bloom<T: BloomHashIndex> {
19 pub keys: Vec<u64>,
20 pub bits: BitVec<u64>,
21 num_bits_set: u64,
22 _phantom: PhantomData<T>,
23}
24
25impl<T: BloomHashIndex> fmt::Debug for Bloom<T> {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 write!(
28 f,
29 "Bloom {{ keys.len: {} bits.len: {} num_set: {} bits: ",
30 self.keys.len(),
31 self.bits.len(),
32 self.num_bits_set
33 )?;
34 const MAX_PRINT_BITS: u64 = 10;
35 for i in 0..std::cmp::min(MAX_PRINT_BITS, self.bits.len()) {
36 if self.bits.get(i) {
37 write!(f, "1")?;
38 } else {
39 write!(f, "0")?;
40 }
41 }
42 if self.bits.len() > MAX_PRINT_BITS {
43 write!(f, "..")?;
44 }
45 write!(f, " }}")
46 }
47}
48
49impl<T: BloomHashIndex> Sanitize for Bloom<T> {
50 fn sanitize(&self) -> Result<(), SanitizeError> {
51 if self.bits.is_empty() {
53 Err(SanitizeError::InvalidValue)
54 } else {
55 Ok(())
56 }
57 }
58}
59
60impl<T: BloomHashIndex> Bloom<T> {
61 pub fn new(num_bits: usize, keys: Vec<u64>) -> Self {
62 let bits = BitVec::new_fill(false, num_bits as u64);
63 Bloom {
64 keys,
65 bits,
66 num_bits_set: 0,
67 _phantom: PhantomData::default(),
68 }
69 }
70 pub fn random(num_items: usize, false_rate: f64, max_bits: usize) -> Self {
77 let m = Self::num_bits(num_items as f64, false_rate);
78 let num_bits = cmp::max(1, cmp::min(m as usize, max_bits));
79 let num_keys = Self::num_keys(num_bits as f64, num_items as f64) as usize;
80 let keys: Vec<u64> = (0..num_keys).map(|_| rand::thread_rng().gen()).collect();
81 Self::new(num_bits, keys)
82 }
83 fn num_bits(num_items: f64, false_rate: f64) -> f64 {
84 let n = num_items;
85 let p = false_rate;
86 ((n * p.ln()) / (1f64 / 2f64.powf(2f64.ln())).ln()).ceil()
87 }
88 fn num_keys(num_bits: f64, num_items: f64) -> f64 {
89 let n = num_items;
90 let m = num_bits;
91 if n == 0.0 {
93 0.0
94 } else {
95 1f64.max(((m / n) * 2f64.ln()).round())
96 }
97 }
98 fn pos(&self, key: &T, k: u64) -> u64 {
99 key.hash_at_index(k) % self.bits.len()
100 }
101 pub fn clear(&mut self) {
102 self.bits = BitVec::new_fill(false, self.bits.len());
103 self.num_bits_set = 0;
104 }
105 pub fn add(&mut self, key: &T) {
106 for k in &self.keys {
107 let pos = self.pos(key, *k);
108 if !self.bits.get(pos) {
109 self.num_bits_set += 1;
110 self.bits.set(pos, true);
111 }
112 }
113 }
114 pub fn contains(&self, key: &T) -> bool {
115 for k in &self.keys {
116 let pos = self.pos(key, *k);
117 if !self.bits.get(pos) {
118 return false;
119 }
120 }
121 true
122 }
123}
124
125fn slice_hash(slice: &[u8], hash_index: u64) -> u64 {
126 let mut hasher = FnvHasher::with_key(hash_index);
127 hasher.write(slice);
128 hasher.finish()
129}
130
131impl<T: AsRef<[u8]>> BloomHashIndex for T {
132 fn hash_at_index(&self, hash_index: u64) -> u64 {
133 slice_hash(self.as_ref(), hash_index)
134 }
135}
136
137pub struct AtomicBloom<T> {
138 num_bits: u64,
139 keys: Vec<u64>,
140 bits: Vec<AtomicU64>,
141 _phantom: PhantomData<T>,
142}
143
144impl<T: BloomHashIndex> From<Bloom<T>> for AtomicBloom<T> {
145 fn from(bloom: Bloom<T>) -> Self {
146 AtomicBloom {
147 num_bits: bloom.bits.len(),
148 keys: bloom.keys,
149 bits: bloom
150 .bits
151 .into_boxed_slice()
152 .iter()
153 .map(|&x| AtomicU64::new(x))
154 .collect(),
155 _phantom: PhantomData::default(),
156 }
157 }
158}
159
160impl<T: BloomHashIndex> AtomicBloom<T> {
161 fn pos(&self, key: &T, hash_index: u64) -> (usize, u64) {
162 let pos = key.hash_at_index(hash_index) % self.num_bits;
163 let index = pos >> 6;
166 let mask = 1u64 << (pos & 63);
169 (index as usize, mask)
170 }
171
172 pub fn add(&self, key: &T) {
173 for k in &self.keys {
174 let (index, mask) = self.pos(key, *k);
175 self.bits[index].fetch_or(mask, Ordering::Relaxed);
176 }
177 }
178
179 pub fn contains(&self, key: &T) -> bool {
180 self.keys.iter().all(|k| {
181 let (index, mask) = self.pos(key, *k);
182 let bit = self.bits[index].load(Ordering::Relaxed) & mask;
183 bit != 0u64
184 })
185 }
186
187 pub fn mock_clone(&self) -> Self {
189 Self {
190 keys: self.keys.clone(),
191 bits: self
192 .bits
193 .iter()
194 .map(|v| AtomicU64::new(v.load(Ordering::Relaxed)))
195 .collect(),
196 ..*self
197 }
198 }
199}
200
201impl<T: BloomHashIndex> From<AtomicBloom<T>> for Bloom<T> {
202 fn from(atomic_bloom: AtomicBloom<T>) -> Self {
203 let bits: Vec<_> = atomic_bloom
204 .bits
205 .into_iter()
206 .map(AtomicU64::into_inner)
207 .collect();
208 let num_bits_set = bits.iter().map(|x| x.count_ones() as u64).sum();
209 let mut bits: BitVec<u64> = bits.into();
210 bits.truncate(atomic_bloom.num_bits);
211 Bloom {
212 keys: atomic_bloom.keys,
213 bits,
214 num_bits_set,
215 _phantom: PhantomData::default(),
216 }
217 }
218}
219
220#[cfg(test)]
221mod test {
222 use super::*;
223 use rayon::prelude::*;
224 use gemachain_sdk::hash::{hash, Hash};
225
226 #[test]
227 fn test_bloom_filter() {
228 let bloom: Bloom<Hash> = Bloom::random(0, 0.1, 100);
230 assert_eq!(bloom.keys.len(), 0);
231 assert_eq!(bloom.bits.len(), 1);
232
233 let bloom: Bloom<Hash> = Bloom::random(10, 0.1, 100);
235 assert_eq!(bloom.keys.len(), 3);
236 assert_eq!(bloom.bits.len(), 48);
237
238 let bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
240 assert_eq!(bloom.keys.len(), 1);
241 assert_eq!(bloom.bits.len(), 100);
242 }
243 #[test]
244 fn test_add_contains() {
245 let mut bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
246 bloom.keys = vec![0, 1, 2, 3];
248
249 let key = hash(b"hello");
250 assert!(!bloom.contains(&key));
251 bloom.add(&key);
252 assert!(bloom.contains(&key));
253
254 let key = hash(b"world");
255 assert!(!bloom.contains(&key));
256 bloom.add(&key);
257 assert!(bloom.contains(&key));
258 }
259 #[test]
260 fn test_random() {
261 let mut b1: Bloom<Hash> = Bloom::random(10, 0.1, 100);
262 let mut b2: Bloom<Hash> = Bloom::random(10, 0.1, 100);
263 b1.keys.sort_unstable();
264 b2.keys.sort_unstable();
265 assert_ne!(b1.keys, b2.keys);
266 }
267 #[test]
278 fn test_filter_math() {
279 assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.1f64) as u64, 480u64);
280 assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.01f64) as u64, 959u64);
281 assert_eq!(Bloom::<Hash>::num_keys(1000f64, 50f64) as u64, 14u64);
282 assert_eq!(Bloom::<Hash>::num_keys(2000f64, 50f64) as u64, 28u64);
283 assert_eq!(Bloom::<Hash>::num_keys(2000f64, 25f64) as u64, 55u64);
284 assert_eq!(Bloom::<Hash>::num_keys(20f64, 1000f64) as u64, 1u64);
286 }
287
288 #[test]
289 fn test_debug() {
290 let mut b: Bloom<Hash> = Bloom::new(3, vec![100]);
291 b.add(&Hash::default());
292 assert_eq!(
293 format!("{:?}", b),
294 "Bloom { keys.len: 1 bits.len: 3 num_set: 1 bits: 001 }"
295 );
296
297 let mut b: Bloom<Hash> = Bloom::new(1000, vec![100]);
298 b.add(&Hash::default());
299 b.add(&hash(&[1, 2]));
300 assert_eq!(
301 format!("{:?}", b),
302 "Bloom { keys.len: 1 bits.len: 1000 num_set: 2 bits: 0000000000.. }"
303 );
304 }
305
306 #[test]
307 fn test_atomic_bloom() {
308 let mut rng = rand::thread_rng();
309 let hash_values: Vec<_> = std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
310 .take(1200)
311 .collect();
312 let bloom: AtomicBloom<_> = Bloom::<Hash>::random(1287, 0.1, 7424).into();
313 assert_eq!(bloom.keys.len(), 3);
314 assert_eq!(bloom.num_bits, 6168);
315 assert_eq!(bloom.bits.len(), 97);
316 hash_values.par_iter().for_each(|v| bloom.add(v));
317 let bloom: Bloom<Hash> = bloom.into();
318 assert_eq!(bloom.keys.len(), 3);
319 assert_eq!(bloom.bits.len(), 6168);
320 assert!(bloom.num_bits_set > 2000);
321 for hash_value in hash_values {
322 assert!(bloom.contains(&hash_value));
323 }
324 let false_positive = std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
325 .take(10_000)
326 .filter(|hash_value| bloom.contains(hash_value))
327 .count();
328 assert!(false_positive < 2_000, "false_positive: {}", false_positive);
329 }
330
331 #[test]
332 fn test_atomic_bloom_round_trip() {
333 let mut rng = rand::thread_rng();
334 let keys: Vec<_> = std::iter::repeat_with(|| rng.gen()).take(5).collect();
335 let mut bloom = Bloom::<Hash>::new(9731, keys.clone());
336 let hash_values: Vec<_> = std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
337 .take(1000)
338 .collect();
339 for hash_value in &hash_values {
340 bloom.add(hash_value);
341 }
342 let num_bits_set = bloom.num_bits_set;
343 assert!(num_bits_set > 2000, "# bits set: {}", num_bits_set);
344 let bloom: AtomicBloom<_> = bloom.into();
346 assert_eq!(bloom.num_bits, 9731);
347 assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
348 for hash_value in &hash_values {
349 assert!(bloom.contains(hash_value));
350 }
351 let bloom: Bloom<_> = bloom.into();
352 assert_eq!(bloom.num_bits_set, num_bits_set);
353 for hash_value in &hash_values {
354 assert!(bloom.contains(hash_value));
355 }
356 let bloom: AtomicBloom<_> = bloom.into();
358 hash_values.par_iter().for_each(|v| bloom.add(v));
359 for hash_value in &hash_values {
360 assert!(bloom.contains(hash_value));
361 }
362 let bloom: Bloom<_> = bloom.into();
363 assert_eq!(bloom.num_bits_set, num_bits_set);
364 assert_eq!(bloom.bits.len(), 9731);
365 for hash_value in &hash_values {
366 assert!(bloom.contains(hash_value));
367 }
368 let more_hash_values: Vec<_> =
370 std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
371 .take(1000)
372 .collect();
373 let bloom: AtomicBloom<_> = bloom.into();
374 assert_eq!(bloom.num_bits, 9731);
375 assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
376 more_hash_values.par_iter().for_each(|v| bloom.add(v));
377 for hash_value in &hash_values {
378 assert!(bloom.contains(hash_value));
379 }
380 for hash_value in &more_hash_values {
381 assert!(bloom.contains(hash_value));
382 }
383 let false_positive = std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
384 .take(10_000)
385 .filter(|hash_value| bloom.contains(hash_value))
386 .count();
387 assert!(false_positive < 2000, "false_positive: {}", false_positive);
388 let bloom: Bloom<_> = bloom.into();
389 assert_eq!(bloom.bits.len(), 9731);
390 assert!(bloom.num_bits_set > num_bits_set);
391 assert!(
392 bloom.num_bits_set > 4000,
393 "# bits set: {}",
394 bloom.num_bits_set
395 );
396 for hash_value in &hash_values {
397 assert!(bloom.contains(hash_value));
398 }
399 for hash_value in &more_hash_values {
400 assert!(bloom.contains(hash_value));
401 }
402 let false_positive = std::iter::repeat_with(|| gemachain_sdk::hash::new_rand(&mut rng))
403 .take(10_000)
404 .filter(|hash_value| bloom.contains(hash_value))
405 .count();
406 assert!(false_positive < 2000, "false_positive: {}", false_positive);
407 let bits = bloom.bits;
410 let mut bloom = Bloom::<Hash>::new(9731, keys);
411 for hash_value in &hash_values {
412 bloom.add(hash_value);
413 }
414 for hash_value in &more_hash_values {
415 bloom.add(hash_value);
416 }
417 assert_eq!(bits, bloom.bits);
418 }
419}