mih_rs/index/ops.rs
1use anyhow::{anyhow, Result};
2
3use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
4
5use crate::{hamdist, index::*, Index};
6
7impl<T: CodeInt> Index<T> {
8 /// Builds an index from binary codes.
9 /// The number of blocks for multi-index is set to the optimal one
10 /// estimated from the number of input codes.
11 /// The input database `codes` is stolen, but the reference can be gotten with [`Index::codes()`].
12 ///
13 /// # Arguments
14 ///
15 /// - `codes`: Vector of binary codes of type [`CodeInt`].
16 ///
17 /// # Errors
18 ///
19 /// `anyhow::Error` will be returned when
20 ///
21 /// - the `codes` is empty, or
22 /// - the number of entries in `codes` is more than `u32::max_value()`.
23 pub fn new(codes: Vec<T>) -> Result<Self> {
24 let num_codes = codes.len() as f64;
25 let dimensions = T::dimensions() as f64;
26
27 let blocks = (dimensions / num_codes.log2()).round() as usize;
28 if blocks < 2 {
29 Self::with_blocks(codes, 2)
30 } else {
31 Self::with_blocks(codes, blocks)
32 }
33 }
34
35 /// Builds an index from binary codes with a manually specified number of blocks.
36 /// The input database `codes` is stolen, but the reference can be gotten with [`Index::codes()`].
37 ///
38 /// # Arguments
39 ///
40 /// - `codes`: Vector of binary codes of type [`CodeInt`].
41 /// - `num_blocks`: The number of blocks for multi-index.
42 ///
43 /// # Errors
44 ///
45 /// `anyhow::Error` will be returned when
46 ///
47 /// - the `codes` is empty,
48 /// - the number of entries in `codes` is more than `u32::max_value()`, or
49 /// - `num_blocks` is less than 2 or more than the number of dimensions in a binary code.
50 pub fn with_blocks(codes: Vec<T>, num_blocks: usize) -> Result<Self> {
51 if codes.is_empty() {
52 return Err(anyhow!("The input codes must not be empty"));
53 }
54
55 if (u32::max_value() as usize) < codes.len() {
56 return Err(anyhow!(
57 "The number of codes {} must not be no more than {}.",
58 codes.len(),
59 u32::max_value()
60 ));
61 }
62
63 let num_dimensions = T::dimensions();
64 if num_blocks < 2 || num_dimensions < num_blocks {
65 return Err(anyhow!(
66 "The number of blocks {} must not be in [2,{}]",
67 num_blocks,
68 num_dimensions
69 ));
70 }
71
72 let mut masks = vec![T::default(); num_blocks];
73 let mut begs = vec![0; num_blocks + 1];
74
75 for b in 0..num_blocks {
76 let dim = (b + num_dimensions) / num_blocks;
77 if 64 == dim {
78 masks[b] = T::from_u64(u64::max_value()).unwrap();
79 } else {
80 masks[b] = T::from_u64((1 << dim) - 1).unwrap();
81 }
82 begs[b + 1] = begs[b] + dim;
83 }
84
85 let mut tables = Vec::<sparsehash::Table>::with_capacity(num_blocks);
86
87 for b in 0..num_blocks {
88 let beg = begs[b];
89 let dim = begs[b + 1] - begs[b];
90
91 let mut table = sparsehash::Table::new(dim)?;
92
93 for &code in &codes {
94 let chunk = (code >> beg) & masks[b];
95 table.count_insert(chunk.to_u64().unwrap() as usize);
96 }
97
98 for (id, &code) in codes.iter().enumerate() {
99 let chunk = (code >> beg) & masks[b];
100 table.data_insert(chunk.to_u64().unwrap() as usize, id as u32);
101 }
102
103 tables.push(table);
104 }
105
106 Ok(Self {
107 num_blocks,
108 codes,
109 tables,
110 masks,
111 begs,
112 })
113 }
114
115 /// Returns a searcher [`RangeSearcher`] to find neighbor codes
116 /// whose Hamming distances to a query code are within a query radius.
117 ///
118 /// # Examples
119 ///
120 /// ```
121 /// use mih_rs::Index;
122 ///
123 /// let codes: Vec<u64> = vec![
124 /// 0b1111111111111111111111011111111111111111111111111011101111111111, // #zeros = 3
125 /// 0b1111111111111111111111111111111101111111111011111111111111111111, // #zeros = 2
126 /// 0b1111111011011101111111111111111101111111111111111111111111111111, // #zeros = 4
127 /// 0b1111111111111101111111111111111111111000111111111110001111111110, // #zeros = 8
128 /// 0b1101111111111111111111111111111111111111111111111111111111111111, // #zeros = 1
129 /// 0b1111111111111111101111111011111111111111111101001110111111111111, // #zeros = 6
130 /// 0b1111111111111111111111111111111111101111111111111111011111111111, // #zeros = 2
131 /// 0b1110110101011011011111111111111101111111111111111000011111111111, // #zeros = 11
132 /// ];
133 ///
134 /// let index = Index::new(codes).unwrap();
135 /// let mut searcher = index.range_searcher();
136 ///
137 /// let qcode: u64 = 0b1111111111111111111111111111111111111111111111111111111111111111; // #zeros = 0
138 /// let answers = searcher.run(qcode, 2);
139 /// assert_eq!(answers, vec![1, 4, 6]);
140 /// ```
141 pub fn range_searcher(&self) -> RangeSearcher<T> {
142 RangeSearcher {
143 index: self,
144 siggen: siggen::SigGenerator64::new(),
145 answers: Vec::with_capacity(1 << 10),
146 }
147 }
148
149 /// Returns a searcher [`TopkSearcher`] to find top-K codes that are closest to a query code.
150 ///
151 /// # Examples
152 ///
153 /// ```
154 /// use mih_rs::Index;
155 ///
156 /// let codes: Vec<u64> = vec![
157 /// 0b1111111111111111111111011111111111111111111111111011101111111111, // #zeros = 3
158 /// 0b1111111111111111111111111111111101111111111011111111111111111111, // #zeros = 2
159 /// 0b1111111011011101111111111111111101111111111111111111111111111111, // #zeros = 4
160 /// 0b1111111111111101111111111111111111111000111111111110001111111110, // #zeros = 8
161 /// 0b1101111111111111111111111111111111111111111111111111111111111111, // #zeros = 1
162 /// 0b1111111111111111101111111011111111111111111101001110111111111111, // #zeros = 6
163 /// 0b1111111111111111111111111111111111101111111111111111011111111111, // #zeros = 2
164 /// 0b1110110101011011011111111111111101111111111111111000011111111111, // #zeros = 11
165 /// ];
166 ///
167 /// let index = Index::new(codes).unwrap();
168 /// let mut searcher = index.topk_searcher();
169 ///
170 /// let qcode: u64 = 0b1111111111111111111111111111111111111111111111111111111111111111; // #zeros = 0
171 /// let answers = searcher.run(qcode, 4);
172 /// assert_eq!(answers, vec![4, 1, 6, 0]);
173 /// ```
174 pub fn topk_searcher(&self) -> TopkSearcher<T> {
175 TopkSearcher {
176 index: self,
177 siggen: siggen::SigGenerator64::new(),
178 answers: Vec::with_capacity(1 << 10),
179 checked: std::collections::HashSet::new(),
180 }
181 }
182
183 /// Gets the reference of the input database.
184 ///
185 /// # Examples
186 ///
187 /// ```
188 /// use mih_rs::Index;
189 ///
190 /// let codes: Vec<u64> = vec![
191 /// 0b1111111111111111111111011111111111111111111111111011101111111111, // #zeros = 3
192 /// 0b1111111111111111111111111111111101111111111011111111111111111111, // #zeros = 2
193 /// 0b1111111011011101111111111111111101111111111111111111111111111111, // #zeros = 4
194 /// 0b1111111111111101111111111111111111111000111111111110001111111110, // #zeros = 8
195 /// 0b1101111111111111111111111111111111111111111111111111111111111111, // #zeros = 1
196 /// 0b1111111111111111101111111011111111111111111101001110111111111111, // #zeros = 6
197 /// 0b1111111111111111111111111111111111101111111111111111011111111111, // #zeros = 2
198 /// 0b1110110101011011011111111111111101111111111111111000011111111111, // #zeros = 11
199 /// ];
200 ///
201 /// let index = Index::new(codes.clone()).unwrap();
202 /// assert_eq!(codes, index.codes());
203 /// ```
204 pub fn codes(&self) -> &[T] {
205 &self.codes
206 }
207
208 /// Gets the number of defined blocks in multi-index.
209 pub fn num_blocks(&self) -> usize {
210 self.num_blocks
211 }
212
213 /// Serializes the index into the file.
214 pub fn serialize_into<W: std::io::Write>(&self, mut writer: W) -> Result<()> {
215 writer.write_u64::<LittleEndian>(self.num_blocks as u64)?;
216 writer.write_u64::<LittleEndian>(self.codes.len() as u64)?;
217 for x in &self.codes {
218 x.serialize_into(&mut writer)?;
219 }
220 writer.write_u64::<LittleEndian>(self.tables.len() as u64)?;
221 for x in &self.tables {
222 x.serialize_into(&mut writer)?;
223 }
224 writer.write_u64::<LittleEndian>(self.masks.len() as u64)?;
225 for x in &self.masks {
226 x.serialize_into(&mut writer)?;
227 }
228 writer.write_u64::<LittleEndian>(self.begs.len() as u64)?;
229 for &x in &self.begs {
230 writer.write_u64::<LittleEndian>(x as u64)?;
231 }
232 Ok(())
233 }
234
235 /// Deserializes the index from the file.
236 pub fn deserialize_from<R: std::io::Read>(mut reader: R) -> Result<Self> {
237 let num_blocks = reader.read_u64::<LittleEndian>()? as usize;
238 let codes = {
239 let len = reader.read_u64::<LittleEndian>()? as usize;
240 let mut codes = Vec::with_capacity(len);
241 for _ in 0..len {
242 codes.push(T::deserialize_from(&mut reader)?);
243 }
244 codes
245 };
246 let tables = {
247 let len = reader.read_u64::<LittleEndian>()? as usize;
248 let mut tables = Vec::with_capacity(len);
249 for _ in 0..len {
250 tables.push(sparsehash::Table::deserialize_from(&mut reader)?);
251 }
252 tables
253 };
254 let masks = {
255 let len = reader.read_u64::<LittleEndian>()? as usize;
256 let mut masks = Vec::with_capacity(len);
257 for _ in 0..len {
258 masks.push(T::deserialize_from(&mut reader)?);
259 }
260 masks
261 };
262 let begs = {
263 let len = reader.read_u64::<LittleEndian>()? as usize;
264 let mut begs = Vec::with_capacity(len);
265 for _ in 0..len {
266 begs.push(reader.read_u64::<LittleEndian>()? as usize);
267 }
268 begs
269 };
270 Ok(Self {
271 num_blocks,
272 codes,
273 tables,
274 masks,
275 begs,
276 })
277 }
278
279 fn get_dim(&self, b: usize) -> usize {
280 self.begs[b + 1] - self.begs[b]
281 }
282
283 fn get_chunk(&self, code: T, b: usize) -> u64 {
284 let chunk = (code >> self.begs[b]) & self.masks[b];
285 chunk.to_u64().unwrap()
286 }
287}
288
289impl<'a, T> RangeSearcher<'a, T>
290where
291 T: CodeInt,
292{
293 /// Searches neighbor codes whose Hamming distances to a query code are within a query radius.
294 ///
295 /// # Arguments
296 ///
297 /// - `qcode`: Binary code of the query.
298 /// - `radius`: Threshold to be searched.
299 ///
300 /// # Returns
301 ///
302 /// A slice of ids of codes whose Hamming distances to `qcode` are within `radius`.
303 /// The ids are sorted.
304 /// Note that the values of the slice will be updated in the next [`RangeSearcher::run()`].
305 ///
306 /// # Examples
307 ///
308 /// ```
309 /// use mih_rs::Index;
310 ///
311 /// let codes: Vec<u64> = vec![
312 /// 0b1111111111111111111111011111111111111111111111111011101111111111, // #zeros = 3
313 /// 0b1111111111111111111111111111111101111111111011111111111111111111, // #zeros = 2
314 /// 0b1111111011011101111111111111111101111111111111111111111111111111, // #zeros = 4
315 /// 0b1111111111111101111111111111111111111000111111111110001111111110, // #zeros = 8
316 /// 0b1101111111111111111111111111111111111111111111111111111111111111, // #zeros = 1
317 /// 0b1111111111111111101111111011111111111111111101001110111111111111, // #zeros = 6
318 /// 0b1111111111111111111111111111111111101111111111111111011111111111, // #zeros = 2
319 /// 0b1110110101011011011111111111111101111111111111111000011111111111, // #zeros = 11
320 /// ];
321 ///
322 /// let index = Index::new(codes).unwrap();
323 /// let mut searcher = index.range_searcher();
324 ///
325 /// let qcode: u64 = 0b1111111111111111111111111111111111111111111111111111111111111111; // #zeros = 0
326 /// let answers = searcher.run(qcode, 2);
327 /// assert_eq!(answers, vec![1, 4, 6]);
328 /// ```
329 pub fn run(&mut self, qcode: T, radius: usize) -> &[u32] {
330 self.answers.clear();
331 let num_blocks = self.index.num_blocks();
332
333 for b in 0..num_blocks {
334 // Based on the general pigeonhole principle
335 if b + radius + 1 < num_blocks {
336 continue;
337 }
338
339 let rad = (b + radius + 1 - num_blocks) / num_blocks;
340 let dim = self.index.get_dim(b);
341 let qcd = self.index.get_chunk(qcode, b);
342
343 let table = &self.index.tables[b];
344
345 // Search with r errors
346 for r in 0..rad + 1 {
347 self.siggen.init(qcd, dim, r);
348 while self.siggen.has_next() {
349 let sig = self.siggen.next();
350 if let Some(a) = table.access(sig as usize) {
351 for v in a {
352 self.answers.push(*v as u32);
353 }
354 }
355 }
356 }
357 }
358
359 let mut n = 0;
360 if !self.answers.is_empty() {
361 self.answers.sort_unstable();
362 for i in 0..self.answers.len() {
363 if i == 0 || self.answers[i - 1] != self.answers[i] {
364 let dist = hamdist(qcode, self.index.codes[self.answers[i] as usize]);
365 if dist <= radius {
366 self.answers[n] = self.answers[i];
367 n += 1;
368 }
369 }
370 }
371 }
372
373 self.answers.resize(n, u32::default());
374 &self.answers
375 }
376}
377
378impl<'a, T> TopkSearcher<'a, T>
379where
380 T: CodeInt,
381{
382 /// Searches top-K codes that are closest to a query code.
383 ///
384 /// # Arguments
385 ///
386 /// - `qcode`: Binary code of the query.
387 /// - `topk`: Threshold to be searched.
388 ///
389 /// # Returns
390 ///
391 /// A slice of ids of the `topk` nearest neighbor codes to `qcode`.
392 /// The ids are sorted in the Hamming distances to `qcode`.
393 /// Note that the values of the slice will be updated in the next [`TopkSearcher::run()`].
394 ///
395 /// # Examples
396 ///
397 /// ```
398 /// use mih_rs::Index;
399 ///
400 /// let codes: Vec<u64> = vec![
401 /// 0b1111111111111111111111011111111111111111111111111011101111111111, // #zeros = 3
402 /// 0b1111111111111111111111111111111101111111111011111111111111111111, // #zeros = 2
403 /// 0b1111111011011101111111111111111101111111111111111111111111111111, // #zeros = 4
404 /// 0b1111111111111101111111111111111111111000111111111110001111111110, // #zeros = 8
405 /// 0b1101111111111111111111111111111111111111111111111111111111111111, // #zeros = 1
406 /// 0b1111111111111111101111111011111111111111111101001110111111111111, // #zeros = 6
407 /// 0b1111111111111111111111111111111111101111111111111111011111111111, // #zeros = 2
408 /// 0b1110110101011011011111111111111101111111111111111000011111111111, // #zeros = 11
409 /// ];
410 ///
411 /// let index = Index::new(codes).unwrap();
412 /// let mut searcher = index.topk_searcher();
413 ///
414 /// let qcode: u64 = 0b1111111111111111111111111111111111111111111111111111111111111111; // #zeros = 0
415 /// let answers = searcher.run(qcode, 4);
416 /// assert_eq!(answers, vec![4, 1, 6, 0]);
417 /// ```
418 pub fn run(&mut self, qcode: T, topk: usize) -> &[u32] {
419 let num_blocks = self.index.num_blocks();
420 let num_dimensions = T::dimensions();
421
422 let mut n = 0;
423 let mut r = 0;
424
425 let mut counts = vec![0; num_dimensions + 1];
426
427 self.answers
428 .resize((num_dimensions + 1) * topk, u32::default());
429 self.checked.clear();
430
431 while n < topk {
432 for b in 0..num_blocks {
433 let dim = self.index.get_dim(b);
434 let qcd = self.index.get_chunk(qcode, b);
435 let table = &self.index.tables[b];
436
437 self.siggen.init(qcd, dim, r);
438 while self.siggen.has_next() {
439 let sig = self.siggen.next();
440 if let Some(a) = table.access(sig as usize) {
441 for &v in a {
442 let id = v as usize;
443 if self.checked.insert(id) {
444 let dist = hamdist(qcode, self.index.codes[id]);
445 if counts[dist] < topk {
446 self.answers[dist * topk + counts[dist]] = id as u32;
447 }
448 counts[dist] += 1;
449 }
450 }
451 }
452 }
453
454 n += counts[r * num_blocks + b];
455 if topk <= n {
456 break;
457 }
458 }
459
460 r += 1;
461 }
462
463 n = 0;
464 r = 0;
465 while n < topk {
466 let mut i = 0;
467 while i < counts[r] && n < topk {
468 self.answers[n] = self.answers[r * topk + i];
469 i += 1;
470 n += 1;
471 }
472 r += 1;
473 }
474
475 self.answers.resize(topk, u32::default());
476 &self.answers
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use crate::ls;
484
485 use rand::distributions::{Distribution, Standard};
486 use rand::{thread_rng, Rng};
487
488 use std::collections::BTreeSet;
489
490 pub fn gen_random_codes<T>(size: usize) -> Vec<T>
491 where
492 Standard: Distribution<T>,
493 {
494 let mut rng = thread_rng();
495 let mut codes: Vec<T> = Vec::with_capacity(size);
496 for _ in 0..size {
497 codes.push(rng.gen::<T>());
498 }
499 codes
500 }
501
502 fn naive_topk_search<T: CodeInt>(codes: &[T], qcode: T, topk: usize) -> Vec<u32> {
503 let mut cands = ls::exhaustive_search(codes, qcode);
504 cands.sort_by_key(|x| x.1);
505
506 let max_dist = cands[topk - 1].1;
507
508 let mut i = 0;
509 let mut answers = Vec::new();
510
511 while cands[i].1 <= max_dist {
512 answers.push(cands[i].0);
513 i += 1;
514 }
515 answers
516 }
517
518 fn do_range_search<T: CodeInt>(codes: Vec<T>) {
519 let index = Index::new(codes).unwrap();
520 let mut searcher = index.range_searcher();
521
522 for rad in 0..6 {
523 for qi in (0..10000).step_by(100) {
524 let qcode = index.codes()[qi];
525 let ans1 = ls::range_search(index.codes(), qcode, rad);
526 let ans2 = searcher.run(qcode, rad);
527 assert_eq!(ans1, ans2);
528 }
529 }
530 }
531
532 fn do_topk_search<T: CodeInt>(codes: Vec<T>) {
533 let index = Index::new(codes).unwrap();
534 let mut searcher = index.topk_searcher();
535
536 for topk in &[1, 10, 100] {
537 for qi in (0..10000).step_by(100) {
538 let qcode = index.codes()[qi];
539 let ans1 = naive_topk_search(index.codes(), qcode, *topk);
540 let ans2 = searcher.run(qcode, *topk);
541 let set1: BTreeSet<u32> = ans1.into_iter().collect();
542 let set2: BTreeSet<u32> = ans2.into_iter().cloned().collect();
543 assert_eq!(set2.is_subset(&set1), true);
544 }
545 }
546 }
547
548 #[test]
549 fn range_search_u8_works() {
550 let codes = gen_random_codes::<u8>(10000);
551 do_range_search(codes);
552 }
553
554 #[test]
555 fn range_search_u16_works() {
556 let codes = gen_random_codes::<u16>(10000);
557 do_range_search(codes);
558 }
559
560 #[test]
561 fn range_search_u32_works() {
562 let codes = gen_random_codes::<u32>(10000);
563 do_range_search(codes);
564 }
565
566 #[test]
567 fn range_search_u64_works() {
568 let codes = gen_random_codes::<u64>(10000);
569 do_range_search(codes);
570 }
571
572 #[test]
573 fn topk_search_u8_works() {
574 let codes = gen_random_codes::<u8>(10000);
575 do_topk_search(codes);
576 }
577
578 #[test]
579 fn topk_search_u16_works() {
580 let codes = gen_random_codes::<u16>(10000);
581 do_topk_search(codes);
582 }
583
584 #[test]
585 fn topk_search_u32_works() {
586 let codes = gen_random_codes::<u32>(10000);
587 do_topk_search(codes);
588 }
589
590 #[test]
591 fn topk_search_u64_works() {
592 let codes = gen_random_codes::<u64>(10000);
593 do_topk_search(codes);
594 }
595
596 #[test]
597 fn serialize_u8_works() {
598 let codes = gen_random_codes::<u8>(10000);
599 let index = Index::new(codes).unwrap();
600
601 let mut data = vec![];
602 index.serialize_into(&mut data).unwrap();
603 let other = Index::<u8>::deserialize_from(&data[..]).unwrap();
604
605 assert_eq!(index, other);
606 }
607
608 #[test]
609 fn serialize_u16_works() {
610 let codes = gen_random_codes::<u16>(10000);
611 let index = Index::new(codes).unwrap();
612
613 let mut data = vec![];
614 index.serialize_into(&mut data).unwrap();
615 let other = Index::<u16>::deserialize_from(&data[..]).unwrap();
616
617 assert_eq!(index, other);
618 }
619
620 #[test]
621 fn serialize_u32_works() {
622 let codes = gen_random_codes::<u32>(10000);
623 let index = Index::new(codes).unwrap();
624
625 let mut data = vec![];
626 index.serialize_into(&mut data).unwrap();
627 let other = Index::<u32>::deserialize_from(&data[..]).unwrap();
628
629 assert_eq!(index, other);
630 }
631
632 #[test]
633 fn serialize_u64_works() {
634 let codes = gen_random_codes::<u64>(10000);
635 let index = Index::new(codes).unwrap();
636
637 let mut data = vec![];
638 index.serialize_into(&mut data).unwrap();
639 let other = Index::<u64>::deserialize_from(&data[..]).unwrap();
640
641 assert_eq!(index, other);
642 }
643}