1use crate::trigram::Trigram;
7use std::hash::Hasher;
8use std::io::Write;
9use xxhash_rust::xxh64::Xxh64;
10
11#[derive(Clone)]
16pub struct BloomFilter {
17 pub size: u16,
19 pub num_hashes: u8,
21 pub bits: Vec<u8>,
23}
24
25impl BloomFilter {
26 #[must_use]
28 pub fn new(size: usize, num_hashes: u8) -> Self {
29 Self {
30 size: u16::try_from(size).unwrap_or(0),
31 num_hashes,
32 bits: vec![0u8; size],
33 }
34 }
35
36 pub fn insert(&mut self, trigram: Trigram) {
38 let tri_bytes = trigram.to_le_bytes();
39 let h1 = Self::hash(&tri_bytes, 0);
40 let h2 = Self::hash(&tri_bytes, 1);
41 let num_bits = usize::from(self.size) * 8;
42
43 for i in 0..self.num_hashes {
44 let bit_pos =
45 (h1.wrapping_add(u64::from(i).wrapping_mul(h2))) % u64::try_from(num_bits).unwrap_or(0);
46 let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
47 let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
48 if let Some(byte) = self.bits.get_mut(byte_idx) {
49 *byte |= 1 << bit_idx;
50 }
51 }
52 }
53
54 #[must_use]
59 pub fn contains(&self, trigram: Trigram) -> bool {
60 let tri_bytes = trigram.to_le_bytes();
61 let h1 = Self::hash(&tri_bytes, 0);
62 let h2 = Self::hash(&tri_bytes, 1);
63 let num_bits = usize::from(self.size) * 8;
64
65 for i in 0..self.num_hashes {
66 let bit_pos =
67 (h1.wrapping_add(u64::from(i).wrapping_mul(h2))) % u64::try_from(num_bits).unwrap_or(0);
68 let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
69 let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
70 if self.bits.get(byte_idx).is_none_or(|&b| b & (1 << bit_idx) == 0) {
71 return false;
72 }
73 }
74 true
75 }
76
77 fn hash(data: &[u8], seed: u64) -> u64 {
78 let mut hasher = Xxh64::new(seed);
79 hasher.write(data);
80 hasher.finish()
81 }
82
83 pub fn serialize<W: Write>(&self, mut w: W) -> std::io::Result<()> {
89 w.write_all(&self.size.to_le_bytes())?;
90 w.write_all(&[self.num_hashes, 0x00])?;
91 w.write_all(&self.bits)?;
92 Ok(())
93 }
94
95 #[must_use]
97 pub fn from_slice(data: &[u8]) -> Option<(&[u8], usize)> {
98 if data.len() < 4 {
99 return None;
100 }
101 let size = data
102 .get(0..2)?
103 .try_into()
104 .ok()
105 .map_or(0, u16::from_le_bytes);
106 let size = usize::from(size);
107 let num_hashes = *data.get(2)?;
108 let total_size = 4 + size;
109 if data.len() < total_size {
110 return None;
111 }
112 data.get(4..total_size)
113 .map(|bits| (bits, usize::from(num_hashes)))
114 }
115
116 #[must_use]
118 pub fn slice_contains(bits: &[u8], num_hashes: u8, trigram: Trigram) -> bool {
119 let tri_bytes = trigram.to_le_bytes();
120 let mut h1_hasher = Xxh64::new(0);
121 h1_hasher.write(&tri_bytes);
122 let h1 = h1_hasher.finish();
123
124 let mut h2_hasher = Xxh64::new(1);
125 h2_hasher.write(&tri_bytes);
126 let h2 = h2_hasher.finish();
127
128 let num_bits = bits.len() * 8;
129
130 for i in 0..num_hashes {
131 let bit_pos =
132 (h1.wrapping_add(u64::from(i).wrapping_mul(h2))) % u64::try_from(num_bits).unwrap_or(0);
133 let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
134 let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
135 if bits.get(byte_idx).is_none_or(|&b| b & (1 << bit_idx) == 0) {
136 return false;
137 }
138 }
139 true
140 }
141}
142
143#[cfg(test)]
144#[allow(clippy::as_conversions, clippy::unwrap_used, clippy::indexing_slicing)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn basic() {
150 let mut bloom = BloomFilter::new(256, 5);
151 let t1 = 0x0001_0203;
152 let t2 = 0x0004_0506;
153 bloom.insert(t1);
154 assert!(bloom.contains(t1));
155 assert!(!bloom.contains(t2));
156 }
157
158 #[test]
159 fn false_positives() {
160 let mut bloom = BloomFilter::new(256, 5);
161 for i in 0..200 {
162 bloom.insert(i as u32);
163 }
164 let mut fp = 0;
165 for i in 200..1200 {
166 if bloom.contains(i as u32) {
167 fp += 1;
168 }
169 }
170 assert!(fp < 20, "FPR too high: {fp}/1000");
171 }
172}