use faer::{MatRef, RowRef};
use num_traits::Float;
use rand::{prelude::*, rng};
use rand_distr::StandardNormal;
use rayon::prelude::*;
use rustc_hash::{FxHashMap, FxHashSet};
use std::{cell::RefCell, collections::BinaryHeap};
use thousands::*;
use crate::prelude::*;
use crate::utils::*;
thread_local! {
static LSH_CANDIDATES: RefCell<Vec<usize>> = const { RefCell::new(Vec::new()) };
static LSH_SEEN_SET: RefCell<FxHashSet<usize>> = RefCell::new(FxHashSet::default());
}
type TablesAndHashes = Vec<(FxHashMap<u64, Vec<usize>>, Vec<u64>)>;
pub struct LSHIndex<T> {
pub vectors_flat: Vec<T>,
pub dim: usize,
pub n: usize,
norms: Vec<T>,
metric: Dist,
hash_tables: Vec<FxHashMap<u64, Vec<usize>>>,
random_vecs: Vec<T>,
num_tables: usize,
bits_per_hash: usize,
vector_hashes: Vec<u64>,
original_ids: Vec<usize>,
}
impl<T> VectorDistance<T> for LSHIndex<T>
where
T: AnnSearchFloat,
{
fn vectors_flat(&self) -> &[T] {
&self.vectors_flat
}
fn dim(&self) -> usize {
self.dim
}
fn norms(&self) -> &[T] {
&self.norms
}
}
impl<T> LSHIndex<T>
where
T: AnnSearchFloat,
{
pub fn new(
data: MatRef<T>,
metric: Dist,
num_tables: usize,
bits_per_hash: usize,
seed: usize,
) -> Self {
let (vectors_flat, n, dim) = matrix_to_flat(data);
let norms = if metric == Dist::Cosine {
(0..n)
.map(|i| {
let start = i * dim;
T::calculate_l2_norm(&vectors_flat[start..start + dim])
})
.collect()
} else {
Vec::new()
};
let mut rng = StdRng::seed_from_u64(seed as u64);
let total_random_vecs = num_tables * bits_per_hash * dim;
let mut random_vecs: Vec<T> = (0..total_random_vecs)
.map(|_| {
let val: f64 = rng.sample(StandardNormal);
T::from_f64(val).unwrap()
})
.collect();
orthogonalise_table_projections(&mut random_vecs, num_tables, bits_per_hash, dim);
let normalised: Vec<T> = if metric == Dist::Euclidean {
let mut buf = vec![T::zero(); n * dim];
for i in 0..n {
let start = i * dim;
let norm = T::calculate_l2_norm(&vectors_flat[start..start + dim]);
if norm > T::epsilon() {
for d in 0..dim {
buf[start + d] = vectors_flat[start + d] / norm;
}
}
}
buf
} else {
Vec::new()
};
let hash_source: &[T] = if metric == Dist::Euclidean {
&normalised
} else {
&vectors_flat
};
let tables_and_hashes: TablesAndHashes = (0..num_tables)
.into_par_iter()
.map(|table_idx| {
let mut table = FxHashMap::default();
let mut hashes = Vec::with_capacity(n);
for vec_idx in 0..n {
let start = vec_idx * dim;
let vec = &hash_source[start..start + dim];
let hash = compute_hash(vec, table_idx, bits_per_hash, dim, &random_vecs);
hashes.push(hash);
table.entry(hash).or_insert_with(Vec::new).push(vec_idx);
}
(table, hashes)
})
.collect();
let mut vector_hashes = vec![0u64; num_tables * n];
let mut hash_tables = Vec::with_capacity(num_tables);
for (table_idx, (table, hashes)) in tables_and_hashes.into_iter().enumerate() {
hash_tables.push(table);
let base = table_idx * n;
vector_hashes[base..base + n].copy_from_slice(&hashes);
}
Self {
vectors_flat,
dim,
n,
norms,
metric,
hash_tables,
random_vecs,
num_tables,
bits_per_hash,
vector_hashes,
original_ids: (0..n).collect(),
}
}
pub fn query(
&self,
query_vec: &[T],
k: usize,
max_cand: Option<usize>,
n_probes: usize,
) -> (Vec<usize>, Vec<T>, bool) {
assert!(
query_vec.len() == self.dim,
"Query vector dimensionality mismatch"
);
let hash_vec;
let hash_input = if self.metric == Dist::Euclidean {
let norm = T::calculate_l2_norm(query_vec);
hash_vec = if norm > T::epsilon() {
query_vec.iter().map(|&v| v / norm).collect::<Vec<T>>()
} else {
vec![T::zero(); self.dim]
};
hash_vec.as_slice()
} else {
query_vec
};
LSH_CANDIDATES.with(|cand_cell| {
let mut candidates = cand_cell.borrow_mut();
candidates.clear();
let budget = max_cand.unwrap_or(self.n);
for table_idx in 0..self.num_tables {
if candidates.len() >= budget {
break;
}
let (hash, projections) = compute_hash_with_projections(
hash_input,
table_idx,
self.bits_per_hash,
self.dim,
&self.random_vecs,
);
if let Some(bucket) = self.hash_tables[table_idx].get(&hash) {
candidates.extend_from_slice(bucket);
}
if n_probes > 0 && candidates.len() < budget {
let probes =
generate_probes_ranked(hash, &projections, self.bits_per_hash, n_probes);
for probe_hash in probes {
if candidates.len() >= budget {
break;
}
if let Some(bucket) = self.hash_tables[table_idx].get(&probe_hash) {
candidates.extend_from_slice(bucket);
}
}
}
}
let fallback_triggered = candidates.is_empty();
if fallback_triggered {
let mut rng = rng();
let sample_size = 1000.min(self.n);
candidates.extend((0..self.n).choose_multiple(&mut rng, sample_size));
}
let (indices, dists) = self.rank_candidates(query_vec, &candidates, k);
(indices, dists, fallback_triggered)
})
}
pub fn query_row(
&self,
query_row: RowRef<T>,
k: usize,
max_cand: Option<usize>,
n_probes: usize,
) -> (Vec<usize>, Vec<T>, bool) {
assert!(
query_row.ncols() == self.dim,
"Query row dimensionality mismatch"
);
if query_row.col_stride() == 1 {
let slice =
unsafe { std::slice::from_raw_parts(query_row.as_ptr(), query_row.ncols()) };
return self.query(slice, k, max_cand, n_probes);
}
let query_vec: Vec<T> = query_row.iter().cloned().collect();
self.query(&query_vec, k, max_cand, n_probes)
}
pub fn generate_knn(
&self,
k: usize,
max_cand: Option<usize>,
n_probes: usize,
return_dist: bool,
verbose: bool,
) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>) {
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
let counter = Arc::new(AtomicUsize::new(0));
let results: Vec<(Vec<usize>, Vec<T>, bool)> = (0..self.n)
.into_par_iter()
.map(|vec_idx| {
if verbose {
let count = counter.fetch_add(1, Ordering::Relaxed) + 1;
if count.is_multiple_of(100_000) {
println!(
" Processed {} / {} samples.",
count.separate_with_underscores(),
self.n.separate_with_underscores()
);
}
}
self.self_query_at(vec_idx, k, max_cand, n_probes)
})
.collect();
#[allow(unused_variables)]
let mut missed: usize = 0;
for (_, _, fallback) in &results {
if *fallback {
missed += 1;
}
}
if (missed as f32) / (self.n as f32) >= 0.01 {
println!("More than 1% of samples were not represented in the buckets.");
println!("Please verify underlying data");
}
if return_dist {
let mut indices = Vec::with_capacity(results.len());
let mut distances = Vec::with_capacity(results.len());
for (idx, dist, _) in results {
indices.push(idx);
distances.push(dist);
}
(indices, Some(distances))
} else {
let indices: Vec<Vec<usize>> = results.into_iter().map(|(idx, _, _)| idx).collect();
(indices, None)
}
}
pub fn memory_usage_bytes(&self) -> usize {
let mut total = std::mem::size_of_val(self);
total += self.vectors_flat.capacity() * std::mem::size_of::<T>();
total += self.norms.capacity() * std::mem::size_of::<T>();
total += self.random_vecs.capacity() * std::mem::size_of::<T>();
total += self.vector_hashes.capacity() * std::mem::size_of::<u64>();
total += self.hash_tables.capacity() * std::mem::size_of::<FxHashMap<u64, Vec<usize>>>();
for table in &self.hash_tables {
total +=
table.capacity() * (std::mem::size_of::<u64>() + std::mem::size_of::<Vec<usize>>());
for indices in table.values() {
total += indices.capacity() * std::mem::size_of::<usize>();
}
}
total
}
pub fn num_bits(&self) -> usize {
self.bits_per_hash
}
fn self_query_at(
&self,
vec_idx: usize,
k: usize,
max_cand: Option<usize>,
n_probes: usize,
) -> (Vec<usize>, Vec<T>, bool) {
LSH_CANDIDATES.with(|cand_cell| {
let mut candidates = cand_cell.borrow_mut();
candidates.clear();
let budget = max_cand.unwrap_or(self.n);
for table_idx in 0..self.num_tables {
if candidates.len() >= budget {
break;
}
let hash = self.vector_hashes[table_idx * self.n + vec_idx];
if let Some(bucket) = self.hash_tables[table_idx].get(&hash) {
candidates.extend_from_slice(bucket);
}
if n_probes > 0 && candidates.len() < budget {
let probes = generate_probes_uniform(hash, self.bits_per_hash, n_probes);
for probe_hash in probes {
if candidates.len() >= budget {
break;
}
if let Some(bucket) = self.hash_tables[table_idx].get(&probe_hash) {
candidates.extend_from_slice(bucket);
}
}
}
}
let fallback_triggered = candidates.is_empty();
if fallback_triggered {
let mut rng = rng();
let sample_size = 1000.min(self.n);
candidates.extend((0..self.n).choose_multiple(&mut rng, sample_size));
}
let start = vec_idx * self.dim;
let query_vec = &self.vectors_flat[start..start + self.dim];
let (indices, dists) = self.rank_candidates(query_vec, &candidates, k);
(indices, dists, fallback_triggered)
})
}
fn rank_candidates(
&self,
query_vec: &[T],
candidates: &[usize],
k: usize,
) -> (Vec<usize>, Vec<T>) {
LSH_SEEN_SET.with(|seen_cell| {
let mut seen = seen_cell.borrow_mut();
seen.clear();
let mut heap: BinaryHeap<(OrderedFloat<T>, usize)> = BinaryHeap::with_capacity(k + 1);
match self.metric {
Dist::Euclidean => {
for &idx in candidates {
if seen.insert(idx) {
let d = self.euclidean_distance_to_query(idx, query_vec);
let item = (OrderedFloat(d), idx);
if heap.len() < k {
heap.push(item);
} else if item.0 < heap.peek().unwrap().0 {
heap.pop();
heap.push(item);
}
}
}
}
Dist::Cosine => {
let query_norm = T::calculate_l2_norm(query_vec);
for &idx in candidates {
if seen.insert(idx) {
let d = self.cosine_distance_to_query(idx, query_vec, query_norm);
let item = (OrderedFloat(d), idx);
if heap.len() < k {
heap.push(item);
} else if item.0 < heap.peek().unwrap().0 {
heap.pop();
heap.push(item);
}
}
}
}
}
let mut results: Vec<_> = heap.into_vec();
results.sort_unstable_by(|a, b| a.0.cmp(&b.0));
let indices = results.iter().map(|&(_, idx)| idx).collect();
let dists = results.iter().map(|&(OrderedFloat(d), _)| d).collect();
(indices, dists)
})
}
}
#[inline]
fn compute_hash<T>(
vec: &[T],
table_idx: usize,
bits_per_hash: usize,
dim: usize,
random_vecs: &[T],
) -> u64
where
T: AnnSearchFloat,
{
let mut hash: u64 = 0;
let random_base = table_idx * bits_per_hash * dim;
for bit_idx in 0..bits_per_hash {
let offset = random_base + bit_idx * dim;
let proj_vec = &random_vecs[offset..offset + dim];
let dot = T::dot_simd(vec, proj_vec);
if dot >= T::zero() {
hash |= 1u64 << bit_idx;
}
}
hash
}
#[inline]
fn compute_hash_with_projections<T>(
vec: &[T],
table_idx: usize,
bits_per_hash: usize,
dim: usize,
random_vecs: &[T],
) -> (u64, Vec<T>)
where
T: AnnSearchFloat,
{
let mut hash: u64 = 0;
let mut projections = Vec::with_capacity(bits_per_hash);
let random_base = table_idx * bits_per_hash * dim;
for bit_idx in 0..bits_per_hash {
let offset = random_base + bit_idx * dim;
let proj_vec = &random_vecs[offset..offset + dim];
let dot = T::dot_simd(vec, proj_vec);
projections.push(dot.abs());
if dot >= T::zero() {
hash |= 1u64 << bit_idx;
}
}
(hash, projections)
}
fn generate_probes_ranked<T: Float>(
base_hash: u64,
projections: &[T],
bits_per_hash: usize,
max_probes: usize,
) -> Vec<u64> {
let mut bit_order: Vec<usize> = (0..bits_per_hash).collect();
bit_order.sort_unstable_by(|&a, &b| {
projections[a]
.partial_cmp(&projections[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut probes = Vec::with_capacity(max_probes);
for &bit in &bit_order {
if probes.len() >= max_probes {
return probes;
}
probes.push(base_hash ^ (1u64 << bit));
}
for (i, &bit_i) in bit_order.iter().enumerate() {
for &bit_j in &bit_order[i + 1..] {
if probes.len() >= max_probes {
return probes;
}
probes.push(base_hash ^ (1u64 << bit_i) ^ (1u64 << bit_j));
}
}
probes
}
fn generate_probes_uniform(base_hash: u64, bits_per_hash: usize, max_probes: usize) -> Vec<u64> {
let mut probes = Vec::with_capacity(max_probes);
for bit in 0..bits_per_hash {
if probes.len() >= max_probes {
return probes;
}
probes.push(base_hash ^ (1u64 << bit));
}
for i in 0..bits_per_hash {
for j in (i + 1)..bits_per_hash {
if probes.len() >= max_probes {
return probes;
}
probes.push(base_hash ^ (1u64 << i) ^ (1u64 << j));
}
}
probes
}
fn orthogonalise_table_projections<T>(
vecs: &mut [T],
num_tables: usize,
bits_per_hash: usize,
dim: usize,
) where
T: Float,
{
for table_idx in 0..num_tables {
let base = table_idx * bits_per_hash * dim;
for i in 0..bits_per_hash {
let i_base = base + i * dim;
for j in 0..i {
let j_base = base + j * dim;
let mut dot = T::zero();
for d in 0..dim {
dot = dot + vecs[i_base + d] * vecs[j_base + d];
}
for d in 0..dim {
vecs[i_base + d] = vecs[i_base + d] - dot * vecs[j_base + d];
}
}
let mut norm_sq = T::zero();
for d in 0..dim {
norm_sq = norm_sq + vecs[i_base + d] * vecs[i_base + d];
}
let norm = norm_sq.sqrt();
if norm > T::epsilon() {
for d in 0..dim {
vecs[i_base + d] = vecs[i_base + d] / norm;
}
}
}
}
}
impl<T> KnnValidation<T> for LSHIndex<T>
where
T: AnnSearchFloat,
{
fn query_for_validation(&self, query_vec: &[T], k: usize) -> (Vec<usize>, Vec<T>) {
let (indices, dist, _) = self.query(query_vec, k, None, self.bits_per_hash);
(indices, dist)
}
fn n(&self) -> usize {
self.n
}
fn metric(&self) -> Dist {
self.metric
}
fn original_ids(&self) -> &[usize] {
&self.original_ids
}
}
#[cfg(test)]
mod tests {
use super::*;
use faer::Mat;
fn simple_test_data() -> Mat<f32> {
Mat::from_fn(5, 3, |i, j| match i {
0 => [1.0, 0.0, 0.0][j],
1 => [0.0, 1.0, 0.0][j],
2 => [0.0, 0.0, 1.0][j],
3 => [1.0, 1.0, 0.0][j],
4 => [0.5, 0.5, 0.7][j],
_ => 0.0,
})
}
#[test]
fn test_index_creation_euclidean() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
assert_eq!(index.n, 5);
assert_eq!(index.dim, 3);
assert_eq!(index.num_tables, 4);
assert_eq!(index.bits_per_hash, 8);
assert_eq!(index.vectors_flat.len(), 15);
assert_eq!(index.vector_hashes.len(), 4 * 5);
}
#[test]
fn test_index_creation_cosine() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Cosine, 4, 8, 42);
assert_eq!(index.n, 5);
assert_eq!(index.dim, 3);
assert_eq!(index.norms.len(), 5);
assert_eq!(index.vector_hashes.len(), 4 * 5);
}
#[test]
fn test_stored_hashes_match_recomputed() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let vec = &index.vectors_flat[0..index.dim];
let norm: f32 = vec.iter().map(|&v| v * v).sum::<f32>().sqrt();
let normalised: Vec<f32> = vec.iter().map(|&v| v / norm).collect();
let recomputed = compute_hash(
&normalised,
0,
index.bits_per_hash,
index.dim,
&index.random_vecs,
);
let stored = index.vector_hashes[0];
assert_eq!(stored, recomputed);
}
#[test]
fn test_basic_query_no_probes() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances, _) = index.query(&query, 3, None, 0);
assert!(!indices.is_empty());
assert!(indices.len() <= 3);
assert_eq!(indices.len(), distances.len());
assert!(indices.contains(&0));
for i in 1..distances.len() {
assert!(distances[i - 1] <= distances[i]);
}
}
#[test]
fn test_basic_query_with_probes() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances, _) = index.query(&query, 3, None, 8);
assert!(!indices.is_empty());
assert!(indices.len() <= 3);
assert_eq!(indices.len(), distances.len());
assert!(indices.contains(&0));
for i in 1..distances.len() {
assert!(distances[i - 1] <= distances[i]);
}
}
#[test]
fn test_multi_probe_finds_more_candidates() {
let n = 100;
let dim = 50;
let mat = Mat::from_fn(n, dim, |i, j| ((i * 7 + j * 13) % 100) as f32 / 100.0);
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 2, 12, 42);
let query = vec![0.5; dim];
let (idx_no_probe, _, _) = index.query(&query, 10, None, 0);
let (idx_probed, _, _) = index.query(&query, 10, None, 12);
assert!(idx_probed.len() >= idx_no_probe.len());
}
#[test]
fn test_query_cosine() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Cosine, 4, 8, 42);
let query = vec![2.0, 0.0, 0.0];
let (indices, distances, _) = index.query(&query, 2, None, 0);
assert!(!indices.is_empty());
assert!(indices.len() <= 2);
assert_eq!(indices.len(), distances.len());
}
#[test]
fn test_query_row() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let query_mat = Mat::from_fn(1, 3, |_, j| [1.0, 0.0, 0.0][j]);
let (indices, distances, _) = index.query_row(query_mat.row(0), 3, None, 0);
assert!(!indices.is_empty());
assert!(indices.len() <= 3);
assert_eq!(indices.len(), distances.len());
}
#[test]
fn test_max_cand_limit() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 10, 8, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices, _, _) = index.query(&query, 2, Some(3), 0);
assert!(indices.len() <= 2);
}
#[test]
fn test_k_larger_than_n() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances, _) = index.query(&query, 100, None, 0);
assert!(indices.len() <= 5);
assert_eq!(indices.len(), distances.len());
}
#[test]
#[should_panic(expected = "Query vector dimensionality mismatch")]
fn test_dimension_mismatch() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let query = vec![1.0, 0.0];
index.query(&query, 3, None, 0);
}
#[test]
#[should_panic(expected = "Query row dimensionality mismatch")]
fn test_query_row_dimension_mismatch() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let query_mat = Mat::from_fn(1, 2, |_, j| [1.0, 0.0][j]);
index.query_row(query_mat.row(0), 3, None, 0);
}
#[test]
fn test_fallback_mechanism() {
let mat = Mat::from_fn(10, 100, |i, j| if j == i * 10 { 1.0 } else { 0.0 });
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 2, 16, 42);
let query = vec![1.0; 100];
let (indices, distances, _) = index.query(&query, 3, None, 0);
assert!(!indices.is_empty());
assert_eq!(indices.len(), distances.len());
}
#[test]
fn test_deterministic_with_seed() {
let mat = simple_test_data();
let index1 = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let index2 = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices1, _, _) = index1.query(&query, 3, None, 4);
let (indices2, _, _) = index2.query(&query, 3, None, 4);
assert_eq!(indices1, indices2);
}
#[test]
fn test_f64_query() {
let mat = Mat::from_fn(3, 3, |i, j| if i == j { 1.0f64 } else { 0.0f64 });
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances, _) = index.query(&query, 2, None, 0);
assert!(!indices.is_empty());
assert!(indices.len() <= 2);
assert_eq!(indices.len(), distances.len());
}
#[test]
fn test_distances_sorted() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let query = vec![1.0, 0.0, 0.0];
let (_, distances, _) = index.query(&query, 5, None, 8);
for i in 1..distances.len() {
assert!(distances[i - 1] <= distances[i]);
}
}
#[test]
fn test_query_returns_k_or_fewer() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 8, 6, 42);
let query = vec![1.0, 0.0, 0.0];
for k in 1..=5 {
let (indices, distances, _) = index.query(&query, k, None, 0);
assert!(indices.len() <= k);
assert_eq!(indices.len(), distances.len());
}
}
#[test]
fn test_no_duplicate_results() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 8, 6, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices, _, _) = index.query(&query, 5, None, 6);
let mut sorted = indices.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(indices.len(), sorted.len(), "Results contain duplicates");
}
#[test]
fn test_no_duplicate_results_with_probes() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 8, 6, 42);
let query = vec![0.5, 0.5, 0.5];
let (indices, _, _) = index.query(&query, 5, None, 10);
let mut sorted = indices.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(
indices.len(),
sorted.len(),
"Results contain duplicates with multi-probe"
);
}
#[test]
fn test_self_query_uses_stored_hashes() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let (indices, distances, _) = index.self_query_at(0, 3, None, 0);
assert!(!indices.is_empty());
assert!(indices.len() <= 3);
assert_eq!(indices.len(), distances.len());
assert!(indices.contains(&0));
}
#[test]
fn test_generate_knn() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let (knn_indices, knn_dists) = index.generate_knn(2, None, 4, true, false);
assert_eq!(knn_indices.len(), 5);
assert!(knn_dists.is_some());
let dists = knn_dists.unwrap();
assert_eq!(dists.len(), 5);
for i in 0..5 {
assert!(!knn_indices[i].is_empty());
assert!(knn_indices[i].len() <= 2);
assert_eq!(knn_indices[i].len(), dists[i].len());
}
}
#[test]
fn test_generate_knn_no_distances() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let (knn_indices, knn_dists) = index.generate_knn(2, None, 0, false, false);
assert_eq!(knn_indices.len(), 5);
assert!(knn_dists.is_none());
}
#[test]
fn test_larger_dataset() {
let n = 1000;
let dim = 50;
let mat = Mat::from_fn(n, dim, |i, j| ((i * 7 + j * 13) % 100) as f32 / 100.0);
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 6, 10, 42);
let query = vec![0.5; dim];
let (indices, distances, _) = index.query(&query, 10, None, 10);
assert!(!indices.is_empty());
assert!(indices.len() <= 10);
assert_eq!(indices.len(), distances.len());
for &idx in &indices {
assert!(idx < n);
}
}
#[test]
fn test_probe_generation_ranked() {
let projections = vec![0.1f32, 0.9, 0.3, 0.5];
let base_hash = 0b1010u64;
let probes = generate_probes_ranked(base_hash, &projections, 4, 6);
assert_eq!(probes[0], base_hash ^ (1u64 << 0));
assert_eq!(probes[1], base_hash ^ (1u64 << 2));
assert!(probes.len() <= 6);
}
#[test]
fn test_probe_generation_uniform() {
let base_hash = 0b101u64;
let probes = generate_probes_uniform(base_hash, 3, 10);
assert_eq!(probes[0], base_hash ^ (1u64 << 0));
assert_eq!(probes[1], base_hash ^ (1u64 << 1));
assert_eq!(probes[2], base_hash ^ (1u64 << 2));
assert_eq!(probes[3], base_hash ^ (1u64 << 0) ^ (1u64 << 1));
assert_eq!(probes[4], base_hash ^ (1u64 << 0) ^ (1u64 << 2));
assert_eq!(probes[5], base_hash ^ (1u64 << 1) ^ (1u64 << 2));
assert_eq!(probes.len(), 6);
}
#[test]
fn test_probe_respects_max() {
let probes = generate_probes_uniform(0u64, 16, 5);
assert_eq!(probes.len(), 5);
}
#[test]
fn test_memory_usage_includes_hashes() {
let mat = simple_test_data();
let index = LSHIndex::new(mat.as_ref(), Dist::Euclidean, 4, 8, 42);
let mem = index.memory_usage_bytes();
assert!(mem >= 4 * 5 * std::mem::size_of::<u64>());
}
}