1use bytes::{Bytes, BytesMut, BufMut};
7
8#[derive(Clone)]
10pub struct BloomFilter {
11 bits: Vec<u8>,
13 num_bits: usize,
15 num_hashes: u32,
17}
18
19impl BloomFilter {
20 pub fn new(num_items: usize, bits_per_key: usize) -> Self {
26 let num_bits = num_items * bits_per_key;
27 let num_bytes = (num_bits + 7) / 8;
28
29 let num_hashes = ((bits_per_key as f64) * 0.693).ceil() as u32;
32 let num_hashes = num_hashes.max(1).min(30); Self {
35 bits: vec![0u8; num_bytes],
36 num_bits,
37 num_hashes,
38 }
39 }
40
41 pub fn add(&mut self, key: &[u8]) {
43 let hash = Self::hash(key);
44 for i in 0..self.num_hashes {
45 let bit_pos = Self::bloom_hash(hash, i) % (self.num_bits as u64);
46 self.set_bit(bit_pos as usize);
47 }
48 }
49
50 pub fn contains(&self, key: &[u8]) -> bool {
52 let hash = Self::hash(key);
53 for i in 0..self.num_hashes {
54 let bit_pos = Self::bloom_hash(hash, i) % (self.num_bits as u64);
55 if !self.get_bit(bit_pos as usize) {
56 return false; }
58 }
59 true }
61
62 pub fn encode(&self) -> Bytes {
64 let mut buf = BytesMut::new();
65 buf.put_u32_le(self.num_bits as u32);
66 buf.put_u32_le(self.num_hashes);
67 buf.put_slice(&self.bits);
68 buf.freeze()
69 }
70
71 pub fn decode(data: &[u8]) -> Option<Self> {
73 if data.len() < 8 {
74 return None;
75 }
76
77 let num_bits = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
78 let num_hashes = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
79
80 if num_bits == 0 {
82 return None;
83 }
84
85 let num_bytes = (num_bits + 7) / 8;
86 if data.len() < 8 + num_bytes {
87 return None;
88 }
89
90 let bits = data[8..8 + num_bytes].to_vec();
91
92 Some(Self {
93 bits,
94 num_bits,
95 num_hashes,
96 })
97 }
98
99 pub fn size(&self) -> usize {
101 8 + self.bits.len()
102 }
103
104 fn set_bit(&mut self, pos: usize) {
107 let byte_idx = pos / 8;
108 let bit_idx = pos % 8;
109 if byte_idx < self.bits.len() {
110 self.bits[byte_idx] |= 1 << bit_idx;
111 }
112 }
113
114 fn get_bit(&self, pos: usize) -> bool {
115 let byte_idx = pos / 8;
116 let bit_idx = pos % 8;
117 if byte_idx < self.bits.len() {
118 (self.bits[byte_idx] & (1 << bit_idx)) != 0
119 } else {
120 false
121 }
122 }
123
124 fn hash(key: &[u8]) -> u64 {
126 let mut hash = 0xcbf29ce484222325u64;
128 for &byte in key {
129 hash ^= byte as u64;
130 hash = hash.wrapping_mul(0x100000001b3);
131 }
132 hash
133 }
134
135 fn bloom_hash(hash: u64, i: u32) -> u64 {
137 let h1 = hash;
138 let h2 = hash.wrapping_shr(32);
139 h1.wrapping_add((i as u64).wrapping_mul(h2))
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn test_bloom_basic() {
149 let mut bloom = BloomFilter::new(100, 10);
150
151 bloom.add(b"key1");
153 bloom.add(b"key2");
154 bloom.add(b"key3");
155
156 assert!(bloom.contains(b"key1"));
158 assert!(bloom.contains(b"key2"));
159 assert!(bloom.contains(b"key3"));
160
161 assert!(!bloom.contains(b"key4"));
163 assert!(!bloom.contains(b"key5"));
164 }
165
166 #[test]
167 fn test_bloom_encode_decode() {
168 let mut bloom = BloomFilter::new(50, 10);
169
170 bloom.add(b"test1");
171 bloom.add(b"test2");
172
173 let encoded = bloom.encode();
174 let decoded = BloomFilter::decode(&encoded).unwrap();
175
176 assert_eq!(decoded.num_bits, bloom.num_bits);
177 assert_eq!(decoded.num_hashes, bloom.num_hashes);
178 assert_eq!(decoded.bits, bloom.bits);
179
180 assert!(decoded.contains(b"test1"));
182 assert!(decoded.contains(b"test2"));
183 assert!(!decoded.contains(b"test3"));
184 }
185
186 #[test]
187 fn test_bloom_false_positive_rate() {
188 let mut bloom = BloomFilter::new(1000, 10);
189
190 for i in 0..1000 {
192 let key = format!("key{}", i);
193 bloom.add(key.as_bytes());
194 }
195
196 let mut false_positives = 0;
198 let test_count = 10000;
199
200 for i in 1000..1000 + test_count {
201 let key = format!("key{}", i);
202 if bloom.contains(key.as_bytes()) {
203 false_positives += 1;
204 }
205 }
206
207 let fp_rate = (false_positives as f64) / (test_count as f64);
209 assert!(fp_rate < 0.02, "False positive rate too high: {}", fp_rate);
210 }
211
212 #[test]
213 fn test_bloom_empty() {
214 let bloom = BloomFilter::new(10, 10);
215
216 assert!(!bloom.contains(b"test"));
218 }
219
220 #[test]
221 fn test_bloom_size() {
222 let bloom = BloomFilter::new(100, 10);
223 let size = bloom.size();
224
225 let expected_bytes = (100 * 10 + 7) / 8;
227 assert_eq!(size, 8 + expected_bytes);
228 }
229
230 #[test]
231 fn test_bloom_decode_invalid() {
232 assert!(BloomFilter::decode(&[0, 1, 2]).is_none());
234
235 let data = vec![0u8; 8];
237 assert!(BloomFilter::decode(&data).is_none());
238 }
239
240 #[test]
241 fn test_bloom_num_hashes() {
242 let bloom = BloomFilter::new(100, 10);
244 assert_eq!(bloom.num_hashes, 7);
245
246 let bloom = BloomFilter::new(100, 5);
248 assert_eq!(bloom.num_hashes, 4);
249 }
250}