use crate::data::expr::{eval_bytecode, eval_bytecode_pred, Bytecode};
use crate::data::tuple::Tuple;
use crate::fts::tokenizer::TextAnalyzer;
use crate::fts::TokenizerConfig;
use crate::runtime::relation::RelationHandle;
use crate::runtime::transact::SessionTx;
use crate::{DataValue, Expr, SourceSpan, Symbol};
use itertools::Itertools;
use miette::{bail, miette, Result};
use quadrature::integrate;
use rand::{thread_rng, RngCore};
use rustc_hash::FxHashSet;
use smartstring::{LazyCompact, SmartString};
use std::cmp::min;
use std::hash::{Hash, Hasher};
use twox_hash::XxHash32;
impl<'a> SessionTx<'a> {
pub(crate) fn del_lsh_index_item(
&mut self,
tuple: &[DataValue],
bytes: Option<Vec<Vec<u8>>>,
idx_handle: &RelationHandle,
inv_idx_handle: &RelationHandle,
) -> Result<()> {
let bytes = match bytes {
None => {
if let Some(mut found) = inv_idx_handle.get_val_only(self, &tuple[..inv_idx_handle.metadata.keys.len()])? {
let inv_key = inv_idx_handle.encode_key_for_store(tuple, Default::default())?;
self.store_tx.del(&inv_key)?;
match found.pop() {
Some(DataValue::List(l)) => l
.into_iter()
.map(|chunk| match chunk {
DataValue::Bytes(b) => b,
_ => unreachable!(),
})
.collect_vec(),
_ => unreachable!(),
}
} else {
return Ok(());
}
}
Some(b) => b,
};
let mut key = Vec::with_capacity(idx_handle.metadata.keys.len());
key.push(DataValue::Bot);
key.extend_from_slice(tuple);
for chunk in bytes {
key[0] = DataValue::Bytes(chunk);
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
self.store_tx.del(&key_bytes)?;
}
Ok(())
}
pub(crate) fn put_lsh_index_item(
&mut self,
tuple: &[DataValue],
extractor: &[Bytecode],
stack: &mut Vec<DataValue>,
tokenizer: &TextAnalyzer,
rel_handle: &RelationHandle,
idx_handle: &RelationHandle,
inv_idx_handle: &RelationHandle,
manifest: &MinHashLshIndexManifest,
hash_perms: &HashPermutations,
) -> Result<()> {
if let Some(mut found) =
inv_idx_handle.get_val_only(self, &tuple[..rel_handle.metadata.keys.len()])?
{
let bytes = match found.pop() {
Some(DataValue::List(l)) => l
.into_iter()
.map(|chunk| match chunk {
DataValue::Bytes(b) => b,
_ => unreachable!(),
})
.collect_vec(),
_ => unreachable!(),
};
self.del_lsh_index_item(tuple, Some(bytes), idx_handle, inv_idx_handle)?;
}
let to_index = eval_bytecode(extractor, tuple, stack)?;
let min_hash = match to_index {
DataValue::Null => return Ok(()),
DataValue::List(l) => HashValues::new(l.iter(), hash_perms),
DataValue::Str(s) => {
let n_grams = tokenizer.unique_ngrams(&s, manifest.n_gram);
HashValues::new(n_grams.iter(), hash_perms)
}
_ => bail!("Cannot put value {:?} into a LSH index", to_index),
};
let bytes = min_hash.get_bytes();
let chunk_size = manifest.n_rows_in_band * std::mem::size_of::<u32>();
let chunks = (0..manifest.n_bands)
.map(|i| {
let mut byte_range = bytes[i * chunk_size..(i + 1) * chunk_size].to_vec();
byte_range.extend_from_slice(&(i as u16).to_le_bytes());
byte_range
})
.collect_vec();
let inv_key_part = &tuple[..rel_handle.metadata.keys.len()];
let mut key = Vec::with_capacity(bytes.len() + 1);
key.push(DataValue::Bot);
key.extend_from_slice(inv_key_part);
for chunk in chunks.iter() {
key[0] = DataValue::Bytes(chunk.clone());
let key_bytes = idx_handle.encode_key_for_store(&key, Default::default())?;
self.store_tx.put(&key_bytes, &[])?;
}
let inv_val_part = vec![DataValue::List(
chunks.into_iter().map(DataValue::Bytes).collect_vec(),
)];
let inv_key = inv_idx_handle.encode_key_for_store(inv_key_part, Default::default())?;
let inv_val =
inv_idx_handle.encode_val_only_for_store(&inv_val_part, Default::default())?;
self.store_tx.put(&inv_key, &inv_val)?;
Ok(())
}
pub(crate) fn lsh_search(
&self,
q: &DataValue,
config: &LshSearch,
stack: &mut Vec<DataValue>,
filter_code: &Option<(Vec<Bytecode>, SourceSpan)>,
perms: &HashPermutations,
tokenizer: &TextAnalyzer,
) -> Result<Vec<Tuple>> {
let bytes = match q {
DataValue::Null => {
return Ok(vec![]);
}
DataValue::List(l) => HashValues::new(l.iter(), perms).get_bytes().to_vec(),
DataValue::Str(s) => {
let n_grams = tokenizer.unique_ngrams(s, config.manifest.n_gram);
HashValues::new(n_grams.iter(), perms).get_bytes().to_vec()
}
_ => bail!("Cannot search for value {:?} in a LSH index", q),
};
let chunk_size = config.manifest.n_rows_in_band * std::mem::size_of::<u32>();
let mut key_prefix = Vec::with_capacity(1);
let mut found_tuples: FxHashSet<_> = FxHashSet::default();
let early_stopper = if filter_code.is_some() {
None
} else {
config.k
};
for (i, chunk) in bytes.chunks_exact(chunk_size).enumerate() {
key_prefix.clear();
let mut chunk = chunk.to_vec();
chunk.extend_from_slice(&(i as u16).to_le_bytes());
key_prefix.push(DataValue::Bytes(chunk));
for ks in config.idx_handle.scan_prefix(self, &key_prefix) {
let ks = ks?;
let key_part = &ks[1..];
found_tuples.insert(key_part.to_vec());
if let Some(k) = early_stopper {
if found_tuples.len() >= k {
break;
}
}
}
}
let mut ret = vec![];
for key in found_tuples {
let orig_tuple = config
.base_handle
.get(self, &key)?
.ok_or_else(|| miette!("Tuple not found in base LSH relation"))?;
if let Some((filter_code, span)) = filter_code {
if !eval_bytecode_pred(filter_code, &orig_tuple, stack, *span)? {
continue;
}
}
ret.push(orig_tuple);
if let Some(k) = config.k {
if ret.len() >= k {
break;
}
}
}
Ok(ret)
}
}
#[derive(Clone, Debug)]
pub(crate) struct LshSearch {
pub(crate) base_handle: RelationHandle,
pub(crate) idx_handle: RelationHandle,
pub(crate) manifest: MinHashLshIndexManifest,
pub(crate) bindings: Vec<Symbol>,
pub(crate) k: Option<usize>,
pub(crate) query: Symbol,
pub(crate) filter: Option<Expr>,
pub(crate) span: SourceSpan,
}
impl LshSearch {
pub(crate) fn all_bindings(&self) -> impl Iterator<Item = &Symbol> {
self.bindings.iter()
}
}
pub(crate) struct HashValues(pub(crate) Vec<u32>);
pub(crate) struct HashPermutations(pub(crate) Vec<u32>);
#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct MinHashLshIndexManifest {
pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>,
pub(crate) extractor: String,
pub(crate) n_gram: usize,
pub(crate) tokenizer: TokenizerConfig,
pub(crate) filters: Vec<TokenizerConfig>,
pub(crate) num_perm: usize,
pub(crate) n_bands: usize,
pub(crate) n_rows_in_band: usize,
pub(crate) threshold: f64,
pub(crate) perms: Vec<u8>,
}
impl MinHashLshIndexManifest {
pub(crate) fn get_hash_perms(&self) -> HashPermutations {
HashPermutations::from_bytes(&self.perms)
}
}
#[derive(Clone, Debug)]
pub(crate) struct LshParams {
pub b: usize,
pub r: usize,
}
#[derive(Clone)]
pub(crate) struct Weights(pub(crate) f64, pub(crate) f64);
const _ALLOWED_INTEGRATE_ERR: f64 = 0.001;
impl LshParams {
pub fn find_optimal_params(threshold: f64, num_perm: usize, weights: &Weights) -> LshParams {
let Weights(false_positive_weight, false_negative_weight) = weights;
let mut min_error = f64::INFINITY;
let mut opt = LshParams { b: 0, r: 0 };
for b in 1..num_perm + 1 {
let max_r = num_perm / b;
for r in 1..max_r + 1 {
let false_pos = LshParams::false_positive_probability(threshold, b, r);
let false_neg = LshParams::false_negative_probability(threshold, b, r);
let error = false_pos * false_positive_weight + false_neg * false_negative_weight;
if error < min_error {
min_error = error;
opt = LshParams { b, r };
}
}
}
opt
}
fn false_positive_probability(threshold: f64, b: usize, r: usize) -> f64 {
let _probability = |s| -> f64 { 1. - f64::powf(1. - f64::powi(s, r as i32), b as f64) };
integrate(_probability, 0.0, threshold, _ALLOWED_INTEGRATE_ERR).integral
}
fn false_negative_probability(threshold: f64, b: usize, r: usize) -> f64 {
let _probability =
|s| -> f64 { 1. - (1. - f64::powf(1. - f64::powi(s, r as i32), b as f64)) };
integrate(_probability, threshold, 1.0, _ALLOWED_INTEGRATE_ERR).integral
}
}
impl HashPermutations {
pub(crate) fn new(n_perms: usize) -> Self {
let mut rng = thread_rng();
let mut perms = Vec::with_capacity(n_perms);
for _ in 0..n_perms {
perms.push(rng.next_u32());
}
Self(perms)
}
pub(crate) fn as_bytes(&self) -> &[u8] {
unsafe {
std::slice::from_raw_parts(
self.0.as_ptr() as *const u8,
self.0.len() * std::mem::size_of::<u32>(),
)
}
}
pub(crate) fn from_bytes(bytes: &[u8]) -> Self {
unsafe {
let ptr = bytes.as_ptr() as *const u32;
let len = bytes.len() / std::mem::size_of::<u32>();
let perms = std::slice::from_raw_parts(ptr, len);
Self(perms.to_vec())
}
}
}
impl HashValues {
pub(crate) fn new<T: Hash>(values: impl Iterator<Item = T>, perms: &HashPermutations) -> Self {
let mut ret = Self::init(perms);
ret.update(values, perms);
ret
}
pub(crate) fn init(perms: &HashPermutations) -> Self {
Self(vec![u32::MAX; perms.0.len()])
}
pub(crate) fn update<T: Hash>(
&mut self,
values: impl Iterator<Item = T>,
perms: &HashPermutations,
) {
for v in values {
for (i, seed) in perms.0.iter().enumerate() {
let mut hasher = XxHash32::with_seed(*seed);
v.hash(&mut hasher);
let hash = hasher.finish() as u32;
self.0[i] = min(self.0[i], hash);
}
}
}
#[cfg(test)]
pub(crate) fn jaccard(&self, other_minhash: &Self) -> f32 {
let matches = self
.0
.iter()
.zip_eq(&other_minhash.0)
.filter(|(left, right)| left == right)
.count();
let result = matches as f32 / self.0.len() as f32;
result
}
pub(crate) fn get_bytes(&self) -> &[u8] {
unsafe {
std::slice::from_raw_parts(
self.0.as_ptr() as *const u8,
self.0.len() * std::mem::size_of::<u32>(),
)
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_minhash() {
let perms = HashPermutations::new(20000);
let mut m1 = HashValues::new([1, 2, 3, 4, 5, 6].iter(), &perms);
let mut m2 = HashValues::new([4, 3, 2, 1, 5, 6].iter(), &perms);
assert_eq!(m1.0, m2.0);
assert_eq!(m1.jaccard(&m2), 1.0);
m1.update([7, 8, 9].iter(), &perms);
assert!(m1.jaccard(&m2) < 1.0);
println!("{:?}", m1.jaccard(&m2));
m2.update([17, 18, 19].iter(), &perms);
assert!(m1.jaccard(&m2) < 1.0);
println!("{:?}", m1.jaccard(&m2));
assert_eq!(perms.0, HashPermutations::from_bytes(perms.as_bytes()).0);
}
}