1use std::collections::HashMap;
11use std::marker::PhantomData;
12use std::sync::{Arc, RwLock};
13
14use manifoldb_core::PointId;
15use manifoldb_storage::{StorageEngine, Transaction};
16
17use crate::distance::DistanceMetric;
18use crate::error::VectorError;
19use crate::types::{CollectionName, Embedding, NamedVector};
20
21use super::config::HnswConfig;
22use super::hnsw::HnswIndex;
23use super::registry::{HnswIndexEntry, HnswRegistry};
24use super::traits::VectorIndex;
25
26pub struct HnswIndexManager<E: StorageEngine> {
36 indexes: RwLock<HashMap<String, Arc<RwLock<HnswIndex<E>>>>>,
39 _phantom: PhantomData<E>,
41}
42
43impl<E: StorageEngine> HnswIndexManager<E> {
44 #[must_use]
49 pub fn new() -> Self {
50 Self { indexes: RwLock::new(HashMap::new()), _phantom: PhantomData }
51 }
52
53 pub fn create_index_for_vector(
72 &self,
73 engine: E,
74 collection: &CollectionName,
75 vector_name: &str,
76 dimension: usize,
77 distance_metric: DistanceMetric,
78 config: &HnswConfig,
79 ) -> Result<String, VectorError> {
80 let collection_str = collection.as_str();
81 let index_name = HnswRegistry::index_name_for_vector(collection_str, vector_name);
82
83 {
85 let tx = engine.begin_read()?;
86 if HnswRegistry::exists(&tx, &index_name)? {
87 return Err(VectorError::InvalidName(format!(
88 "index '{}' already exists",
89 index_name
90 )));
91 }
92 }
93
94 let entry = HnswIndexEntry::for_named_vector(
96 collection_str,
97 vector_name,
98 dimension,
99 distance_metric,
100 config,
101 );
102
103 {
104 let mut tx = engine.begin_write()?;
105 HnswRegistry::register(&mut tx, &entry)?;
106 tx.commit()?;
107 }
108
109 let hnsw = HnswIndex::new(engine, &index_name, dimension, distance_metric, config.clone())?;
111
112 {
114 let mut indexes = self.indexes.write().map_err(|_| VectorError::LockPoisoned)?;
115 indexes.insert(index_name.clone(), Arc::new(RwLock::new(hnsw)));
116 }
117
118 Ok(index_name)
119 }
120
121 pub fn drop_indexes_for_collection(
128 &self,
129 engine: &E,
130 collection: &CollectionName,
131 ) -> Result<Vec<String>, VectorError> {
132 let collection_str = collection.as_str();
133 let mut dropped = Vec::new();
134
135 let entries = {
137 let tx = engine.begin_read()?;
138 HnswRegistry::list_for_collection(&tx, collection_str)?
139 };
140
141 for entry in entries {
143 {
145 let mut indexes = self.indexes.write().map_err(|_| VectorError::LockPoisoned)?;
146 indexes.remove(&entry.name);
147 }
148
149 {
151 let mut tx = engine.begin_write()?;
152 HnswRegistry::drop(&mut tx, &entry.name)?;
153 tx.commit()?;
155 }
156
157 dropped.push(entry.name);
158 }
159
160 Ok(dropped)
161 }
162
163 pub fn drop_index(&self, engine: &E, index_name: &str) -> Result<bool, VectorError> {
165 {
167 let mut indexes = self.indexes.write().map_err(|_| VectorError::LockPoisoned)?;
168 indexes.remove(index_name);
169 }
170
171 let mut tx = engine.begin_write()?;
173 let existed = HnswRegistry::drop(&mut tx, index_name)?;
174 if existed {
175 super::persistence::clear_index_tx(
177 &mut tx,
178 &super::persistence::table_name(index_name),
179 )?;
180 }
181 tx.commit()?;
182
183 Ok(existed)
184 }
185
186 pub fn on_point_upsert(
200 &self,
201 collection: &CollectionName,
202 point_id: PointId,
203 vectors: &HashMap<String, NamedVector>,
204 ) -> Result<(), VectorError> {
205 let collection_str = collection.as_str();
206
207 for (vector_name, vector) in vectors {
208 if let NamedVector::Dense(data) = vector {
210 let index_name = HnswRegistry::index_name_for_vector(collection_str, vector_name);
212
213 let indexes = self.indexes.read().map_err(|_| VectorError::LockPoisoned)?;
214 if let Some(index) = indexes.get(&index_name) {
215 let embedding = Embedding::new(data.clone())?;
217 let entity_id = manifoldb_core::EntityId::new(point_id.as_u64());
219
220 let mut index_guard = index.write().map_err(|_| VectorError::LockPoisoned)?;
221 index_guard.insert(entity_id, &embedding)?;
222 }
223 }
224 }
225
226 Ok(())
227 }
228
229 pub fn on_vector_update(
231 &self,
232 collection: &CollectionName,
233 point_id: PointId,
234 vector_name: &str,
235 vector: &NamedVector,
236 ) -> Result<(), VectorError> {
237 if let NamedVector::Dense(data) = vector {
239 let collection_str = collection.as_str();
240 let index_name = HnswRegistry::index_name_for_vector(collection_str, vector_name);
241
242 let indexes = self.indexes.read().map_err(|_| VectorError::LockPoisoned)?;
243 if let Some(index) = indexes.get(&index_name) {
244 let embedding = Embedding::new(data.clone())?;
245 let entity_id = manifoldb_core::EntityId::new(point_id.as_u64());
246
247 let mut index_guard = index.write().map_err(|_| VectorError::LockPoisoned)?;
248 index_guard.insert(entity_id, &embedding)?;
249 }
250 }
251
252 Ok(())
253 }
254
255 pub fn on_point_delete(
259 &self,
260 collection: &CollectionName,
261 point_id: PointId,
262 ) -> Result<(), VectorError> {
263 let collection_str = collection.as_str();
264 let entity_id = manifoldb_core::EntityId::new(point_id.as_u64());
265
266 let indexes = self.indexes.read().map_err(|_| VectorError::LockPoisoned)?;
267
268 for (name, index) in indexes.iter() {
270 if name.starts_with(&format!("{}_", collection_str)) && name.ends_with("_hnsw") {
271 let mut index_guard = index.write().map_err(|_| VectorError::LockPoisoned)?;
272 let _ = index_guard.delete(entity_id)?;
273 }
274 }
275
276 Ok(())
277 }
278
279 pub fn on_vector_delete(
281 &self,
282 collection: &CollectionName,
283 point_id: PointId,
284 vector_name: &str,
285 ) -> Result<bool, VectorError> {
286 let collection_str = collection.as_str();
287 let index_name = HnswRegistry::index_name_for_vector(collection_str, vector_name);
288 let entity_id = manifoldb_core::EntityId::new(point_id.as_u64());
289
290 let indexes = self.indexes.read().map_err(|_| VectorError::LockPoisoned)?;
291 if let Some(index) = indexes.get(&index_name) {
292 let mut index_guard = index.write().map_err(|_| VectorError::LockPoisoned)?;
293 return index_guard.delete(entity_id);
294 }
295
296 Ok(false)
297 }
298
299 pub fn get_index(
307 &self,
308 collection: &str,
309 vector_name: &str,
310 ) -> Result<Option<Arc<RwLock<HnswIndex<E>>>>, VectorError> {
311 let index_name = HnswRegistry::index_name_for_vector(collection, vector_name);
312 let indexes = self.indexes.read().map_err(|_| VectorError::LockPoisoned)?;
313 Ok(indexes.get(&index_name).cloned())
314 }
315
316 pub fn get_index_by_name(
318 &self,
319 index_name: &str,
320 ) -> Result<Option<Arc<RwLock<HnswIndex<E>>>>, VectorError> {
321 let indexes = self.indexes.read().map_err(|_| VectorError::LockPoisoned)?;
322 Ok(indexes.get(index_name).cloned())
323 }
324
325 pub fn load_index(&self, engine: E, index_name: &str) -> Result<(), VectorError> {
329 {
331 let indexes = self.indexes.read().map_err(|_| VectorError::LockPoisoned)?;
332 if indexes.contains_key(index_name) {
333 return Ok(());
334 }
335 }
336
337 let hnsw = HnswIndex::open(engine, index_name)?;
339
340 {
342 let mut indexes = self.indexes.write().map_err(|_| VectorError::LockPoisoned)?;
343 indexes.insert(index_name.to_string(), Arc::new(RwLock::new(hnsw)));
344 }
345
346 Ok(())
347 }
348
349 pub fn has_index(
351 &self,
352 engine: &E,
353 collection: &str,
354 vector_name: &str,
355 ) -> Result<bool, VectorError> {
356 let tx = engine.begin_read()?;
357 HnswRegistry::exists_for_named_vector(&tx, collection, vector_name)
358 }
359
360 pub fn is_index_loaded(
362 &self,
363 collection: &str,
364 vector_name: &str,
365 ) -> Result<bool, VectorError> {
366 let index_name = HnswRegistry::index_name_for_vector(collection, vector_name);
367 let indexes = self.indexes.read().map_err(|_| VectorError::LockPoisoned)?;
368 Ok(indexes.contains_key(&index_name))
369 }
370
371 pub fn list_indexes(
373 &self,
374 engine: &E,
375 collection: &str,
376 ) -> Result<Vec<HnswIndexEntry>, VectorError> {
377 let tx = engine.begin_read()?;
378 HnswRegistry::list_for_collection(&tx, collection)
379 }
380
381 pub fn rebuild_index<I>(
395 &self,
396 collection: &str,
397 vector_name: &str,
398 points: I,
399 ) -> Result<usize, VectorError>
400 where
401 I: IntoIterator<Item = (PointId, Vec<f32>)>,
402 {
403 let index_name = HnswRegistry::index_name_for_vector(collection, vector_name);
404
405 let indexes = self.indexes.read().map_err(|_| VectorError::LockPoisoned)?;
407 let index = indexes.get(&index_name).ok_or_else(|| {
408 VectorError::SpaceNotFound(format!("index '{}' not found in cache", index_name))
409 })?;
410
411 let mut index_guard = index.write().map_err(|_| VectorError::LockPoisoned)?;
412
413 let embeddings: Vec<(manifoldb_core::EntityId, Embedding)> = points
415 .into_iter()
416 .map(|(pid, data)| {
417 let entity_id = manifoldb_core::EntityId::new(pid.as_u64());
418 let embedding = Embedding::new(data)?;
419 Ok((entity_id, embedding))
420 })
421 .collect::<Result<Vec<_>, VectorError>>()?;
422
423 let count = embeddings.len();
424
425 let refs: Vec<(manifoldb_core::EntityId, &Embedding)> =
427 embeddings.iter().map(|(id, emb)| (*id, emb)).collect();
428
429 index_guard.insert_batch(&refs)?;
430 index_guard.flush()?;
431
432 Ok(count)
433 }
434
435 pub fn load_indexes_for_collection<F>(
450 &self,
451 engine: &E,
452 engine_factory: F,
453 collection: &str,
454 ) -> Result<Vec<(String, Result<(), VectorError>)>, VectorError>
455 where
456 F: Fn() -> Result<E, VectorError>,
457 {
458 let entries = {
460 let tx = engine.begin_read()?;
461 HnswRegistry::list_for_collection(&tx, collection)?
462 };
463
464 let mut results = Vec::with_capacity(entries.len());
465
466 for entry in entries {
467 let index_name = entry.name.clone();
468
469 let load_result = (|| -> Result<(), VectorError> {
471 {
473 let indexes = self.indexes.read().map_err(|_| VectorError::LockPoisoned)?;
474 if indexes.contains_key(&index_name) {
475 return Ok(());
476 }
477 }
478
479 let new_engine = engine_factory()?;
481
482 let hnsw = HnswIndex::open(new_engine, &index_name)?;
484
485 {
487 let mut indexes =
488 self.indexes.write().map_err(|_| VectorError::LockPoisoned)?;
489 indexes.insert(index_name.clone(), Arc::new(RwLock::new(hnsw)));
490 }
491
492 Ok(())
493 })();
494
495 results.push((index_name, load_result));
496 }
497
498 Ok(results)
499 }
500
501 pub fn verify_index(
516 &self,
517 collection: &str,
518 vector_name: &str,
519 expected_point_count: usize,
520 ) -> Result<RecoveryStatus, VectorError> {
521 let index_name = HnswRegistry::index_name_for_vector(collection, vector_name);
522
523 let indexes = self.indexes.read().map_err(|_| VectorError::LockPoisoned)?;
524 let index = match indexes.get(&index_name) {
525 Some(idx) => idx,
526 None => return Ok(RecoveryStatus::NotLoaded),
527 };
528
529 let guard = index.read().map_err(|_| VectorError::LockPoisoned)?;
530 let actual_count = guard.len()?;
531
532 if actual_count == expected_point_count {
533 Ok(RecoveryStatus::Valid)
534 } else {
535 Ok(RecoveryStatus::NeedsRebuild {
536 expected: expected_point_count,
537 actual: actual_count,
538 })
539 }
540 }
541
542 pub fn rebuild_index_from_scratch<I>(
558 &self,
559 engine: E,
560 collection: &str,
561 vector_name: &str,
562 points: I,
563 ) -> Result<usize, VectorError>
564 where
565 I: IntoIterator<Item = (PointId, Vec<f32>)>,
566 {
567 let index_name = HnswRegistry::index_name_for_vector(collection, vector_name);
568
569 let entry = {
571 let tx = engine.begin_read()?;
572 HnswRegistry::get(&tx, &index_name)?.ok_or_else(|| {
573 VectorError::SpaceNotFound(format!("index '{}' not in registry", index_name))
574 })?
575 };
576
577 {
579 let mut indexes = self.indexes.write().map_err(|_| VectorError::LockPoisoned)?;
580 indexes.remove(&index_name);
581 }
582
583 {
585 let mut tx = engine.begin_write()?;
586 super::persistence::clear_index_tx(
587 &mut tx,
588 &super::persistence::table_name(&index_name),
589 )?;
590 tx.commit()?;
591 }
592
593 let config = entry.config();
595 let distance_metric = entry.distance_metric.into();
596 let hnsw = HnswIndex::new(engine, &index_name, entry.dimension, distance_metric, config)?;
597
598 let embeddings: Vec<(manifoldb_core::EntityId, Embedding)> = points
600 .into_iter()
601 .map(|(pid, data)| {
602 let entity_id = manifoldb_core::EntityId::new(pid.as_u64());
603 let embedding = Embedding::new(data)?;
604 Ok((entity_id, embedding))
605 })
606 .collect::<Result<Vec<_>, VectorError>>()?;
607
608 let count = embeddings.len();
609
610 if embeddings.is_empty() {
611 {
613 let mut indexes = self.indexes.write().map_err(|_| VectorError::LockPoisoned)?;
614 indexes.insert(index_name, Arc::new(RwLock::new(hnsw)));
615 }
616 } else {
617 let refs: Vec<(manifoldb_core::EntityId, &Embedding)> =
618 embeddings.iter().map(|(id, emb)| (*id, emb)).collect();
619 let mut hnsw_guard = hnsw;
620 hnsw_guard.insert_batch(&refs)?;
621 hnsw_guard.flush()?;
622
623 {
625 let mut indexes = self.indexes.write().map_err(|_| VectorError::LockPoisoned)?;
626 indexes.insert(index_name, Arc::new(RwLock::new(hnsw_guard)));
627 }
628 }
629
630 Ok(count)
631 }
632}
633
634#[derive(Debug, Clone, PartialEq, Eq)]
636pub enum RecoveryStatus {
637 Valid,
639 NotLoaded,
641 NeedsRebuild {
643 expected: usize,
645 actual: usize,
647 },
648}
649
650impl<E: StorageEngine> Default for HnswIndexManager<E> {
651 fn default() -> Self {
652 Self::new()
653 }
654}
655
656#[cfg(test)]
657mod tests {
658 use super::*;
659 use manifoldb_storage::backends::RedbEngine;
660
661 fn create_test_manager() -> HnswIndexManager<RedbEngine> {
662 HnswIndexManager::new()
663 }
664
665 #[test]
666 fn test_create_index_for_vector() {
667 let manager = create_test_manager();
668 let engine = RedbEngine::in_memory().unwrap();
669 let collection = CollectionName::new("documents").unwrap();
670
671 let index_name = manager
672 .create_index_for_vector(
673 engine,
674 &collection,
675 "embedding",
676 384,
677 DistanceMetric::Cosine,
678 &HnswConfig::default(),
679 )
680 .unwrap();
681
682 assert_eq!(index_name, "documents_embedding_hnsw");
683
684 assert!(manager.is_index_loaded("documents", "embedding").unwrap());
686 }
687
688 #[test]
689 fn test_point_upsert_and_delete() {
690 let manager = create_test_manager();
691 let engine = RedbEngine::in_memory().unwrap();
692 let collection = CollectionName::new("documents").unwrap();
693
694 manager
696 .create_index_for_vector(
697 engine,
698 &collection,
699 "embedding",
700 4,
701 DistanceMetric::Euclidean,
702 &HnswConfig::default(),
703 )
704 .unwrap();
705
706 let point_id = PointId::new(1);
708 let mut vectors = HashMap::new();
709 vectors.insert("embedding".to_string(), NamedVector::Dense(vec![1.0, 2.0, 3.0, 4.0]));
710
711 manager.on_point_upsert(&collection, point_id, &vectors).unwrap();
712
713 let index = manager.get_index("documents", "embedding").unwrap().unwrap();
715 let guard = index.read().unwrap();
716 assert!(guard.contains(manifoldb_core::EntityId::new(1)).unwrap());
717 drop(guard);
718
719 manager.on_point_delete(&collection, point_id).unwrap();
721
722 let guard = index.read().unwrap();
724 assert!(!guard.contains(manifoldb_core::EntityId::new(1)).unwrap());
725 }
726
727 #[test]
728 fn test_vector_update() {
729 let manager = create_test_manager();
730 let engine = RedbEngine::in_memory().unwrap();
731 let collection = CollectionName::new("documents").unwrap();
732
733 manager
734 .create_index_for_vector(
735 engine,
736 &collection,
737 "embedding",
738 4,
739 DistanceMetric::Euclidean,
740 &HnswConfig::default(),
741 )
742 .unwrap();
743
744 let point_id = PointId::new(1);
745
746 let vector = NamedVector::Dense(vec![1.0, 2.0, 3.0, 4.0]);
748 manager.on_vector_update(&collection, point_id, "embedding", &vector).unwrap();
749
750 let new_vector = NamedVector::Dense(vec![5.0, 6.0, 7.0, 8.0]);
752 manager.on_vector_update(&collection, point_id, "embedding", &new_vector).unwrap();
753
754 let index = manager.get_index("documents", "embedding").unwrap().unwrap();
756 let guard = index.read().unwrap();
757 assert!(guard.contains(manifoldb_core::EntityId::new(1)).unwrap());
758 assert_eq!(guard.len().unwrap(), 1);
759 }
760
761 #[test]
762 fn test_sparse_vector_ignored() {
763 let manager = create_test_manager();
764 let engine = RedbEngine::in_memory().unwrap();
765 let collection = CollectionName::new("documents").unwrap();
766
767 manager
769 .create_index_for_vector(
770 engine,
771 &collection,
772 "embedding",
773 4,
774 DistanceMetric::Euclidean,
775 &HnswConfig::default(),
776 )
777 .unwrap();
778
779 let point_id = PointId::new(1);
781 let mut vectors = HashMap::new();
782 vectors.insert("sparse_vec".to_string(), NamedVector::Sparse(vec![(0, 1.0), (5, 0.5)]));
783
784 manager.on_point_upsert(&collection, point_id, &vectors).unwrap();
786 }
787
788 #[test]
789 fn test_get_index() {
790 let manager = create_test_manager();
791 let engine = RedbEngine::in_memory().unwrap();
792 let collection = CollectionName::new("test").unwrap();
793
794 assert!(manager.get_index("test", "v1").unwrap().is_none());
796
797 manager
799 .create_index_for_vector(
800 engine,
801 &collection,
802 "v1",
803 64,
804 DistanceMetric::Cosine,
805 &HnswConfig::default(),
806 )
807 .unwrap();
808
809 let index = manager.get_index("test", "v1").unwrap();
811 assert!(index.is_some());
812 }
813}