use crate::data::Integer;
use crate::table::general::Bucket;
use crate::{data::Numeric, prelude::*, utils::create_rng};
use fnv::FnvHashSet;
use ndarray::prelude::*;
use num::Float;
use rand::Rng;
use rayon::prelude::*;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::path::Path;
pub struct LSH<H, N, T, K = i8>
where
N: Numeric, H: VecHash<N, K>, T: HashTables<N, K>, K: Integer, {
pub n_hash_tables: usize,
pub n_projections: usize,
pub hashers: Vec<H>,
pub dim: usize,
pub hash_tables: Option<T>,
_seed: u64,
only_index_storage: bool,
_multi_probe: bool,
pub(crate) _multi_probe_budget: usize,
_db_path: String,
phantom: PhantomData<(N, K)>,
}
fn lsh_from_lsh<
N: Numeric,
T: HashTables<N, K>,
H: VecHash<N, K> + Serialize + DeserializeOwned,
K: Integer,
>(
lsh: &mut LSH<H, N, T, K>,
hashers: Vec<H>,
) -> Result<LSH<H, N, T, K>> {
let mut ht = *T::new(lsh.n_hash_tables, lsh.only_index_storage, &lsh._db_path)?;
let hashers = match ht.store_hashers(&hashers) {
Ok(_) => hashers,
Err(_) => match ht.load_hashers() {
Err(e) => panic!("could not load hashers: {}", e),
Ok(hashers) => hashers,
},
};
let lsh = LSH {
n_hash_tables: lsh.n_hash_tables,
n_projections: lsh.n_projections,
hashers,
dim: lsh.dim,
hash_tables: Some(ht),
_seed: lsh._seed,
only_index_storage: lsh.only_index_storage,
_multi_probe: lsh._multi_probe,
_multi_probe_budget: lsh._multi_probe_budget,
_db_path: lsh._db_path.clone(),
phantom: PhantomData,
};
Ok(lsh)
}
impl<N, T> LSH<SignRandomProjections<N>, N, T, i8>
where
N: Numeric + DeserializeOwned,
T: HashTables<N, i8>,
{
pub fn srp(&mut self) -> Result<Self> {
let mut rng = create_rng(self._seed);
let mut hashers = Vec::with_capacity(self.n_hash_tables);
for _ in 0..self.n_hash_tables {
let seed = rng.gen();
let hasher = SignRandomProjections::new(self.n_projections, self.dim, seed);
hashers.push(hasher);
}
lsh_from_lsh(self, hashers)
}
}
impl<N, T, K> LSH<L2<N, K>, N, T, K>
where
N: Numeric + Float + DeserializeOwned,
K: Integer + DeserializeOwned,
T: HashTables<N, K>,
{
pub fn l2(&mut self, r: f32) -> Result<Self> {
let mut rng = create_rng(self._seed);
let mut hashers = Vec::with_capacity(self.n_hash_tables);
for _ in 0..self.n_hash_tables {
let seed = rng.gen();
let hasher = L2::new(self.dim, r, self.n_projections, seed);
hashers.push(hasher);
}
lsh_from_lsh(self, hashers)
}
}
impl<N, T, K> LSH<MIPS<N, K>, N, T, K>
where
N: Numeric + Float + DeserializeOwned,
K: Integer + DeserializeOwned,
T: HashTables<N, K>,
{
pub fn mips(&mut self, r: f32, U: N, m: usize) -> Result<Self> {
let mut rng = create_rng(self._seed);
let mut hashers = Vec::with_capacity(self.n_hash_tables);
for _ in 0..self.n_hash_tables {
let seed = rng.gen();
let hasher = MIPS::new(self.dim, r, U, m, self.n_projections, seed);
hashers.push(hasher);
}
lsh_from_lsh(self, hashers)
}
pub fn fit(&mut self, vs: &[Vec<N>]) -> Result<()> {
self.hashers.iter_mut().for_each(|h| h.fit(vs));
Ok(())
}
}
impl<N, T, K> LSH<MinHash<N, K>, N, T, K>
where
N: Integer + DeserializeOwned,
K: Integer + DeserializeOwned,
T: HashTables<N, K>,
{
pub fn minhash(&mut self) -> Result<Self> {
let mut rng = create_rng(self._seed);
let mut hashers = Vec::with_capacity(self.n_hash_tables);
for _ in 0..self.n_hash_tables {
let seed = rng.gen();
let hasher = MinHash::new(self.n_projections, self.dim, seed);
hashers.push(hasher);
}
lsh_from_lsh(self, hashers)
}
}
impl<H, N, T, K> LSH<H, N, T, K>
where
N: Numeric,
H: VecHash<N, K> + Sync,
T: HashTables<N, K> + Sync,
K: Integer,
{
pub fn query_bucket_ids_batch_par(&self, vs: &[Vec<N>]) -> Result<Vec<Vec<u32>>> {
vs.into_par_iter()
.map(|v| self.query_bucket_ids(v))
.collect()
}
pub fn query_bucket_ids_batch_arr_par(&self, vs: ArrayView2<N>) -> Result<Vec<Vec<u32>>> {
vs.axis_iter(Axis(0))
.into_par_iter()
.map(|v| self.query_bucket_ids(v.as_slice().unwrap()))
.collect()
}
}
impl<H, N, T, K> LSH<H, N, T, K>
where
H: VecHash<N, K>,
N: Numeric + Sync,
T: HashTables<N, K>,
K: Integer,
{
pub fn store_vecs(&mut self, vs: &[Vec<N>]) -> Result<Vec<u32>> {
self.validate_vec(&vs[0])?;
self.hash_tables
.as_mut()
.unwrap()
.increase_storage(vs.len());
let mut ht = self.hash_tables.take().unwrap();
let mut insert_idx = Vec::with_capacity(vs.len());
for (i, proj) in self.hashers.iter().enumerate() {
for v in vs.iter() {
let hash = proj.hash_vec_put(v);
match (ht.put(hash, v, i), i) {
(Ok(idx), 0) => insert_idx.push(idx),
(Err(e), _) => return Err(e),
_ => {}
}
}
}
self.hash_tables.replace(ht);
Ok(insert_idx)
}
pub fn store_array(&mut self, vs: ArrayView2<N>) -> Result<Vec<u32>> {
self.validate_vec(vs.slice(s![0, ..]).as_slice().unwrap())?;
self.hash_tables
.as_mut()
.unwrap()
.increase_storage(vs.len());
let mut ht = self.hash_tables.take().unwrap();
let mut insert_idx = Vec::with_capacity(vs.len());
for (i, proj) in self.hashers.iter().enumerate() {
for v in vs.axis_iter(Axis(0)) {
let hash = proj.hash_vec_put(v.as_slice().unwrap());
match (ht.put(hash, v.as_slice().unwrap(), i), i) {
(Ok(idx), 0) => insert_idx.push(idx),
(Err(e), _) => return Err(e),
_ => {}
}
}
}
self.hash_tables.replace(ht);
Ok(insert_idx)
}
}
impl<H, N, T, K> LSH<H, N, T, K>
where
N: Numeric,
H: VecHash<N, K>,
T: HashTables<N, K>,
K: Integer,
{
pub fn new(n_projections: usize, n_hash_tables: usize, dim: usize) -> Self {
let lsh = LSH {
n_hash_tables,
n_projections,
hashers: Vec::with_capacity(0),
dim,
hash_tables: None,
_seed: 0,
only_index_storage: false,
_multi_probe: false,
_multi_probe_budget: 16,
_db_path: "./lsh.db3".to_string(),
phantom: PhantomData,
};
lsh
}
pub(crate) fn validate_vec<A>(&self, v: &[A]) -> Result<()> {
if !(v.len() == self.dim) {
return Err(Error::Failed(
"data point is not valid, are the dimensions correct?".to_string(),
));
};
Ok(())
}
pub fn seed(&mut self, seed: u64) -> &mut Self {
self._seed = seed;
self
}
pub fn only_index(&mut self) -> &mut Self {
self.only_index_storage = true;
self
}
pub fn multi_probe(&mut self, budget: usize) -> &mut Self {
self._multi_probe = true;
self._multi_probe_budget = budget;
self
}
pub fn base(&mut self) -> &mut Self {
self._multi_probe = false;
self
}
pub fn increase_storage(&mut self, upper_bound: usize) -> Result<&mut Self> {
self.hash_tables
.as_mut()
.unwrap()
.increase_storage(upper_bound);
Ok(self)
}
pub fn set_database_file(&mut self, path: &str) -> &mut Self {
self._db_path = path.to_string();
self
}
pub fn describe(&self) -> Result<String> {
self.hash_tables.as_ref().unwrap().describe()
}
pub fn store_vec(&mut self, v: &[N]) -> Result<u32> {
self.validate_vec(v)?;
let mut idx = 0;
let mut ht = self.hash_tables.take().unwrap();
for (i, proj) in self.hashers.iter().enumerate() {
let hash = proj.hash_vec_put(v);
idx = ht.put(hash, &v, i)?;
}
self.hash_tables.replace(ht);
Ok(idx)
}
pub fn update_by_idx(&mut self, idx: u32, new_v: &[N], old_v: &[N]) -> Result<()> {
let mut ht = self.hash_tables.take().unwrap();
for (i, proj) in self.hashers.iter().enumerate() {
let new_hash = proj.hash_vec_put(new_v);
let old_hash = proj.hash_vec_put(old_v);
ht.update_by_idx(&old_hash, new_hash, idx, i)?;
}
self.hash_tables.replace(ht);
Ok(())
}
fn query_bucket_union(&self, v: &[N]) -> Result<Bucket> {
self.validate_vec(v)?;
if self._multi_probe {
return self.multi_probe_bucket_union(v);
}
let mut bucket_union = FnvHashSet::default();
for (i, proj) in self.hashers.iter().enumerate() {
let hash = proj.hash_vec_query(v);
self.process_bucket_union_result(&hash, i, &mut bucket_union)?;
}
Ok(bucket_union)
}
pub fn query_bucket(&self, v: &[N]) -> Result<Vec<&Vec<N>>> {
self.validate_vec(v)?;
if self.only_index_storage {
return Err(Error::Failed(
"cannot query bucket, use query_bucket_ids".to_string(),
));
}
let bucket_union = self.query_bucket_union(v)?;
bucket_union
.iter()
.map(|&idx| Ok(self.hash_tables.as_ref().unwrap().idx_to_datapoint(idx)?))
.collect()
}
pub fn query_bucket_ids(&self, v: &[N]) -> Result<Vec<u32>> {
self.validate_vec(v)?;
let bucket_union = self.query_bucket_union(v)?;
Ok(bucket_union.iter().copied().collect())
}
pub fn query_bucket_ids_batch(&self, vs: &[Vec<N>]) -> Result<Vec<Vec<u32>>> {
vs.iter().map(|v| self.query_bucket_ids(v)).collect()
}
pub fn query_bucket_ids_batch_arr(&self, vs: ArrayView2<N>) -> Result<Vec<Vec<u32>>> {
vs.axis_iter(Axis(0))
.map(|v| self.query_bucket_ids(v.as_slice().unwrap()))
.collect()
}
pub fn delete_vec(&mut self, v: &[N]) -> Result<()> {
self.validate_vec(v)?;
for (i, proj) in self.hashers.iter().enumerate() {
let hash = proj.hash_vec_query(v);
let mut ht = self.hash_tables.take().unwrap();
ht.delete(&hash, v, i).unwrap_or_default();
self.hash_tables = Some(ht)
}
Ok(())
}
pub(crate) fn process_bucket_union_result(
&self,
hash: &[K],
hash_table_idx: usize,
bucket_union: &mut Bucket,
) -> Result<()> {
match self
.hash_tables
.as_ref()
.unwrap()
.query_bucket(hash, hash_table_idx)
{
Err(Error::NotFound) => Ok(()),
Ok(bucket) => {
*bucket_union = bucket_union.union(&bucket).copied().collect();
Ok(())
}
Err(e) => Err(e),
}
}
}
#[cfg(feature = "sqlite")]
impl<N, H, K> LSH<H, N, SqlTable<N, K>, K>
where
N: Numeric,
H: VecHash<N, K> + Serialize,
K: Integer,
{
pub fn commit(&mut self) -> Result<()> {
let ht = self.hash_tables.as_mut().unwrap();
ht.commit()?;
Ok(())
}
pub fn init_transaction(&mut self) -> Result<()> {
let ht = self.hash_tables.as_mut().unwrap();
ht.init_transaction()?;
Ok(())
}
}
#[derive(Serialize, Deserialize)]
struct IntermediatBlob {
hash_tables: Vec<u8>,
hashers: Vec<u8>,
n_hash_tables: usize,
n_projections: usize,
dim: usize,
_seed: u64,
}
impl<H, N, K> LSH<H, N, MemoryTable<N, K>, K>
where
H: Serialize + DeserializeOwned + VecHash<N, K>,
N: Numeric + DeserializeOwned,
K: Integer + DeserializeOwned,
{
pub fn load<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
let mut f = File::open(path)?;
let mut buf: Vec<u8> = vec![];
f.read_to_end(&mut buf)?;
let ib: IntermediatBlob = bincode::deserialize(&buf)?;
self.hashers = bincode::deserialize(&ib.hashers)?;
self.hash_tables = bincode::deserialize(&ib.hash_tables)?;
self.n_hash_tables = ib.n_hash_tables;
self.n_projections = ib.n_projections;
self.dim = ib.dim;
self._seed = ib._seed;
Ok(())
}
pub fn dump<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let hash_tables = bincode::serialize(&self.hash_tables)?;
let hashers = bincode::serialize(&self.hashers)?;
let ib = IntermediatBlob {
hash_tables,
hashers,
n_hash_tables: self.n_hash_tables,
n_projections: self.n_projections,
dim: self.dim,
_seed: self._seed,
};
let mut f = File::create(path)?;
let blob = bincode::serialize(&ib)?;
f.write(&blob)?;
Ok(())
}
}