bloomy_rs/
lib.rs

1#![feature(generic_const_exprs)]
2mod bitarray;
3
4use ahash::AHasher;
5use std::hash::Hash;
6use std::hash::Hasher;
7
8struct HashesIter {
9    h1: u64,
10    h2: u64,
11    mod_mask: u64,
12    cur_iter: usize,
13    iters: usize,
14}
15
16impl HashesIter {
17    fn new<T: Hash>(data: &T, iters: usize, mod_mask: u64) -> Self {
18        let (mut h1, mut h2) = Self::hash(data);
19        h1 &= mod_mask;
20        h2 &= mod_mask;
21
22        Self {
23            h1,
24            h2,
25            mod_mask,
26            cur_iter: 0,
27            iters,
28        }
29    }
30
31    fn hash<T: Hash>(data: &T) -> (u64, u64) {
32        let mut hasher = AHasher::default();
33        data.hash(&mut hasher);
34        let hash1 = hasher.finish();
35        hash1.hash(&mut hasher);
36        let hash2 = hasher.finish();
37        (hash1, hash2)
38    }
39}
40
41impl std::iter::Iterator for HashesIter {
42    type Item = u64;
43
44    fn next(&mut self) -> Option<Self::Item> {
45        if self.cur_iter == self.iters {
46            return None;
47        }
48        let h2i = ((self.cur_iter as u64 & self.mod_mask) * self.h2) & self.mod_mask;
49        self.cur_iter += 1;
50        Some((self.h1 + h2i) & self.mod_mask)
51    }
52}
53
54pub struct BloomFilter<const N: usize>
55where
56    [u8; N / 8]: Sized,
57{
58    bitarray: bitarray::BitArray<N>,
59    hash_functions_num: usize,
60}
61
62impl<const N: usize> BloomFilter<N>
63where
64    [u8; N / 8]: Sized,
65{
66    const N_BOUND_MASK: u64 = N as u64 - 1;
67    pub fn new(expected_inserts_number: usize) -> Self {
68        const LN_2: f64 = std::f64::consts::LN_2;
69        let hash_functions_num =
70            ((N as f64 / expected_inserts_number as f64) * LN_2).ceil() as usize;
71        Self {
72            bitarray: bitarray::BitArray::new(),
73            hash_functions_num,
74        }
75    }
76
77    pub fn insert<T: Hash>(&mut self, data: &T) {
78        let hashes_iter = HashesIter::new(data, self.hash_functions_num, Self::N_BOUND_MASK);
79        for hash in hashes_iter {
80            self.bitarray.set(hash as usize);
81        }
82    }
83
84    pub fn contains<T: Hash>(&self, data: &T) -> bool {
85        let hashes_iter = HashesIter::new(data, self.hash_functions_num, Self::N_BOUND_MASK);
86        for hash in hashes_iter {
87            if !self.bitarray.get(hash as usize) {
88                return false;
89            }
90        }
91        true
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use rand::prelude::*;
99    use std::collections::HashSet;
100
101    // e^(-filter_sz/inserts_num * (ln2)^2)
102    fn false_pos_prob(filter_sz: usize, inserts_num: usize) -> f64 {
103        let e = std::f64::consts::E;
104        let ln2sq = std::f64::consts::LN_2.powi(2);
105        let f_sz = filter_sz as f64;
106        let ins_cnt = inserts_num as f64;
107        let power = -(f_sz / ins_cnt as f64 * ln2sq);
108        e.powf(power)
109    }
110
111    #[test]
112    fn stress_test() {
113        const INSERTS_COUNT: usize = 1000;
114        const READS_COUNT: usize = 100000;
115        const FILTER_SIZE: usize = 4096;
116        let mut rng = rand::thread_rng();
117        let mut bf = BloomFilter::<FILTER_SIZE>::new(INSERTS_COUNT);
118        println!(
119            "False positive probability: {}",
120            false_pos_prob(FILTER_SIZE, INSERTS_COUNT)
121        );
122        let mut inserted = HashSet::new();
123        for _ in 0..INSERTS_COUNT {
124            let data = rng.gen::<u64>();
125            bf.insert(&data);
126            inserted.insert(data);
127            assert!(bf.contains(&data));
128        }
129        let mut false_positives = 0;
130        for _ in 0..READS_COUNT {
131            let data = rng.gen::<u64>();
132            if !inserted.contains(&data) && bf.contains(&data) {
133                false_positives += 1;
134            }
135            if !bf.contains(&data) {
136                assert!(!inserted.contains(&data));
137            }
138        }
139        println!(
140            "False positives: {} out of {} reads ({}%)",
141            false_positives,
142            READS_COUNT,
143            false_positives as f64 / READS_COUNT as f64 * 100.0
144        );
145    }
146}
147
148#[cfg(test)]
149mod gen_data_for_bench {
150    use rand::Rng;
151    const INPUT_SIZE: usize = 10000;
152    use std::fs::File;
153    use std::io::prelude::*;
154
155    use std::path::Path;
156
157    #[test]
158    fn write_input() {
159        const FILENAME: &str = "input.txt";
160        if Path::new(FILENAME).exists() {
161            return;
162        }
163        let mut file = File::create(FILENAME).unwrap();
164        for _ in 0..INPUT_SIZE {
165            let str_len = rand::thread_rng().gen_range(10..100);
166            let mut val = String::new();
167            for _ in 0..str_len {
168                val.push(rand::thread_rng().gen_range('a'..'z'));
169            }
170            file.write_all(val.as_bytes()).unwrap();
171            file.write_all("\n".as_bytes()).unwrap();
172        }
173    }
174}