1use anyhow::{anyhow, Result};
50use rand::rngs::StdRng;
51use rand::{Rng, SeedableRng};
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54use tracing::{debug, info};
55
56use crate::simd::cosine_similarity_simd;
57use crate::types::SearchResult;
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct LshConfig {
62 pub num_tables: usize,
64 pub num_bits: usize,
66 pub num_probes: usize,
68 pub seed: u64,
70}
71
72impl Default for LshConfig {
73 fn default() -> Self {
74 Self {
75 num_tables: 10,
76 num_bits: 16,
77 num_probes: 3,
78 seed: 42,
79 }
80 }
81}
82
83impl LshConfig {
84 pub fn high_recall() -> Self {
86 Self {
87 num_tables: 20,
88 num_bits: 20,
89 num_probes: 10,
90 seed: 42,
91 }
92 }
93
94 pub fn fast() -> Self {
96 Self {
97 num_tables: 5,
98 num_bits: 12,
99 num_probes: 1,
100 seed: 42,
101 }
102 }
103
104 pub fn memory_efficient() -> Self {
106 Self {
107 num_tables: 5,
108 num_bits: 10,
109 num_probes: 5,
110 seed: 42,
111 }
112 }
113}
114
115type HashValue = u64;
117
118#[derive(Debug, Clone)]
120struct HashTable {
121 projections: Vec<Vec<f32>>,
123 buckets: HashMap<HashValue, Vec<usize>>,
125}
126
127impl HashTable {
128 fn new(num_bits: usize, dimensions: usize, rng: &mut impl Rng) -> Self {
129 let projections: Vec<Vec<f32>> = (0..num_bits)
131 .map(|_| {
132 (0..dimensions)
133 .map(|_| rng.random_range(-1.0..1.0))
134 .collect()
135 })
136 .collect();
137
138 Self {
139 projections,
140 buckets: HashMap::new(),
141 }
142 }
143
144 fn hash(&self, vector: &[f32]) -> HashValue {
146 let mut hash_val: HashValue = 0;
147
148 for (i, projection) in self.projections.iter().enumerate() {
149 let dot: f32 = vector
151 .iter()
152 .zip(projection.iter())
153 .map(|(v, p)| v * p)
154 .sum();
155
156 if dot > 0.0 {
158 hash_val |= 1u64 << i;
159 }
160 }
161
162 hash_val
163 }
164
165 fn insert(&mut self, vector: &[f32], index: usize) {
167 let hash_val = self.hash(vector);
168 self.buckets.entry(hash_val).or_default().push(index);
169 }
170
171 fn query(&self, vector: &[f32], num_probes: usize) -> Vec<usize> {
173 let hash_val = self.hash(vector);
174 let mut candidates = Vec::new();
175
176 if let Some(bucket) = self.buckets.get(&hash_val) {
178 candidates.extend(bucket);
179 }
180
181 if num_probes > 1 {
183 for probe in 1..num_probes.min(self.projections.len()) {
184 let flipped_hash = hash_val ^ (1u64 << probe);
186 if let Some(bucket) = self.buckets.get(&flipped_hash) {
187 candidates.extend(bucket);
188 }
189 }
190 }
191
192 candidates
193 }
194}
195
196#[derive(Debug, Clone)]
198pub struct LshIndex {
199 config: LshConfig,
200 tables: Vec<HashTable>,
201 vectors: Vec<Vec<f32>>,
202 entity_ids: Vec<String>,
203 dimensions: usize,
204 is_built: bool,
205}
206
207impl LshIndex {
208 pub fn new(config: LshConfig) -> Self {
210 info!(
211 "Initialized LSH index: num_tables={}, num_bits={}, num_probes={}",
212 config.num_tables, config.num_bits, config.num_probes
213 );
214
215 Self {
216 config,
217 tables: Vec::new(),
218 vectors: Vec::new(),
219 entity_ids: Vec::new(),
220 dimensions: 0,
221 is_built: false,
222 }
223 }
224
225 pub fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
227 if embeddings.is_empty() {
228 return Err(anyhow!("Cannot build index from empty embeddings"));
229 }
230
231 info!("Building LSH index for {} entities", embeddings.len());
232
233 self.dimensions = embeddings.values().next().unwrap().len();
235
236 for (id, vec) in embeddings {
238 if vec.len() != self.dimensions {
239 return Err(anyhow!(
240 "Dimension mismatch for entity {}: expected {}, got {}",
241 id,
242 self.dimensions,
243 vec.len()
244 ));
245 }
246 }
247
248 self.vectors.clear();
250 self.entity_ids.clear();
251 for (id, vec) in embeddings {
252 self.vectors.push(vec.clone());
253 self.entity_ids.push(id.clone());
254 }
255
256 let mut rng = StdRng::seed_from_u64(self.config.seed);
258
259 self.tables.clear();
261 for table_idx in 0..self.config.num_tables {
262 debug!(
263 "Building hash table {}/{}",
264 table_idx + 1,
265 self.config.num_tables
266 );
267
268 let mut table = HashTable::new(self.config.num_bits, self.dimensions, &mut rng);
269
270 for (idx, vector) in self.vectors.iter().enumerate() {
272 table.insert(vector, idx);
273 }
274
275 self.tables.push(table);
276 }
277
278 self.is_built = true;
279 info!("LSH index built successfully");
280 Ok(())
281 }
282
283 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
285 if !self.is_built {
286 return Err(anyhow!("Index not built. Call build() first"));
287 }
288
289 if query.len() != self.dimensions {
290 return Err(anyhow!(
291 "Query dimension mismatch: expected {}, got {}",
292 self.dimensions,
293 query.len()
294 ));
295 }
296
297 debug!("LSH search for k={}", k);
298
299 let mut candidate_set: std::collections::HashSet<usize> = std::collections::HashSet::new();
301
302 for table in &self.tables {
303 let candidates = table.query(query, self.config.num_probes);
304 candidate_set.extend(candidates);
305 }
306
307 debug!("Found {} unique candidates", candidate_set.len());
308
309 let mut scored_candidates: Vec<(usize, f32)> = candidate_set
311 .into_iter()
312 .map(|idx| {
313 let similarity = cosine_similarity_simd(query, &self.vectors[idx]);
314 (idx, similarity)
315 })
316 .collect();
317
318 scored_candidates
320 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
321
322 let results: Vec<SearchResult> = scored_candidates
324 .into_iter()
325 .take(k)
326 .enumerate()
327 .map(|(rank, (idx, score))| SearchResult {
328 entity_id: self.entity_ids[idx].clone(),
329 score,
330 distance: 1.0 - score,
331 rank: rank + 1,
332 })
333 .collect();
334
335 debug!("Returning {} results", results.len());
336 Ok(results)
337 }
338
339 pub fn len(&self) -> usize {
341 self.vectors.len()
342 }
343
344 pub fn is_empty(&self) -> bool {
346 self.vectors.is_empty()
347 }
348
349 pub fn stats(&self) -> LshStats {
351 let total_buckets: usize = self.tables.iter().map(|t| t.buckets.len()).sum();
352 let avg_bucket_size: f32 = if total_buckets > 0 {
353 let total_entries: usize = self
354 .tables
355 .iter()
356 .flat_map(|t| t.buckets.values())
357 .map(|b| b.len())
358 .sum();
359 total_entries as f32 / total_buckets as f32
360 } else {
361 0.0
362 };
363
364 let max_bucket_size: usize = self
365 .tables
366 .iter()
367 .flat_map(|t| t.buckets.values())
368 .map(|b| b.len())
369 .max()
370 .unwrap_or(0);
371
372 LshStats {
373 num_vectors: self.vectors.len(),
374 num_tables: self.tables.len(),
375 num_bits: self.config.num_bits,
376 total_buckets,
377 avg_bucket_size,
378 max_bucket_size,
379 dimensions: self.dimensions,
380 }
381 }
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
386pub struct LshStats {
387 pub num_vectors: usize,
389 pub num_tables: usize,
391 pub num_bits: usize,
393 pub total_buckets: usize,
395 pub avg_bucket_size: f32,
397 pub max_bucket_size: usize,
399 pub dimensions: usize,
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 fn create_test_embeddings(n: usize, dims: usize) -> HashMap<String, Vec<f32>> {
408 let mut embeddings = HashMap::new();
409 for i in 0..n {
410 let vec: Vec<f32> = (0..dims).map(|d| ((i * d) as f32 * 0.01).sin()).collect();
411 embeddings.insert(format!("doc{}", i), vec);
412 }
413 embeddings
414 }
415
416 #[test]
417 fn test_lsh_build() {
418 let embeddings = create_test_embeddings(100, 64);
419 let mut index = LshIndex::new(LshConfig::default());
420 assert!(index.build(&embeddings).is_ok());
421 assert_eq!(index.len(), 100);
422 assert!(index.is_built);
423 }
424
425 #[test]
426 fn test_lsh_search() {
427 let embeddings = create_test_embeddings(100, 64);
428 let mut index = LshIndex::new(LshConfig::default());
429 index.build(&embeddings).unwrap();
430
431 let query: Vec<f32> = (0..64).map(|d| (d as f32 * 0.01).sin()).collect();
432 let results = index.search(&query, 10).unwrap();
433
434 assert!(!results.is_empty());
436 assert!(results.len() <= 10);
437 if results.len() > 1 {
439 assert!(results[0].score >= results[results.len() - 1].score);
440 }
441 }
442
443 #[test]
444 fn test_lsh_empty_embeddings() {
445 let embeddings = HashMap::new();
446 let mut index = LshIndex::new(LshConfig::default());
447 assert!(index.build(&embeddings).is_err());
448 }
449
450 #[test]
451 fn test_lsh_dimension_mismatch() {
452 let mut embeddings = HashMap::new();
453 embeddings.insert("doc1".to_string(), vec![1.0, 2.0, 3.0]);
454 embeddings.insert("doc2".to_string(), vec![1.0, 2.0]); let mut index = LshIndex::new(LshConfig::default());
457 assert!(index.build(&embeddings).is_err());
458 }
459
460 #[test]
461 fn test_lsh_search_before_build() {
462 let index = LshIndex::new(LshConfig::default());
463 let query = vec![1.0, 2.0, 3.0];
464 assert!(index.search(&query, 10).is_err());
465 }
466
467 #[test]
468 fn test_lsh_query_dimension_mismatch() {
469 let embeddings = create_test_embeddings(100, 64);
470 let mut index = LshIndex::new(LshConfig::default());
471 index.build(&embeddings).unwrap();
472
473 let wrong_query = vec![1.0, 2.0]; assert!(index.search(&wrong_query, 10).is_err());
475 }
476
477 #[test]
478 fn test_lsh_stats() {
479 let embeddings = create_test_embeddings(100, 64);
480 let mut index = LshIndex::new(LshConfig::default());
481 index.build(&embeddings).unwrap();
482
483 let stats = index.stats();
484 assert_eq!(stats.num_vectors, 100);
485 assert_eq!(stats.num_tables, 10);
486 assert_eq!(stats.dimensions, 64);
487 assert!(stats.total_buckets > 0);
488 assert!(stats.avg_bucket_size > 0.0);
489 }
490
491 #[test]
492 fn test_lsh_config_presets() {
493 let high_recall = LshConfig::high_recall();
494 assert_eq!(high_recall.num_tables, 20);
495 assert_eq!(high_recall.num_probes, 10);
496
497 let fast = LshConfig::fast();
498 assert_eq!(fast.num_tables, 5);
499 assert_eq!(fast.num_probes, 1);
500
501 let memory = LshConfig::memory_efficient();
502 assert_eq!(memory.num_tables, 5);
503 assert_eq!(memory.num_bits, 10);
504 }
505
506 #[test]
507 fn test_hash_table_hash() {
508 let mut rng = StdRng::seed_from_u64(42);
509 let table = HashTable::new(8, 3, &mut rng);
510
511 let vec1 = vec![1.0, 2.0, 3.0];
512 let vec2 = vec![1.0, 2.0, 3.0];
513 let vec3 = vec![-1.0, -2.0, -3.0];
514
515 assert_eq!(table.hash(&vec1), table.hash(&vec2));
517
518 let hash1 = table.hash(&vec1);
521 let hash3 = table.hash(&vec3);
522 assert_ne!(hash1, hash3);
524 }
525
526 #[test]
527 fn test_multiprobe_increases_candidates() {
528 let embeddings = create_test_embeddings(50, 32);
529
530 let config_1probe = LshConfig {
532 num_tables: 5,
533 num_bits: 10,
534 num_probes: 1,
535 seed: 42,
536 };
537 let mut index_1probe = LshIndex::new(config_1probe);
538 index_1probe.build(&embeddings).unwrap();
539
540 let config_5probe = LshConfig {
542 num_tables: 5,
543 num_bits: 10,
544 num_probes: 5,
545 seed: 42,
546 };
547 let mut index_5probe = LshIndex::new(config_5probe);
548 index_5probe.build(&embeddings).unwrap();
549
550 let query: Vec<f32> = (0..32).map(|d| (d as f32 * 0.02).cos()).collect();
551
552 let results_1probe = index_1probe.search(&query, 20).unwrap();
553 let results_5probe = index_5probe.search(&query, 20).unwrap();
554
555 assert!(results_5probe.len() >= results_1probe.len());
558 }
559}