mih_rs/index/
ops.rs

1use anyhow::{anyhow, Result};
2
3use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
4
5use crate::{hamdist, index::*, Index};
6
7impl<T: CodeInt> Index<T> {
8    /// Builds an index from binary codes.
9    /// The number of blocks for multi-index is set to the optimal one
10    /// estimated from the number of input codes.
11    /// The input database `codes` is stolen, but the reference can be gotten with [`Index::codes()`].
12    ///
13    /// # Arguments
14    ///
15    /// - `codes`: Vector of binary codes of type [`CodeInt`].
16    ///
17    /// # Errors
18    ///
19    /// `anyhow::Error` will be returned when
20    ///
21    ///  - the `codes` is empty, or
22    ///  - the number of entries in `codes` is more than `u32::max_value()`.
23    pub fn new(codes: Vec<T>) -> Result<Self> {
24        let num_codes = codes.len() as f64;
25        let dimensions = T::dimensions() as f64;
26
27        let blocks = (dimensions / num_codes.log2()).round() as usize;
28        if blocks < 2 {
29            Self::with_blocks(codes, 2)
30        } else {
31            Self::with_blocks(codes, blocks)
32        }
33    }
34
35    /// Builds an index from binary codes with a manually specified number of blocks.
36    /// The input database `codes` is stolen, but the reference can be gotten with [`Index::codes()`].
37    ///
38    /// # Arguments
39    ///
40    /// - `codes`: Vector of binary codes of type [`CodeInt`].
41    /// - `num_blocks`: The number of blocks for multi-index.
42    ///
43    /// # Errors
44    ///
45    /// `anyhow::Error` will be returned when
46    ///
47    ///  - the `codes` is empty,
48    ///  - the number of entries in `codes` is more than `u32::max_value()`, or
49    ///  - `num_blocks` is less than 2 or more than the number of dimensions in a binary code.
50    pub fn with_blocks(codes: Vec<T>, num_blocks: usize) -> Result<Self> {
51        if codes.is_empty() {
52            return Err(anyhow!("The input codes must not be empty"));
53        }
54
55        if (u32::max_value() as usize) < codes.len() {
56            return Err(anyhow!(
57                "The number of codes {} must not be no more than {}.",
58                codes.len(),
59                u32::max_value()
60            ));
61        }
62
63        let num_dimensions = T::dimensions();
64        if num_blocks < 2 || num_dimensions < num_blocks {
65            return Err(anyhow!(
66                "The number of blocks {} must not be in [2,{}]",
67                num_blocks,
68                num_dimensions
69            ));
70        }
71
72        let mut masks = vec![T::default(); num_blocks];
73        let mut begs = vec![0; num_blocks + 1];
74
75        for b in 0..num_blocks {
76            let dim = (b + num_dimensions) / num_blocks;
77            if 64 == dim {
78                masks[b] = T::from_u64(u64::max_value()).unwrap();
79            } else {
80                masks[b] = T::from_u64((1 << dim) - 1).unwrap();
81            }
82            begs[b + 1] = begs[b] + dim;
83        }
84
85        let mut tables = Vec::<sparsehash::Table>::with_capacity(num_blocks);
86
87        for b in 0..num_blocks {
88            let beg = begs[b];
89            let dim = begs[b + 1] - begs[b];
90
91            let mut table = sparsehash::Table::new(dim)?;
92
93            for &code in &codes {
94                let chunk = (code >> beg) & masks[b];
95                table.count_insert(chunk.to_u64().unwrap() as usize);
96            }
97
98            for (id, &code) in codes.iter().enumerate() {
99                let chunk = (code >> beg) & masks[b];
100                table.data_insert(chunk.to_u64().unwrap() as usize, id as u32);
101            }
102
103            tables.push(table);
104        }
105
106        Ok(Self {
107            num_blocks,
108            codes,
109            tables,
110            masks,
111            begs,
112        })
113    }
114
115    /// Returns a searcher [`RangeSearcher`] to find neighbor codes
116    /// whose Hamming distances to a query code are within a query radius.
117    ///
118    /// # Examples
119    ///
120    /// ```
121    /// use mih_rs::Index;
122    ///
123    /// let codes: Vec<u64> = vec![
124    ///     0b1111111111111111111111011111111111111111111111111011101111111111, // #zeros = 3
125    ///     0b1111111111111111111111111111111101111111111011111111111111111111, // #zeros = 2
126    ///     0b1111111011011101111111111111111101111111111111111111111111111111, // #zeros = 4
127    ///     0b1111111111111101111111111111111111111000111111111110001111111110, // #zeros = 8
128    ///     0b1101111111111111111111111111111111111111111111111111111111111111, // #zeros = 1
129    ///     0b1111111111111111101111111011111111111111111101001110111111111111, // #zeros = 6
130    ///     0b1111111111111111111111111111111111101111111111111111011111111111, // #zeros = 2
131    ///     0b1110110101011011011111111111111101111111111111111000011111111111, // #zeros = 11
132    /// ];
133    ///
134    /// let index = Index::new(codes).unwrap();
135    /// let mut searcher = index.range_searcher();
136    ///
137    /// let qcode: u64 = 0b1111111111111111111111111111111111111111111111111111111111111111; // #zeros = 0
138    /// let answers = searcher.run(qcode, 2);
139    /// assert_eq!(answers, vec![1, 4, 6]);
140    /// ```
141    pub fn range_searcher(&self) -> RangeSearcher<T> {
142        RangeSearcher {
143            index: self,
144            siggen: siggen::SigGenerator64::new(),
145            answers: Vec::with_capacity(1 << 10),
146        }
147    }
148
149    /// Returns a searcher [`TopkSearcher`] to find top-K codes that are closest to a query code.
150    ///
151    /// # Examples
152    ///
153    /// ```
154    /// use mih_rs::Index;
155    ///
156    /// let codes: Vec<u64> = vec![
157    ///     0b1111111111111111111111011111111111111111111111111011101111111111, // #zeros = 3
158    ///     0b1111111111111111111111111111111101111111111011111111111111111111, // #zeros = 2
159    ///     0b1111111011011101111111111111111101111111111111111111111111111111, // #zeros = 4
160    ///     0b1111111111111101111111111111111111111000111111111110001111111110, // #zeros = 8
161    ///     0b1101111111111111111111111111111111111111111111111111111111111111, // #zeros = 1
162    ///     0b1111111111111111101111111011111111111111111101001110111111111111, // #zeros = 6
163    ///     0b1111111111111111111111111111111111101111111111111111011111111111, // #zeros = 2
164    ///     0b1110110101011011011111111111111101111111111111111000011111111111, // #zeros = 11
165    /// ];
166    ///
167    /// let index = Index::new(codes).unwrap();
168    /// let mut searcher = index.topk_searcher();
169    ///
170    /// let qcode: u64 = 0b1111111111111111111111111111111111111111111111111111111111111111; // #zeros = 0
171    /// let answers = searcher.run(qcode, 4);
172    /// assert_eq!(answers, vec![4, 1, 6, 0]);
173    /// ```
174    pub fn topk_searcher(&self) -> TopkSearcher<T> {
175        TopkSearcher {
176            index: self,
177            siggen: siggen::SigGenerator64::new(),
178            answers: Vec::with_capacity(1 << 10),
179            checked: std::collections::HashSet::new(),
180        }
181    }
182
183    /// Gets the reference of the input database.
184    ///
185    /// # Examples
186    ///
187    /// ```
188    /// use mih_rs::Index;
189    ///
190    /// let codes: Vec<u64> = vec![
191    ///     0b1111111111111111111111011111111111111111111111111011101111111111, // #zeros = 3
192    ///     0b1111111111111111111111111111111101111111111011111111111111111111, // #zeros = 2
193    ///     0b1111111011011101111111111111111101111111111111111111111111111111, // #zeros = 4
194    ///     0b1111111111111101111111111111111111111000111111111110001111111110, // #zeros = 8
195    ///     0b1101111111111111111111111111111111111111111111111111111111111111, // #zeros = 1
196    ///     0b1111111111111111101111111011111111111111111101001110111111111111, // #zeros = 6
197    ///     0b1111111111111111111111111111111111101111111111111111011111111111, // #zeros = 2
198    ///     0b1110110101011011011111111111111101111111111111111000011111111111, // #zeros = 11
199    /// ];
200    ///
201    /// let index = Index::new(codes.clone()).unwrap();
202    /// assert_eq!(codes, index.codes());
203    /// ```
204    pub fn codes(&self) -> &[T] {
205        &self.codes
206    }
207
208    /// Gets the number of defined blocks in multi-index.
209    pub fn num_blocks(&self) -> usize {
210        self.num_blocks
211    }
212
213    /// Serializes the index into the file.
214    pub fn serialize_into<W: std::io::Write>(&self, mut writer: W) -> Result<()> {
215        writer.write_u64::<LittleEndian>(self.num_blocks as u64)?;
216        writer.write_u64::<LittleEndian>(self.codes.len() as u64)?;
217        for x in &self.codes {
218            x.serialize_into(&mut writer)?;
219        }
220        writer.write_u64::<LittleEndian>(self.tables.len() as u64)?;
221        for x in &self.tables {
222            x.serialize_into(&mut writer)?;
223        }
224        writer.write_u64::<LittleEndian>(self.masks.len() as u64)?;
225        for x in &self.masks {
226            x.serialize_into(&mut writer)?;
227        }
228        writer.write_u64::<LittleEndian>(self.begs.len() as u64)?;
229        for &x in &self.begs {
230            writer.write_u64::<LittleEndian>(x as u64)?;
231        }
232        Ok(())
233    }
234
235    /// Deserializes the index from the file.
236    pub fn deserialize_from<R: std::io::Read>(mut reader: R) -> Result<Self> {
237        let num_blocks = reader.read_u64::<LittleEndian>()? as usize;
238        let codes = {
239            let len = reader.read_u64::<LittleEndian>()? as usize;
240            let mut codes = Vec::with_capacity(len);
241            for _ in 0..len {
242                codes.push(T::deserialize_from(&mut reader)?);
243            }
244            codes
245        };
246        let tables = {
247            let len = reader.read_u64::<LittleEndian>()? as usize;
248            let mut tables = Vec::with_capacity(len);
249            for _ in 0..len {
250                tables.push(sparsehash::Table::deserialize_from(&mut reader)?);
251            }
252            tables
253        };
254        let masks = {
255            let len = reader.read_u64::<LittleEndian>()? as usize;
256            let mut masks = Vec::with_capacity(len);
257            for _ in 0..len {
258                masks.push(T::deserialize_from(&mut reader)?);
259            }
260            masks
261        };
262        let begs = {
263            let len = reader.read_u64::<LittleEndian>()? as usize;
264            let mut begs = Vec::with_capacity(len);
265            for _ in 0..len {
266                begs.push(reader.read_u64::<LittleEndian>()? as usize);
267            }
268            begs
269        };
270        Ok(Self {
271            num_blocks,
272            codes,
273            tables,
274            masks,
275            begs,
276        })
277    }
278
279    fn get_dim(&self, b: usize) -> usize {
280        self.begs[b + 1] - self.begs[b]
281    }
282
283    fn get_chunk(&self, code: T, b: usize) -> u64 {
284        let chunk = (code >> self.begs[b]) & self.masks[b];
285        chunk.to_u64().unwrap()
286    }
287}
288
289impl<'a, T> RangeSearcher<'a, T>
290where
291    T: CodeInt,
292{
293    /// Searches neighbor codes whose Hamming distances to a query code are within a query radius.
294    ///
295    /// # Arguments
296    ///
297    /// - `qcode`: Binary code of the query.
298    /// - `radius`: Threshold to be searched.
299    ///
300    /// # Returns
301    ///
302    /// A slice of ids of codes whose Hamming distances to `qcode` are within `radius`.
303    /// The ids are sorted.
304    /// Note that the values of the slice will be updated in the next [`RangeSearcher::run()`].
305    ///
306    /// # Examples
307    ///
308    /// ```
309    /// use mih_rs::Index;
310    ///
311    /// let codes: Vec<u64> = vec![
312    ///     0b1111111111111111111111011111111111111111111111111011101111111111, // #zeros = 3
313    ///     0b1111111111111111111111111111111101111111111011111111111111111111, // #zeros = 2
314    ///     0b1111111011011101111111111111111101111111111111111111111111111111, // #zeros = 4
315    ///     0b1111111111111101111111111111111111111000111111111110001111111110, // #zeros = 8
316    ///     0b1101111111111111111111111111111111111111111111111111111111111111, // #zeros = 1
317    ///     0b1111111111111111101111111011111111111111111101001110111111111111, // #zeros = 6
318    ///     0b1111111111111111111111111111111111101111111111111111011111111111, // #zeros = 2
319    ///     0b1110110101011011011111111111111101111111111111111000011111111111, // #zeros = 11
320    /// ];
321    ///
322    /// let index = Index::new(codes).unwrap();
323    /// let mut searcher = index.range_searcher();
324    ///
325    /// let qcode: u64 = 0b1111111111111111111111111111111111111111111111111111111111111111; // #zeros = 0
326    /// let answers = searcher.run(qcode, 2);
327    /// assert_eq!(answers, vec![1, 4, 6]);
328    /// ```
329    pub fn run(&mut self, qcode: T, radius: usize) -> &[u32] {
330        self.answers.clear();
331        let num_blocks = self.index.num_blocks();
332
333        for b in 0..num_blocks {
334            // Based on the general pigeonhole principle
335            if b + radius + 1 < num_blocks {
336                continue;
337            }
338
339            let rad = (b + radius + 1 - num_blocks) / num_blocks;
340            let dim = self.index.get_dim(b);
341            let qcd = self.index.get_chunk(qcode, b);
342
343            let table = &self.index.tables[b];
344
345            // Search with r errors
346            for r in 0..rad + 1 {
347                self.siggen.init(qcd, dim, r);
348                while self.siggen.has_next() {
349                    let sig = self.siggen.next();
350                    if let Some(a) = table.access(sig as usize) {
351                        for v in a {
352                            self.answers.push(*v as u32);
353                        }
354                    }
355                }
356            }
357        }
358
359        let mut n = 0;
360        if !self.answers.is_empty() {
361            self.answers.sort_unstable();
362            for i in 0..self.answers.len() {
363                if i == 0 || self.answers[i - 1] != self.answers[i] {
364                    let dist = hamdist(qcode, self.index.codes[self.answers[i] as usize]);
365                    if dist <= radius {
366                        self.answers[n] = self.answers[i];
367                        n += 1;
368                    }
369                }
370            }
371        }
372
373        self.answers.resize(n, u32::default());
374        &self.answers
375    }
376}
377
378impl<'a, T> TopkSearcher<'a, T>
379where
380    T: CodeInt,
381{
382    /// Searches top-K codes that are closest to a query code.
383    ///
384    /// # Arguments
385    ///
386    /// - `qcode`: Binary code of the query.
387    /// - `topk`: Threshold to be searched.
388    ///
389    /// # Returns
390    ///
391    /// A slice of ids of the `topk` nearest neighbor codes to `qcode`.
392    /// The ids are sorted in the Hamming distances to `qcode`.
393    /// Note that the values of the slice will be updated in the next [`TopkSearcher::run()`].
394    ///
395    /// # Examples
396    ///
397    /// ```
398    /// use mih_rs::Index;
399    ///
400    /// let codes: Vec<u64> = vec![
401    ///     0b1111111111111111111111011111111111111111111111111011101111111111, // #zeros = 3
402    ///     0b1111111111111111111111111111111101111111111011111111111111111111, // #zeros = 2
403    ///     0b1111111011011101111111111111111101111111111111111111111111111111, // #zeros = 4
404    ///     0b1111111111111101111111111111111111111000111111111110001111111110, // #zeros = 8
405    ///     0b1101111111111111111111111111111111111111111111111111111111111111, // #zeros = 1
406    ///     0b1111111111111111101111111011111111111111111101001110111111111111, // #zeros = 6
407    ///     0b1111111111111111111111111111111111101111111111111111011111111111, // #zeros = 2
408    ///     0b1110110101011011011111111111111101111111111111111000011111111111, // #zeros = 11
409    /// ];
410    ///
411    /// let index = Index::new(codes).unwrap();
412    /// let mut searcher = index.topk_searcher();
413    ///
414    /// let qcode: u64 = 0b1111111111111111111111111111111111111111111111111111111111111111; // #zeros = 0
415    /// let answers = searcher.run(qcode, 4);
416    /// assert_eq!(answers, vec![4, 1, 6, 0]);
417    /// ```
418    pub fn run(&mut self, qcode: T, topk: usize) -> &[u32] {
419        let num_blocks = self.index.num_blocks();
420        let num_dimensions = T::dimensions();
421
422        let mut n = 0;
423        let mut r = 0;
424
425        let mut counts = vec![0; num_dimensions + 1];
426
427        self.answers
428            .resize((num_dimensions + 1) * topk, u32::default());
429        self.checked.clear();
430
431        while n < topk {
432            for b in 0..num_blocks {
433                let dim = self.index.get_dim(b);
434                let qcd = self.index.get_chunk(qcode, b);
435                let table = &self.index.tables[b];
436
437                self.siggen.init(qcd, dim, r);
438                while self.siggen.has_next() {
439                    let sig = self.siggen.next();
440                    if let Some(a) = table.access(sig as usize) {
441                        for &v in a {
442                            let id = v as usize;
443                            if self.checked.insert(id) {
444                                let dist = hamdist(qcode, self.index.codes[id]);
445                                if counts[dist] < topk {
446                                    self.answers[dist * topk + counts[dist]] = id as u32;
447                                }
448                                counts[dist] += 1;
449                            }
450                        }
451                    }
452                }
453
454                n += counts[r * num_blocks + b];
455                if topk <= n {
456                    break;
457                }
458            }
459
460            r += 1;
461        }
462
463        n = 0;
464        r = 0;
465        while n < topk {
466            let mut i = 0;
467            while i < counts[r] && n < topk {
468                self.answers[n] = self.answers[r * topk + i];
469                i += 1;
470                n += 1;
471            }
472            r += 1;
473        }
474
475        self.answers.resize(topk, u32::default());
476        &self.answers
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use crate::ls;
484
485    use rand::distributions::{Distribution, Standard};
486    use rand::{thread_rng, Rng};
487
488    use std::collections::BTreeSet;
489
490    pub fn gen_random_codes<T>(size: usize) -> Vec<T>
491    where
492        Standard: Distribution<T>,
493    {
494        let mut rng = thread_rng();
495        let mut codes: Vec<T> = Vec::with_capacity(size);
496        for _ in 0..size {
497            codes.push(rng.gen::<T>());
498        }
499        codes
500    }
501
502    fn naive_topk_search<T: CodeInt>(codes: &[T], qcode: T, topk: usize) -> Vec<u32> {
503        let mut cands = ls::exhaustive_search(codes, qcode);
504        cands.sort_by_key(|x| x.1);
505
506        let max_dist = cands[topk - 1].1;
507
508        let mut i = 0;
509        let mut answers = Vec::new();
510
511        while cands[i].1 <= max_dist {
512            answers.push(cands[i].0);
513            i += 1;
514        }
515        answers
516    }
517
518    fn do_range_search<T: CodeInt>(codes: Vec<T>) {
519        let index = Index::new(codes).unwrap();
520        let mut searcher = index.range_searcher();
521
522        for rad in 0..6 {
523            for qi in (0..10000).step_by(100) {
524                let qcode = index.codes()[qi];
525                let ans1 = ls::range_search(index.codes(), qcode, rad);
526                let ans2 = searcher.run(qcode, rad);
527                assert_eq!(ans1, ans2);
528            }
529        }
530    }
531
532    fn do_topk_search<T: CodeInt>(codes: Vec<T>) {
533        let index = Index::new(codes).unwrap();
534        let mut searcher = index.topk_searcher();
535
536        for topk in &[1, 10, 100] {
537            for qi in (0..10000).step_by(100) {
538                let qcode = index.codes()[qi];
539                let ans1 = naive_topk_search(index.codes(), qcode, *topk);
540                let ans2 = searcher.run(qcode, *topk);
541                let set1: BTreeSet<u32> = ans1.into_iter().collect();
542                let set2: BTreeSet<u32> = ans2.into_iter().cloned().collect();
543                assert_eq!(set2.is_subset(&set1), true);
544            }
545        }
546    }
547
548    #[test]
549    fn range_search_u8_works() {
550        let codes = gen_random_codes::<u8>(10000);
551        do_range_search(codes);
552    }
553
554    #[test]
555    fn range_search_u16_works() {
556        let codes = gen_random_codes::<u16>(10000);
557        do_range_search(codes);
558    }
559
560    #[test]
561    fn range_search_u32_works() {
562        let codes = gen_random_codes::<u32>(10000);
563        do_range_search(codes);
564    }
565
566    #[test]
567    fn range_search_u64_works() {
568        let codes = gen_random_codes::<u64>(10000);
569        do_range_search(codes);
570    }
571
572    #[test]
573    fn topk_search_u8_works() {
574        let codes = gen_random_codes::<u8>(10000);
575        do_topk_search(codes);
576    }
577
578    #[test]
579    fn topk_search_u16_works() {
580        let codes = gen_random_codes::<u16>(10000);
581        do_topk_search(codes);
582    }
583
584    #[test]
585    fn topk_search_u32_works() {
586        let codes = gen_random_codes::<u32>(10000);
587        do_topk_search(codes);
588    }
589
590    #[test]
591    fn topk_search_u64_works() {
592        let codes = gen_random_codes::<u64>(10000);
593        do_topk_search(codes);
594    }
595
596    #[test]
597    fn serialize_u8_works() {
598        let codes = gen_random_codes::<u8>(10000);
599        let index = Index::new(codes).unwrap();
600
601        let mut data = vec![];
602        index.serialize_into(&mut data).unwrap();
603        let other = Index::<u8>::deserialize_from(&data[..]).unwrap();
604
605        assert_eq!(index, other);
606    }
607
608    #[test]
609    fn serialize_u16_works() {
610        let codes = gen_random_codes::<u16>(10000);
611        let index = Index::new(codes).unwrap();
612
613        let mut data = vec![];
614        index.serialize_into(&mut data).unwrap();
615        let other = Index::<u16>::deserialize_from(&data[..]).unwrap();
616
617        assert_eq!(index, other);
618    }
619
620    #[test]
621    fn serialize_u32_works() {
622        let codes = gen_random_codes::<u32>(10000);
623        let index = Index::new(codes).unwrap();
624
625        let mut data = vec![];
626        index.serialize_into(&mut data).unwrap();
627        let other = Index::<u32>::deserialize_from(&data[..]).unwrap();
628
629        assert_eq!(index, other);
630    }
631
632    #[test]
633    fn serialize_u64_works() {
634        let codes = gen_random_codes::<u64>(10000);
635        let index = Index::new(codes).unwrap();
636
637        let mut data = vec![];
638        index.serialize_into(&mut data).unwrap();
639        let other = Index::<u64>::deserialize_from(&data[..]).unwrap();
640
641        assert_eq!(index, other);
642    }
643}