#![cfg_attr(coverage_nightly, coverage(off))]
use blake3::Hasher;
use std::collections::{HashMap, HashSet};
use super::types::{FragmentId, MinHashSignature};
#[derive(Debug, Clone)]
pub struct LshIndex {
bands: usize,
rows_per_band: usize,
buckets: Vec<HashMap<u64, Vec<FragmentId>>>,
signatures: HashMap<FragmentId, MinHashSignature>,
}
impl LshIndex {
#[must_use]
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn new(num_bands: usize, rows_per_band: usize) -> Self {
let buckets = (0..num_bands).map(|_| HashMap::new()).collect();
Self {
bands: num_bands,
rows_per_band,
buckets,
signatures: HashMap::new(),
}
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn insert(&mut self, fragment_id: FragmentId, signature: MinHashSignature) {
self.signatures.insert(fragment_id, signature.clone());
for (band_idx, band) in signature.values.chunks(self.rows_per_band).enumerate() {
if band_idx >= self.bands {
break;
}
let band_hash = self.hash_band(band);
self.buckets[band_idx]
.entry(band_hash)
.or_default()
.push(fragment_id);
}
}
#[must_use]
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn query(&self, query: &MinHashSignature) -> HashSet<FragmentId> {
let mut candidates = HashSet::new();
for (band_idx, band) in query.values.chunks(self.rows_per_band).enumerate() {
if band_idx >= self.bands {
break;
}
let band_hash = self.hash_band(band);
if let Some(bucket) = self.buckets[band_idx].get(&band_hash) {
candidates.extend(bucket.iter().copied());
}
}
candidates
}
#[must_use]
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn find_similar(&self, query: &MinHashSignature, threshold: f64) -> Vec<(FragmentId, f64)> {
let candidates = self.query(query);
candidates
.into_iter()
.filter_map(|id| {
self.signatures.get(&id).map(|sig| {
let similarity = query.jaccard_similarity(sig);
(id, similarity)
})
})
.filter(|(_, sim)| *sim >= threshold)
.collect()
}
#[must_use]
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn len(&self) -> usize {
self.signatures.len()
}
#[must_use]
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn is_empty(&self) -> bool {
self.signatures.is_empty()
}
#[must_use]
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn get_signature(&self, fragment_id: FragmentId) -> Option<&MinHashSignature> {
self.signatures.get(&fragment_id)
}
fn hash_band(&self, band: &[u64]) -> u64 {
let mut hasher = Hasher::new();
for &val in band {
hasher.update(&val.to_le_bytes());
}
let hash_bytes = hasher.finalize();
u64::from_le_bytes(
hash_bytes.as_bytes()[0..8]
.try_into()
.expect("blake3 hash always has at least 8 bytes"),
)
}
#[must_use]
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn collision_probability(&self, jaccard_similarity: f64) -> f64 {
let s = jaccard_similarity;
let r = self.rows_per_band as f64;
let b = self.bands as f64;
1.0 - (1.0 - s.powf(r)).powf(b)
}
}