use alloc::vec::Vec;
use core::hash::{BuildHasherDefault, Hasher};
use hashbrown::{HashMap, HashSet};
use smallvec::SmallVec;
use crate::classical::minhash::{MinHashSig, jaccard};
use crate::error::{Error, Result};
const CANDIDATE_INLINE: usize = 4;
#[derive(Default, Clone, Copy)]
struct U64IdentityHasher(u64);
impl Hasher for U64IdentityHasher {
#[inline]
fn finish(&self) -> u64 {
self.0
}
#[inline]
fn write_u64(&mut self, n: u64) {
self.0 = n;
}
#[inline]
fn write(&mut self, _bytes: &[u8]) {
debug_assert!(false, "U64IdentityHasher only accepts u64 keys");
}
}
type U64Hasher = BuildHasherDefault<U64IdentityHasher>;
type BandTable = HashMap<u64, SmallVec<[u64; CANDIDATE_INLINE]>, U64Hasher>;
pub struct LshIndex<const H: usize> {
bands: usize,
rows: usize,
tables: Vec<BandTable>,
sigs: HashMap<u64, MinHashSig<H>>,
}
impl<const H: usize> LshIndex<H> {
pub fn with_bands_rows(bands: usize, rows: usize) -> Result<Self> {
if bands == 0 || rows == 0 {
return Err(Error::Config("bands and rows must be > 0".into()));
}
if bands * rows != H {
return Err(Error::Config(alloc::format!(
"bands * rows ({} * {} = {}) must equal H = {}",
bands,
rows,
bands * rows,
H,
)));
}
let mut tables = Vec::with_capacity(bands);
for _ in 0..bands {
tables.push(BandTable::with_hasher(U64Hasher::default()));
}
Ok(Self {
bands,
rows,
tables,
sigs: HashMap::new(),
})
}
#[inline]
#[must_use]
pub fn bands(&self) -> usize {
self.bands
}
#[inline]
#[must_use]
pub fn rows(&self) -> usize {
self.rows
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.sigs.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.sigs.is_empty()
}
#[inline]
#[must_use]
pub fn get(&self, id: u64) -> Option<&MinHashSig<H>> {
self.sigs.get(&id)
}
pub fn insert(&mut self, id: u64, sig: MinHashSig<H>) {
if self.sigs.contains_key(&id) {
self.remove(id);
}
for (band, table) in self.tables.iter_mut().enumerate() {
let key = band_key(&sig, band, self.rows);
table.entry(key).or_default().push(id);
}
self.sigs.insert(id, sig);
}
#[cfg(feature = "parallel")]
#[cfg_attr(docsrs, doc(cfg(feature = "parallel")))]
pub fn extend_par<I>(&mut self, items: I)
where
I: IntoIterator<Item = (u64, MinHashSig<H>)>,
{
use rayon::prelude::*;
let items: alloc::vec::Vec<(u64, MinHashSig<H>)> = items.into_iter().collect();
for (id, sig) in &items {
debug_assert!(
!self.sigs.contains_key(id),
"LshIndex::extend_par: id {id} already exists; remove() first"
);
self.sigs.insert(*id, *sig);
}
let rows = self.rows;
let items_ref = items.as_slice();
self.tables
.par_iter_mut()
.enumerate()
.for_each(|(band, table)| {
for (id, sig) in items_ref {
let key = band_key(sig, band, rows);
table.entry(key).or_default().push(*id);
}
});
}
pub fn remove(&mut self, id: u64) -> Option<MinHashSig<H>> {
let sig = self.sigs.remove(&id)?;
for (band, table) in self.tables.iter_mut().enumerate() {
let key = band_key(&sig, band, self.rows);
if let Some(list) = table.get_mut(&key) {
list.retain(|v| *v != id);
if list.is_empty() {
table.remove(&key);
}
}
}
Some(sig)
}
#[must_use]
pub fn query(&self, sig: &MinHashSig<H>) -> Vec<u64> {
let mut seen: HashSet<u64> = HashSet::with_capacity(self.bands * 4);
let mut out: Vec<u64> = Vec::new();
for (band, table) in self.tables.iter().enumerate() {
let key = band_key(sig, band, self.rows);
if let Some(list) = table.get(&key) {
for &id in list {
if seen.insert(id) {
out.push(id);
}
}
}
}
out
}
#[must_use]
pub fn query_with_threshold(&self, sig: &MinHashSig<H>, threshold: f32) -> Vec<u64> {
let candidates = self.query(sig);
let threshold = threshold.clamp(0.0, 1.0);
candidates
.into_iter()
.filter(|id| {
self.sigs
.get(id)
.map(|other| jaccard(sig, other) >= threshold)
.unwrap_or(false)
})
.collect()
}
}
fn band_key<const H: usize>(sig: &MinHashSig<H>, band: usize, rows: usize) -> u64 {
let start = band * rows;
let end = start + rows;
debug_assert!(end <= H, "band slice out of range");
let slice = &sig.hashes[start..end];
let bytes = bytemuck::cast_slice::<u64, u8>(slice);
xxhash_rust::xxh3::xxh3_64(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::canonical::Canonicalizer;
use crate::classical::Fingerprinter;
use crate::classical::minhash::MinHashFingerprinter;
use crate::tokenize::{ShingleTokenizer, WordTokenizer};
fn make() -> LshIndex<128> {
LshIndex::<128>::with_bands_rows(16, 8).unwrap()
}
fn fp() -> MinHashFingerprinter<ShingleTokenizer<WordTokenizer>, 128> {
MinHashFingerprinter::<_, 128>::new(
Canonicalizer::default(),
ShingleTokenizer {
k: 5,
inner: WordTokenizer,
},
)
}
#[test]
fn rejects_mismatched_h() {
let r = LshIndex::<128>::with_bands_rows(7, 9);
assert!(matches!(r, Err(Error::Config(_))));
}
#[test]
fn rejects_zero_dimensions() {
let r = LshIndex::<128>::with_bands_rows(0, 128);
assert!(matches!(r, Err(Error::Config(_))));
let r = LshIndex::<128>::with_bands_rows(128, 0);
assert!(matches!(r, Err(Error::Config(_))));
}
#[test]
fn empty_index() {
let idx = make();
assert!(idx.is_empty());
assert_eq!(idx.len(), 0);
assert_eq!(idx.bands(), 16);
assert_eq!(idx.rows(), 8);
}
#[test]
fn insert_and_get() {
let mut idx = make();
let f = fp();
let s = f.fingerprint("the quick brown fox jumps").unwrap();
idx.insert(42, s);
assert_eq!(idx.len(), 1);
assert_eq!(idx.get(42), Some(&s));
assert!(idx.get(43).is_none());
}
#[test]
fn self_query_hits() {
let mut idx = make();
let f = fp();
let s = f
.fingerprint("the quick brown fox jumps over the lazy dog")
.unwrap();
idx.insert(7, s);
let neighbours = idx.query(&s);
assert_eq!(neighbours, alloc::vec![7]);
}
#[test]
fn near_duplicate_is_a_candidate() {
let mut idx = LshIndex::<128>::with_bands_rows(64, 2).unwrap();
let f = fp();
let s1 = f
.fingerprint("the quick brown fox jumps over the lazy dog at noon today")
.unwrap();
let s2 = f
.fingerprint("the quick brown fox jumps over the lazy dog at dusk today")
.unwrap();
idx.insert(1, s1);
idx.insert(2, s2);
let mut hits = idx.query(&s1);
hits.sort();
assert!(hits.contains(&1));
assert!(hits.contains(&2), "near-duplicate missed: {hits:?}");
}
#[test]
fn dissimilar_doc_does_not_collide() {
let mut idx = make();
let f = fp();
let s1 = f
.fingerprint("the quick brown fox jumps over the lazy dog")
.unwrap();
let s2 = f
.fingerprint("astronomers detect cosmic background radiation in space")
.unwrap();
idx.insert(1, s1);
idx.insert(2, s2);
let hits = idx.query(&s1);
assert!(hits.contains(&1));
assert!(!hits.contains(&2), "false positive: {hits:?}");
}
#[test]
fn dedup_repeat_inserts() {
let mut idx = make();
let f = fp();
let s = f.fingerprint("the quick brown fox").unwrap();
idx.insert(1, s);
idx.insert(1, s);
idx.insert(1, s);
assert_eq!(idx.len(), 1);
let hits = idx.query(&s);
assert_eq!(hits, alloc::vec![1]);
}
#[test]
fn replace_changes_signature() {
let mut idx = make();
let f = fp();
let s1 = f.fingerprint("alpha beta gamma delta epsilon").unwrap();
let s2 = f.fingerprint("zeta eta theta iota kappa").unwrap();
idx.insert(1, s1);
idx.insert(1, s2);
assert_eq!(idx.get(1), Some(&s2));
assert_eq!(idx.query(&s2), alloc::vec![1]);
let hits = idx.query(&s1);
assert!(!hits.contains(&1), "old bands not scrubbed: {hits:?}");
}
#[test]
fn remove_takes_signature_out() {
let mut idx = make();
let f = fp();
let s = f.fingerprint("the quick brown fox").unwrap();
idx.insert(1, s);
let removed = idx.remove(1);
assert_eq!(removed, Some(s));
assert!(idx.is_empty());
assert!(idx.query(&s).is_empty());
}
#[test]
fn remove_missing_returns_none() {
let mut idx = make();
assert!(idx.remove(99).is_none());
}
#[cfg(feature = "parallel")]
#[test]
fn extend_par_matches_serial_insert() {
let f = fp();
let docs: alloc::vec::Vec<alloc::string::String> = (0..200)
.map(|i| alloc::format!("the quick brown fox jumps over the lazy dog {i}"))
.collect();
let sigs: alloc::vec::Vec<_> = docs.iter().map(|d| f.fingerprint(d).unwrap()).collect();
let mut serial = make();
for (i, sig) in sigs.iter().enumerate() {
serial.insert(i as u64, *sig);
}
let mut parallel = make();
let pairs: alloc::vec::Vec<_> = sigs
.iter()
.enumerate()
.map(|(i, sig)| (i as u64, *sig))
.collect();
parallel.extend_par(pairs);
assert_eq!(parallel.len(), serial.len());
for i in 0..200u64 {
assert_eq!(parallel.get(i), serial.get(i));
let mut p = parallel.query(serial.get(i).unwrap());
let mut s = serial.query(serial.get(i).unwrap());
p.sort_unstable();
s.sort_unstable();
assert_eq!(p, s, "candidate set differs for id {i}");
}
}
#[test]
fn threshold_filter_drops_far_candidates() {
let mut idx = make();
let f = fp();
let s1 = f
.fingerprint("the quick brown fox jumps over the lazy dog")
.unwrap();
let s2 = f
.fingerprint("the quick brown fox leaps over the lazy dog")
.unwrap();
idx.insert(1, s1);
idx.insert(2, s2);
let strict = idx.query_with_threshold(&s1, 0.95);
assert!(strict.contains(&1));
assert!(!strict.contains(&2));
}
}