lsh_rs2/
multi_probe.rs

1//! Multi probe LSH
2use crate::data::{Integer, Numeric};
3use crate::{prelude::*, utils::create_rng};
4use fnv::FnvHashSet;
5use itertools::Itertools;
6use ndarray::prelude::*;
7use ndarray::stack;
8use num::{Float, One, Zero};
9use rand::distributions::Uniform;
10use rand::seq::SliceRandom;
11use rand::Rng;
12use statrs::function::factorial::binomial;
13use std::cmp::Ordering;
14use std::collections::BinaryHeap;
15
16/// Query directed probing
17///
18/// Implementation of paper:
19///
20/// Liv, Q., Josephson, W., Whang, L., Charikar, M., & Li, K. (n.d.).
21/// Multi-Probe LSH: Efficient Indexing for High-Dimensional Similarity Search
22/// Retrieved from https://www.cs.princeton.edu/cass/papers/mplsh_vldb07.pdf
23
24pub trait QueryDirectedProbe<N, K> {
25    fn query_directed_probe(&self, q: &[N], budget: usize) -> Result<Vec<Vec<K>>>;
26}
27
28/// Step wise probing
29pub trait StepWiseProbe<N, K>: VecHash<N, K> {
30    fn step_wise_probe(&self, q: &[N], budget: usize, hash_len: usize) -> Result<Vec<Vec<K>>>;
31}
32
33impl<N> StepWiseProbe<N, i8> for SignRandomProjections<N>
34where
35    N: Numeric,
36{
37    fn step_wise_probe(&self, q: &[N], budget: usize, hash_len: usize) -> Result<Vec<Vec<i8>>> {
38        let probing_seq = step_wise_probing(hash_len, budget, false);
39        let original_hash = self.hash_vec_query(q);
40
41        let a = probing_seq
42            .iter()
43            .map(|pertub| {
44                original_hash
45                    .iter()
46                    .zip(pertub)
47                    .map(
48                        |(&original, &shift)| {
49                            if shift == 1 {
50                                original * -1
51                            } else {
52                                original
53                            }
54                        },
55                    )
56                    .collect_vec()
57            })
58            .collect_vec();
59        Ok(a)
60    }
61}
62
63fn uniform_without_replacement<T: Copy>(bucket: &mut [T], n: usize) -> Vec<T> {
64    // https://stackoverflow.com/questions/196017/unique-non-repeating-random-numbers-in-o1#196065
65    let mut max_idx = bucket.len() - 1;
66    let mut rng = create_rng(0);
67
68    let mut samples = Vec::with_capacity(n);
69
70    for _ in 0..n {
71        let idx = rng.sample(Uniform::new(0, max_idx));
72        debug_assert!(idx < bucket.len());
73        unsafe {
74            samples.push(*bucket.get_unchecked(idx));
75        };
76        bucket.swap(idx, max_idx);
77        max_idx -= 1;
78    }
79    samples
80}
81
82fn create_hash_permutation(hash_len: usize, n: usize) -> Vec<i8> {
83    let mut permut = vec![0; hash_len];
84    let shift_options = [-1i8, 1];
85
86    let mut idx: Vec<usize> = (0..hash_len).collect();
87    let candidate_idx = uniform_without_replacement(&mut idx, n);
88
89    let mut rng = create_rng(0);
90    for i in candidate_idx {
91        debug_assert!(i < permut.len());
92        let v = *shift_options.choose(&mut rng).unwrap();
93        // bounds check not needed as i cannot be larger than permut
94        unsafe { *permut.get_unchecked_mut(i) += v }
95    }
96    permut
97}
98
99/// Retrieve perturbation indexes. Every index in a hash can be perturbed by +1 or -1.
100///
101/// First retrieve all hashes where 1 index is changed,
102/// then all combinations where two indexes are changed etc.
103///
104/// # Arguments
105/// * - `hash_length` The hash length is used to determine all the combinations of indexes that can be shifted.
106/// * - `n_perturbation` The number of indexes allowed to be changed. We generally first deplete
107/// * - `two_shifts` If true every index is changed by +1 and -1, else only by +1.
108fn step_wise_perturb(
109    hash_length: usize,
110    n_perturbations: usize,
111    two_shifts: bool,
112) -> impl Iterator<Item = Vec<(usize, i8)>> {
113    let multiply;
114    if two_shifts {
115        multiply = 2
116    } else {
117        multiply = 1
118    }
119
120    let idx = 0..hash_length * multiply;
121    let switchpoint = hash_length;
122    let a = idx.combinations(n_perturbations).map(move |comb| {
123        // return of comb are indexes and perturbations (-1 or +1).
124        // where idx are the indexes that are perturbed.
125        // if n_perturbations is 2 output could be:
126        // comb -> [(0, -1), (3, 1)]
127        // if n_perturbations is 4 output could be:
128        // comb -> [(1, -1), (9, -1), (4, 1), (3, -1)]
129        comb.iter()
130            .map(|&i| {
131                if i >= switchpoint {
132                    (i - switchpoint, -1)
133                } else {
134                    (i, 1)
135                }
136            })
137            .collect_vec()
138    });
139    a
140}
141
142/// Generates new hashes by step wise shifting one indexes.
143/// First all one index shifts are returned (these are closer to the original hash)
144/// then the two index shifts, three index shifts etc.
145///
146/// This is done until the budget is depleted.
147fn step_wise_probing(hash_len: usize, mut budget: usize, two_shifts: bool) -> Vec<Vec<i8>> {
148    let mut hash_perturbs = Vec::with_capacity(budget);
149
150    let n = hash_len as u64;
151    // number of combinations (indexes we allow to perturb)
152    let mut k = 1;
153    while budget > 0 && k <= n {
154        // binomial coefficient
155        // times two as we have -1 and +1.
156        let multiply;
157        if two_shifts {
158            multiply = 2
159        } else {
160            multiply = 1
161        }
162        let n_combinations = binomial(n, k) as usize * multiply;
163
164        step_wise_perturb(n as usize, k as usize, two_shifts)
165            .take(budget as usize)
166            .for_each(|v| {
167                let mut new_perturb = vec![0; hash_len];
168                v.iter().for_each(|(idx, shift)| {
169                    debug_assert!(*idx < new_perturb.len());
170                    let v = unsafe { new_perturb.get_unchecked_mut(*idx) };
171                    *v += *shift;
172                });
173                hash_perturbs.push(new_perturb)
174            });
175        k += 1;
176        budget -= n_combinations;
177    }
178    hash_perturbs
179}
180
181#[derive(PartialEq, Clone)]
182struct PerturbState<'a, N, K>
183where
184    N: Numeric + Float + Copy,
185{
186    // original sorted zj
187    z: &'a [usize],
188    // original xi(delta)
189    distances: &'a [N],
190    // selection of zjs
191    // We start with the first one, as this is the lowest score.
192    selection: Vec<usize>,
193    switchpoint: usize,
194    original_hash: Option<Vec<K>>,
195}
196
197impl<'a, N, K> PerturbState<'a, N, K>
198where
199    N: Numeric + Float,
200    K: Integer,
201{
202    fn new(z: &'a [usize], distances: &'a [N], switchpoint: usize, hash: Vec<K>) -> Self {
203        PerturbState {
204            z,
205            distances,
206            selection: vec![0],
207            switchpoint,
208            original_hash: Some(hash),
209        }
210    }
211
212    fn score(&self) -> N {
213        let mut score = Zero::zero();
214        for &index in self.selection.iter() {
215            debug_assert!(index < self.z.len());
216            let zj = unsafe { *self.z.get_unchecked(index) };
217            debug_assert!(zj < self.distances.len());
218            unsafe { score += self.distances.get_unchecked(zj).clone() };
219        }
220        score
221    }
222
223    // map zj value to (i, delta) as in paper
224    fn i_delta(&self) -> Vec<(usize, K)> {
225        let mut out = Vec::with_capacity(self.z.len());
226        for &idx in self.selection.iter() {
227            debug_assert!(idx < self.z.len());
228            let zj = unsafe { *self.z.get_unchecked(idx) };
229            let delta;
230            let index;
231            if zj >= self.switchpoint {
232                delta = One::one();
233                index = zj - self.switchpoint;
234            } else {
235                delta = K::from_i8(-1).unwrap();
236                index = zj;
237            }
238            out.push((index, delta))
239        }
240        out
241    }
242
243    fn check_bounds(&mut self, max: usize) -> Result<()> {
244        if max == self.z.len() - 1 {
245            Err(Error::Failed("Out of bounds".to_string()))
246        } else {
247            self.selection.push(max + 1);
248            Ok(())
249        }
250    }
251
252    fn shift(&mut self) -> Result<()> {
253        let max = self.selection.pop().unwrap();
254        self.check_bounds(max)
255    }
256
257    fn expand(&mut self) -> Result<()> {
258        let max = self.selection[self.selection.len() - 1];
259        self.check_bounds(max)
260    }
261
262    fn gen_hash(&mut self) -> Vec<K> {
263        let mut hash = self.original_hash.take().expect("hash already taken");
264        for (i, delta) in self.i_delta() {
265            debug_assert!(i < hash.len());
266            let ptr = unsafe { hash.get_unchecked_mut(i) };
267            *ptr += delta
268        }
269        hash
270    }
271}
272
273// implement ordering so that we can create a min heap
274impl<N, K> Ord for PerturbState<'_, N, K>
275where
276    N: Numeric + Float,
277    K: Integer,
278{
279    fn cmp(&self, other: &PerturbState<N, K>) -> Ordering {
280        self.partial_cmp(other).unwrap()
281    }
282}
283
284impl<N, K> PartialOrd for PerturbState<'_, N, K>
285where
286    N: Numeric + Float,
287    K: Integer,
288{
289    fn partial_cmp(&self, other: &PerturbState<N, K>) -> Option<Ordering> {
290        other.score().partial_cmp(&self.score())
291    }
292}
293
294impl<N, K> Eq for PerturbState<'_, N, K>
295where
296    N: Numeric + Float,
297    K: Integer,
298{
299}
300
301macro_rules! impl_query_directed_probe {
302    ($vechash:ident) => {
303        impl<N, K> $vechash<N, K>
304        where
305            N: Numeric + Float,
306            K: Integer,
307        {
308            /// Computes the distance between the query hash and the boundary of the slot r (W in the paper)
309            ///
310            /// As stated by Multi-Probe LSH paper:
311            /// For δ ∈ {−1, +1}, let xi(δ) be the distance of q from the boundary of the slot
312            fn distance_to_bound(&self, q: &[N], hash: Option<&Vec<K>>) -> (Array1<N>, Array1<N>) {
313                let hash = match hash {
314                    None => self.hash_vec(q).to_vec(),
315                    Some(h) => h.iter().map(|&k| N::from(k).unwrap()).collect_vec(),
316                };
317                let f = self.a.dot(&aview1(q)) + &self.b;
318                let xi_min1 = f - &aview1(&hash) * self.r;
319                let xi_plus1: Array1<N> = xi_min1.map(|x| self.r - *x);
320                (xi_min1, xi_plus1)
321            }
322        }
323
324        impl<N, K> QueryDirectedProbe<N, K> for $vechash<N, K>
325        where
326            N: Numeric + Float,
327            K: Integer,
328        {
329            fn query_directed_probe(&self, q: &[N], budget: usize) -> Result<Vec<Vec<K>>> {
330                // https://www.cs.princeton.edu/cass/papers/mplsh_vldb07.pdf
331                // https://www.youtube.com/watch?v=c5DHtx5VxX8
332                let hash = self.hash_vec_query(q);
333                let (xi_min, xi_plus) = self.distance_to_bound(q, Some(&hash));
334                // >= this point = +1
335                // < this point = -1
336                let switchpoint = xi_min.len();
337
338                let distances: Vec<N> = stack!(Axis(0), xi_min, xi_plus).to_vec();
339
340                // indexes of the least scores to the highest
341                // all below is an argsort
342                let z = distances.clone();
343                let mut z = z.iter().enumerate().collect::<Vec<_>>();
344                z.sort_unstable_by(|(_idx_a, a), (_idx_b, b)| a.partial_cmp(b).unwrap());
345                let z = z.iter().map(|(idx, _)| *idx).collect::<Vec<_>>();
346
347                let mut hashes = Vec::with_capacity(budget + 1);
348                hashes.push(hash.clone());
349                // Algorithm 1 from paper
350                let mut heap = BinaryHeap::new();
351                let a0 = PerturbState::new(&z, &distances, switchpoint, hash);
352                heap.push(a0);
353                for _ in 0..budget {
354                    let mut ai = match heap.pop() {
355                        Some(ai) => ai,
356                        None => {
357                            return Err(Error::Failed(
358                                "All query directed probing combinations depleted".to_string(),
359                            ))
360                        }
361                    };
362                    let mut a_s = ai.clone();
363                    let mut a_e = ai.clone();
364                    if a_s.shift().is_ok() {
365                        heap.push(a_s);
366                    }
367                    if a_e.expand().is_ok() {
368                        heap.push(a_e);
369                    }
370                    hashes.push(ai.gen_hash())
371                }
372                Ok(hashes)
373            }
374        }
375    };
376}
377impl_query_directed_probe!(L2);
378impl_query_directed_probe!(MIPS);
379
380impl<N, K, H, T> LSH<H, N, T, K>
381where
382    N: Numeric,
383    K: Integer,
384    H: VecHash<N, K>,
385    T: HashTables<N, K>,
386{
387    pub fn multi_probe_bucket_union(&self, v: &[N]) -> Result<FnvHashSet<u32>> {
388        self.validate_vec(v)?;
389        let mut bucket_union = FnvHashSet::default();
390
391        // Check if hasher has implemented this trait. If so follow this more specialized path.
392        // Only L2 should have implemented it. This is the trick to choose a different function
393        // path for the L2 struct.
394        let h0 = &self.hashers[0];
395        if h0.as_query_directed_probe().is_some() {
396            for (i, hasher) in self.hashers.iter().enumerate() {
397                if let Some(h) = hasher.as_query_directed_probe() {
398                    let hashes = h.query_directed_probe(v, self._multi_probe_budget)?;
399                    for hash in hashes {
400                        self.process_bucket_union_result(&hash, i, &mut bucket_union)?
401                    }
402                }
403            }
404        } else if h0.as_step_wise_probe().is_some() {
405            for (i, hasher) in self.hashers.iter().enumerate() {
406                if let Some(h) = hasher.as_step_wise_probe() {
407                    let hashes =
408                        h.step_wise_probe(v, self._multi_probe_budget, self.n_projections)?;
409                    for hash in hashes {
410                        self.process_bucket_union_result(&hash, i, &mut bucket_union)?
411                    }
412                }
413            }
414        } else {
415            unimplemented!()
416        }
417        Ok(bucket_union)
418    }
419}
420
421#[cfg(test)]
422mod test {
423    use super::*;
424
425    #[test]
426    fn test_permutation() {
427        let permut = create_hash_permutation(5, 3);
428        println!("{:?}", permut);
429    }
430
431    #[test]
432    fn test_step_wise_perturb() {
433        let a = step_wise_perturb(4, 2, true);
434        assert_eq!(
435            vec![vec![(0, 1), (1, 1)], vec![(0, 1), (2, 1)]],
436            a.take(2).collect_vec()
437        );
438    }
439
440    #[test]
441    fn test_step_wise_probe() {
442        let a = step_wise_probing(4, 20, true);
443        assert_eq!(vec![1, 0, 0, 0], a[0]);
444        assert_eq!(vec![0, 1, -1, 0], a[a.len() - 1]);
445    }
446
447    #[test]
448    fn test_l2_xi_distances() {
449        let l2 = L2::<f32>::new(4, 4., 3, 1);
450        let (xi_min, xi_plus) = l2.distance_to_bound(&[1., 2., 3., 1.], None);
451        assert_eq!(xi_min, arr1(&[2.0210547, 1.9154847, 0.89937115]));
452        assert_eq!(xi_plus, arr1(&[1.9789453, 2.0845153, 3.1006289]));
453    }
454
455    #[test]
456    fn test_perturbstate() {
457        let distances = [1., 0.1, 3., 2., 9., 4., 0.8, 5.];
458        // argsort
459        let z = vec![1, 6, 0, 3, 2, 5, 7, 4];
460        let switchpoint = 4;
461        let a0 = PerturbState::new(&z, &distances, switchpoint, vec![0, 0, 0, 0]);
462        // initial selection is the first zj [0]
463        // This leads to:
464        //   distance/score:    0.1
465        //   index:             1
466        //   delta:             -1
467        assert_eq!(a0.clone().gen_hash(), [0, -1, 0, 0]);
468        assert_eq!(a0.score(), 0.1);
469        assert_eq!(a0.selection, [0]);
470
471        // after expansion operation selection is [0, 1]
472        // This leads to:
473        //   distance/ score:   0.1 + 0.8
474        //   index:             [1, 2]
475        //   delta:             [-1, 1]
476
477        let mut ae = a0.clone();
478        ae.expand().unwrap();
479        assert_eq!(ae.gen_hash(), [0, -1, 1, 0]);
480        assert_eq!(ae.score(), 0.1 + 0.8);
481        assert_eq!(ae.selection, [0, 1]);
482
483        // after shift operation selection is [1]
484        // This leads to:
485        //   distance/ score:   0.8
486        //   index:             2
487        //   delta:             1
488        let mut a_s = a0.clone();
489        a_s.shift().unwrap();
490        assert_eq!(a_s.gen_hash(), [0, 0, 1, 0]);
491        assert_eq!(a_s.score(), 0.8);
492        assert_eq!(a_s.selection, [1]);
493    }
494
495    #[test]
496    fn test_query_directed_probe() {
497        let l2 = <L2>::new(4, 4., 3, 1);
498        let hashes = l2.query_directed_probe(&[1., 2., 3., 1.], 4).unwrap();
499        println!("{:?}", hashes)
500    }
501
502    #[test]
503    fn test_query_directed_bounds() {
504        // if shift and expand operation have reached the end of the vecs an error should be returned
505        let mut lsh = hi8::LshMem::new(2, 1, 1).multi_probe(1000).l2(4.).unwrap();
506        lsh.store_vec(&[1.]).unwrap();
507        assert!(lsh.query_bucket_ids(&[1.]).is_err())
508    }
509}