1#![feature(generic_const_exprs)]
2mod bitarray;
3
4use ahash::AHasher;
5use std::hash::Hash;
6use std::hash::Hasher;
7
8struct HashesIter {
9 h1: u64,
10 h2: u64,
11 mod_mask: u64,
12 cur_iter: usize,
13 iters: usize,
14}
15
16impl HashesIter {
17 fn new<T: Hash>(data: &T, iters: usize, mod_mask: u64) -> Self {
18 let (mut h1, mut h2) = Self::hash(data);
19 h1 &= mod_mask;
20 h2 &= mod_mask;
21
22 Self {
23 h1,
24 h2,
25 mod_mask,
26 cur_iter: 0,
27 iters,
28 }
29 }
30
31 fn hash<T: Hash>(data: &T) -> (u64, u64) {
32 let mut hasher = AHasher::default();
33 data.hash(&mut hasher);
34 let hash1 = hasher.finish();
35 hash1.hash(&mut hasher);
36 let hash2 = hasher.finish();
37 (hash1, hash2)
38 }
39}
40
41impl std::iter::Iterator for HashesIter {
42 type Item = u64;
43
44 fn next(&mut self) -> Option<Self::Item> {
45 if self.cur_iter == self.iters {
46 return None;
47 }
48 let h2i = ((self.cur_iter as u64 & self.mod_mask) * self.h2) & self.mod_mask;
49 self.cur_iter += 1;
50 Some((self.h1 + h2i) & self.mod_mask)
51 }
52}
53
54pub struct BloomFilter<const N: usize>
55where
56 [u8; N / 8]: Sized,
57{
58 bitarray: bitarray::BitArray<N>,
59 hash_functions_num: usize,
60}
61
62impl<const N: usize> BloomFilter<N>
63where
64 [u8; N / 8]: Sized,
65{
66 const N_BOUND_MASK: u64 = N as u64 - 1;
67 pub fn new(expected_inserts_number: usize) -> Self {
68 const LN_2: f64 = std::f64::consts::LN_2;
69 let hash_functions_num =
70 ((N as f64 / expected_inserts_number as f64) * LN_2).ceil() as usize;
71 Self {
72 bitarray: bitarray::BitArray::new(),
73 hash_functions_num,
74 }
75 }
76
77 pub fn insert<T: Hash>(&mut self, data: &T) {
78 let hashes_iter = HashesIter::new(data, self.hash_functions_num, Self::N_BOUND_MASK);
79 for hash in hashes_iter {
80 self.bitarray.set(hash as usize);
81 }
82 }
83
84 pub fn contains<T: Hash>(&self, data: &T) -> bool {
85 let hashes_iter = HashesIter::new(data, self.hash_functions_num, Self::N_BOUND_MASK);
86 for hash in hashes_iter {
87 if !self.bitarray.get(hash as usize) {
88 return false;
89 }
90 }
91 true
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use rand::prelude::*;
99 use std::collections::HashSet;
100
101 fn false_pos_prob(filter_sz: usize, inserts_num: usize) -> f64 {
103 let e = std::f64::consts::E;
104 let ln2sq = std::f64::consts::LN_2.powi(2);
105 let f_sz = filter_sz as f64;
106 let ins_cnt = inserts_num as f64;
107 let power = -(f_sz / ins_cnt as f64 * ln2sq);
108 e.powf(power)
109 }
110
111 #[test]
112 fn stress_test() {
113 const INSERTS_COUNT: usize = 1000;
114 const READS_COUNT: usize = 100000;
115 const FILTER_SIZE: usize = 4096;
116 let mut rng = rand::thread_rng();
117 let mut bf = BloomFilter::<FILTER_SIZE>::new(INSERTS_COUNT);
118 println!(
119 "False positive probability: {}",
120 false_pos_prob(FILTER_SIZE, INSERTS_COUNT)
121 );
122 let mut inserted = HashSet::new();
123 for _ in 0..INSERTS_COUNT {
124 let data = rng.gen::<u64>();
125 bf.insert(&data);
126 inserted.insert(data);
127 assert!(bf.contains(&data));
128 }
129 let mut false_positives = 0;
130 for _ in 0..READS_COUNT {
131 let data = rng.gen::<u64>();
132 if !inserted.contains(&data) && bf.contains(&data) {
133 false_positives += 1;
134 }
135 if !bf.contains(&data) {
136 assert!(!inserted.contains(&data));
137 }
138 }
139 println!(
140 "False positives: {} out of {} reads ({}%)",
141 false_positives,
142 READS_COUNT,
143 false_positives as f64 / READS_COUNT as f64 * 100.0
144 );
145 }
146}
147
148#[cfg(test)]
149mod gen_data_for_bench {
150 use rand::Rng;
151 const INPUT_SIZE: usize = 10000;
152 use std::fs::File;
153 use std::io::prelude::*;
154
155 use std::path::Path;
156
157 #[test]
158 fn write_input() {
159 const FILENAME: &str = "input.txt";
160 if Path::new(FILENAME).exists() {
161 return;
162 }
163 let mut file = File::create(FILENAME).unwrap();
164 for _ in 0..INPUT_SIZE {
165 let str_len = rand::thread_rng().gen_range(10..100);
166 let mut val = String::new();
167 for _ in 0..str_len {
168 val.push(rand::thread_rng().gen_range('a'..'z'));
169 }
170 file.write_all(val.as_bytes()).unwrap();
171 file.write_all("\n".as_bytes()).unwrap();
172 }
173 }
174}