1use crate::hash::bloom_hash;
2
3mod hash;
4
5pub struct BloomFilter {
6 k: usize,
7 bits_per_key: usize,
8}
9
10impl BloomFilter {
11 pub fn new(bits_per_key: usize) -> BloomFilter {
12 let mut k = (bits_per_key as f64 * 0.69) as usize;
13 k = k.clamp(1, 30);
14
15 BloomFilter { k, bits_per_key }
16 }
17
18 pub fn create_filter(&self, keys: &[Vec<u8>]) -> Vec<u8> {
19 let mut bits = keys.len() * self.bits_per_key;
20
21 if bits < 64 {
22 bits = 64;
23 }
24
25 let bytes = (bits + 7) / 8;
26 bits = bytes * 8;
27
28 let mut filter = vec![0; bytes];
29 filter.push(self.k as u8);
30
31 for key in keys {
32 let mut h = bloom_hash(key);
33 let delta = h.rotate_right(17);
34
35 for _ in 0..self.k {
36 let bitpos = (h % bits as u32) as usize;
37 filter[bitpos / 8] |= 1 << (bitpos % 8);
38 h = h.wrapping_add(delta);
39 }
40 }
41 filter
42 }
43
44 pub fn key_may_match(&self, key: &[u8], bloom_filter: &[u8]) -> bool {
45 let len = bloom_filter.len();
46
47 if len < 2 {
48 return false;
49 }
50
51 let bits = (len - 1) * 8;
52
53 let k = bloom_filter.last().unwrap();
54 if *k > 30 {
55 return true;
56 }
57
58 let mut h = bloom_hash(key);
59 let delta = h.rotate_right(17);
60 for _ in 0..*k {
61 let bitpos = (h % bits as u32) as usize;
62 if (bloom_filter[bitpos / 8] & (1 << (bitpos % 8))) == 0 {
63 return false;
64 }
65
66 h = h.wrapping_add(delta);
67 }
68
69 true
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use crate::BloomFilter;
76
77 fn encode_fixed32(value: u32) -> [u8; 4] {
80 value.to_le_bytes()
81 }
82
83 #[test]
84 fn test_empty_filter() {
85 let filter = BloomFilter::new(10);
86 assert!(!filter.key_may_match(b"hello", &[]));
87 assert!(!filter.key_may_match(b"hello", &[]));
88 assert!(!filter.key_may_match(b"x", &[]));
89 assert!(!filter.key_may_match(b"foo", &[]));
90 }
91
92 #[test]
93 fn test_small() {
94 let bloom_filter = BloomFilter::new(10);
95 let mut v = Vec::new();
96 v.push(b"hello".to_vec());
97 v.push(b"world".to_vec());
98 let filter = bloom_filter.create_filter(v.as_slice());
99 assert!(bloom_filter.key_may_match(b"hello", &filter));
100 assert!(bloom_filter.key_may_match(b"world", &filter));
101 assert!(!bloom_filter.key_may_match(b"x", &filter));
102 assert!(!bloom_filter.key_may_match(b"foo", &filter));
103 }
104
105 fn next_length(length: usize) -> usize {
106 if length < 10 {
107 return length + 1;
108 } else if length < 100 {
109 return length + 10;
110 } else if length < 1000 {
111 return length + 100;
112 }
113 length + 1000
114 }
115
116 fn false_positive_rate(bloom_filter: &BloomFilter, filter: &[u8]) -> f64 {
117 let mut result = 0;
118 for i in 0..10000 {
119 if bloom_filter.key_may_match(&encode_fixed32(i + 1000000000), filter) {
120 result += 1;
121 }
122 }
123 result as f64 / 10000.0
124 }
125
126 #[test]
127 fn test_varing_length() {
128 let bloom_filter = BloomFilter::new(10);
129 let mut mediocre_filters = 0;
130 let mut good_filters = 0;
131
132 let mut length = 1;
133
134 loop {
135 if length > 10000 {
136 break;
137 }
138
139 let mut keys = Vec::with_capacity(length);
140 for i in 0..length {
141 keys.push(encode_fixed32(i as u32).to_vec());
142 }
143 let filter = bloom_filter.create_filter(keys.as_slice());
144
145 assert!(filter.len() <= ((length * 10 / 8) + 40));
146
147 for i in 0..length {
148 assert!(bloom_filter.key_may_match(&encode_fixed32(i as u32), &filter));
149 }
150
151 let rate = false_positive_rate(&bloom_filter, &filter);
152
153 println!(
154 "False positives: {:5.2}% @ length = {:6} ; bytes = {:6}",
155 rate * 100.0,
156 length,
157 filter.len() as i32
158 );
159
160 assert!(rate <= 0.02);
161 if rate > 0.0125 {
162 mediocre_filters += 1;
163 } else {
164 good_filters += 1;
165 }
166
167 length = next_length(length);
168 }
169
170 println!(
171 "Filters: {} good, {} mediocre",
172 good_filters,
173 mediocre_filters
174 );
175
176 assert!(mediocre_filters <= good_filters / 5);
177 }
178}