1use crate::error::{Error, Result};
29use crate::index::traits::{DistanceType, SearchResult, VectorIndex};
30use ahash::AHashMap;
31use parking_lot::RwLock;
32use std::sync::Arc;
33
34#[derive(Debug, Clone)]
36pub struct IndexInfo {
37 pub name: String,
39 pub dimension: usize,
41 pub distance_type: DistanceType,
43 pub size: usize,
45 pub memory_bytes: usize,
47}
48
49#[derive(Debug, Clone)]
51pub struct MultiIndexResult {
52 pub index_name: String,
54 pub results: Vec<SearchResult>,
56}
57
58#[derive(Debug, Clone, Default)]
60pub struct MultiIndexResults {
61 pub by_index: Vec<MultiIndexResult>,
63 pub total_count: usize,
65}
66
67impl MultiIndexResults {
68 #[must_use]
70 pub fn new() -> Self {
71 Self::default()
72 }
73
74 pub fn add(&mut self, index_name: String, results: Vec<SearchResult>) {
76 self.total_count += results.len();
77 self.by_index.push(MultiIndexResult {
78 index_name,
79 results,
80 });
81 }
82
83 #[must_use]
87 pub fn flatten(&self) -> Vec<(String, SearchResult)> {
88 self.by_index
89 .iter()
90 .flat_map(|mir| {
91 mir.results
92 .iter()
93 .cloned()
94 .map(|r| (mir.index_name.clone(), r))
95 })
96 .collect()
97 }
98}
99
100#[derive(Debug, Default)]
123pub struct IndexRegistry {
124 indexes: AHashMap<String, Box<dyn VectorIndex>>,
126}
127
128impl IndexRegistry {
129 #[must_use]
131 pub fn new() -> Self {
132 Self {
133 indexes: AHashMap::new(),
134 }
135 }
136
137 #[must_use]
139 pub fn with_capacity(capacity: usize) -> Self {
140 Self {
141 indexes: AHashMap::with_capacity(capacity),
142 }
143 }
144
145 pub fn register<I: VectorIndex + 'static>(
151 &mut self,
152 name: impl Into<String>,
153 index: I,
154 ) -> Result<()> {
155 let name = name.into();
156 if self.indexes.contains_key(&name) {
157 return Err(Error::DuplicateIndex { name });
158 }
159 self.indexes.insert(name, Box::new(index));
160 Ok(())
161 }
162
163 pub fn register_or_replace<I: VectorIndex + 'static>(
167 &mut self,
168 name: impl Into<String>,
169 index: I,
170 ) -> Option<Box<dyn VectorIndex>> {
171 self.indexes.insert(name.into(), Box::new(index))
172 }
173
174 #[must_use]
176 pub fn get(&self, name: &str) -> Option<&dyn VectorIndex> {
177 self.indexes.get(name).map(AsRef::as_ref)
178 }
179
180 pub fn remove(&mut self, name: &str) -> Option<Box<dyn VectorIndex>> {
184 self.indexes.remove(name)
185 }
186
187 #[must_use]
189 pub fn contains(&self, name: &str) -> bool {
190 self.indexes.contains_key(name)
191 }
192
193 #[must_use]
195 pub fn list(&self) -> Vec<&str> {
196 self.indexes.keys().map(String::as_str).collect()
197 }
198
199 #[must_use]
201 pub fn info(&self) -> Vec<IndexInfo> {
202 self.indexes
203 .iter()
204 .map(|(name, index)| IndexInfo {
205 name: name.clone(),
206 dimension: index.dimension(),
207 distance_type: index.distance_type(),
208 size: index.len(),
209 memory_bytes: index.memory_usage(),
210 })
211 .collect()
212 }
213
214 #[must_use]
216 pub fn len(&self) -> usize {
217 self.indexes.len()
218 }
219
220 #[must_use]
222 pub fn is_empty(&self) -> bool {
223 self.indexes.is_empty()
224 }
225
226 #[must_use]
228 pub fn total_vectors(&self) -> usize {
229 self.indexes.values().map(|i| i.len()).sum()
230 }
231
232 #[must_use]
234 pub fn total_memory(&self) -> usize {
235 self.indexes.values().map(|i| i.memory_usage()).sum()
236 }
237
238 pub fn search(&self, name: &str, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
244 let index = self.indexes.get(name).ok_or_else(|| Error::IndexNotFound {
245 name: name.to_string(),
246 })?;
247 index.search(query, k)
248 }
249
250 pub fn search_all(&self, query: &[f32], k: usize) -> Result<MultiIndexResults> {
258 let mut results = MultiIndexResults::new();
259
260 for (name, index) in &self.indexes {
261 if index.dimension() != query.len() {
263 continue;
264 }
265
266 let index_results = index.search(query, k)?;
267 results.add(name.clone(), index_results);
268 }
269
270 Ok(results)
271 }
272
273 pub fn search_indexes(
285 &self,
286 names: &[&str],
287 query: &[f32],
288 k: usize,
289 ) -> Result<MultiIndexResults> {
290 let mut results = MultiIndexResults::new();
291
292 for name in names {
293 let index = self.indexes.get(*name).ok_or_else(|| Error::IndexNotFound {
294 name: (*name).to_string(),
295 })?;
296
297 if index.dimension() != query.len() {
299 return Err(Error::DimensionMismatch {
300 expected: index.dimension(),
301 got: query.len(),
302 });
303 }
304
305 let index_results = index.search(query, k)?;
306 results.add((*name).to_string(), index_results);
307 }
308
309 Ok(results)
310 }
311
312 pub fn add(&mut self, index_name: &str, id: String, vector: &[f32]) -> Result<()> {
318 let index = self
319 .indexes
320 .get_mut(index_name)
321 .ok_or_else(|| Error::IndexNotFound {
322 name: index_name.to_string(),
323 })?;
324 index.add(id, vector)
325 }
326
327 pub fn remove_vector(&mut self, index_name: &str, id: &str) -> Result<bool> {
333 let index = self
334 .indexes
335 .get_mut(index_name)
336 .ok_or_else(|| Error::IndexNotFound {
337 name: index_name.to_string(),
338 })?;
339 index.remove(id)
340 }
341
342 pub fn clear_all(&mut self) {
344 for index in self.indexes.values_mut() {
345 index.clear();
346 }
347 }
348}
349
350pub type SharedRegistry = Arc<RwLock<IndexRegistry>>;
352
353#[must_use]
355pub fn shared_registry() -> SharedRegistry {
356 Arc::new(RwLock::new(IndexRegistry::new()))
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use crate::index::{FlatIndex, IndexConfig};
363
364 fn create_test_index(dim: usize) -> FlatIndex {
365 FlatIndex::new(IndexConfig::new(dim))
366 }
367
368 #[test]
369 fn test_register_and_get() {
370 let mut registry = IndexRegistry::new();
371 let index = create_test_index(128);
372
373 registry.register("test", index).unwrap();
374
375 assert!(registry.contains("test"));
376 assert!(!registry.contains("other"));
377 assert_eq!(registry.len(), 1);
378
379 let retrieved = registry.get("test").unwrap();
380 assert_eq!(retrieved.dimension(), 128);
381 }
382
383 #[test]
384 fn test_duplicate_register_error() {
385 let mut registry = IndexRegistry::new();
386
387 registry.register("test", create_test_index(128)).unwrap();
388 let result = registry.register("test", create_test_index(256));
389
390 assert!(result.is_err());
391 }
392
393 #[test]
394 fn test_register_or_replace() {
395 let mut registry = IndexRegistry::new();
396
397 let old = registry.register_or_replace("test", create_test_index(128));
399 assert!(old.is_none());
400
401 let old = registry.register_or_replace("test", create_test_index(256));
403 assert!(old.is_some());
404 assert_eq!(old.unwrap().dimension(), 128);
405
406 assert_eq!(registry.get("test").unwrap().dimension(), 256);
408 }
409
410 #[test]
411 fn test_remove() {
412 let mut registry = IndexRegistry::new();
413 registry.register("test", create_test_index(128)).unwrap();
414
415 let removed = registry.remove("test");
416 assert!(removed.is_some());
417 assert_eq!(removed.unwrap().dimension(), 128);
418 assert!(registry.is_empty());
419 }
420
421 #[test]
422 fn test_list_and_info() {
423 let mut registry = IndexRegistry::new();
424 registry.register("a", create_test_index(128)).unwrap();
425 registry.register("b", create_test_index(256)).unwrap();
426
427 let names = registry.list();
428 assert_eq!(names.len(), 2);
429 assert!(names.contains(&"a"));
430 assert!(names.contains(&"b"));
431
432 let info = registry.info();
433 assert_eq!(info.len(), 2);
434 }
435
436 #[test]
437 fn test_search_specific_index() {
438 let mut registry = IndexRegistry::new();
439 let mut index = create_test_index(4);
440
441 index.add("v1".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
443 index.add("v2".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
444
445 registry.register("test", index).unwrap();
446
447 let query = [1.0, 0.0, 0.0, 0.0];
448 let results = registry.search("test", &query, 2).unwrap();
449
450 assert_eq!(results.len(), 2);
451 assert_eq!(results[0].id, "v1"); }
453
454 #[test]
455 fn test_search_nonexistent_index() {
456 let registry = IndexRegistry::new();
457 let result = registry.search("nonexistent", &[1.0], 1);
458
459 assert!(result.is_err());
460 }
461
462 #[test]
463 fn test_search_all() {
464 let mut registry = IndexRegistry::new();
465
466 let mut index1 = create_test_index(4);
468 index1.add("a1".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
469
470 let mut index2 = create_test_index(4);
471 index2.add("b1".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
472
473 registry.register("index1", index1).unwrap();
474 registry.register("index2", index2).unwrap();
475
476 let query = [0.5, 0.5, 0.0, 0.0];
477 let results = registry.search_all(&query, 10).unwrap();
478
479 assert_eq!(results.by_index.len(), 2);
480 assert_eq!(results.total_count, 2);
481 }
482
483 #[test]
484 fn test_search_all_skips_incompatible_dimensions() {
485 let mut registry = IndexRegistry::new();
486
487 let mut index1 = create_test_index(4);
488 index1.add("a1".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
489
490 let mut index2 = create_test_index(8); index2
492 .add("b1".to_string(), &[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
493 .unwrap();
494
495 registry.register("index1", index1).unwrap();
496 registry.register("index2", index2).unwrap();
497
498 let query = [0.5, 0.5, 0.0, 0.0];
500 let results = registry.search_all(&query, 10).unwrap();
501
502 assert_eq!(results.by_index.len(), 1);
503 assert_eq!(results.by_index[0].index_name, "index1");
504 }
505
506 #[test]
507 fn test_search_indexes() {
508 let mut registry = IndexRegistry::new();
509
510 let mut index1 = create_test_index(4);
511 index1.add("a1".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
512
513 let mut index2 = create_test_index(4);
514 index2.add("b1".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
515
516 let mut index3 = create_test_index(4);
517 index3.add("c1".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
518
519 registry.register("idx1", index1).unwrap();
520 registry.register("idx2", index2).unwrap();
521 registry.register("idx3", index3).unwrap();
522
523 let query = [0.5, 0.5, 0.0, 0.0];
525 let results = registry
526 .search_indexes(&["idx1", "idx2"], &query, 10)
527 .unwrap();
528
529 assert_eq!(results.by_index.len(), 2);
530 assert_eq!(results.total_count, 2);
531 }
532
533 #[test]
534 fn test_add_to_index() {
535 let mut registry = IndexRegistry::new();
536 registry.register("test", create_test_index(4)).unwrap();
537
538 registry
539 .add("test", "v1".to_string(), &[1.0, 0.0, 0.0, 0.0])
540 .unwrap();
541
542 assert_eq!(registry.get("test").unwrap().len(), 1);
543 }
544
545 #[test]
546 fn test_multi_index_results_flatten() {
547 let mut results = MultiIndexResults::new();
548
549 results.add(
550 "idx1".to_string(),
551 vec![SearchResult::new("a".to_string(), 0.5, DistanceType::L2)],
552 );
553 results.add(
554 "idx2".to_string(),
555 vec![SearchResult::new("b".to_string(), 0.3, DistanceType::L2)],
556 );
557
558 let flat = results.flatten();
559 assert_eq!(flat.len(), 2);
560 assert_eq!(flat[0].0, "idx1");
561 assert_eq!(flat[0].1.id, "a");
562 assert_eq!(flat[1].0, "idx2");
563 assert_eq!(flat[1].1.id, "b");
564 }
565
566 #[test]
567 fn test_total_vectors_and_memory() {
568 let mut registry = IndexRegistry::new();
569
570 let mut index1 = create_test_index(4);
571 index1.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
572 index1.add("b".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
573
574 let mut index2 = create_test_index(4);
575 index2.add("c".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
576
577 registry.register("idx1", index1).unwrap();
578 registry.register("idx2", index2).unwrap();
579
580 assert_eq!(registry.total_vectors(), 3);
581 assert!(registry.total_memory() > 0);
582 }
583
584 #[test]
585 fn test_clear_all() {
586 let mut registry = IndexRegistry::new();
587
588 let mut index1 = create_test_index(4);
589 index1.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
590
591 let mut index2 = create_test_index(4);
592 index2.add("b".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
593
594 registry.register("idx1", index1).unwrap();
595 registry.register("idx2", index2).unwrap();
596
597 assert_eq!(registry.total_vectors(), 2);
598
599 registry.clear_all();
600
601 assert_eq!(registry.total_vectors(), 0);
602 assert_eq!(registry.len(), 2); }
604
605 #[test]
606 fn test_shared_registry() {
607 let registry = shared_registry();
608
609 {
611 let mut reg = registry.write();
612 reg.register("test", create_test_index(128)).unwrap();
613 }
614
615 {
617 let reg = registry.read();
618 assert!(reg.contains("test"));
619 }
620 }
621}