1use crate::data::{Integer, Numeric};
3use crate::{prelude::*, utils::create_rng};
4use fnv::FnvHashSet;
5use itertools::Itertools;
6use ndarray::prelude::*;
7use ndarray::stack;
8use num::{Float, One, Zero};
9use rand::distributions::Uniform;
10use rand::seq::SliceRandom;
11use rand::Rng;
12use statrs::function::factorial::binomial;
13use std::cmp::Ordering;
14use std::collections::BinaryHeap;
15
16pub trait QueryDirectedProbe<N, K> {
25 fn query_directed_probe(&self, q: &[N], budget: usize) -> Result<Vec<Vec<K>>>;
26}
27
28pub trait StepWiseProbe<N, K>: VecHash<N, K> {
30 fn step_wise_probe(&self, q: &[N], budget: usize, hash_len: usize) -> Result<Vec<Vec<K>>>;
31}
32
33impl<N> StepWiseProbe<N, i8> for SignRandomProjections<N>
34where
35 N: Numeric,
36{
37 fn step_wise_probe(&self, q: &[N], budget: usize, hash_len: usize) -> Result<Vec<Vec<i8>>> {
38 let probing_seq = step_wise_probing(hash_len, budget, false);
39 let original_hash = self.hash_vec_query(q);
40
41 let a = probing_seq
42 .iter()
43 .map(|pertub| {
44 original_hash
45 .iter()
46 .zip(pertub)
47 .map(
48 |(&original, &shift)| {
49 if shift == 1 {
50 original * -1
51 } else {
52 original
53 }
54 },
55 )
56 .collect_vec()
57 })
58 .collect_vec();
59 Ok(a)
60 }
61}
62
63fn uniform_without_replacement<T: Copy>(bucket: &mut [T], n: usize) -> Vec<T> {
64 let mut max_idx = bucket.len() - 1;
66 let mut rng = create_rng(0);
67
68 let mut samples = Vec::with_capacity(n);
69
70 for _ in 0..n {
71 let idx = rng.sample(Uniform::new(0, max_idx));
72 debug_assert!(idx < bucket.len());
73 unsafe {
74 samples.push(*bucket.get_unchecked(idx));
75 };
76 bucket.swap(idx, max_idx);
77 max_idx -= 1;
78 }
79 samples
80}
81
82fn create_hash_permutation(hash_len: usize, n: usize) -> Vec<i8> {
83 let mut permut = vec![0; hash_len];
84 let shift_options = [-1i8, 1];
85
86 let mut idx: Vec<usize> = (0..hash_len).collect();
87 let candidate_idx = uniform_without_replacement(&mut idx, n);
88
89 let mut rng = create_rng(0);
90 for i in candidate_idx {
91 debug_assert!(i < permut.len());
92 let v = *shift_options.choose(&mut rng).unwrap();
93 unsafe { *permut.get_unchecked_mut(i) += v }
95 }
96 permut
97}
98
99fn step_wise_perturb(
109 hash_length: usize,
110 n_perturbations: usize,
111 two_shifts: bool,
112) -> impl Iterator<Item = Vec<(usize, i8)>> {
113 let multiply;
114 if two_shifts {
115 multiply = 2
116 } else {
117 multiply = 1
118 }
119
120 let idx = 0..hash_length * multiply;
121 let switchpoint = hash_length;
122 let a = idx.combinations(n_perturbations).map(move |comb| {
123 comb.iter()
130 .map(|&i| {
131 if i >= switchpoint {
132 (i - switchpoint, -1)
133 } else {
134 (i, 1)
135 }
136 })
137 .collect_vec()
138 });
139 a
140}
141
142fn step_wise_probing(hash_len: usize, mut budget: usize, two_shifts: bool) -> Vec<Vec<i8>> {
148 let mut hash_perturbs = Vec::with_capacity(budget);
149
150 let n = hash_len as u64;
151 let mut k = 1;
153 while budget > 0 && k <= n {
154 let multiply;
157 if two_shifts {
158 multiply = 2
159 } else {
160 multiply = 1
161 }
162 let n_combinations = binomial(n, k) as usize * multiply;
163
164 step_wise_perturb(n as usize, k as usize, two_shifts)
165 .take(budget as usize)
166 .for_each(|v| {
167 let mut new_perturb = vec![0; hash_len];
168 v.iter().for_each(|(idx, shift)| {
169 debug_assert!(*idx < new_perturb.len());
170 let v = unsafe { new_perturb.get_unchecked_mut(*idx) };
171 *v += *shift;
172 });
173 hash_perturbs.push(new_perturb)
174 });
175 k += 1;
176 budget -= n_combinations;
177 }
178 hash_perturbs
179}
180
181#[derive(PartialEq, Clone)]
182struct PerturbState<'a, N, K>
183where
184 N: Numeric + Float + Copy,
185{
186 z: &'a [usize],
188 distances: &'a [N],
190 selection: Vec<usize>,
193 switchpoint: usize,
194 original_hash: Option<Vec<K>>,
195}
196
197impl<'a, N, K> PerturbState<'a, N, K>
198where
199 N: Numeric + Float,
200 K: Integer,
201{
202 fn new(z: &'a [usize], distances: &'a [N], switchpoint: usize, hash: Vec<K>) -> Self {
203 PerturbState {
204 z,
205 distances,
206 selection: vec![0],
207 switchpoint,
208 original_hash: Some(hash),
209 }
210 }
211
212 fn score(&self) -> N {
213 let mut score = Zero::zero();
214 for &index in self.selection.iter() {
215 debug_assert!(index < self.z.len());
216 let zj = unsafe { *self.z.get_unchecked(index) };
217 debug_assert!(zj < self.distances.len());
218 unsafe { score += self.distances.get_unchecked(zj).clone() };
219 }
220 score
221 }
222
223 fn i_delta(&self) -> Vec<(usize, K)> {
225 let mut out = Vec::with_capacity(self.z.len());
226 for &idx in self.selection.iter() {
227 debug_assert!(idx < self.z.len());
228 let zj = unsafe { *self.z.get_unchecked(idx) };
229 let delta;
230 let index;
231 if zj >= self.switchpoint {
232 delta = One::one();
233 index = zj - self.switchpoint;
234 } else {
235 delta = K::from_i8(-1).unwrap();
236 index = zj;
237 }
238 out.push((index, delta))
239 }
240 out
241 }
242
243 fn check_bounds(&mut self, max: usize) -> Result<()> {
244 if max == self.z.len() - 1 {
245 Err(Error::Failed("Out of bounds".to_string()))
246 } else {
247 self.selection.push(max + 1);
248 Ok(())
249 }
250 }
251
252 fn shift(&mut self) -> Result<()> {
253 let max = self.selection.pop().unwrap();
254 self.check_bounds(max)
255 }
256
257 fn expand(&mut self) -> Result<()> {
258 let max = self.selection[self.selection.len() - 1];
259 self.check_bounds(max)
260 }
261
262 fn gen_hash(&mut self) -> Vec<K> {
263 let mut hash = self.original_hash.take().expect("hash already taken");
264 for (i, delta) in self.i_delta() {
265 debug_assert!(i < hash.len());
266 let ptr = unsafe { hash.get_unchecked_mut(i) };
267 *ptr += delta
268 }
269 hash
270 }
271}
272
273impl<N, K> Ord for PerturbState<'_, N, K>
275where
276 N: Numeric + Float,
277 K: Integer,
278{
279 fn cmp(&self, other: &PerturbState<N, K>) -> Ordering {
280 self.partial_cmp(other).unwrap()
281 }
282}
283
284impl<N, K> PartialOrd for PerturbState<'_, N, K>
285where
286 N: Numeric + Float,
287 K: Integer,
288{
289 fn partial_cmp(&self, other: &PerturbState<N, K>) -> Option<Ordering> {
290 other.score().partial_cmp(&self.score())
291 }
292}
293
294impl<N, K> Eq for PerturbState<'_, N, K>
295where
296 N: Numeric + Float,
297 K: Integer,
298{
299}
300
301macro_rules! impl_query_directed_probe {
302 ($vechash:ident) => {
303 impl<N, K> $vechash<N, K>
304 where
305 N: Numeric + Float,
306 K: Integer,
307 {
308 fn distance_to_bound(&self, q: &[N], hash: Option<&Vec<K>>) -> (Array1<N>, Array1<N>) {
313 let hash = match hash {
314 None => self.hash_vec(q).to_vec(),
315 Some(h) => h.iter().map(|&k| N::from(k).unwrap()).collect_vec(),
316 };
317 let f = self.a.dot(&aview1(q)) + &self.b;
318 let xi_min1 = f - &aview1(&hash) * self.r;
319 let xi_plus1: Array1<N> = xi_min1.map(|x| self.r - *x);
320 (xi_min1, xi_plus1)
321 }
322 }
323
324 impl<N, K> QueryDirectedProbe<N, K> for $vechash<N, K>
325 where
326 N: Numeric + Float,
327 K: Integer,
328 {
329 fn query_directed_probe(&self, q: &[N], budget: usize) -> Result<Vec<Vec<K>>> {
330 let hash = self.hash_vec_query(q);
333 let (xi_min, xi_plus) = self.distance_to_bound(q, Some(&hash));
334 let switchpoint = xi_min.len();
337
338 let distances: Vec<N> = stack!(Axis(0), xi_min, xi_plus).to_vec();
339
340 let z = distances.clone();
343 let mut z = z.iter().enumerate().collect::<Vec<_>>();
344 z.sort_unstable_by(|(_idx_a, a), (_idx_b, b)| a.partial_cmp(b).unwrap());
345 let z = z.iter().map(|(idx, _)| *idx).collect::<Vec<_>>();
346
347 let mut hashes = Vec::with_capacity(budget + 1);
348 hashes.push(hash.clone());
349 let mut heap = BinaryHeap::new();
351 let a0 = PerturbState::new(&z, &distances, switchpoint, hash);
352 heap.push(a0);
353 for _ in 0..budget {
354 let mut ai = match heap.pop() {
355 Some(ai) => ai,
356 None => {
357 return Err(Error::Failed(
358 "All query directed probing combinations depleted".to_string(),
359 ))
360 }
361 };
362 let mut a_s = ai.clone();
363 let mut a_e = ai.clone();
364 if a_s.shift().is_ok() {
365 heap.push(a_s);
366 }
367 if a_e.expand().is_ok() {
368 heap.push(a_e);
369 }
370 hashes.push(ai.gen_hash())
371 }
372 Ok(hashes)
373 }
374 }
375 };
376}
377impl_query_directed_probe!(L2);
378impl_query_directed_probe!(MIPS);
379
380impl<N, K, H, T> LSH<H, N, T, K>
381where
382 N: Numeric,
383 K: Integer,
384 H: VecHash<N, K>,
385 T: HashTables<N, K>,
386{
387 pub fn multi_probe_bucket_union(&self, v: &[N]) -> Result<FnvHashSet<u32>> {
388 self.validate_vec(v)?;
389 let mut bucket_union = FnvHashSet::default();
390
391 let h0 = &self.hashers[0];
395 if h0.as_query_directed_probe().is_some() {
396 for (i, hasher) in self.hashers.iter().enumerate() {
397 if let Some(h) = hasher.as_query_directed_probe() {
398 let hashes = h.query_directed_probe(v, self._multi_probe_budget)?;
399 for hash in hashes {
400 self.process_bucket_union_result(&hash, i, &mut bucket_union)?
401 }
402 }
403 }
404 } else if h0.as_step_wise_probe().is_some() {
405 for (i, hasher) in self.hashers.iter().enumerate() {
406 if let Some(h) = hasher.as_step_wise_probe() {
407 let hashes =
408 h.step_wise_probe(v, self._multi_probe_budget, self.n_projections)?;
409 for hash in hashes {
410 self.process_bucket_union_result(&hash, i, &mut bucket_union)?
411 }
412 }
413 }
414 } else {
415 unimplemented!()
416 }
417 Ok(bucket_union)
418 }
419}
420
421#[cfg(test)]
422mod test {
423 use super::*;
424
425 #[test]
426 fn test_permutation() {
427 let permut = create_hash_permutation(5, 3);
428 println!("{:?}", permut);
429 }
430
431 #[test]
432 fn test_step_wise_perturb() {
433 let a = step_wise_perturb(4, 2, true);
434 assert_eq!(
435 vec![vec![(0, 1), (1, 1)], vec![(0, 1), (2, 1)]],
436 a.take(2).collect_vec()
437 );
438 }
439
440 #[test]
441 fn test_step_wise_probe() {
442 let a = step_wise_probing(4, 20, true);
443 assert_eq!(vec![1, 0, 0, 0], a[0]);
444 assert_eq!(vec![0, 1, -1, 0], a[a.len() - 1]);
445 }
446
447 #[test]
448 fn test_l2_xi_distances() {
449 let l2 = L2::<f32>::new(4, 4., 3, 1);
450 let (xi_min, xi_plus) = l2.distance_to_bound(&[1., 2., 3., 1.], None);
451 assert_eq!(xi_min, arr1(&[2.0210547, 1.9154847, 0.89937115]));
452 assert_eq!(xi_plus, arr1(&[1.9789453, 2.0845153, 3.1006289]));
453 }
454
455 #[test]
456 fn test_perturbstate() {
457 let distances = [1., 0.1, 3., 2., 9., 4., 0.8, 5.];
458 let z = vec![1, 6, 0, 3, 2, 5, 7, 4];
460 let switchpoint = 4;
461 let a0 = PerturbState::new(&z, &distances, switchpoint, vec![0, 0, 0, 0]);
462 assert_eq!(a0.clone().gen_hash(), [0, -1, 0, 0]);
468 assert_eq!(a0.score(), 0.1);
469 assert_eq!(a0.selection, [0]);
470
471 let mut ae = a0.clone();
478 ae.expand().unwrap();
479 assert_eq!(ae.gen_hash(), [0, -1, 1, 0]);
480 assert_eq!(ae.score(), 0.1 + 0.8);
481 assert_eq!(ae.selection, [0, 1]);
482
483 let mut a_s = a0.clone();
489 a_s.shift().unwrap();
490 assert_eq!(a_s.gen_hash(), [0, 0, 1, 0]);
491 assert_eq!(a_s.score(), 0.8);
492 assert_eq!(a_s.selection, [1]);
493 }
494
495 #[test]
496 fn test_query_directed_probe() {
497 let l2 = <L2>::new(4, 4., 3, 1);
498 let hashes = l2.query_directed_probe(&[1., 2., 3., 1.], 4).unwrap();
499 println!("{:?}", hashes)
500 }
501
502 #[test]
503 fn test_query_directed_bounds() {
504 let mut lsh = hi8::LshMem::new(2, 1, 1).multi_probe(1000).l2(4.).unwrap();
506 lsh.store_vec(&[1.]).unwrap();
507 assert!(lsh.query_bucket_ids(&[1.]).is_err())
508 }
509}