use crate::distance;
use crate::RetrieveError;
use sketchir::cross_polytope::{self, CrossPolytopeHasher};
use std::collections::HashMap;
#[derive(Clone, Debug)]
pub struct LSHParams {
pub num_tables: usize,
pub num_probes: usize,
pub seed: Option<u64>,
}
impl Default for LSHParams {
fn default() -> Self {
Self {
num_tables: 8,
num_probes: 4,
seed: None,
}
}
}
#[derive(Debug)]
pub struct CrossPolytopeLSHIndex {
vectors: Vec<f32>,
dimension: usize,
num_vectors: usize,
params: LSHParams,
hashers: Vec<CrossPolytopeHasher>,
tables: Vec<HashMap<u32, Vec<u32>>>,
built: bool,
}
impl CrossPolytopeLSHIndex {
pub fn new(dimension: usize, params: LSHParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be > 0".into(),
));
}
if params.num_tables == 0 {
return Err(RetrieveError::InvalidParameter(
"num_tables must be > 0".into(),
));
}
if params.num_probes == 0 {
return Err(RetrieveError::InvalidParameter(
"num_probes must be > 0".into(),
));
}
Ok(Self {
vectors: Vec::new(),
dimension,
num_vectors: 0,
params,
hashers: Vec::new(),
tables: Vec::new(),
built: false,
})
}
pub fn add_vectors(&mut self, vectors: &[f32]) -> Result<(), RetrieveError> {
if !vectors.len().is_multiple_of(self.dimension) {
return Err(RetrieveError::DimensionMismatch {
query_dim: vectors.len(),
doc_dim: self.dimension,
});
}
self.vectors.extend_from_slice(vectors);
self.num_vectors += vectors.len() / self.dimension;
self.built = false;
Ok(())
}
pub fn insert(&mut self, vector: &[f32]) -> Result<u32, RetrieveError> {
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dimension,
});
}
let id = self.num_vectors as u32;
self.vectors.extend_from_slice(vector);
self.num_vectors += 1;
if self.built {
for (table_idx, hasher) in self.hashers.iter().enumerate() {
if let Ok(bucket) = hasher.hash(vector) {
self.tables[table_idx].entry(bucket).or_default().push(id);
}
}
}
Ok(id)
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let base_seed = self.params.seed.unwrap_or_else(|| {
use rand::RngCore;
rand::rng().next_u64()
});
self.hashers =
cross_polytope::multi_hasher(self.dimension, self.params.num_tables, base_seed)
.map_err(|e| RetrieveError::InvalidParameter(format!("sketchir: {e}")))?;
self.tables = vec![HashMap::new(); self.params.num_tables];
for vec_idx in 0..self.num_vectors {
let start = vec_idx * self.dimension;
let vec = &self.vectors[start..start + self.dimension];
for (table_idx, hasher) in self.hashers.iter().enumerate() {
if let Ok(bucket) = hasher.hash(vec) {
self.tables[table_idx]
.entry(bucket)
.or_default()
.push(vec_idx as u32);
}
}
}
self.built = true;
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter("index not built".into()));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let mut candidates = Vec::new();
let mut seen = vec![false; self.num_vectors];
for (table_idx, hasher) in self.hashers.iter().enumerate() {
let probes = hasher
.hash_ranked(query, self.params.num_probes)
.unwrap_or_default();
for bucket_id in probes {
if let Some(ids) = self.tables[table_idx].get(&bucket_id) {
for &id in ids {
if !seen[id as usize] {
seen[id as usize] = true;
candidates.push(id);
}
}
}
}
}
let mut results: Vec<(u32, f32)> = candidates
.iter()
.map(|&id| {
let dist = distance::l2_distance(query, self.get_vector(id as usize));
(id, dist)
})
.collect();
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
results.truncate(k);
Ok(results)
}
#[cfg(test)]
fn hash_vector(&self, vector: &[f32], table_idx: usize) -> u32 {
self.hashers[table_idx].hash(vector).unwrap_or(0)
}
#[inline]
fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
&self.vectors[start..start + self.dimension]
}
pub fn stats(&self) -> LSHStats {
let total_entries: usize = self
.tables
.iter()
.map(|t| t.values().map(|v| v.len()).sum::<usize>())
.sum();
let num_buckets: usize = self.tables.iter().map(|t| t.len()).sum();
LSHStats {
num_vectors: self.num_vectors,
num_tables: self.params.num_tables,
num_probes: self.params.num_probes,
num_occupied_buckets: num_buckets,
avg_bucket_size: if num_buckets > 0 {
total_entries as f32 / num_buckets as f32
} else {
0.0
},
}
}
}
#[derive(Debug, Clone)]
pub struct LSHStats {
pub num_vectors: usize,
pub num_tables: usize,
pub num_probes: usize,
pub num_occupied_buckets: usize,
pub avg_bucket_size: f32,
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn clustered_data(n_clusters: usize, points_per_cluster: usize, dim: usize) -> Vec<f32> {
use rand::prelude::*;
let mut rng = StdRng::seed_from_u64(42);
let mut data = Vec::new();
for c in 0..n_clusters {
let center: Vec<f32> = (0..dim)
.map(|_| (c as f32) * 10.0 + rng.random::<f32>())
.collect();
for _ in 0..points_per_cluster {
for val in ¢er {
data.push(val + rng.random::<f32>() * 0.5);
}
}
}
data
}
fn brute_force_knn(data: &[f32], dim: usize, query: &[f32], k: usize) -> Vec<(usize, f32)> {
let n = data.len() / dim;
let mut dists: Vec<(usize, f32)> = (0..n)
.map(|i| {
let v = &data[i * dim..(i + 1) * dim];
(i, distance::l2_distance(query, v))
})
.collect();
dists.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
dists.truncate(k);
dists
}
#[test]
fn test_build_and_search() {
let dim = 16;
let data = clustered_data(5, 40, dim);
let n = data.len() / dim;
let params = LSHParams {
num_tables: 12,
num_probes: 6,
seed: Some(42),
};
let mut index = CrossPolytopeLSHIndex::new(dim, params).unwrap();
index.add_vectors(&data).unwrap();
index.build().unwrap();
assert_eq!(index.num_vectors, n);
let query = &data[0..dim];
let results = index.search(query, 5).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, 0, "should find the query point itself");
assert!(results[0].1 < 1e-6, "self-distance should be ~0");
}
#[test]
fn test_recall() {
let dim = 16;
let data = clustered_data(5, 100, dim);
let params = LSHParams {
num_tables: 16,
num_probes: 8,
seed: Some(42),
};
let mut index = CrossPolytopeLSHIndex::new(dim, params).unwrap();
index.add_vectors(&data).unwrap();
index.build().unwrap();
use rand::prelude::*;
let mut rng = StdRng::seed_from_u64(123);
let n = data.len() / dim;
let num_queries = 30;
let k = 10;
let mut total_recall = 0.0;
for _ in 0..num_queries {
let query_idx = rng.random_range(0..n);
let query = &data[query_idx * dim..(query_idx + 1) * dim];
let results = index.search(query, k).unwrap();
let gt = brute_force_knn(&data, dim, query, k);
let gt_set: std::collections::HashSet<u32> =
gt.iter().map(|&(id, _)| id as u32).collect();
let result_set: std::collections::HashSet<u32> =
results.iter().map(|&(id, _)| id).collect();
let hits = gt_set.intersection(&result_set).count();
total_recall += hits as f32 / k as f32;
}
let avg_recall = total_recall / num_queries as f32;
assert!(
avg_recall > 0.3,
"LSH recall too low: {:.2}% (expected >30%)",
avg_recall * 100.0
);
}
#[test]
fn test_multiprobe_improves_recall() {
let dim = 16;
let data = clustered_data(5, 80, dim);
let k = 10;
use rand::prelude::*;
let mut rng = StdRng::seed_from_u64(123);
let n = data.len() / dim;
let queries: Vec<usize> = (0..20).map(|_| rng.random_range(0..n)).collect();
let params1 = LSHParams {
num_tables: 8,
num_probes: 1,
seed: Some(42),
};
let mut idx1 = CrossPolytopeLSHIndex::new(dim, params1).unwrap();
idx1.add_vectors(&data).unwrap();
idx1.build().unwrap();
let params4 = LSHParams {
num_tables: 8,
num_probes: 6,
seed: Some(42),
};
let mut idx4 = CrossPolytopeLSHIndex::new(dim, params4).unwrap();
idx4.add_vectors(&data).unwrap();
idx4.build().unwrap();
let mut recall1 = 0.0;
let mut recall4 = 0.0;
for &qi in &queries {
let query = &data[qi * dim..(qi + 1) * dim];
let gt = brute_force_knn(&data, dim, query, k);
let gt_set: std::collections::HashSet<u32> =
gt.iter().map(|&(id, _)| id as u32).collect();
let r1 = idx1.search(query, k).unwrap();
let r4 = idx4.search(query, k).unwrap();
let s1: std::collections::HashSet<u32> = r1.iter().map(|&(id, _)| id).collect();
let s4: std::collections::HashSet<u32> = r4.iter().map(|&(id, _)| id).collect();
recall1 += gt_set.intersection(&s1).count() as f32 / k as f32;
recall4 += gt_set.intersection(&s4).count() as f32 / k as f32;
}
recall1 /= queries.len() as f32;
recall4 /= queries.len() as f32;
assert!(
recall4 >= recall1,
"Multiprobe recall ({:.2}%) should be >= single probe ({:.2}%)",
recall4 * 100.0,
recall1 * 100.0
);
}
#[test]
fn test_online_insert() {
let dim = 8;
let params = LSHParams {
num_tables: 4,
num_probes: 2,
seed: Some(42),
};
let mut index = CrossPolytopeLSHIndex::new(dim, params).unwrap();
let initial: Vec<f32> = (0..20).flat_map(|i| vec![i as f32; dim]).collect();
index.add_vectors(&initial).unwrap();
index.build().unwrap();
let new_vec = vec![5.5; dim];
let new_id = index.insert(&new_vec).unwrap();
let results = index.search(&new_vec, 3).unwrap();
assert!(
results.iter().any(|&(id, _)| id == new_id),
"newly inserted vector should be found"
);
}
#[test]
fn test_hash_determinism() {
let dim = 8;
let params = LSHParams {
num_tables: 4,
num_probes: 2,
seed: Some(999),
};
let data: Vec<f32> = (0..80).map(|i| (i as f32) * 0.1).collect();
let mut idx1 = CrossPolytopeLSHIndex::new(dim, params.clone()).unwrap();
idx1.add_vectors(&data).unwrap();
idx1.build().unwrap();
let mut idx2 = CrossPolytopeLSHIndex::new(dim, params).unwrap();
idx2.add_vectors(&data).unwrap();
idx2.build().unwrap();
let query = &data[0..dim];
let r1 = idx1.search(query, 5).unwrap();
let r2 = idx2.search(query, 5).unwrap();
assert_eq!(r1, r2, "same seed should produce identical results");
}
#[test]
fn test_similar_vectors_hash_together() {
let dim = 32;
let params = LSHParams {
num_tables: 1,
num_probes: 1,
seed: Some(42),
};
let mut index = CrossPolytopeLSHIndex::new(dim, params).unwrap();
let v1: Vec<f32> = (0..dim).map(|i| i as f32).collect();
let v2: Vec<f32> = (0..dim).map(|i| i as f32 + 0.001).collect();
let v3: Vec<f32> = (0..dim).map(|i| -(i as f32) * 10.0).collect();
let mut all = v1.clone();
all.extend(&v2);
all.extend(&v3);
index.add_vectors(&all).unwrap();
index.build().unwrap();
let h1 = index.hash_vector(&v1, 0);
let h2 = index.hash_vector(&v2, 0);
let h3 = index.hash_vector(&v3, 0);
assert_eq!(h1, h2, "similar vectors should hash to same bucket");
let _ = h3; }
#[test]
fn test_cross_polytope_vertex_basic() {
let hasher = CrossPolytopeHasher::new(3, 0).unwrap();
let v = vec![1.0_f32, 0.0, 0.0];
let bucket = hasher.hash(&v).unwrap();
assert!(bucket < 6, "bucket should be in 0..2*dim");
}
#[test]
fn test_hadamard_search_recall() {
let dim = 128;
let n_clusters = 5;
let points_per_cluster = 60;
use rand::prelude::*;
let mut rng = StdRng::seed_from_u64(42);
let mut data = Vec::new();
for c in 0..n_clusters {
let center: Vec<f32> = (0..dim)
.map(|_| (c as f32) * 5.0 + rng.random::<f32>())
.collect();
for _ in 0..points_per_cluster {
for val in ¢er {
data.push(val + rng.random::<f32>() * 0.3);
}
}
}
let n = data.len() / dim;
let params = LSHParams {
num_tables: 16,
num_probes: 8,
seed: Some(42),
};
let mut index = CrossPolytopeLSHIndex::new(dim, params).unwrap();
index.add_vectors(&data).unwrap();
index.build().unwrap();
let mut total_recall = 0.0;
let num_queries = 20;
let k = 10;
for qi in 0..num_queries {
let query = &data[qi * dim..(qi + 1) * dim];
let results = index.search(query, k).unwrap();
let gt = brute_force_knn(&data, dim, query, k);
let gt_set: std::collections::HashSet<u32> =
gt.iter().map(|&(id, _)| id as u32).collect();
let result_set: std::collections::HashSet<u32> =
results.iter().map(|&(id, _)| id).collect();
total_recall += gt_set.intersection(&result_set).count() as f32 / k as f32;
}
let avg_recall = total_recall / num_queries as f32;
assert!(
avg_recall > 0.3,
"Hadamard LSH recall too low: {:.1}% on {}d data (n={})",
avg_recall * 100.0,
dim,
n
);
}
#[test]
fn test_empty_index() {
let params = LSHParams::default();
let mut index = CrossPolytopeLSHIndex::new(4, params).unwrap();
assert!(index.build().is_err());
}
#[test]
fn test_dimension_mismatch() {
let params = LSHParams::default();
let mut index = CrossPolytopeLSHIndex::new(4, params).unwrap();
index.add_vectors(&[1.0, 2.0, 3.0, 4.0]).unwrap();
index.build().unwrap();
let result = index.search(&[1.0, 2.0], 1);
assert!(result.is_err());
}
#[test]
fn test_stats() {
let dim = 8;
let params = LSHParams {
num_tables: 4,
num_probes: 2,
seed: Some(42),
};
let data: Vec<f32> = (0..240).map(|i| (i as f32) * 0.1).collect();
let mut index = CrossPolytopeLSHIndex::new(dim, params).unwrap();
index.add_vectors(&data).unwrap();
index.build().unwrap();
let stats = index.stats();
assert_eq!(stats.num_vectors, 30);
assert_eq!(stats.num_tables, 4);
assert!(stats.num_occupied_buckets > 0);
assert!(stats.avg_bucket_size > 0.0);
}
}