#![allow(dead_code)]
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct LshBucket {
items: Vec<u64>,
}
impl LshBucket {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, id: u64) {
if !self.items.contains(&id) {
self.items.push(id);
}
}
#[must_use]
pub fn size(&self) -> usize {
self.items.len()
}
#[must_use]
pub fn items(&self) -> &[u64] {
&self.items
}
}
#[derive(Debug, Clone)]
pub struct BucketStats {
pub bucket_count: usize,
pub avg_size: f64,
pub max_size: usize,
pub total_items: usize,
}
impl BucketStats {
#[must_use]
pub fn avg_size(&self) -> f64 {
self.avg_size
}
#[must_use]
pub fn max_size(&self) -> usize {
self.max_size
}
}
#[derive(Debug)]
pub struct LshIndex {
num_tables: usize,
bits_per_table: usize,
projections: Vec<Vec<Vec<f32>>>,
tables: Vec<HashMap<u64, LshBucket>>,
dim: usize,
}
impl LshIndex {
#[must_use]
pub fn new(dim: usize, num_tables: usize, bits_per_table: usize, seed: u64) -> Self {
let projections = Self::generate_projections(dim, num_tables, bits_per_table, seed);
let tables = vec![HashMap::new(); num_tables];
Self {
num_tables,
bits_per_table,
projections,
tables,
dim,
}
}
#[allow(clippy::cast_precision_loss)]
fn generate_projections(
dim: usize,
num_tables: usize,
bits: usize,
seed: u64,
) -> Vec<Vec<Vec<f32>>> {
let mut state = seed.wrapping_add(1);
let lcg_next = |s: &mut u64| -> f32 {
*s = s
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let val = (*s >> 11) as f32 / (1u64 << 53) as f32;
val * 2.0 - 1.0
};
(0..num_tables)
.map(|_| {
(0..bits)
.map(|_| (0..dim).map(|_| lcg_next(&mut state)).collect())
.collect()
})
.collect()
}
#[allow(clippy::cast_precision_loss)]
fn bucket_key(&self, vec: &[f32], table_idx: usize) -> u64 {
let mut key = 0u64;
for (bit_idx, proj) in self.projections[table_idx].iter().enumerate() {
let dot: f32 = vec.iter().zip(proj.iter()).map(|(a, b)| a * b).sum();
if dot >= 0.0 {
key |= 1u64 << bit_idx;
}
}
key
}
pub fn insert(&mut self, id: u64, vec: &[f32]) {
assert_eq!(
vec.len(),
self.dim,
"Vector dimensionality mismatch: expected {}, got {}",
self.dim,
vec.len()
);
for t in 0..self.num_tables {
let key = self.bucket_key(vec, t);
self.tables[t].entry(key).or_default().insert(id);
}
}
#[must_use]
pub fn query(&self, vec: &[f32]) -> Vec<u64> {
assert_eq!(
vec.len(),
self.dim,
"Vector dimensionality mismatch: expected {}, got {}",
self.dim,
vec.len()
);
let mut candidates = std::collections::HashSet::new();
for t in 0..self.num_tables {
let key = self.bucket_key(vec, t);
if let Some(bucket) = self.tables[t].get(&key) {
for &id in bucket.items() {
candidates.insert(id);
}
}
}
let mut result: Vec<u64> = candidates.into_iter().collect();
result.sort_unstable();
result
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn approximate_neighbors(&self, vec: &[f32], k: usize) -> Vec<u64> {
let mut candidates = self.query(vec);
candidates.truncate(k);
candidates
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn bucket_stats(&self) -> BucketStats {
let all_sizes: Vec<usize> = self
.tables
.iter()
.flat_map(|table| table.values().map(LshBucket::size))
.collect();
let bucket_count = all_sizes.len();
let total_items: usize = all_sizes.iter().sum();
let max_size = all_sizes.iter().copied().max().unwrap_or(0);
let avg_size = if bucket_count == 0 {
0.0
} else {
total_items as f64 / bucket_count as f64
};
BucketStats {
bucket_count,
avg_size,
max_size,
total_items,
}
}
#[must_use]
pub fn num_tables(&self) -> usize {
self.num_tables
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
}
#[derive(Debug)]
pub struct BitLshIndex {
num_tables: usize,
bits_per_table: usize,
bit_masks: Vec<Vec<u8>>,
tables: Vec<HashMap<u64, Vec<(u64, u64)>>>,
}
impl BitLshIndex {
#[must_use]
pub fn new(num_tables: usize, bits_per_table: usize, seed: u64) -> Self {
let bits_per_table = bits_per_table.min(64);
let bit_masks = Self::generate_bit_masks(num_tables, bits_per_table, seed);
let tables = vec![HashMap::new(); num_tables];
Self {
num_tables,
bits_per_table,
bit_masks,
tables,
}
}
fn generate_bit_masks(num_tables: usize, bits: usize, seed: u64) -> Vec<Vec<u8>> {
let mut state = seed.wrapping_add(1);
let lcg_next = |s: &mut u64| -> u64 {
*s = s
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
*s
};
(0..num_tables)
.map(|_| {
let mut positions: Vec<u8> = Vec::with_capacity(bits);
while positions.len() < bits {
let candidate = (lcg_next(&mut state) % 64) as u8;
if !positions.contains(&candidate) {
positions.push(candidate);
}
}
positions.sort_unstable();
positions
})
.collect()
}
fn bucket_key(&self, hash: u64, table_idx: usize) -> u64 {
let mut key = 0u64;
for (i, &bit_pos) in self.bit_masks[table_idx].iter().enumerate() {
if hash & (1u64 << bit_pos) != 0 {
key |= 1u64 << i;
}
}
key
}
pub fn insert(&mut self, id: u64, hash: u64) {
for t in 0..self.num_tables {
let key = self.bucket_key(hash, t);
self.tables[t].entry(key).or_default().push((id, hash));
}
}
#[must_use]
pub fn query_candidates(&self, hash: u64) -> Vec<(u64, u64)> {
let mut seen = std::collections::HashSet::new();
let mut candidates = Vec::new();
for t in 0..self.num_tables {
let key = self.bucket_key(hash, t);
if let Some(bucket) = self.tables[t].get(&key) {
for &(id, h) in bucket {
if seen.insert(id) {
candidates.push((id, h));
}
}
}
}
candidates
}
#[must_use]
pub fn find_near_duplicates(&self, max_distance: u32) -> Vec<(u64, u64, u32)> {
let mut all_items: Vec<(u64, u64)> = Vec::new();
let mut seen_ids = std::collections::HashSet::new();
for table in &self.tables {
for bucket in table.values() {
for &(id, hash) in bucket {
if seen_ids.insert(id) {
all_items.push((id, hash));
}
}
}
}
let mut pairs = std::collections::HashSet::new();
let mut results = Vec::new();
for &(id, hash) in &all_items {
let candidates = self.query_candidates(hash);
for (cid, chash) in candidates {
if cid == id {
continue;
}
let (lo, hi) = if id < cid { (id, cid) } else { (cid, id) };
if pairs.insert((lo, hi)) {
let dist = (hash ^ chash).count_ones();
if dist <= max_distance {
results.push((lo, hi, dist));
}
}
}
}
results
}
#[must_use]
pub fn num_tables(&self) -> usize {
self.num_tables
}
#[must_use]
pub fn bits_per_table(&self) -> usize {
self.bits_per_table
}
}
#[derive(Debug, Clone)]
pub struct LshDedupResult {
pub pairs: Vec<(u64, u64, u32)>,
pub candidates_checked: usize,
pub total_items: usize,
}
impl LshDedupResult {
#[must_use]
pub fn comparison_ratio(&self) -> f64 {
let n = self.total_items;
if n < 2 {
return 0.0;
}
let full_pairs = n * (n - 1) / 2;
if full_pairs == 0 {
return 0.0;
}
self.candidates_checked as f64 / full_pairs as f64
}
#[must_use]
pub fn num_pairs(&self) -> usize {
self.pairs.len()
}
}
#[must_use]
pub fn lsh_dedup_pass(
hashes: &[(u64, u64)],
max_distance: u32,
num_tables: usize,
bits_per_table: usize,
seed: u64,
) -> LshDedupResult {
if hashes.len() < 2 {
return LshDedupResult {
pairs: Vec::new(),
candidates_checked: 0,
total_items: hashes.len(),
};
}
let mut index = BitLshIndex::new(num_tables, bits_per_table, seed);
for &(id, hash) in hashes {
index.insert(id, hash);
}
let mut seen_pairs = std::collections::HashSet::new();
let mut results = Vec::new();
let mut candidates_checked: usize = 0;
for &(id, hash) in hashes {
let candidates = index.query_candidates(hash);
for (cid, chash) in candidates {
if cid == id {
continue;
}
let (lo, hi) = if id < cid { (id, cid) } else { (cid, id) };
if seen_pairs.insert((lo, hi)) {
candidates_checked += 1;
let dist = (hash ^ chash).count_ones();
if dist <= max_distance {
results.push((lo, hi, dist));
}
}
}
}
LshDedupResult {
pairs: results,
candidates_checked,
total_items: hashes.len(),
}
}
#[must_use]
pub fn group_by_lsh_pairs(pairs: &[(u64, u64, u32)], all_ids: &[u64]) -> Vec<Vec<u64>> {
use std::collections::HashMap;
if pairs.is_empty() {
return Vec::new();
}
let mut parent: HashMap<u64, u64> = HashMap::new();
for &id in all_ids {
parent.insert(id, id);
}
fn find(parent: &mut HashMap<u64, u64>, x: u64) -> u64 {
let p = parent.get(&x).copied().unwrap_or(x);
if p == x {
return x;
}
let root = find(parent, p);
parent.insert(x, root);
root
}
for &(a, b, _) in pairs {
let ra = find(&mut parent, a);
let rb = find(&mut parent, b);
if ra != rb {
parent.insert(ra, rb);
}
}
let mut groups: HashMap<u64, Vec<u64>> = HashMap::new();
for &id in all_ids {
let root = find(&mut parent, id);
groups.entry(root).or_default().push(id);
}
groups.into_values().filter(|g| g.len() > 1).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn unit_vec(dim: usize, hot: usize) -> Vec<f32> {
let mut v = vec![0.0f32; dim];
v[hot % dim] = 1.0;
v
}
#[test]
fn test_lsh_bucket_size_empty() {
let b = LshBucket::new();
assert_eq!(b.size(), 0);
}
#[test]
fn test_lsh_bucket_insert_and_size() {
let mut b = LshBucket::new();
b.insert(1);
b.insert(2);
b.insert(1); assert_eq!(b.size(), 2);
}
#[test]
fn test_lsh_bucket_items() {
let mut b = LshBucket::new();
b.insert(42);
b.insert(99);
assert!(b.items().contains(&42));
assert!(b.items().contains(&99));
}
#[test]
fn test_lsh_index_creation() {
let idx = LshIndex::new(8, 4, 6, 42);
assert_eq!(idx.dim(), 8);
assert_eq!(idx.num_tables(), 4);
}
#[test]
fn test_lsh_index_insert_and_query_self() {
let mut idx = LshIndex::new(4, 3, 4, 7);
let v = vec![1.0f32, 0.0, 0.0, 0.0];
idx.insert(1, &v);
let results = idx.query(&v);
assert!(results.contains(&1));
}
#[test]
fn test_lsh_query_returns_sorted() {
let mut idx = LshIndex::new(4, 2, 4, 13);
let v = vec![1.0f32, 1.0, 1.0, 1.0];
idx.insert(5, &v);
idx.insert(3, &v);
idx.insert(7, &v);
let results = idx.query(&v);
let mut sorted = results.clone();
sorted.sort_unstable();
assert_eq!(results, sorted);
}
#[test]
fn test_lsh_similar_vectors_in_same_bucket() {
let mut idx = LshIndex::new(8, 6, 6, 99);
let v1 = vec![1.0f32, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0];
let v2 = vec![1.0f32, 1.0, 1.0, 0.9, 0.0, 0.0, 0.0, 0.0]; idx.insert(10, &v1);
idx.insert(11, &v2);
let results = idx.query(&v1);
assert!(results.contains(&10));
}
#[test]
fn test_lsh_approximate_neighbors_k_limit() {
let mut idx = LshIndex::new(4, 2, 4, 17);
let v = vec![1.0f32, 0.0, 0.0, 0.0];
for i in 0..20u64 {
idx.insert(i, &v);
}
let results = idx.approximate_neighbors(&v, 5);
assert!(results.len() <= 5);
}
#[test]
fn test_bucket_stats_empty() {
let idx = LshIndex::new(4, 3, 4, 0);
let stats = idx.bucket_stats();
assert_eq!(stats.bucket_count, 0);
assert_eq!(stats.max_size(), 0);
assert_eq!(stats.avg_size(), 0.0);
}
#[test]
fn test_bucket_stats_after_inserts() {
let mut idx = LshIndex::new(4, 2, 4, 55);
let v = vec![0.5f32, 0.5, 0.5, 0.5];
idx.insert(1, &v);
idx.insert(2, &v);
let stats = idx.bucket_stats();
assert!(stats.bucket_count > 0);
assert!(stats.max_size() >= 1);
assert!(stats.avg_size() > 0.0);
}
#[test]
fn test_unit_vectors_different_dimensions() {
let mut idx = LshIndex::new(8, 4, 5, 77);
for i in 0..8u64 {
let v = unit_vec(8, i as usize);
idx.insert(i, &v);
}
assert_eq!(idx.dim(), 8);
}
#[test]
fn test_insert_multiple_tables() {
let mut idx = LshIndex::new(4, 5, 4, 11);
let v = vec![0.1f32, 0.2, 0.3, 0.4];
idx.insert(100, &v);
let r = idx.query(&v);
assert!(r.contains(&100));
}
#[test]
fn test_bucket_stats_avg_max() {
let stats = BucketStats {
bucket_count: 3,
avg_size: 2.5,
max_size: 5,
total_items: 7,
};
assert_eq!(stats.avg_size(), 2.5);
assert_eq!(stats.max_size(), 5);
}
#[test]
fn test_bit_lsh_creation() {
let idx = BitLshIndex::new(4, 8, 42);
assert_eq!(idx.num_tables(), 4);
assert_eq!(idx.bits_per_table(), 8);
}
#[test]
fn test_bit_lsh_insert_and_self_query() {
let mut idx = BitLshIndex::new(4, 8, 42);
let hash = 0xDEAD_BEEF_CAFE_BABEu64;
idx.insert(1, hash);
let candidates = idx.query_candidates(hash);
assert!(candidates.iter().any(|(id, _)| *id == 1));
}
#[test]
fn test_bit_lsh_identical_hashes_found() {
let mut idx = BitLshIndex::new(6, 10, 99);
let hash = 0x1234_5678_9ABC_DEF0u64;
idx.insert(1, hash);
idx.insert(2, hash);
let pairs = idx.find_near_duplicates(0);
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0], (1, 2, 0));
}
#[test]
fn test_bit_lsh_near_duplicates_within_distance() {
let mut idx = BitLshIndex::new(8, 6, 77);
let base = 0xFFFF_FFFF_FFFF_FFFFu64;
let similar = base ^ 0b111;
idx.insert(10, base);
idx.insert(11, similar);
let pairs = idx.find_near_duplicates(5);
assert!(!pairs.is_empty());
let found = pairs.iter().any(|&(a, b, d)| a == 10 && b == 11 && d == 3);
assert!(found, "Should find pair with distance 3");
}
#[test]
fn test_bit_lsh_distant_hashes_not_paired() {
let mut idx = BitLshIndex::new(4, 16, 55);
idx.insert(1, 0x0000_0000_0000_0000);
idx.insert(2, 0xFFFF_FFFF_FFFF_FFFF);
let pairs = idx.find_near_duplicates(5);
assert!(pairs.is_empty());
}
#[test]
fn test_bit_lsh_many_items() {
let mut idx = BitLshIndex::new(4, 8, 42);
for i in 0..100u64 {
idx.insert(i, i);
}
let candidates = idx.query_candidates(50);
assert!(candidates.iter().any(|(id, _)| *id == 50));
}
#[test]
fn test_bit_lsh_deduplicated_candidates() {
let mut idx = BitLshIndex::new(8, 6, 42);
let hash = 0xAAAA_BBBB_CCCC_DDDDu64;
idx.insert(1, hash);
let candidates = idx.query_candidates(hash);
let count_1 = candidates.iter().filter(|(id, _)| *id == 1).count();
assert_eq!(count_1, 1);
}
#[test]
fn test_lsh_dedup_pass_empty() {
let result = lsh_dedup_pass(&[], 5, 4, 8, 42);
assert!(result.pairs.is_empty());
assert_eq!(result.total_items, 0);
}
#[test]
fn test_lsh_dedup_pass_single() {
let result = lsh_dedup_pass(&[(1, 0xDEAD)], 5, 4, 8, 42);
assert!(result.pairs.is_empty());
assert_eq!(result.total_items, 1);
}
#[test]
fn test_lsh_dedup_pass_identical() {
let hash = 0xDEAD_BEEF_CAFE_BABEu64;
let hashes = vec![(1, hash), (2, hash), (3, hash)];
let result = lsh_dedup_pass(&hashes, 0, 6, 8, 42);
assert_eq!(result.pairs.len(), 3);
for &(_, _, d) in &result.pairs {
assert_eq!(d, 0);
}
}
#[test]
fn test_lsh_dedup_pass_near_duplicates() {
let base = 0xFFFF_FFFF_FFFF_FFFFu64;
let similar = base ^ 0b111; let hashes = vec![(10, base), (20, similar)];
let result = lsh_dedup_pass(&hashes, 5, 8, 6, 77);
assert!(!result.pairs.is_empty(), "Should find near-duplicate pair");
let found = result
.pairs
.iter()
.any(|&(a, b, d)| a == 10 && b == 20 && d == 3);
assert!(found);
}
#[test]
fn test_lsh_dedup_pass_distant_not_paired() {
let hashes = vec![(1, 0x0000_0000_0000_0000u64), (2, 0xFFFF_FFFF_FFFF_FFFFu64)];
let result = lsh_dedup_pass(&hashes, 5, 4, 16, 55);
assert!(result.pairs.is_empty());
}
#[test]
fn test_lsh_dedup_pass_comparison_ratio() {
let hash = 0xABCDu64;
let hashes: Vec<(u64, u64)> = (0..100).map(|i| (i, hash)).collect();
let result = lsh_dedup_pass(&hashes, 0, 4, 8, 42);
assert!(result.comparison_ratio() <= 1.0);
assert!(result.comparison_ratio() > 0.0);
}
#[test]
fn test_lsh_dedup_result_num_pairs() {
let result = LshDedupResult {
pairs: vec![(1, 2, 0), (1, 3, 1)],
candidates_checked: 5,
total_items: 3,
};
assert_eq!(result.num_pairs(), 2);
}
#[test]
fn test_group_by_lsh_pairs_empty() {
let groups = group_by_lsh_pairs(&[], &[1, 2, 3]);
assert!(groups.is_empty());
}
#[test]
fn test_group_by_lsh_pairs_single_pair() {
let pairs = vec![(1, 2, 0)];
let groups = group_by_lsh_pairs(&pairs, &[1, 2, 3]);
assert_eq!(groups.len(), 1);
let g = &groups[0];
assert!(g.contains(&1));
assert!(g.contains(&2));
assert!(!g.contains(&3));
}
#[test]
fn test_group_by_lsh_pairs_transitive() {
let pairs = vec![(1, 2, 3), (2, 3, 2)];
let groups = group_by_lsh_pairs(&pairs, &[1, 2, 3]);
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].len(), 3);
}
#[test]
fn test_group_by_lsh_pairs_two_groups() {
let pairs = vec![(1, 2, 0), (3, 4, 1)];
let groups = group_by_lsh_pairs(&pairs, &[1, 2, 3, 4, 5]);
assert_eq!(groups.len(), 2);
}
#[test]
fn test_lsh_dedup_pass_many_items_sparse() {
let hashes: Vec<(u64, u64)> = (0..200).map(|i| (i, i * 0x0101_0101_0101_0101)).collect();
let result = lsh_dedup_pass(&hashes, 3, 4, 12, 42);
assert_eq!(result.total_items, 200);
}
}