1use std::sync::Arc;
2
3use hashbrown::HashMap;
4use ndarray::Array1;
5use parking_lot::RwLock;
6use rand::rngs::StdRng;
7use rand::SeedableRng;
8
9use crate::distance::{self, DistanceMetric};
10use crate::error::{LshError, Result};
11use crate::hash::{multi_probe_keys, RandomProjectionHasher};
12use crate::metrics::{MetricsCollector, MetricsSnapshot, QueryTimer};
13
14#[derive(Debug, Clone)]
16#[cfg_attr(
17 feature = "persistence",
18 derive(serde::Serialize, serde::Deserialize)
19)]
20pub struct IndexConfig {
21 pub dim: usize,
23 pub num_hashes: usize,
25 pub num_tables: usize,
27 pub num_probes: usize,
29 pub distance_metric: DistanceMetric,
31 pub normalize_vectors: bool,
33 pub seed: Option<u64>,
35}
36
37impl Default for IndexConfig {
38 fn default() -> Self {
39 Self {
40 dim: 768,
41 num_hashes: 8,
42 num_tables: 16,
43 num_probes: 3,
44 distance_metric: DistanceMetric::Cosine,
45 normalize_vectors: true,
46 seed: None,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct QueryResult {
54 pub id: usize,
56 pub distance: f32,
58}
59
60#[derive(Debug, Clone)]
62pub struct IndexStats {
63 pub num_vectors: usize,
64 pub num_tables: usize,
65 pub num_hashes: usize,
66 pub dimension: usize,
67 pub total_buckets: usize,
68 pub avg_bucket_size: f64,
69 pub max_bucket_size: usize,
70 pub memory_estimate_bytes: usize,
71}
72
73impl std::fmt::Display for IndexStats {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 write!(
76 f,
77 "LshIndex {{ vectors: {}, tables: {}, hashes/table: {}, dim: {}, \
78 buckets: {}, avg_bucket: {:.1}, max_bucket: {}, mem: ~{:.1}MB }}",
79 self.num_vectors,
80 self.num_tables,
81 self.num_hashes,
82 self.dimension,
83 self.total_buckets,
84 self.avg_bucket_size,
85 self.max_bucket_size,
86 self.memory_estimate_bytes as f64 / (1024.0 * 1024.0),
87 )
88 }
89}
90
91#[cfg_attr(
96 feature = "persistence",
97 derive(serde::Serialize, serde::Deserialize)
98)]
99pub(crate) struct IndexInner {
100 pub(crate) vectors: HashMap<usize, Array1<f32>>,
101 pub(crate) tables: Vec<HashMap<u64, Vec<usize>>>,
102 pub(crate) hashers: Vec<RandomProjectionHasher>,
103 pub(crate) config: IndexConfig,
104 pub(crate) next_id: usize,
105}
106
107pub struct LshIndex {
116 pub(crate) inner: RwLock<IndexInner>,
117 pub(crate) metrics: Option<Arc<MetricsCollector>>,
118}
119
120impl std::fmt::Debug for LshIndex {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 let inner = self.inner.read();
123 f.debug_struct("LshIndex")
124 .field("num_vectors", &inner.vectors.len())
125 .field("config", &inner.config)
126 .field("has_metrics", &self.metrics.is_some())
127 .finish()
128 }
129}
130
131impl LshIndex {
132 pub fn builder() -> LshIndexBuilder {
134 LshIndexBuilder::new()
135 }
136
137 pub fn new(config: IndexConfig) -> Result<Self> {
139 Self::new_with_metrics(config, false)
140 }
141
142 fn new_with_metrics(config: IndexConfig, enable_metrics: bool) -> Result<Self> {
143 if config.dim == 0 {
144 return Err(LshError::ZeroDimension);
145 }
146 if config.num_hashes == 0 || config.num_hashes > 64 {
147 return Err(LshError::InvalidNumHashes(config.num_hashes));
148 }
149 if config.num_tables == 0 {
150 return Err(LshError::InvalidConfig(
151 "num_tables must be > 0".into(),
152 ));
153 }
154
155 let mut rng = match config.seed {
156 Some(seed) => StdRng::seed_from_u64(seed),
157 None => StdRng::from_entropy(),
158 };
159
160 let hashers: Vec<RandomProjectionHasher> = (0..config.num_tables)
161 .map(|_| RandomProjectionHasher::new(config.dim, config.num_hashes, &mut rng))
162 .collect();
163
164 let tables = (0..config.num_tables).map(|_| HashMap::new()).collect();
165
166 let inner = IndexInner {
167 vectors: HashMap::new(),
168 tables,
169 hashers,
170 config,
171 next_id: 0,
172 };
173
174 let metrics = if enable_metrics {
175 Some(Arc::new(MetricsCollector::new()))
176 } else {
177 None
178 };
179
180 Ok(Self {
181 inner: RwLock::new(inner),
182 metrics,
183 })
184 }
185
186 pub fn insert(&self, id: usize, vector: &[f32]) -> Result<()> {
194 let mut inner = self.inner.write();
195
196 if vector.len() != inner.config.dim {
197 return Err(LshError::DimensionMismatch {
198 expected: inner.config.dim,
199 got: vector.len(),
200 });
201 }
202
203 if let Some(old_vec) = inner.vectors.get(&id) {
205 let old_vec = old_vec.clone();
206 let old_hashes: Vec<u64> = inner
207 .hashers
208 .iter()
209 .map(|h| h.hash_vector_fast(&old_vec.view()))
210 .collect();
211 for (i, old_hash) in old_hashes.into_iter().enumerate() {
212 if let Some(bucket) = inner.tables[i].get_mut(&old_hash) {
213 bucket.retain(|&x| x != id);
214 if bucket.is_empty() {
215 inner.tables[i].remove(&old_hash);
216 }
217 }
218 }
219 }
220
221 let mut arr = Array1::from_vec(vector.to_vec());
222 if inner.config.normalize_vectors {
223 distance::normalize(&mut arr);
224 }
225
226 let new_hashes: Vec<u64> = inner
227 .hashers
228 .iter()
229 .map(|h| h.hash_vector_fast(&arr.view()))
230 .collect();
231 for (i, hash) in new_hashes.into_iter().enumerate() {
232 inner.tables[i].entry(hash).or_default().push(id);
233 }
234
235 inner.vectors.insert(id, arr);
236
237 if id >= inner.next_id {
238 inner.next_id = id + 1;
239 }
240
241 if let Some(ref m) = self.metrics {
242 m.record_insert();
243 }
244
245 Ok(())
246 }
247
248 pub fn insert_auto(&self, vector: &[f32]) -> Result<usize> {
253 let mut inner = self.inner.write();
254
255 if vector.len() != inner.config.dim {
256 return Err(LshError::DimensionMismatch {
257 expected: inner.config.dim,
258 got: vector.len(),
259 });
260 }
261
262 let id = inner.next_id;
263
264 let mut arr = Array1::from_vec(vector.to_vec());
265 if inner.config.normalize_vectors {
266 distance::normalize(&mut arr);
267 }
268
269 let new_hashes: Vec<u64> = inner
270 .hashers
271 .iter()
272 .map(|h| h.hash_vector_fast(&arr.view()))
273 .collect();
274 for (i, hash) in new_hashes.into_iter().enumerate() {
275 inner.tables[i].entry(hash).or_default().push(id);
276 }
277
278 inner.vectors.insert(id, arr);
279 inner.next_id = id + 1;
280
281 if let Some(ref m) = self.metrics {
282 m.record_insert();
283 }
284
285 Ok(id)
286 }
287
288 pub fn insert_batch(&self, vectors: &[(usize, &[f32])]) -> Result<()> {
290 for &(id, v) in vectors {
291 self.insert(id, v)?;
292 }
293 Ok(())
294 }
295
296 pub fn query(&self, vector: &[f32], k: usize) -> Result<Vec<QueryResult>> {
304 let timer = self.metrics.as_ref().map(|_| QueryTimer::new());
305 let inner = self.inner.read();
306
307 if vector.len() != inner.config.dim {
308 return Err(LshError::DimensionMismatch {
309 expected: inner.config.dim,
310 got: vector.len(),
311 });
312 }
313
314 if inner.vectors.is_empty() {
315 return Ok(Vec::new());
316 }
317
318 let mut query_vec = Array1::from_vec(vector.to_vec());
319 if inner.config.normalize_vectors {
320 distance::normalize(&mut query_vec);
321 }
322
323 let mut candidates: HashMap<usize, ()> = HashMap::new();
325
326 for (i, hasher) in inner.hashers.iter().enumerate() {
327 let (hash, margins) = hasher.hash_vector(&query_vec.view());
328
329 let probe_keys = if inner.config.num_probes > 0 {
330 multi_probe_keys(hash, &margins, inner.config.num_probes)
331 } else {
332 vec![hash]
333 };
334
335 for key in probe_keys {
336 if let Some(bucket) = inner.tables[i].get(&key) {
337 if let Some(ref m) = self.metrics {
338 m.record_bucket_hit();
339 }
340 for &id in bucket {
341 candidates.insert(id, ());
342 }
343 } else if let Some(ref m) = self.metrics {
344 m.record_bucket_miss();
345 }
346 }
347 }
348
349 let mut results: Vec<QueryResult> = candidates
351 .keys()
352 .filter_map(|&id| {
353 inner.vectors.get(&id).map(|stored| {
354 let dist = inner
355 .config
356 .distance_metric
357 .compute(&query_vec.view(), &stored.view());
358 QueryResult { id, distance: dist }
359 })
360 })
361 .collect();
362
363 results.sort_by(|a, b| {
364 a.distance
365 .partial_cmp(&b.distance)
366 .unwrap_or(std::cmp::Ordering::Equal)
367 });
368 results.truncate(k);
369
370 if let Some(ref m) = self.metrics {
371 if let Some(t) = timer {
372 m.record_query(candidates.len() as u64, t.elapsed_ns());
373 }
374 }
375
376 Ok(results)
377 }
378
379 pub fn remove(&self, id: usize) -> Result<()> {
385 let mut inner = self.inner.write();
386
387 let vec = inner.vectors.remove(&id).ok_or(LshError::NotFound(id))?;
388
389 let hashes: Vec<u64> = inner
390 .hashers
391 .iter()
392 .map(|h| h.hash_vector_fast(&vec.view()))
393 .collect();
394 for (i, hash) in hashes.into_iter().enumerate() {
395 if let Some(bucket) = inner.tables[i].get_mut(&hash) {
396 bucket.retain(|&x| x != id);
397 if bucket.is_empty() {
398 inner.tables[i].remove(&hash);
399 }
400 }
401 }
402
403 Ok(())
404 }
405
406 pub fn contains(&self, id: usize) -> bool {
408 self.inner.read().vectors.contains_key(&id)
409 }
410
411 pub fn len(&self) -> usize {
417 self.inner.read().vectors.len()
418 }
419
420 pub fn is_empty(&self) -> bool {
422 self.inner.read().vectors.is_empty()
423 }
424
425 pub fn stats(&self) -> IndexStats {
427 let inner = self.inner.read();
428
429 let total_buckets: usize = inner.tables.iter().map(|t| t.len()).sum();
430 let total_entries: usize = inner
431 .tables
432 .iter()
433 .flat_map(|t| t.values())
434 .map(|v| v.len())
435 .sum();
436 let max_bucket_size = inner
437 .tables
438 .iter()
439 .flat_map(|t| t.values())
440 .map(|v| v.len())
441 .max()
442 .unwrap_or(0);
443
444 let avg_bucket_size = if total_buckets > 0 {
445 total_entries as f64 / total_buckets as f64
446 } else {
447 0.0
448 };
449
450 let vector_mem =
451 inner.vectors.len() * (inner.config.dim * 4 + std::mem::size_of::<usize>());
452 let table_mem = total_buckets * (std::mem::size_of::<u64>() + 24);
453 let entry_mem = total_entries * std::mem::size_of::<usize>();
454 let proj_mem =
455 inner.config.num_tables * inner.config.num_hashes * inner.config.dim * 4;
456
457 IndexStats {
458 num_vectors: inner.vectors.len(),
459 num_tables: inner.config.num_tables,
460 num_hashes: inner.config.num_hashes,
461 dimension: inner.config.dim,
462 total_buckets,
463 avg_bucket_size,
464 max_bucket_size,
465 memory_estimate_bytes: vector_mem + table_mem + entry_mem + proj_mem,
466 }
467 }
468
469 pub fn metrics(&self) -> Option<MetricsSnapshot> {
471 self.metrics.as_ref().map(|m| m.snapshot())
472 }
473
474 pub fn reset_metrics(&self) {
476 if let Some(ref m) = self.metrics {
477 m.reset();
478 }
479 }
480
481 pub fn clear(&self) {
483 let mut inner = self.inner.write();
484 inner.vectors.clear();
485 for table in &mut inner.tables {
486 table.clear();
487 }
488 inner.next_id = 0;
489 }
490
491 pub fn config(&self) -> IndexConfig {
493 self.inner.read().config.clone()
494 }
495}
496
497#[cfg(feature = "parallel")]
502impl LshIndex {
503 pub fn par_insert_batch(&self, vectors: &[(usize, Vec<f32>)]) -> Result<()> {
505 use rayon::prelude::*;
506
507 let (config, hashers) = {
509 let inner = self.inner.read();
510 (inner.config.clone(), inner.hashers.clone())
511 };
512
513 for (_, v) in vectors {
515 if v.len() != config.dim {
516 return Err(LshError::DimensionMismatch {
517 expected: config.dim,
518 got: v.len(),
519 });
520 }
521 }
522
523 let prepared: Vec<(usize, Array1<f32>, Vec<u64>)> = vectors
525 .par_iter()
526 .map(|(id, v)| {
527 let mut arr = Array1::from_vec(v.clone());
528 if config.normalize_vectors {
529 distance::normalize(&mut arr);
530 }
531 let hashes: Vec<u64> = hashers
532 .iter()
533 .map(|h| h.hash_vector_fast(&arr.view()))
534 .collect();
535 (*id, arr, hashes)
536 })
537 .collect();
538
539 let mut inner = self.inner.write();
541 for (id, arr, hashes) in prepared {
542 if let Some(old_vec) = inner.vectors.get(&id) {
544 let old_vec = old_vec.clone();
545 let old_hashes: Vec<u64> = hashers
546 .iter()
547 .map(|h| h.hash_vector_fast(&old_vec.view()))
548 .collect();
549 for (i, old_hash) in old_hashes.into_iter().enumerate() {
550 if let Some(bucket) = inner.tables[i].get_mut(&old_hash) {
551 bucket.retain(|&x| x != id);
552 if bucket.is_empty() {
553 inner.tables[i].remove(&old_hash);
554 }
555 }
556 }
557 }
558
559 for (i, hash) in hashes.into_iter().enumerate() {
560 inner.tables[i].entry(hash).or_default().push(id);
561 }
562 inner.vectors.insert(id, arr);
563 if id >= inner.next_id {
564 inner.next_id = id + 1;
565 }
566 }
567
568 Ok(())
569 }
570
571 pub fn par_query_batch(
573 &self,
574 queries: &[Vec<f32>],
575 k: usize,
576 ) -> Result<Vec<Vec<QueryResult>>> {
577 use rayon::prelude::*;
578
579 queries
580 .par_iter()
581 .map(|q| self.query(q, k))
582 .collect()
583 }
584}
585
586#[derive(Default)]
592pub struct LshIndexBuilder {
593 config: IndexConfig,
594 enable_metrics: bool,
595}
596
597impl LshIndexBuilder {
598 pub fn new() -> Self {
599 Self::default()
600 }
601
602 pub fn dim(mut self, dim: usize) -> Self {
603 self.config.dim = dim;
604 self
605 }
606
607 pub fn num_hashes(mut self, n: usize) -> Self {
608 self.config.num_hashes = n;
609 self
610 }
611
612 pub fn num_tables(mut self, n: usize) -> Self {
613 self.config.num_tables = n;
614 self
615 }
616
617 pub fn num_probes(mut self, n: usize) -> Self {
618 self.config.num_probes = n;
619 self
620 }
621
622 pub fn distance_metric(mut self, m: DistanceMetric) -> Self {
623 self.config.distance_metric = m;
624 self
625 }
626
627 pub fn normalize(mut self, yes: bool) -> Self {
628 self.config.normalize_vectors = yes;
629 self
630 }
631
632 pub fn seed(mut self, seed: u64) -> Self {
633 self.config.seed = Some(seed);
634 self
635 }
636
637 pub fn enable_metrics(mut self) -> Self {
638 self.enable_metrics = true;
639 self
640 }
641
642 pub fn build(self) -> Result<LshIndex> {
644 LshIndex::new_with_metrics(self.config, self.enable_metrics)
645 }
646}