cosine_lsh/
lib.rs

1mod hash;
2use hash::{Hash, HashTableKey, Hyperlanes, QueryResult, Point, HashTableBucket, HashTable};
3use core::borrow::Borrow;
4use std::collections::HashMap;
5
6fn euclidean_dist_square(p1: &[f64], p2: &[f64]) -> f64 {
7    p1.iter().zip(p2).
8       fold(0.0, |acc, (i, j)| acc + (j - i).
9       powi(2))
10}
11
12struct CosineLshParam {
13    dim: i64,
14    l: i64,
15    m: i64,
16    h: i64,
17    hyperplanes: Hyperlanes,
18}
19
20impl CosineLshParam {
21    fn new(dim: i64, l: i64, m: i64, h: i64, hyperplanes: Hyperlanes)
22        -> CosineLshParam {
23        CosineLshParam{dim, l, m, hyperplanes, h}
24    }
25
26    fn hash(&self, point: &[f64]) -> Vec<HashTableKey> {
27        let simhash = Hash::new(self.hyperplanes.borrow(), point);
28        let mut hvs :Vec<HashTableKey> = Vec::with_capacity(self.l as usize);
29        for i in 0..hvs.capacity() {
30            let mut s = Vec::with_capacity(self.m as usize);
31            for j in 0..self.m {
32                s.push(simhash.sig.0[i*self.m as usize+j as usize]);
33            }
34            hvs.push(s);
35        };
36        hvs
37    }
38}
39
40
41pub struct CosineLSH<T> {
42    tables: Vec<HashTable<T>>,
43    next_id: u64,
44    param: CosineLshParam
45}
46
47impl<T> CosineLSH<T>
48    where T: Clone + Copy {
49    pub fn new(dim: i64, l: i64, m: i64) -> Self {
50        CosineLSH {
51            tables: vec![HashTable::new(); l as usize],
52            next_id: 0,
53            param: CosineLshParam::new(dim, l, m, m * l, Hyperlanes::new(m * l, dim))
54        }
55    }
56
57    pub fn insert(&mut self, point: Vec<f64>, extra_data: T) -> Option<()> {
58        if let Some(hvs) = self.to_basic_hash_table_keys(self.param.hash(point.as_slice())) {
59            for (a, b) in self.tables.iter_mut().enumerate() {
60                let j = hvs[a];
61
62                self.next_id += 1;
63                b.entry(j).
64                    or_insert_with(HashTableBucket::new).
65                    push(Point { vector: point.clone(), id: self.next_id, extra_data });
66            };
67        } else {
68            return None
69        }
70        Some(())
71    }
72
73    pub fn query(&self, q: Vec<f64>, max_result: usize) -> Option<Vec<QueryResult<T>>> {
74        let mut seen :HashMap<u64,&Point<T>> = HashMap::new();
75        if let Some(hvs) = self.to_basic_hash_table_keys(self.param.hash(&q)) {
76            self.tables.iter().enumerate().for_each(|(i, table)|
77                {
78                    if let Some(candidates) = table.get(hvs[i].borrow()) {
79                        candidates.iter().
80                            for_each(|p| { seen.entry(p.id).or_insert(p); });
81                    }
82                }
83            )
84        }
85
86
87        let mut distances :Vec<QueryResult<T>> = Vec::with_capacity(seen.len());
88        for (_, value) in seen {
89            let distance = euclidean_dist_square(&q, &value.vector);
90            distances.push(QueryResult{
91                distance,
92                vector: &value.vector,
93                id: value.id,
94                extra_data: value.extra_data});
95        }
96        if distances.is_empty() { return None }
97        distances.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
98        if max_result> 0 && distances.len() > max_result as usize {
99            Some(distances[0..max_result].to_vec())
100        } else {
101            Some(distances)
102        }
103    }
104
105    fn to_basic_hash_table_keys(&self, keys: Vec<HashTableKey>) -> Option<Vec<u64>> {
106        let mut basic_keys :Vec<u64> = Vec::with_capacity(self.param.l as usize);
107        for key in keys {
108            let mut s = "".to_string();
109            for (_, hash_val) in key.iter().enumerate() {
110                match hash_val {
111                    0 => s.push_str("0"),
112                    1 => s.push_str("1"),
113                    _ => return None
114                }
115            }
116            basic_keys.push(s.parse().unwrap());
117        }
118        Some(basic_keys)
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn euclidian_test() {
128        let a = vec![34.0,12.0,65.0,29.0];
129        let b = vec![2.0,3.0,4.0];
130        assert_eq!(euclidean_dist_square(a.as_slice(), b.as_slice()), 4826.0);
131    }
132}