1use std::cmp::Ordering;
26use std::collections::HashSet;
27use std::hash::{BuildHasher, Hasher};
28use std::{cmp, io};
29
30use siphasher::sip::SipHasher24;
31
32pub const P: u8 = 19;
34pub const M: u64 = 784_931;
35
36pub struct SipHasher24Builder {
38 k0: u64,
39 k1: u64,
40}
41
42impl SipHasher24Builder {
43 pub fn new(k0: u64, k1: u64) -> SipHasher24Builder {
44 SipHasher24Builder { k0, k1 }
45 }
46}
47
48impl BuildHasher for SipHasher24Builder {
49 type Hasher = SipHasher24;
50 fn build_hasher(&self) -> Self::Hasher {
51 SipHasher24::new_with_keys(self.k0, self.k1)
52 }
53}
54
55pub struct GCSFilterReader<H> {
57 filter: GCSFilter<H>,
58 m: u64,
59}
60
61impl<H: BuildHasher> GCSFilterReader<H> {
62 pub fn new(hasher_builder: H, m: u64, p: u8) -> GCSFilterReader<H> {
64 GCSFilterReader {
65 filter: GCSFilter::new(hasher_builder, p),
66 m,
67 }
68 }
69
70 pub fn match_any(
73 &self,
74 reader: &mut dyn io::Read,
75 query: &mut dyn Iterator<Item = &[u8]>,
76 ) -> Result<bool, io::Error> {
77 let mut decoder = reader;
78
79 let mut length_data = [0u8; 8];
81 let n_elements = decoder
82 .read_exact(&mut length_data)
83 .map(|()| u64::from_le_bytes(length_data))
84 .unwrap_or(0);
85
86 let reader = &mut decoder;
87 let nm = n_elements * self.m;
89 let mut mapped = query
90 .map(|e| map_to_range(self.filter.hash(e), nm))
91 .collect::<Vec<_>>();
92 mapped.sort_unstable();
94 if mapped.is_empty() {
95 return Ok(false);
96 }
97 if n_elements == 0 {
98 return Ok(false);
99 }
100
101 let mut reader = BitStreamReader::new(reader);
103 let mut data = self.filter.golomb_rice_decode(&mut reader)?;
104 let mut remaining = n_elements - 1;
105 for p in mapped {
106 loop {
107 match data.cmp(&p) {
108 Ordering::Equal => {
109 return Ok(true);
110 }
111 Ordering::Less => {
112 if remaining > 0 {
113 data += self.filter.golomb_rice_decode(&mut reader)?;
114 remaining -= 1;
115 } else {
116 return Ok(false);
117 }
118 }
119 Ordering::Greater => {
120 break;
121 }
122 }
123 }
124 }
125 Ok(false)
126 }
127
128 pub fn match_all(
131 &self,
132 reader: &mut dyn io::Read,
133 query: &mut dyn Iterator<Item = &[u8]>,
134 ) -> Result<bool, io::Error> {
135 let mut decoder = reader;
136
137 let mut length_data = [0u8; 8];
139 let n_elements = decoder
140 .read_exact(&mut length_data)
141 .map(|()| u64::from_le_bytes(length_data))
142 .unwrap_or(0);
143
144 let reader = &mut decoder;
145 let nm = n_elements * self.m;
147 let mut mapped = query
148 .map(|e| map_to_range(self.filter.hash(e), nm))
149 .collect::<Vec<_>>();
150 mapped.sort_unstable();
152 mapped.dedup();
153 if mapped.is_empty() {
154 return Ok(false);
155 }
156 if n_elements == 0 {
157 return Ok(false);
158 }
159
160 let mut reader = BitStreamReader::new(reader);
162 let mut data = self.filter.golomb_rice_decode(&mut reader)?;
163 let mut remaining = n_elements - 1;
164 for p in mapped {
165 loop {
166 match data.cmp(&p) {
167 Ordering::Equal => {
168 break;
169 }
170 Ordering::Less => {
171 if remaining > 0 {
172 data += self.filter.golomb_rice_decode(&mut reader)?;
173 remaining -= 1;
174 } else {
175 return Ok(false);
176 }
177 }
178 Ordering::Greater => {
179 return Ok(false);
180 }
181 }
182 }
183 }
184 Ok(true)
185 }
186}
187
188fn map_to_range(hash: u64, nm: u64) -> u64 {
190 ((hash as u128 * nm as u128) >> 64) as u64
191}
192
193pub struct GCSFilterWriter<'a, H> {
195 filter: GCSFilter<H>,
196 writer: &'a mut dyn io::Write,
197 elements: HashSet<Vec<u8>>,
198 m: u64,
199}
200
201impl<'a, H: BuildHasher> GCSFilterWriter<'a, H> {
202 pub fn new(
204 writer: &'a mut dyn io::Write,
205 hasher_builder: H,
206 m: u64,
207 p: u8,
208 ) -> GCSFilterWriter<'a, H> {
209 GCSFilterWriter {
210 filter: GCSFilter::new(hasher_builder, p),
211 writer,
212 elements: HashSet::new(),
213 m,
214 }
215 }
216
217 pub fn add_element(&mut self, element: &[u8]) {
219 if !element.is_empty() {
220 self.elements.insert(element.to_vec());
221 }
222 }
223
224 pub fn finish(&mut self) -> Result<usize, io::Error> {
226 let nm = self.elements.len() as u64 * self.m;
227
228 let mut mapped: Vec<_> = self
230 .elements
231 .iter()
232 .map(|e| map_to_range(self.filter.hash(e.as_slice()), nm))
233 .collect();
234 mapped.sort_unstable();
235
236 let mut wrote = self.writer.write(&(mapped.len() as u64).to_le_bytes())?;
239
240 let mut writer = BitStreamWriter::new(self.writer);
242 let mut last = 0;
243 for data in mapped {
244 wrote += self.filter.golomb_rice_encode(&mut writer, data - last)?;
245 last = data;
246 }
247 wrote += writer.flush()?;
248 Ok(wrote)
249 }
250}
251
252struct GCSFilter<H> {
254 hasher_builder: H,
255 p: u8,
256}
257
258impl<H: BuildHasher> GCSFilter<H> {
259 fn new(hasher_builder: H, p: u8) -> GCSFilter<H> {
261 GCSFilter { hasher_builder, p }
262 }
263
264 fn golomb_rice_encode(&self, writer: &mut BitStreamWriter, n: u64) -> Result<usize, io::Error> {
266 let mut wrote = 0;
267 let mut q = n >> self.p;
268 while q > 0 {
269 let nbits = cmp::min(q, 64);
270 wrote += writer.write(!0u64, nbits as u8)?;
271 q -= nbits;
272 }
273 wrote += writer.write(0, 1)?;
274 wrote += writer.write(n, self.p)?;
275 Ok(wrote)
276 }
277
278 fn golomb_rice_decode(&self, reader: &mut BitStreamReader) -> Result<u64, io::Error> {
280 let mut q = 0u64;
281 while reader.read(1)? == 1 {
282 q += 1;
283 }
284 let r = reader.read(self.p)?;
285 Ok((q << self.p) + r)
286 }
287
288 fn hash(&self, element: &[u8]) -> u64 {
290 let mut hasher = self.hasher_builder.build_hasher();
291 hasher.write(element);
292 hasher.finish()
293 }
294}
295
296pub struct BitStreamReader<'a> {
298 buffer: [u8; 1],
299 offset: u8,
300 reader: &'a mut dyn io::Read,
301}
302
303impl<'a> BitStreamReader<'a> {
304 pub fn new(reader: &'a mut dyn io::Read) -> BitStreamReader {
306 BitStreamReader {
307 buffer: [0u8],
308 reader,
309 offset: 8,
310 }
311 }
312
313 pub fn read(&mut self, mut nbits: u8) -> Result<u64, io::Error> {
315 if nbits > 64 {
316 return Err(io::Error::new(
317 io::ErrorKind::Other,
318 "can not read more than 64 bits at once",
319 ));
320 }
321 let mut data = 0u64;
322 while nbits > 0 {
323 if self.offset == 8 {
324 self.reader.read_exact(&mut self.buffer)?;
325 self.offset = 0;
326 }
327 let bits = cmp::min(8 - self.offset, nbits);
328 data <<= bits;
329 data |= ((self.buffer[0] << self.offset) >> (8 - bits)) as u64;
330 self.offset += bits;
331 nbits -= bits;
332 }
333 Ok(data)
334 }
335}
336
337pub struct BitStreamWriter<'a> {
339 buffer: [u8; 1],
340 offset: u8,
341 writer: &'a mut dyn io::Write,
342}
343
344impl<'a> BitStreamWriter<'a> {
345 pub fn new(writer: &'a mut dyn io::Write) -> BitStreamWriter {
347 BitStreamWriter {
348 buffer: [0u8],
349 writer,
350 offset: 0,
351 }
352 }
353
354 pub fn write(&mut self, data: u64, mut nbits: u8) -> Result<usize, io::Error> {
356 if nbits > 64 {
357 return Err(io::Error::new(
358 io::ErrorKind::Other,
359 "can not write more than 64 bits at once",
360 ));
361 }
362 let mut wrote = 0;
363 while nbits > 0 {
364 let bits = cmp::min(8 - self.offset, nbits);
365 self.buffer[0] |= ((data << (64 - nbits)) >> (64 - 8 + self.offset)) as u8;
366 self.offset += bits;
367 nbits -= bits;
368 if self.offset == 8 {
369 wrote += self.flush()?;
370 }
371 }
372 Ok(wrote)
373 }
374
375 pub fn flush(&mut self) -> Result<usize, io::Error> {
377 if self.offset > 0 {
378 self.writer.write_all(&self.buffer)?;
379 self.buffer[0] = 0u8;
380 self.offset = 0;
381 Ok(1)
382 } else {
383 Ok(0)
384 }
385 }
386}
387
388#[cfg(test)]
389mod test {
390 use super::*;
391
392 use std::collections::HashSet;
393 use std::io::Cursor;
394
395 #[test]
396 fn test_filter() {
397 let mut patterns = HashSet::new();
398
399 patterns.insert(hex::decode("000000").unwrap());
400 patterns.insert(hex::decode("111111").unwrap());
401 patterns.insert(hex::decode("222222").unwrap());
402 patterns.insert(hex::decode("333333").unwrap());
403 patterns.insert(hex::decode("444444").unwrap());
404 patterns.insert(hex::decode("555555").unwrap());
405 patterns.insert(hex::decode("666666").unwrap());
406 patterns.insert(hex::decode("777777").unwrap());
407 patterns.insert(hex::decode("888888").unwrap());
408 patterns.insert(hex::decode("999999").unwrap());
409 patterns.insert(hex::decode("aaaaaa").unwrap());
410 patterns.insert(hex::decode("bbbbbb").unwrap());
411 patterns.insert(hex::decode("cccccc").unwrap());
412 patterns.insert(hex::decode("dddddd").unwrap());
413 patterns.insert(hex::decode("eeeeee").unwrap());
414 patterns.insert(hex::decode("ffffff").unwrap());
415
416 let mut out = Cursor::new(Vec::new());
417 {
418 let mut writer = GCSFilterWriter::new(&mut out, SipHasher24Builder::new(0, 0), M, P);
419 for p in &patterns {
420 writer.add_element(p.as_slice());
421 }
422 writer.finish().unwrap();
423 }
424
425 let bytes = out.into_inner();
426
427 {
428 let query = vec![
429 hex::decode("abcdef").unwrap(),
430 hex::decode("eeeeee").unwrap(),
431 ];
432 let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
433 let mut input = Cursor::new(bytes.clone());
434 assert!(reader
435 .match_any(&mut input, &mut query.iter().map(|v| v.as_slice()))
436 .unwrap());
437 }
438 {
439 let query = vec![
440 hex::decode("abcdef").unwrap(),
441 hex::decode("123456").unwrap(),
442 ];
443 let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
444 let mut input = Cursor::new(bytes.clone());
445 assert!(!reader
446 .match_any(&mut input, &mut query.iter().map(|v| v.as_slice()))
447 .unwrap());
448 }
449 {
450 let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
451 let mut query = Vec::new();
452 for p in &patterns {
453 query.push(p.clone());
454 }
455 let mut input = Cursor::new(bytes.clone());
456 assert!(reader
457 .match_all(&mut input, &mut query.iter().map(|v| v.as_slice()))
458 .unwrap());
459 }
460 {
461 let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
462 let mut query = Vec::new();
463 for p in &patterns {
464 query.push(p.clone());
465 }
466 query.push(hex::decode("abcdef").unwrap());
467 let mut input = Cursor::new(bytes.clone());
468 assert!(!reader
469 .match_all(&mut input, &mut query.iter().map(|v| v.as_slice()))
470 .unwrap());
471 }
472 {
473 let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
475 let mut input = Cursor::new(bytes.clone());
476 let query: Vec<Vec<u8>> = Vec::new();
477 assert!(!reader
478 .match_any(&mut input, &mut query.iter().map(|v| v.as_slice()))
479 .unwrap());
480 }
481 {
482 let reader = GCSFilterReader::new(SipHasher24Builder::new(0, 0), M, P);
484 let mut input = Cursor::new(bytes);
485 let query: Vec<Vec<u8>> = Vec::new();
486 assert!(!reader
487 .match_all(&mut input, &mut query.iter().map(|v| v.as_slice()))
488 .unwrap());
489 }
490 }
491
492 #[test]
493 fn test_bit_stream() {
494 let mut out = Cursor::new(Vec::new());
495 {
496 let mut writer = BitStreamWriter::new(&mut out);
497 writer.write(0, 1).unwrap(); writer.write(2, 2).unwrap(); writer.write(6, 3).unwrap(); writer.write(11, 4).unwrap(); writer.write(1, 5).unwrap(); writer.write(32, 6).unwrap(); writer.write(7, 7).unwrap(); writer.flush().unwrap();
505 }
506 let bytes = out.into_inner();
507 assert_eq!(
508 "01011010110000110000000001110000",
509 format!(
510 "{:08b}{:08b}{:08b}{:08b}",
511 bytes[0], bytes[1], bytes[2], bytes[3]
512 )
513 );
514 {
515 let mut input = Cursor::new(bytes);
516 let mut reader = BitStreamReader::new(&mut input);
517 assert_eq!(reader.read(1).unwrap(), 0);
518 assert_eq!(reader.read(2).unwrap(), 2);
519 assert_eq!(reader.read(3).unwrap(), 6);
520 assert_eq!(reader.read(4).unwrap(), 11);
521 assert_eq!(reader.read(5).unwrap(), 1);
522 assert_eq!(reader.read(6).unwrap(), 32);
523 assert_eq!(reader.read(7).unwrap(), 7);
524 assert!(reader.read(5).is_err());
526 }
527 }
528}