1use std::sync::Arc;
2
3use hashbrown::HashMap;
4use ndarray::Array1;
5use parking_lot::RwLock;
6use rand::rngs::StdRng;
7use rand::SeedableRng;
8
9use crate::distance::{self, DistanceMetric};
10use crate::error::{LshError, Result};
11use crate::hash::{multi_probe_keys, RandomProjectionHasher};
12use crate::metrics::{MetricsCollector, MetricsSnapshot, QueryTimer};
13
14#[derive(Debug, Clone)]
16#[cfg_attr(
17 feature = "persistence",
18 derive(serde::Serialize, serde::Deserialize)
19)]
20pub struct IndexConfig {
21 pub dim: usize,
23 pub num_hashes: usize,
25 pub num_tables: usize,
27 pub num_probes: usize,
29 pub distance_metric: DistanceMetric,
31 pub normalize_vectors: bool,
33 pub seed: Option<u64>,
35}
36
37impl Default for IndexConfig {
38 fn default() -> Self {
39 Self {
40 dim: 768,
41 num_hashes: 8,
42 num_tables: 16,
43 num_probes: 3,
44 distance_metric: DistanceMetric::Cosine,
45 normalize_vectors: true,
46 seed: None,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct QueryResult {
54 pub id: usize,
56 pub distance: f32,
58}
59
60#[derive(Debug, Clone)]
62pub struct IndexStats {
63 pub num_vectors: usize,
64 pub num_tables: usize,
65 pub num_hashes: usize,
66 pub dimension: usize,
67 pub total_buckets: usize,
68 pub avg_bucket_size: f64,
69 pub max_bucket_size: usize,
70 pub memory_estimate_bytes: usize,
71}
72
73impl std::fmt::Display for IndexStats {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 write!(
76 f,
77 "LshIndex {{ vectors: {}, tables: {}, hashes/table: {}, dim: {}, \
78 buckets: {}, avg_bucket: {:.1}, max_bucket: {}, mem: ~{:.1}MB }}",
79 self.num_vectors,
80 self.num_tables,
81 self.num_hashes,
82 self.dimension,
83 self.total_buckets,
84 self.avg_bucket_size,
85 self.max_bucket_size,
86 self.memory_estimate_bytes as f64 / (1024.0 * 1024.0),
87 )
88 }
89}
90
91#[cfg_attr(
96 feature = "persistence",
97 derive(serde::Serialize, serde::Deserialize)
98)]
99pub(crate) struct IndexInner {
100 pub(crate) vectors: HashMap<usize, Array1<f32>>,
101 pub(crate) tables: Vec<HashMap<u64, Vec<usize>>>,
102 pub(crate) hashers: Vec<RandomProjectionHasher>,
103 pub(crate) config: IndexConfig,
104 pub(crate) next_id: usize,
105}
106
107pub struct LshIndex {
116 pub(crate) inner: RwLock<IndexInner>,
117 pub(crate) metrics: Option<Arc<MetricsCollector>>,
118}
119
120impl std::fmt::Debug for LshIndex {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 let inner = self.inner.read();
123 f.debug_struct("LshIndex")
124 .field("num_vectors", &inner.vectors.len())
125 .field("config", &inner.config)
126 .field("has_metrics", &self.metrics.is_some())
127 .finish()
128 }
129}
130
131impl LshIndex {
132 pub fn builder() -> LshIndexBuilder {
134 LshIndexBuilder::new()
135 }
136
137 pub fn new(config: IndexConfig) -> Result<Self> {
139 Self::new_with_metrics(config, false)
140 }
141
142 fn new_with_metrics(config: IndexConfig, enable_metrics: bool) -> Result<Self> {
143 if config.dim == 0 {
144 return Err(LshError::ZeroDimension);
145 }
146 if config.num_hashes == 0 || config.num_hashes > 64 {
147 return Err(LshError::InvalidNumHashes(config.num_hashes));
148 }
149 if config.num_tables == 0 {
150 return Err(LshError::InvalidConfig(
151 "num_tables must be > 0".into(),
152 ));
153 }
154
155 let mut rng = match config.seed {
156 Some(seed) => StdRng::seed_from_u64(seed),
157 None => StdRng::from_entropy(),
158 };
159
160 let hashers: Vec<RandomProjectionHasher> = (0..config.num_tables)
161 .map(|_| RandomProjectionHasher::new(config.dim, config.num_hashes, &mut rng))
162 .collect();
163
164 let tables = (0..config.num_tables).map(|_| HashMap::new()).collect();
165
166 let inner = IndexInner {
167 vectors: HashMap::new(),
168 tables,
169 hashers,
170 config,
171 next_id: 0,
172 };
173
174 let metrics = if enable_metrics {
175 Some(Arc::new(MetricsCollector::new()))
176 } else {
177 None
178 };
179
180 Ok(Self {
181 inner: RwLock::new(inner),
182 metrics,
183 })
184 }
185
186 pub fn insert(&self, id: usize, vector: &[f32]) -> Result<()> {
194 let mut inner = self.inner.write();
195
196 if vector.len() != inner.config.dim {
197 return Err(LshError::DimensionMismatch {
198 expected: inner.config.dim,
199 got: vector.len(),
200 });
201 }
202
203 if let Some(old_vec) = inner.vectors.get(&id) {
205 let old_vec = old_vec.clone();
206 let old_hashes: Vec<u64> = inner
207 .hashers
208 .iter()
209 .map(|h| h.hash_vector_fast(&old_vec.view()))
210 .collect();
211 for (i, old_hash) in old_hashes.into_iter().enumerate() {
212 if let Some(bucket) = inner.tables[i].get_mut(&old_hash) {
213 bucket.retain(|&x| x != id);
214 if bucket.is_empty() {
215 inner.tables[i].remove(&old_hash);
216 }
217 }
218 }
219 }
220
221 let mut arr = Array1::from_vec(vector.to_vec());
222 if inner.config.normalize_vectors {
223 distance::normalize(&mut arr);
224 }
225
226 let new_hashes: Vec<u64> = inner
227 .hashers
228 .iter()
229 .map(|h| h.hash_vector_fast(&arr.view()))
230 .collect();
231 for (i, hash) in new_hashes.into_iter().enumerate() {
232 inner.tables[i].entry(hash).or_default().push(id);
233 }
234
235 inner.vectors.insert(id, arr);
236
237 if id >= inner.next_id {
238 inner.next_id = id + 1;
239 }
240
241 if let Some(ref m) = self.metrics {
242 m.record_insert();
243 }
244
245 Ok(())
246 }
247
248 pub fn insert_auto(&self, vector: &[f32]) -> Result<usize> {
253 let mut inner = self.inner.write();
254
255 if vector.len() != inner.config.dim {
256 return Err(LshError::DimensionMismatch {
257 expected: inner.config.dim,
258 got: vector.len(),
259 });
260 }
261
262 let id = inner.next_id;
263
264 let mut arr = Array1::from_vec(vector.to_vec());
265 if inner.config.normalize_vectors {
266 distance::normalize(&mut arr);
267 }
268
269 let new_hashes: Vec<u64> = inner
270 .hashers
271 .iter()
272 .map(|h| h.hash_vector_fast(&arr.view()))
273 .collect();
274 for (i, hash) in new_hashes.into_iter().enumerate() {
275 inner.tables[i].entry(hash).or_default().push(id);
276 }
277
278 inner.vectors.insert(id, arr);
279 inner.next_id = id + 1;
280
281 if let Some(ref m) = self.metrics {
282 m.record_insert();
283 }
284
285 Ok(id)
286 }
287
288 pub fn insert_batch(&self, vectors: &[(usize, &[f32])]) -> Result<()> {
290 for &(id, v) in vectors {
291 self.insert(id, v)?;
292 }
293 Ok(())
294 }
295
296 pub fn query(&self, vector: &[f32], k: usize) -> Result<Vec<QueryResult>> {
304 let timer = self.metrics.as_ref().map(|_| QueryTimer::new());
305 let inner = self.inner.read();
306
307 if vector.len() != inner.config.dim {
308 return Err(LshError::DimensionMismatch {
309 expected: inner.config.dim,
310 got: vector.len(),
311 });
312 }
313
314 if inner.vectors.is_empty() {
315 return Ok(Vec::new());
316 }
317
318 let mut query_vec = Array1::from_vec(vector.to_vec());
319 if inner.config.normalize_vectors {
320 distance::normalize(&mut query_vec);
321 }
322
323 let num_vectors = inner.vectors.len();
327 let use_bitvec = inner.next_id <= num_vectors.saturating_mul(4);
328 let mut seen = if use_bitvec {
329 vec![false; inner.next_id]
330 } else {
331 Vec::new()
332 };
333 let mut candidate_set: HashMap<usize, ()> = if use_bitvec {
334 HashMap::new() } else {
336 HashMap::with_capacity(num_vectors / 4)
337 };
338 let mut candidate_ids: Vec<usize> = Vec::new();
339
340 for (i, hasher) in inner.hashers.iter().enumerate() {
341 let (hash, margins) = hasher.hash_vector(&query_vec.view());
342
343 let probe_keys = if inner.config.num_probes > 0 {
344 multi_probe_keys(hash, &margins, inner.config.num_probes)
345 } else {
346 vec![hash]
347 };
348
349 for key in probe_keys {
350 if let Some(bucket) = inner.tables[i].get(&key) {
351 if let Some(ref m) = self.metrics {
352 m.record_bucket_hit();
353 }
354 for &id in bucket {
355 if use_bitvec {
356 if !seen[id] {
357 seen[id] = true;
358 candidate_ids.push(id);
359 }
360 } else if candidate_set.insert(id, ()).is_none() {
361 candidate_ids.push(id);
362 }
363 }
364 } else if let Some(ref m) = self.metrics {
365 m.record_bucket_miss();
366 }
367 }
368 }
369
370 let use_fast_cosine = inner.config.normalize_vectors
374 && inner.config.distance_metric == distance::DistanceMetric::Cosine;
375 let query_view = query_vec.view();
376
377 let num_candidates = candidate_ids.len();
378
379 let mut results: Vec<QueryResult> = candidate_ids
380 .iter()
381 .filter_map(|&id| {
382 inner.vectors.get(&id).map(|stored| {
383 let dist = if use_fast_cosine {
384 distance::cosine_distance_normalized(&query_view, &stored.view())
385 } else {
386 inner
387 .config
388 .distance_metric
389 .compute(&query_view, &stored.view())
390 };
391 QueryResult { id, distance: dist }
392 })
393 })
394 .collect();
395
396 results.sort_by(|a, b| {
397 a.distance
398 .partial_cmp(&b.distance)
399 .unwrap_or(std::cmp::Ordering::Equal)
400 });
401 results.truncate(k);
402
403 if let Some(ref m) = self.metrics {
404 if let Some(t) = timer {
405 m.record_query(num_candidates as u64, t.elapsed_ns());
406 }
407 }
408
409 Ok(results)
410 }
411
412 pub fn remove(&self, id: usize) -> Result<()> {
418 let mut inner = self.inner.write();
419
420 let vec = inner.vectors.remove(&id).ok_or(LshError::NotFound(id))?;
421
422 let hashes: Vec<u64> = inner
423 .hashers
424 .iter()
425 .map(|h| h.hash_vector_fast(&vec.view()))
426 .collect();
427 for (i, hash) in hashes.into_iter().enumerate() {
428 if let Some(bucket) = inner.tables[i].get_mut(&hash) {
429 bucket.retain(|&x| x != id);
430 if bucket.is_empty() {
431 inner.tables[i].remove(&hash);
432 }
433 }
434 }
435
436 Ok(())
437 }
438
439 pub fn contains(&self, id: usize) -> bool {
441 self.inner.read().vectors.contains_key(&id)
442 }
443
444 pub fn len(&self) -> usize {
450 self.inner.read().vectors.len()
451 }
452
453 pub fn is_empty(&self) -> bool {
455 self.inner.read().vectors.is_empty()
456 }
457
458 pub fn stats(&self) -> IndexStats {
460 let inner = self.inner.read();
461
462 let total_buckets: usize = inner.tables.iter().map(|t| t.len()).sum();
463 let total_entries: usize = inner
464 .tables
465 .iter()
466 .flat_map(|t| t.values())
467 .map(|v| v.len())
468 .sum();
469 let max_bucket_size = inner
470 .tables
471 .iter()
472 .flat_map(|t| t.values())
473 .map(|v| v.len())
474 .max()
475 .unwrap_or(0);
476
477 let avg_bucket_size = if total_buckets > 0 {
478 total_entries as f64 / total_buckets as f64
479 } else {
480 0.0
481 };
482
483 let vector_mem =
484 inner.vectors.len() * (inner.config.dim * 4 + std::mem::size_of::<usize>());
485 let table_mem = total_buckets * (std::mem::size_of::<u64>() + 24);
486 let entry_mem = total_entries * std::mem::size_of::<usize>();
487 let proj_mem =
488 inner.config.num_tables * inner.config.num_hashes * inner.config.dim * 4;
489
490 IndexStats {
491 num_vectors: inner.vectors.len(),
492 num_tables: inner.config.num_tables,
493 num_hashes: inner.config.num_hashes,
494 dimension: inner.config.dim,
495 total_buckets,
496 avg_bucket_size,
497 max_bucket_size,
498 memory_estimate_bytes: vector_mem + table_mem + entry_mem + proj_mem,
499 }
500 }
501
502 pub fn metrics(&self) -> Option<MetricsSnapshot> {
504 self.metrics.as_ref().map(|m| m.snapshot())
505 }
506
507 pub fn reset_metrics(&self) {
509 if let Some(ref m) = self.metrics {
510 m.reset();
511 }
512 }
513
514 pub fn clear(&self) {
516 let mut inner = self.inner.write();
517 inner.vectors.clear();
518 for table in &mut inner.tables {
519 table.clear();
520 }
521 inner.next_id = 0;
522 }
523
524 pub fn config(&self) -> IndexConfig {
526 self.inner.read().config.clone()
527 }
528}
529
530#[cfg(feature = "parallel")]
535impl LshIndex {
536 pub fn par_insert_batch(&self, vectors: &[(usize, Vec<f32>)]) -> Result<()> {
538 use rayon::prelude::*;
539
540 let (config, hashers) = {
542 let inner = self.inner.read();
543 (inner.config.clone(), inner.hashers.clone())
544 };
545
546 for (_, v) in vectors {
548 if v.len() != config.dim {
549 return Err(LshError::DimensionMismatch {
550 expected: config.dim,
551 got: v.len(),
552 });
553 }
554 }
555
556 let prepared: Vec<(usize, Array1<f32>, Vec<u64>)> = vectors
558 .par_iter()
559 .map(|(id, v)| {
560 let mut arr = Array1::from_vec(v.clone());
561 if config.normalize_vectors {
562 distance::normalize(&mut arr);
563 }
564 let hashes: Vec<u64> = hashers
565 .iter()
566 .map(|h| h.hash_vector_fast(&arr.view()))
567 .collect();
568 (*id, arr, hashes)
569 })
570 .collect();
571
572 let mut inner = self.inner.write();
574 for (id, arr, hashes) in prepared {
575 if let Some(old_vec) = inner.vectors.get(&id) {
577 let old_vec = old_vec.clone();
578 let old_hashes: Vec<u64> = hashers
579 .iter()
580 .map(|h| h.hash_vector_fast(&old_vec.view()))
581 .collect();
582 for (i, old_hash) in old_hashes.into_iter().enumerate() {
583 if let Some(bucket) = inner.tables[i].get_mut(&old_hash) {
584 bucket.retain(|&x| x != id);
585 if bucket.is_empty() {
586 inner.tables[i].remove(&old_hash);
587 }
588 }
589 }
590 }
591
592 for (i, hash) in hashes.into_iter().enumerate() {
593 inner.tables[i].entry(hash).or_default().push(id);
594 }
595 inner.vectors.insert(id, arr);
596 if id >= inner.next_id {
597 inner.next_id = id + 1;
598 }
599 }
600
601 Ok(())
602 }
603
604 pub fn par_query_batch(
606 &self,
607 queries: &[Vec<f32>],
608 k: usize,
609 ) -> Result<Vec<Vec<QueryResult>>> {
610 use rayon::prelude::*;
611
612 queries
613 .par_iter()
614 .map(|q| self.query(q, k))
615 .collect()
616 }
617}
618
619#[derive(Default)]
625pub struct LshIndexBuilder {
626 config: IndexConfig,
627 enable_metrics: bool,
628}
629
630impl LshIndexBuilder {
631 pub fn new() -> Self {
632 Self::default()
633 }
634
635 pub fn dim(mut self, dim: usize) -> Self {
636 self.config.dim = dim;
637 self
638 }
639
640 pub fn num_hashes(mut self, n: usize) -> Self {
641 self.config.num_hashes = n;
642 self
643 }
644
645 pub fn num_tables(mut self, n: usize) -> Self {
646 self.config.num_tables = n;
647 self
648 }
649
650 pub fn num_probes(mut self, n: usize) -> Self {
651 self.config.num_probes = n;
652 self
653 }
654
655 pub fn distance_metric(mut self, m: DistanceMetric) -> Self {
656 self.config.distance_metric = m;
657 self
658 }
659
660 pub fn normalize(mut self, yes: bool) -> Self {
661 self.config.normalize_vectors = yes;
662 self
663 }
664
665 pub fn seed(mut self, seed: u64) -> Self {
666 self.config.seed = Some(seed);
667 self
668 }
669
670 pub fn enable_metrics(mut self) -> Self {
671 self.enable_metrics = true;
672 self
673 }
674
675 pub fn build(self) -> Result<LshIndex> {
677 LshIndex::new_with_metrics(self.config, self.enable_metrics)
678 }
679}