1use crate::trigram::Trigram;
7use std::hash::Hasher;
8use std::io::Write;
9use xxhash_rust::xxh64::Xxh64;
10
11#[derive(Clone)]
16pub struct BloomFilter {
17 pub size: u16,
19 pub num_hashes: u8,
21 pub bits: Vec<u8>,
23}
24
25impl BloomFilter {
26 #[must_use]
28 pub fn new(size: usize, num_hashes: u8) -> Self {
29 Self {
30 size: u16::try_from(size).unwrap_or(0),
31 num_hashes,
32 bits: vec![0u8; size],
33 }
34 }
35
36 pub fn insert(&mut self, trigram: Trigram) {
38 let tri_bytes = trigram.to_le_bytes();
39 let h1 = Self::hash(&tri_bytes, 0);
40 let h2 = Self::hash(&tri_bytes, 1);
41 let num_bits = usize::from(self.size) * 8;
42
43 for i in 0..self.num_hashes {
44 let bit_pos = (h1.wrapping_add(u64::from(i).wrapping_mul(h2)))
45 % u64::try_from(num_bits).unwrap_or(0);
46 let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
47 let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
48 if let Some(byte) = self.bits.get_mut(byte_idx) {
49 *byte |= 1 << bit_idx;
50 }
51 }
52 }
53
54 #[must_use]
59 pub fn contains(&self, trigram: Trigram) -> bool {
60 let tri_bytes = trigram.to_le_bytes();
61 let h1 = Self::hash(&tri_bytes, 0);
62 let h2 = Self::hash(&tri_bytes, 1);
63 let num_bits = usize::from(self.size) * 8;
64
65 for i in 0..self.num_hashes {
66 let bit_pos = (h1.wrapping_add(u64::from(i).wrapping_mul(h2)))
67 % u64::try_from(num_bits).unwrap_or(0);
68 let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
69 let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
70 if self
71 .bits
72 .get(byte_idx)
73 .is_none_or(|&b| b & (1 << bit_idx) == 0)
74 {
75 return false;
76 }
77 }
78 true
79 }
80
81 fn hash(data: &[u8], seed: u64) -> u64 {
82 let mut hasher = Xxh64::new(seed);
83 hasher.write(data);
84 hasher.finish()
85 }
86
87 pub fn serialize<W: Write>(&self, mut w: W) -> std::io::Result<()> {
93 w.write_all(&self.size.to_le_bytes())?;
94 w.write_all(&[self.num_hashes, 0x00])?;
95 w.write_all(&self.bits)?;
96 Ok(())
97 }
98
99 #[must_use]
101 pub fn from_slice(data: &[u8]) -> Option<(&[u8], usize)> {
102 if data.len() < 4 {
103 return None;
104 }
105 let size = data
106 .get(0..2)?
107 .try_into()
108 .ok()
109 .map_or(0, u16::from_le_bytes);
110 let size = usize::from(size);
111 let num_hashes = *data.get(2)?;
112 let total_size = 4 + size;
113 if data.len() < total_size {
114 return None;
115 }
116 data.get(4..total_size)
117 .map(|bits| (bits, usize::from(num_hashes)))
118 }
119
120 #[must_use]
122 pub fn slice_contains(bits: &[u8], num_hashes: u8, trigram: Trigram) -> bool {
123 let tri_bytes = trigram.to_le_bytes();
124 let mut h1_hasher = Xxh64::new(0);
125 h1_hasher.write(&tri_bytes);
126 let h1 = h1_hasher.finish();
127
128 let mut h2_hasher = Xxh64::new(1);
129 h2_hasher.write(&tri_bytes);
130 let h2 = h2_hasher.finish();
131
132 let num_bits = bits.len() * 8;
133
134 for i in 0..num_hashes {
135 let bit_pos = (h1.wrapping_add(u64::from(i).wrapping_mul(h2)))
136 % u64::try_from(num_bits).unwrap_or(0);
137 let byte_idx = usize::try_from(bit_pos / 8).unwrap_or(0);
138 let bit_idx = u8::try_from(bit_pos % 8).unwrap_or(0);
139 if bits.get(byte_idx).is_none_or(|&b| b & (1 << bit_idx) == 0) {
140 return false;
141 }
142 }
143 true
144 }
145}
146
147#[cfg(test)]
148#[allow(clippy::as_conversions, clippy::unwrap_used, clippy::indexing_slicing)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn basic() {
154 let mut bloom = BloomFilter::new(256, 5);
155 let t1 = 0x0001_0203;
156 let t2 = 0x0004_0506;
157 bloom.insert(t1);
158 assert!(bloom.contains(t1));
159 assert!(!bloom.contains(t2));
160 }
161
162 #[test]
163 fn false_positives() {
164 let mut bloom = BloomFilter::new(256, 5);
165 for i in 0..200 {
166 bloom.insert(i as u32);
167 }
168 let mut fp = 0;
169 for i in 200..1200 {
170 if bloom.contains(i as u32) {
171 fp += 1;
172 }
173 }
174 assert!(fp < 20, "FPR too high: {fp}/1000");
175 }
176}