1mod hash;
2use hash::{Hash, HashTableKey, Hyperlanes, QueryResult, Point, HashTableBucket, HashTable};
3use core::borrow::Borrow;
4use std::collections::HashMap;
5
6fn euclidean_dist_square(p1: &[f64], p2: &[f64]) -> f64 {
7 p1.iter().zip(p2).
8 fold(0.0, |acc, (i, j)| acc + (j - i).
9 powi(2))
10}
11
12struct CosineLshParam {
13 dim: i64,
14 l: i64,
15 m: i64,
16 h: i64,
17 hyperplanes: Hyperlanes,
18}
19
20impl CosineLshParam {
21 fn new(dim: i64, l: i64, m: i64, h: i64, hyperplanes: Hyperlanes)
22 -> CosineLshParam {
23 CosineLshParam{dim, l, m, hyperplanes, h}
24 }
25
26 fn hash(&self, point: &[f64]) -> Vec<HashTableKey> {
27 let simhash = Hash::new(self.hyperplanes.borrow(), point);
28 let mut hvs :Vec<HashTableKey> = Vec::with_capacity(self.l as usize);
29 for i in 0..hvs.capacity() {
30 let mut s = Vec::with_capacity(self.m as usize);
31 for j in 0..self.m {
32 s.push(simhash.sig.0[i*self.m as usize+j as usize]);
33 }
34 hvs.push(s);
35 };
36 hvs
37 }
38}
39
40
41pub struct CosineLSH<T> {
42 tables: Vec<HashTable<T>>,
43 next_id: u64,
44 param: CosineLshParam
45}
46
47impl<T> CosineLSH<T>
48 where T: Clone + Copy {
49 pub fn new(dim: i64, l: i64, m: i64) -> Self {
50 CosineLSH {
51 tables: vec![HashTable::new(); l as usize],
52 next_id: 0,
53 param: CosineLshParam::new(dim, l, m, m * l, Hyperlanes::new(m * l, dim))
54 }
55 }
56
57 pub fn insert(&mut self, point: Vec<f64>, extra_data: T) -> Option<()> {
58 if let Some(hvs) = self.to_basic_hash_table_keys(self.param.hash(point.as_slice())) {
59 for (a, b) in self.tables.iter_mut().enumerate() {
60 let j = hvs[a];
61
62 self.next_id += 1;
63 b.entry(j).
64 or_insert_with(HashTableBucket::new).
65 push(Point { vector: point.clone(), id: self.next_id, extra_data });
66 };
67 } else {
68 return None
69 }
70 Some(())
71 }
72
73 pub fn query(&self, q: Vec<f64>, max_result: usize) -> Option<Vec<QueryResult<T>>> {
74 let mut seen :HashMap<u64,&Point<T>> = HashMap::new();
75 if let Some(hvs) = self.to_basic_hash_table_keys(self.param.hash(&q)) {
76 self.tables.iter().enumerate().for_each(|(i, table)|
77 {
78 if let Some(candidates) = table.get(hvs[i].borrow()) {
79 candidates.iter().
80 for_each(|p| { seen.entry(p.id).or_insert(p); });
81 }
82 }
83 )
84 }
85
86
87 let mut distances :Vec<QueryResult<T>> = Vec::with_capacity(seen.len());
88 for (_, value) in seen {
89 let distance = euclidean_dist_square(&q, &value.vector);
90 distances.push(QueryResult{
91 distance,
92 vector: &value.vector,
93 id: value.id,
94 extra_data: value.extra_data});
95 }
96 if distances.is_empty() { return None }
97 distances.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
98 if max_result> 0 && distances.len() > max_result as usize {
99 Some(distances[0..max_result].to_vec())
100 } else {
101 Some(distances)
102 }
103 }
104
105 fn to_basic_hash_table_keys(&self, keys: Vec<HashTableKey>) -> Option<Vec<u64>> {
106 let mut basic_keys :Vec<u64> = Vec::with_capacity(self.param.l as usize);
107 for key in keys {
108 let mut s = "".to_string();
109 for (_, hash_val) in key.iter().enumerate() {
110 match hash_val {
111 0 => s.push_str("0"),
112 1 => s.push_str("1"),
113 _ => return None
114 }
115 }
116 basic_keys.push(s.parse().unwrap());
117 }
118 Some(basic_keys)
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn euclidian_test() {
128 let a = vec![34.0,12.0,65.0,29.0];
129 let b = vec![2.0,3.0,4.0];
130 assert_eq!(euclidean_dist_square(a.as_slice(), b.as_slice()), 4826.0);
131 }
132}