use anyhow::{anyhow, Result};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, info};
use crate::simd::cosine_similarity_simd;
use crate::types::SearchResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LshConfig {
pub num_tables: usize,
pub num_bits: usize,
pub num_probes: usize,
pub seed: u64,
}
impl Default for LshConfig {
fn default() -> Self {
Self {
num_tables: 10,
num_bits: 16,
num_probes: 3,
seed: 42,
}
}
}
impl LshConfig {
pub fn high_recall() -> Self {
Self {
num_tables: 20,
num_bits: 20,
num_probes: 10,
seed: 42,
}
}
pub fn fast() -> Self {
Self {
num_tables: 5,
num_bits: 12,
num_probes: 1,
seed: 42,
}
}
pub fn memory_efficient() -> Self {
Self {
num_tables: 5,
num_bits: 10,
num_probes: 5,
seed: 42,
}
}
}
type HashValue = u64;
#[derive(Debug, Clone)]
struct HashTable {
projections: Vec<Vec<f32>>,
buckets: HashMap<HashValue, Vec<usize>>,
}
impl HashTable {
fn new(num_bits: usize, dimensions: usize, rng: &mut impl Rng) -> Self {
let projections: Vec<Vec<f32>> = (0..num_bits)
.map(|_| {
(0..dimensions)
.map(|_| rng.random_range(-1.0..1.0))
.collect()
})
.collect();
Self {
projections,
buckets: HashMap::new(),
}
}
fn hash(&self, vector: &[f32]) -> HashValue {
let mut hash_val: HashValue = 0;
for (i, projection) in self.projections.iter().enumerate() {
let dot: f32 = vector
.iter()
.zip(projection.iter())
.map(|(v, p)| v * p)
.sum();
if dot > 0.0 {
hash_val |= 1u64 << i;
}
}
hash_val
}
fn insert(&mut self, vector: &[f32], index: usize) {
let hash_val = self.hash(vector);
self.buckets.entry(hash_val).or_default().push(index);
}
fn query(&self, vector: &[f32], num_probes: usize) -> Vec<usize> {
let hash_val = self.hash(vector);
let mut candidates = Vec::new();
if let Some(bucket) = self.buckets.get(&hash_val) {
candidates.extend(bucket);
}
if num_probes > 1 {
for probe in 1..num_probes.min(self.projections.len()) {
let flipped_hash = hash_val ^ (1u64 << probe);
if let Some(bucket) = self.buckets.get(&flipped_hash) {
candidates.extend(bucket);
}
}
}
candidates
}
}
#[derive(Debug, Clone)]
pub struct LshIndex {
config: LshConfig,
tables: Vec<HashTable>,
vectors: Vec<Vec<f32>>,
entity_ids: Vec<String>,
dimensions: usize,
is_built: bool,
}
impl LshIndex {
pub fn new(config: LshConfig) -> Self {
info!(
"Initialized LSH index: num_tables={}, num_bits={}, num_probes={}",
config.num_tables, config.num_bits, config.num_probes
);
Self {
config,
tables: Vec::new(),
vectors: Vec::new(),
entity_ids: Vec::new(),
dimensions: 0,
is_built: false,
}
}
pub fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
if embeddings.is_empty() {
return Err(anyhow!("Cannot build index from empty embeddings"));
}
info!("Building LSH index for {} entities", embeddings.len());
self.dimensions = embeddings.values().next().unwrap().len();
for (id, vec) in embeddings {
if vec.len() != self.dimensions {
return Err(anyhow!(
"Dimension mismatch for entity {}: expected {}, got {}",
id,
self.dimensions,
vec.len()
));
}
}
self.vectors.clear();
self.entity_ids.clear();
for (id, vec) in embeddings {
self.vectors.push(vec.clone());
self.entity_ids.push(id.clone());
}
let mut rng = StdRng::seed_from_u64(self.config.seed);
self.tables.clear();
for table_idx in 0..self.config.num_tables {
debug!(
"Building hash table {}/{}",
table_idx + 1,
self.config.num_tables
);
let mut table = HashTable::new(self.config.num_bits, self.dimensions, &mut rng);
for (idx, vector) in self.vectors.iter().enumerate() {
table.insert(vector, idx);
}
self.tables.push(table);
}
self.is_built = true;
info!("LSH index built successfully");
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
if !self.is_built {
return Err(anyhow!("Index not built. Call build() first"));
}
if query.len() != self.dimensions {
return Err(anyhow!(
"Query dimension mismatch: expected {}, got {}",
self.dimensions,
query.len()
));
}
debug!("LSH search for k={}", k);
let mut candidate_set: std::collections::HashSet<usize> = std::collections::HashSet::new();
for table in &self.tables {
let candidates = table.query(query, self.config.num_probes);
candidate_set.extend(candidates);
}
debug!("Found {} unique candidates", candidate_set.len());
let mut scored_candidates: Vec<(usize, f32)> = candidate_set
.into_iter()
.map(|idx| {
let similarity = cosine_similarity_simd(query, &self.vectors[idx]);
(idx, similarity)
})
.collect();
scored_candidates
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let results: Vec<SearchResult> = scored_candidates
.into_iter()
.take(k)
.enumerate()
.map(|(rank, (idx, score))| SearchResult {
entity_id: self.entity_ids[idx].clone(),
score,
distance: 1.0 - score,
rank: rank + 1,
})
.collect();
debug!("Returning {} results", results.len());
Ok(results)
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn stats(&self) -> LshStats {
let total_buckets: usize = self.tables.iter().map(|t| t.buckets.len()).sum();
let avg_bucket_size: f32 = if total_buckets > 0 {
let total_entries: usize = self
.tables
.iter()
.flat_map(|t| t.buckets.values())
.map(|b| b.len())
.sum();
total_entries as f32 / total_buckets as f32
} else {
0.0
};
let max_bucket_size: usize = self
.tables
.iter()
.flat_map(|t| t.buckets.values())
.map(|b| b.len())
.max()
.unwrap_or(0);
LshStats {
num_vectors: self.vectors.len(),
num_tables: self.tables.len(),
num_bits: self.config.num_bits,
total_buckets,
avg_bucket_size,
max_bucket_size,
dimensions: self.dimensions,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LshStats {
pub num_vectors: usize,
pub num_tables: usize,
pub num_bits: usize,
pub total_buckets: usize,
pub avg_bucket_size: f32,
pub max_bucket_size: usize,
pub dimensions: usize,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_embeddings(n: usize, dims: usize) -> HashMap<String, Vec<f32>> {
let mut embeddings = HashMap::new();
for i in 0..n {
let vec: Vec<f32> = (0..dims).map(|d| ((i * d) as f32 * 0.01).sin()).collect();
embeddings.insert(format!("doc{}", i), vec);
}
embeddings
}
#[test]
fn test_lsh_build() {
let embeddings = create_test_embeddings(100, 64);
let mut index = LshIndex::new(LshConfig::default());
assert!(index.build(&embeddings).is_ok());
assert_eq!(index.len(), 100);
assert!(index.is_built);
}
#[test]
fn test_lsh_search() {
let embeddings = create_test_embeddings(100, 64);
let mut index = LshIndex::new(LshConfig::default());
index.build(&embeddings).unwrap();
let query: Vec<f32> = (0..64).map(|d| (d as f32 * 0.01).sin()).collect();
let results = index.search(&query, 10).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 10);
if results.len() > 1 {
assert!(results[0].score >= results[results.len() - 1].score);
}
}
#[test]
fn test_lsh_empty_embeddings() {
let embeddings = HashMap::new();
let mut index = LshIndex::new(LshConfig::default());
assert!(index.build(&embeddings).is_err());
}
#[test]
fn test_lsh_dimension_mismatch() {
let mut embeddings = HashMap::new();
embeddings.insert("doc1".to_string(), vec![1.0, 2.0, 3.0]);
embeddings.insert("doc2".to_string(), vec![1.0, 2.0]);
let mut index = LshIndex::new(LshConfig::default());
assert!(index.build(&embeddings).is_err());
}
#[test]
fn test_lsh_search_before_build() {
let index = LshIndex::new(LshConfig::default());
let query = vec![1.0, 2.0, 3.0];
assert!(index.search(&query, 10).is_err());
}
#[test]
fn test_lsh_query_dimension_mismatch() {
let embeddings = create_test_embeddings(100, 64);
let mut index = LshIndex::new(LshConfig::default());
index.build(&embeddings).unwrap();
let wrong_query = vec![1.0, 2.0]; assert!(index.search(&wrong_query, 10).is_err());
}
#[test]
fn test_lsh_stats() {
let embeddings = create_test_embeddings(100, 64);
let mut index = LshIndex::new(LshConfig::default());
index.build(&embeddings).unwrap();
let stats = index.stats();
assert_eq!(stats.num_vectors, 100);
assert_eq!(stats.num_tables, 10);
assert_eq!(stats.dimensions, 64);
assert!(stats.total_buckets > 0);
assert!(stats.avg_bucket_size > 0.0);
}
#[test]
fn test_lsh_config_presets() {
let high_recall = LshConfig::high_recall();
assert_eq!(high_recall.num_tables, 20);
assert_eq!(high_recall.num_probes, 10);
let fast = LshConfig::fast();
assert_eq!(fast.num_tables, 5);
assert_eq!(fast.num_probes, 1);
let memory = LshConfig::memory_efficient();
assert_eq!(memory.num_tables, 5);
assert_eq!(memory.num_bits, 10);
}
#[test]
fn test_hash_table_hash() {
let mut rng = StdRng::seed_from_u64(42);
let table = HashTable::new(8, 3, &mut rng);
let vec1 = vec![1.0, 2.0, 3.0];
let vec2 = vec![1.0, 2.0, 3.0];
let vec3 = vec![-1.0, -2.0, -3.0];
assert_eq!(table.hash(&vec1), table.hash(&vec2));
let hash1 = table.hash(&vec1);
let hash3 = table.hash(&vec3);
assert_ne!(hash1, hash3);
}
#[test]
fn test_multiprobe_increases_candidates() {
let embeddings = create_test_embeddings(50, 32);
let config_1probe = LshConfig {
num_tables: 5,
num_bits: 10,
num_probes: 1,
seed: 42,
};
let mut index_1probe = LshIndex::new(config_1probe);
index_1probe.build(&embeddings).unwrap();
let config_5probe = LshConfig {
num_tables: 5,
num_bits: 10,
num_probes: 5,
seed: 42,
};
let mut index_5probe = LshIndex::new(config_5probe);
index_5probe.build(&embeddings).unwrap();
let query: Vec<f32> = (0..32).map(|d| (d as f32 * 0.02).cos()).collect();
let results_1probe = index_1probe.search(&query, 20).unwrap();
let results_5probe = index_5probe.search(&query, 20).unwrap();
assert!(results_5probe.len() >= results_1probe.len());
}
}