use alloc::vec::Vec;
use hashbrown::HashMap;
use smallvec::SmallVec;
use crate::classical::minhash::{MinHashSig, jaccard};
use crate::error::{Error, Result};
const CANDIDATE_INLINE: usize = 4;
pub struct LshIndex<const H: usize> {
bands: usize,
rows: usize,
tables: Vec<HashMap<u64, SmallVec<[u64; CANDIDATE_INLINE]>>>,
sigs: HashMap<u64, MinHashSig<H>>,
}
impl<const H: usize> LshIndex<H> {
pub fn with_bands_rows(bands: usize, rows: usize) -> Result<Self> {
if bands == 0 || rows == 0 {
return Err(Error::Config("bands and rows must be > 0".into()));
}
if bands * rows != H {
return Err(Error::Config(alloc::format!(
"bands * rows ({} * {} = {}) must equal H = {}",
bands,
rows,
bands * rows,
H,
)));
}
let mut tables = Vec::with_capacity(bands);
for _ in 0..bands {
tables.push(HashMap::new());
}
Ok(Self {
bands,
rows,
tables,
sigs: HashMap::new(),
})
}
#[inline]
#[must_use]
pub fn bands(&self) -> usize {
self.bands
}
#[inline]
#[must_use]
pub fn rows(&self) -> usize {
self.rows
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.sigs.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.sigs.is_empty()
}
#[inline]
#[must_use]
pub fn get(&self, id: u64) -> Option<&MinHashSig<H>> {
self.sigs.get(&id)
}
pub fn insert(&mut self, id: u64, sig: MinHashSig<H>) {
if self.sigs.contains_key(&id) {
self.remove(id);
}
for (band, table) in self.tables.iter_mut().enumerate() {
let key = band_key(&sig, band, self.rows);
table.entry(key).or_default().push(id);
}
self.sigs.insert(id, sig);
}
pub fn remove(&mut self, id: u64) -> Option<MinHashSig<H>> {
let sig = self.sigs.remove(&id)?;
for (band, table) in self.tables.iter_mut().enumerate() {
let key = band_key(&sig, band, self.rows);
if let Some(list) = table.get_mut(&key) {
list.retain(|v| *v != id);
if list.is_empty() {
table.remove(&key);
}
}
}
Some(sig)
}
#[must_use]
pub fn query(&self, sig: &MinHashSig<H>) -> Vec<u64> {
let mut out: Vec<u64> = Vec::new();
let mut seen: HashMap<u64, ()> = HashMap::new();
for (band, table) in self.tables.iter().enumerate() {
let key = band_key(sig, band, self.rows);
if let Some(list) = table.get(&key) {
for &id in list {
if seen.insert(id, ()).is_none() {
out.push(id);
}
}
}
}
out
}
#[must_use]
pub fn query_with_threshold(&self, sig: &MinHashSig<H>, threshold: f32) -> Vec<u64> {
let candidates = self.query(sig);
let threshold = threshold.clamp(0.0, 1.0);
candidates
.into_iter()
.filter(|id| {
self.sigs
.get(id)
.map(|other| jaccard(sig, other) >= threshold)
.unwrap_or(false)
})
.collect()
}
}
fn band_key<const H: usize>(sig: &MinHashSig<H>, band: usize, rows: usize) -> u64 {
let start = band * rows;
let end = start + rows;
debug_assert!(end <= H, "band slice out of range");
let slice = &sig.hashes[start..end];
let bytes = bytemuck::cast_slice::<u64, u8>(slice);
xxhash_rust::xxh3::xxh3_64(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::canonical::Canonicalizer;
use crate::classical::Fingerprinter;
use crate::classical::minhash::MinHashFingerprinter;
use crate::tokenize::{ShingleTokenizer, WordTokenizer};
fn make() -> LshIndex<128> {
LshIndex::<128>::with_bands_rows(16, 8).unwrap()
}
fn fp() -> MinHashFingerprinter<ShingleTokenizer<WordTokenizer>, 128> {
MinHashFingerprinter::<_, 128>::new(
Canonicalizer::default(),
ShingleTokenizer {
k: 5,
inner: WordTokenizer,
},
)
}
#[test]
fn rejects_mismatched_h() {
let r = LshIndex::<128>::with_bands_rows(7, 9);
assert!(matches!(r, Err(Error::Config(_))));
}
#[test]
fn rejects_zero_dimensions() {
let r = LshIndex::<128>::with_bands_rows(0, 128);
assert!(matches!(r, Err(Error::Config(_))));
let r = LshIndex::<128>::with_bands_rows(128, 0);
assert!(matches!(r, Err(Error::Config(_))));
}
#[test]
fn empty_index() {
let idx = make();
assert!(idx.is_empty());
assert_eq!(idx.len(), 0);
assert_eq!(idx.bands(), 16);
assert_eq!(idx.rows(), 8);
}
#[test]
fn insert_and_get() {
let mut idx = make();
let f = fp();
let s = f.fingerprint("the quick brown fox jumps").unwrap();
idx.insert(42, s);
assert_eq!(idx.len(), 1);
assert_eq!(idx.get(42), Some(&s));
assert!(idx.get(43).is_none());
}
#[test]
fn self_query_hits() {
let mut idx = make();
let f = fp();
let s = f
.fingerprint("the quick brown fox jumps over the lazy dog")
.unwrap();
idx.insert(7, s);
let neighbours = idx.query(&s);
assert_eq!(neighbours, alloc::vec![7]);
}
#[test]
fn near_duplicate_is_a_candidate() {
let mut idx = LshIndex::<128>::with_bands_rows(64, 2).unwrap();
let f = fp();
let s1 = f
.fingerprint("the quick brown fox jumps over the lazy dog at noon today")
.unwrap();
let s2 = f
.fingerprint("the quick brown fox jumps over the lazy dog at dusk today")
.unwrap();
idx.insert(1, s1);
idx.insert(2, s2);
let mut hits = idx.query(&s1);
hits.sort();
assert!(hits.contains(&1));
assert!(hits.contains(&2), "near-duplicate missed: {hits:?}");
}
#[test]
fn dissimilar_doc_does_not_collide() {
let mut idx = make();
let f = fp();
let s1 = f
.fingerprint("the quick brown fox jumps over the lazy dog")
.unwrap();
let s2 = f
.fingerprint("astronomers detect cosmic background radiation in space")
.unwrap();
idx.insert(1, s1);
idx.insert(2, s2);
let hits = idx.query(&s1);
assert!(hits.contains(&1));
assert!(!hits.contains(&2), "false positive: {hits:?}");
}
#[test]
fn dedup_repeat_inserts() {
let mut idx = make();
let f = fp();
let s = f.fingerprint("the quick brown fox").unwrap();
idx.insert(1, s);
idx.insert(1, s);
idx.insert(1, s);
assert_eq!(idx.len(), 1);
let hits = idx.query(&s);
assert_eq!(hits, alloc::vec![1]);
}
#[test]
fn replace_changes_signature() {
let mut idx = make();
let f = fp();
let s1 = f.fingerprint("alpha beta gamma delta epsilon").unwrap();
let s2 = f.fingerprint("zeta eta theta iota kappa").unwrap();
idx.insert(1, s1);
idx.insert(1, s2);
assert_eq!(idx.get(1), Some(&s2));
assert_eq!(idx.query(&s2), alloc::vec![1]);
let hits = idx.query(&s1);
assert!(!hits.contains(&1), "old bands not scrubbed: {hits:?}");
}
#[test]
fn remove_takes_signature_out() {
let mut idx = make();
let f = fp();
let s = f.fingerprint("the quick brown fox").unwrap();
idx.insert(1, s);
let removed = idx.remove(1);
assert_eq!(removed, Some(s));
assert!(idx.is_empty());
assert!(idx.query(&s).is_empty());
}
#[test]
fn remove_missing_returns_none() {
let mut idx = make();
assert!(idx.remove(99).is_none());
}
#[test]
fn threshold_filter_drops_far_candidates() {
let mut idx = make();
let f = fp();
let s1 = f
.fingerprint("the quick brown fox jumps over the lazy dog")
.unwrap();
let s2 = f
.fingerprint("the quick brown fox leaps over the lazy dog")
.unwrap();
idx.insert(1, s1);
idx.insert(2, s2);
let strict = idx.query_with_threshold(&s1, 0.95);
assert!(strict.contains(&1));
assert!(!strict.contains(&2));
}
}