use std::collections::{HashMap, HashSet};
use xxhash_rust::xxh3::xxh3_64;
pub struct LshIndex {
bands: usize,
rows_per_band: usize,
buckets: Vec<HashMap<u64, Vec<usize>>>,
num_items: usize,
}
impl LshIndex {
pub fn new(bands: usize, rows_per_band: usize) -> Self {
Self {
bands,
rows_per_band,
buckets: (0..bands).map(|_| HashMap::new()).collect(),
num_items: 0,
}
}
pub fn with_threshold(num_hashes: usize, threshold: f64) -> Self {
let (bands, rows) = select_lsh_params(num_hashes, threshold);
Self::new(bands, rows)
}
pub fn insert(&mut self, idx: usize, hash_values: &[u64]) {
for band in 0..self.bands {
let start = band * self.rows_per_band;
let end = (start + self.rows_per_band).min(hash_values.len());
if start >= hash_values.len() {
break;
}
let band_hash = hash_band(&hash_values[start..end], band);
self.buckets[band].entry(band_hash).or_default().push(idx);
}
self.num_items += 1;
}
pub fn insert_bulk(&mut self, items: &[Vec<u64>]) {
for (idx, hash_values) in items.iter().enumerate() {
self.insert(idx, hash_values);
}
}
pub fn query(&self, hash_values: &[u64]) -> Vec<usize> {
let mut candidates = HashSet::new();
for band in 0..self.bands {
let start = band * self.rows_per_band;
let end = (start + self.rows_per_band).min(hash_values.len());
if start >= hash_values.len() {
break;
}
let band_hash = hash_band(&hash_values[start..end], band);
if let Some(items) = self.buckets[band].get(&band_hash) {
for &item_idx in items {
candidates.insert(item_idx);
}
}
}
let mut result: Vec<usize> = candidates.into_iter().collect();
result.sort_unstable();
result
}
pub fn candidate_pairs(&self) -> Vec<(usize, usize)> {
let mut pairs = HashSet::new();
for band_buckets in &self.buckets {
for members in band_buckets.values() {
if members.len() < 2 {
continue;
}
for i in 0..members.len() {
for j in (i + 1)..members.len() {
let a = members[i].min(members[j]);
let b = members[i].max(members[j]);
pairs.insert((a, b));
}
}
}
}
let mut result: Vec<(usize, usize)> = pairs.into_iter().collect();
result.sort_unstable();
result
}
pub fn len(&self) -> usize {
self.num_items
}
pub fn is_empty(&self) -> bool {
self.num_items == 0
}
pub fn bands(&self) -> usize {
self.bands
}
pub fn rows_per_band(&self) -> usize {
self.rows_per_band
}
pub fn collision_probability(&self, similarity: f64) -> f64 {
let r = self.rows_per_band as f64;
let b = self.bands as f64;
1.0 - (1.0 - similarity.powf(r)).powf(b)
}
}
pub fn select_lsh_params(num_permutations: usize, threshold: f64) -> (usize, usize) {
let mut best_bands = 1;
let mut best_rows = num_permutations;
let mut best_error = f64::MAX;
let max_r = num_permutations.min(20);
for r in 1..=max_r {
let b = num_permutations / r;
if b == 0 {
continue;
}
let inflection = (1.0 / b as f64).powf(1.0 / r as f64);
let error = (inflection - threshold).abs();
if error < best_error {
best_error = error;
best_bands = b;
best_rows = r;
}
}
(best_bands, best_rows)
}
fn hash_band(values: &[u64], band_idx: usize) -> u64 {
let mut data = Vec::with_capacity(values.len() * 8 + 8);
data.extend_from_slice(&(band_idx as u64).to_le_bytes());
for &v in values {
data.extend_from_slice(&v.to_le_bytes());
}
xxh3_64(&data)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_identical_signatures_always_found() {
let sig: Vec<u64> = (0..128).collect();
let mut index = LshIndex::with_threshold(128, 0.5);
index.insert(0, &sig);
let candidates = index.query(&sig);
assert!(
candidates.contains(&0),
"Identical signature must always be found as candidate"
);
}
#[test]
fn test_very_different_signatures_rarely_found() {
let num_hashes = 128;
let mut index = LshIndex::with_threshold(num_hashes, 0.8);
let sigs: Vec<Vec<u64>> = (0..100)
.map(|i| {
(0..num_hashes as u64)
.map(|j| xxh3_64(&[i * 1000 + j].map(|v| v.to_le_bytes()).concat()))
.collect()
})
.collect();
for (i, sig) in sigs.iter().enumerate() {
index.insert(i, sig);
}
let candidates = index.query(&sigs[0]);
assert!(
candidates.len() < 50,
"Expected few false positives at high threshold, got {} candidates out of 100",
candidates.len()
);
}
#[test]
fn test_parameter_selection_reasonable() {
let (bands, rows) = select_lsh_params(128, 0.5);
assert!(bands > 0, "bands must be positive");
assert!(rows > 0, "rows must be positive");
assert!(
bands * rows <= 128,
"bands * rows must not exceed num_permutations"
);
let inflection = (1.0 / bands as f64).powf(1.0 / rows as f64);
assert!(
(inflection - 0.5).abs() < 0.15,
"Inflection point {inflection} should be near threshold 0.5"
);
}
#[test]
fn test_parameter_selection_various_thresholds() {
for threshold in [0.3, 0.5, 0.7, 0.9] {
let (bands, rows) = select_lsh_params(128, threshold);
assert!(bands > 0);
assert!(rows > 0);
assert!(bands * rows <= 128);
let inflection = (1.0 / bands as f64).powf(1.0 / rows as f64);
assert!(
(inflection - threshold).abs() < 0.2,
"threshold={threshold}, inflection={inflection}, bands={bands}, rows={rows}"
);
}
}
#[test]
fn test_empty_index() {
let index = LshIndex::new(16, 8);
assert!(index.is_empty());
assert_eq!(index.len(), 0);
let candidates = index.query(&[1, 2, 3, 4, 5, 6, 7, 8]);
assert!(candidates.is_empty());
}
#[test]
fn test_empty_signature() {
let mut index = LshIndex::new(4, 2);
index.insert(0, &[]);
assert_eq!(index.len(), 1);
let candidates = index.query(&[]);
assert!(candidates.is_empty());
}
#[test]
fn test_query_returns_sorted_deduplicated() {
let sig: Vec<u64> = (0..16).collect();
let mut index = LshIndex::new(4, 4);
index.insert(0, &sig);
index.insert(1, &sig); index.insert(2, &sig);
let candidates = index.query(&sig);
for w in candidates.windows(2) {
assert!(w[0] <= w[1], "Results must be sorted");
}
let unique: HashSet<usize> = candidates.iter().copied().collect();
assert_eq!(
unique.len(),
candidates.len(),
"Results must be deduplicated"
);
}
#[test]
fn test_candidate_pairs() {
let sig_a: Vec<u64> = (0..16).collect();
let sig_b: Vec<u64> = (0..16).collect(); let sig_c: Vec<u64> = (1000..1016).collect();
let mut index = LshIndex::new(4, 4);
index.insert(0, &sig_a);
index.insert(1, &sig_b);
index.insert(2, &sig_c);
let pairs = index.candidate_pairs();
assert!(
pairs.contains(&(0, 1)),
"Identical signatures should be candidate pair"
);
assert!(
!pairs.contains(&(0, 2)),
"Very different signatures should not be candidate pair"
);
assert!(
!pairs.contains(&(1, 2)),
"Very different signatures should not be candidate pair"
);
}
#[test]
fn test_collision_probability() {
let index = LshIndex::new(16, 8);
let prob_identical = index.collision_probability(1.0);
assert!(
(prob_identical - 1.0).abs() < f64::EPSILON,
"P(collision | s=1.0) should be 1.0, got {prob_identical}"
);
let prob_zero = index.collision_probability(0.0);
assert!(
prob_zero.abs() < f64::EPSILON,
"P(collision | s=0.0) should be ~0.0, got {prob_zero}"
);
let prob_mid = index.collision_probability(0.5);
assert!(
(0.0..=1.0).contains(&prob_mid),
"P(collision | s=0.5) should be in [0,1], got {prob_mid}"
);
}
#[test]
fn test_bulk_insert() {
let items: Vec<Vec<u64>> = (0..10).map(|i| vec![i; 16]).collect();
let mut index = LshIndex::new(4, 4);
index.insert_bulk(&items);
assert_eq!(index.len(), 10);
}
#[test]
fn test_short_signature_handled() {
let sig: Vec<u64> = vec![42, 43];
let mut index = LshIndex::new(16, 8);
index.insert(0, &sig);
let candidates = index.query(&sig);
assert!(
candidates.contains(&0),
"Short signature should still be findable"
);
}
}