1use crate::trigram::Trigram;
7use std::hash::Hasher;
8use std::io::Write;
9use xxhash_rust::xxh64::Xxh64;
10
11#[derive(Clone)]
12pub struct BloomFilter {
13 pub size: u16,
14 pub num_hashes: u8,
15 pub bits: Vec<u8>,
16}
17
18impl BloomFilter {
19 pub fn new(size: usize, num_hashes: u8) -> Self {
20 Self {
21 size: size as u16,
22 num_hashes,
23 bits: vec![0u8; size],
24 }
25 }
26
27 pub fn insert(&mut self, trigram: Trigram) {
28 let tri_bytes = trigram.to_le_bytes();
29 let h1 = self.hash(&tri_bytes, 0);
30 let h2 = self.hash(&tri_bytes, 1);
31 let num_bits = (self.size as usize) * 8;
32
33 for i in 0..self.num_hashes {
34 let bit_pos = (h1.wrapping_add((i as u64).wrapping_mul(h2))) % (num_bits as u64);
35 let byte_idx = (bit_pos / 8) as usize;
36 let bit_idx = (bit_pos % 8) as u8;
37 self.bits[byte_idx] |= 1 << bit_idx;
38 }
39 }
40
41 pub fn contains(&self, trigram: Trigram) -> bool {
42 let tri_bytes = trigram.to_le_bytes();
43 let h1 = self.hash(&tri_bytes, 0);
44 let h2 = self.hash(&tri_bytes, 1);
45 let num_bits = (self.size as usize) * 8;
46
47 for i in 0..self.num_hashes {
48 let bit_pos = (h1.wrapping_add((i as u64).wrapping_mul(h2))) % (num_bits as u64);
49 let byte_idx = (bit_pos / 8) as usize;
50 let bit_idx = (bit_pos % 8) as u8;
51 if self.bits[byte_idx] & (1 << bit_idx) == 0 {
52 return false;
53 }
54 }
55 true
56 }
57
58 fn hash(&self, data: &[u8], seed: u64) -> u64 {
59 let mut hasher = Xxh64::new(seed);
60 hasher.write(data);
61 hasher.finish()
62 }
63
64 pub fn serialize<W: Write>(&self, mut w: W) -> std::io::Result<()> {
65 w.write_all(&self.size.to_le_bytes())?;
66 w.write_all(&[self.num_hashes, 0x00])?; w.write_all(&self.bits)?;
68 Ok(())
69 }
70
71 pub fn from_slice(data: &[u8]) -> Option<(&[u8], usize)> {
73 if data.len() < 4 {
74 return None;
75 }
76 let size = data[0..2].try_into().ok().map(u16::from_le_bytes).unwrap_or(0) as usize;
77 let num_hashes = data[2];
78 let total_size = 4 + size;
79 if data.len() < total_size {
80 return None;
81 }
82 Some((&data[4..total_size], num_hashes as usize))
83 }
84
85 pub fn slice_contains(bits: &[u8], num_hashes: u8, trigram: Trigram) -> bool {
87 let tri_bytes = trigram.to_le_bytes();
88 let mut h1_hasher = Xxh64::new(0);
89 h1_hasher.write(&tri_bytes);
90 let h1 = h1_hasher.finish();
91
92 let mut h2_hasher = Xxh64::new(1);
93 h2_hasher.write(&tri_bytes);
94 let h2 = h2_hasher.finish();
95
96 let num_bits = bits.len() * 8;
97
98 for i in 0..num_hashes {
99 let bit_pos = (h1.wrapping_add((i as u64).wrapping_mul(h2))) % (num_bits as u64);
100 let byte_idx = (bit_pos / 8) as usize;
101 let bit_idx = (bit_pos % 8) as u8;
102 if bits[byte_idx] & (1 << bit_idx) == 0 {
103 return false;
104 }
105 }
106 true
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn basic() {
116 let mut bloom = BloomFilter::new(256, 5);
117 let t1 = 0x010203;
118 let t2 = 0x040506;
119 bloom.insert(t1);
120 assert!(bloom.contains(t1));
121 assert!(!bloom.contains(t2));
122 }
123
124 #[test]
125 fn false_positives() {
126 let mut bloom = BloomFilter::new(256, 5);
127 for i in 0..200 {
128 bloom.insert(i as u32);
129 }
130 let mut fp = 0;
131 for i in 200..1200 {
132 if bloom.contains(i as u32) {
133 fp += 1;
134 }
135 }
136 assert!(fp < 20, "FPR too high: {}/1000", fp);
138 }
139}