Skip to main content

oxirs_vec/
index.rs

1//! Advanced vector indexing with HNSW and other efficient algorithms
2
3use crate::Vector;
4
5// Re-export VectorIndex trait for use by other modules
6pub use crate::VectorIndex;
7use anyhow::{anyhow, Result};
8use oxirs_core::parallel::*;
9use oxirs_core::Triple;
10use serde::{Deserialize, Serialize};
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::sync::Arc;
14
15use crate::hnsw::{HnswConfig, HnswIndex};
16use crate::ivf::{IvfConfig, IvfIndex};
17use crate::pq::{PQConfig, PQIndex};
18
19/// Type alias for filter functions
20pub type FilterFunction = Box<dyn Fn(&str) -> bool>;
21/// Type alias for filter functions with Send + Sync
22pub type FilterFunctionSync = Box<dyn Fn(&str) -> bool + Send + Sync>;
23
24/// Configuration for vector index
25#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
26pub struct IndexConfig {
27    /// Index type to use
28    pub index_type: IndexType,
29    /// Maximum number of connections for each node (for HNSW)
30    pub max_connections: usize,
31    /// Construction parameter (for HNSW)
32    pub ef_construction: usize,
33    /// Search parameter (for HNSW)
34    pub ef_search: usize,
35    /// Distance metric to use
36    pub distance_metric: DistanceMetric,
37    /// Whether to enable parallel operations
38    pub parallel: bool,
39}
40
41impl Default for IndexConfig {
42    fn default() -> Self {
43        Self {
44            index_type: IndexType::Hnsw,
45            max_connections: 16,
46            ef_construction: 200,
47            ef_search: 50,
48            distance_metric: DistanceMetric::Cosine,
49            parallel: true,
50        }
51    }
52}
53
54/// Available index types
55#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
56pub enum IndexType {
57    /// Hierarchical Navigable Small World
58    Hnsw,
59    /// Simple flat index (brute force)
60    Flat,
61    /// IVF (Inverted File) index
62    Ivf,
63    /// Product Quantization
64    PQ,
65}
66
67/// Distance metrics supported
68#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
69pub enum DistanceMetric {
70    /// Cosine distance (1 - cosine_similarity)
71    Cosine,
72    /// Euclidean (L2) distance
73    Euclidean,
74    /// Manhattan (L1) distance
75    Manhattan,
76    /// Dot product (negative for max-heap behavior)
77    DotProduct,
78}
79
80impl DistanceMetric {
81    /// Calculate distance between two vectors
82    pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
83        use oxirs_core::simd::SimdOps;
84
85        match self {
86            DistanceMetric::Cosine => f32::cosine_distance(a, b),
87            DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
88            DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
89            DistanceMetric::DotProduct => -f32::dot(a, b), // Negative for max-heap
90        }
91    }
92
93    /// Calculate distance between two Vector objects
94    pub fn distance_vectors(&self, a: &Vector, b: &Vector) -> f32 {
95        let a_f32 = a.as_f32();
96        let b_f32 = b.as_f32();
97        self.distance(&a_f32, &b_f32)
98    }
99}
100
101/// Search result with distance/score
102#[derive(Debug, Clone, PartialEq)]
103pub struct SearchResult {
104    pub uri: String,
105    pub distance: f32,
106    pub score: f32,
107    pub metadata: Option<HashMap<String, String>>,
108}
109
110impl Eq for SearchResult {}
111
112impl Ord for SearchResult {
113    fn cmp(&self, other: &Self) -> Ordering {
114        self.distance
115            .partial_cmp(&other.distance)
116            .unwrap_or(Ordering::Equal)
117    }
118}
119
120impl PartialOrd for SearchResult {
121    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
122        Some(self.cmp(other))
123    }
124}
125
126/// Advanced vector index with multiple implementations
127pub struct AdvancedVectorIndex {
128    config: IndexConfig,
129    vectors: Vec<(String, Vector)>,
130    uri_to_id: HashMap<String, usize>,
131    hnsw_index: Option<HnswIndex>,
132    /// Trained IVF index (populated after `build()` when `IndexType::Ivf`)
133    ivf_index: Option<IvfIndex>,
134    /// Trained PQ index (populated after `build()` when `IndexType::PQ`)
135    pq_index: Option<PQIndex>,
136    dimensions: Option<usize>,
137}
138
139impl AdvancedVectorIndex {
140    pub fn new(config: IndexConfig) -> Self {
141        Self {
142            config,
143            vectors: Vec::new(),
144            uri_to_id: HashMap::new(),
145            hnsw_index: None,
146            ivf_index: None,
147            pq_index: None,
148            dimensions: None,
149        }
150    }
151
152    /// Build the index after adding all vectors
153    pub fn build(&mut self) -> Result<()> {
154        if self.vectors.is_empty() {
155            return Ok(());
156        }
157
158        match self.config.index_type {
159            IndexType::Hnsw => {
160                self.build_hnsw_index()?;
161            }
162            IndexType::Flat => {
163                // No special building needed for flat index
164            }
165            IndexType::Ivf => {
166                self.build_ivf_index()?;
167            }
168            IndexType::PQ => {
169                self.build_pq_index()?;
170            }
171        }
172
173        Ok(())
174    }
175
176    fn build_hnsw_index(&mut self) -> Result<()> {
177        if self.dimensions.is_some() {
178            let hnsw_config = HnswConfig {
179                m: self.config.max_connections,
180                m_l0: self.config.max_connections * 2,
181                ef_construction: self.config.ef_construction,
182                ef: self.config.ef_search,
183                ..HnswConfig::default()
184            };
185
186            let mut hnsw = HnswIndex::new_cpu_only(hnsw_config);
187
188            for (uri, vector) in &self.vectors {
189                hnsw.insert(uri.clone(), vector.clone())?;
190            }
191
192            self.hnsw_index = Some(hnsw);
193        }
194
195        Ok(())
196    }
197
198    /// Build an IVF index by training on all stored vectors and then adding them.
199    ///
200    /// Uses the default `IvfConfig` with `n_clusters` derived from the number of
201    /// stored vectors (capped at 256) so that every cluster has at least one
202    /// training sample.
203    fn build_ivf_index(&mut self) -> Result<()> {
204        let training_vectors: Vec<Vector> = self.vectors.iter().map(|(_, v)| v.clone()).collect();
205        let n_clusters = (self.vectors.len() / 4).clamp(2, 256);
206
207        let config = IvfConfig {
208            n_clusters,
209            n_probes: (n_clusters / 8).max(1),
210            ..Default::default()
211        };
212        let mut ivf = IvfIndex::new(config)?;
213        ivf.train(&training_vectors)?;
214
215        for (uri, vector) in &self.vectors {
216            ivf.insert(uri.clone(), vector.clone())?;
217        }
218
219        self.ivf_index = Some(ivf);
220        Ok(())
221    }
222
223    /// Build a PQ index by training on all stored vectors and then encoding them.
224    ///
225    /// Chooses `n_subquantizers` so that it evenly divides the vector dimension
226    /// and stays within [1, 8].  Falls back gracefully with a descriptive error
227    /// if dimensions are not yet known.
228    fn build_pq_index(&mut self) -> Result<()> {
229        let dims = self
230            .dimensions
231            .ok_or_else(|| anyhow!("Cannot build PQ index: no vectors have been inserted yet"))?;
232
233        // Pick the largest power-of-two divisor of dims in [1, 8]
234        let n_subquantizers = [8usize, 4, 2, 1]
235            .iter()
236            .copied()
237            .find(|&s| dims % s == 0)
238            .unwrap_or(1);
239
240        let config = PQConfig {
241            n_subquantizers,
242            n_centroids: 16, // small for small test sets
243            ..Default::default()
244        };
245        let mut pq = PQIndex::new(config);
246        let training_vectors: Vec<Vector> = self.vectors.iter().map(|(_, v)| v.clone()).collect();
247        pq.train(&training_vectors)?;
248
249        for (uri, vector) in &self.vectors {
250            pq.insert(uri.clone(), vector.clone())?;
251        }
252
253        self.pq_index = Some(pq);
254        Ok(())
255    }
256
257    /// Add metadata to a vector
258    pub fn add_metadata(&mut self, _uri: &str, _metadata: HashMap<String, String>) -> Result<()> {
259        // For now, we'll store metadata separately
260        // In a full implementation, this would be integrated with the index
261        Ok(())
262    }
263
264    /// Search with advanced parameters
265    pub fn search_advanced(
266        &self,
267        query: &Vector,
268        k: usize,
269        _ef: Option<usize>,
270        filter: Option<FilterFunction>,
271    ) -> Result<Vec<SearchResult>> {
272        match self.config.index_type {
273            IndexType::Hnsw => self.search_hnsw(query, k),
274            IndexType::Ivf => self.search_ivf(query, k),
275            IndexType::PQ => self.search_pq(query, k),
276            IndexType::Flat => self.search_flat(query, k, filter),
277        }
278    }
279
280    fn search_hnsw(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
281        if let Some(ref hnsw) = self.hnsw_index {
282            let results = hnsw.search_knn(query, k)?;
283
284            Ok(results
285                .into_iter()
286                .map(|(uri, distance)| SearchResult {
287                    uri,
288                    distance,
289                    score: 1.0 - distance,
290                    metadata: None,
291                })
292                .collect())
293        } else {
294            Err(anyhow!("HNSW index not built"))
295        }
296    }
297
298    fn search_ivf(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
299        let ivf = self
300            .ivf_index
301            .as_ref()
302            .ok_or_else(|| anyhow!("IVF index not built — call build() first"))?;
303        let results = ivf.search_knn(query, k)?;
304        Ok(results
305            .into_iter()
306            .map(|(uri, distance)| SearchResult {
307                uri,
308                score: 1.0 - distance,
309                distance,
310                metadata: None,
311            })
312            .collect())
313    }
314
315    fn search_pq(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
316        let pq = self
317            .pq_index
318            .as_ref()
319            .ok_or_else(|| anyhow!("PQ index not built — call build() first"))?;
320        let results = pq.search_knn(query, k)?;
321        Ok(results
322            .into_iter()
323            .map(|(uri, distance)| SearchResult {
324                uri,
325                score: 1.0 - distance,
326                distance,
327                metadata: None,
328            })
329            .collect())
330    }
331
332    fn search_flat(
333        &self,
334        query: &Vector,
335        k: usize,
336        filter: Option<FilterFunction>,
337    ) -> Result<Vec<SearchResult>> {
338        if self.config.parallel && self.vectors.len() > 1000 {
339            // For parallel search, we need Send + Sync filter
340            if filter.is_some() {
341                // Fall back to sequential if filter is present but not Send + Sync
342                self.search_flat_sequential(query, k, filter)
343            } else {
344                self.search_flat_parallel(query, k, None)
345            }
346        } else {
347            self.search_flat_sequential(query, k, filter)
348        }
349    }
350
351    fn search_flat_sequential(
352        &self,
353        query: &Vector,
354        k: usize,
355        filter: Option<FilterFunction>,
356    ) -> Result<Vec<SearchResult>> {
357        let mut heap = BinaryHeap::new();
358
359        for (uri, vector) in &self.vectors {
360            if let Some(ref filter_fn) = filter {
361                if !filter_fn(uri) {
362                    continue;
363                }
364            }
365
366            let distance = self.config.distance_metric.distance_vectors(query, vector);
367
368            if heap.len() < k {
369                heap.push(std::cmp::Reverse(SearchResult {
370                    uri: uri.clone(),
371                    distance,
372                    score: 1.0 - distance, // Convert distance to similarity score
373                    metadata: None,
374                }));
375            } else if let Some(std::cmp::Reverse(worst)) = heap.peek() {
376                if distance < worst.distance {
377                    heap.pop();
378                    heap.push(std::cmp::Reverse(SearchResult {
379                        uri: uri.clone(),
380                        distance,
381                        score: 1.0 - distance, // Convert distance to similarity score
382                        metadata: None,
383                    }));
384                }
385            }
386        }
387
388        let mut results: Vec<SearchResult> = heap.into_iter().map(|r| r.0).collect();
389        results.sort_by(|a, b| {
390            a.distance
391                .partial_cmp(&b.distance)
392                .unwrap_or(std::cmp::Ordering::Equal)
393        });
394
395        Ok(results)
396    }
397
398    fn search_flat_parallel(
399        &self,
400        query: &Vector,
401        k: usize,
402        filter: Option<FilterFunctionSync>,
403    ) -> Result<Vec<SearchResult>> {
404        // Split vectors into chunks for parallel processing
405        let chunk_size = (self.vectors.len() / num_threads()).max(100);
406
407        // Use Arc for thread-safe sharing of the filter
408        let filter_arc = filter.map(Arc::new);
409
410        // Process chunks in parallel and collect top-k from each
411        let partial_results: Vec<Vec<SearchResult>> = self
412            .vectors
413            .par_chunks(chunk_size)
414            .map(|chunk| {
415                let mut local_heap = BinaryHeap::new();
416                let filter_ref = filter_arc.as_ref();
417
418                for (uri, vector) in chunk {
419                    if let Some(filter_fn) = filter_ref {
420                        if !filter_fn(uri) {
421                            continue;
422                        }
423                    }
424
425                    let distance = self.config.distance_metric.distance_vectors(query, vector);
426
427                    if local_heap.len() < k {
428                        local_heap.push(std::cmp::Reverse(SearchResult {
429                            uri: uri.clone(),
430                            distance,
431                            score: 1.0 - distance, // Convert distance to similarity score
432                            metadata: None,
433                        }));
434                    } else if let Some(std::cmp::Reverse(worst)) = local_heap.peek() {
435                        if distance < worst.distance {
436                            local_heap.pop();
437                            local_heap.push(std::cmp::Reverse(SearchResult {
438                                uri: uri.clone(),
439                                distance,
440                                score: 1.0 - distance, // Convert distance to similarity score
441                                metadata: None,
442                            }));
443                        }
444                    }
445                }
446
447                local_heap
448                    .into_sorted_vec()
449                    .into_iter()
450                    .map(|r| r.0)
451                    .collect()
452            })
453            .collect();
454
455        // Merge results from all chunks
456        let mut final_heap = BinaryHeap::new();
457        for partial in partial_results {
458            for result in partial {
459                if final_heap.len() < k {
460                    final_heap.push(std::cmp::Reverse(result));
461                } else if let Some(std::cmp::Reverse(worst)) = final_heap.peek() {
462                    if result.distance < worst.distance {
463                        final_heap.pop();
464                        final_heap.push(std::cmp::Reverse(result));
465                    }
466                }
467            }
468        }
469
470        let mut results: Vec<SearchResult> = final_heap.into_iter().map(|r| r.0).collect();
471        results.sort_by(|a, b| {
472            a.distance
473                .partial_cmp(&b.distance)
474                .unwrap_or(std::cmp::Ordering::Equal)
475        });
476
477        Ok(results)
478    }
479
480    /// Get index statistics
481    pub fn stats(&self) -> IndexStats {
482        IndexStats {
483            num_vectors: self.vectors.len(),
484            dimensions: self.dimensions.unwrap_or(0),
485            index_type: self.config.index_type,
486            memory_usage: self.estimate_memory_usage(),
487        }
488    }
489
490    fn estimate_memory_usage(&self) -> usize {
491        let vector_memory = self.vectors.len()
492            * (std::mem::size_of::<String>()
493                + self.dimensions.unwrap_or(0) * std::mem::size_of::<f32>());
494
495        let uri_map_memory =
496            self.uri_to_id.len() * (std::mem::size_of::<String>() + std::mem::size_of::<usize>());
497
498        vector_memory + uri_map_memory
499    }
500
501    /// Get the number of vectors in the index
502    pub fn len(&self) -> usize {
503        self.vectors.len()
504    }
505
506    /// Check if the index is empty
507    pub fn is_empty(&self) -> bool {
508        self.vectors.is_empty()
509    }
510
511    /// Add a vector with RDF triple and metadata (for compatibility with tests)
512    pub fn add(
513        &mut self,
514        id: String,
515        vector: Vec<f32>,
516        _triple: Triple,
517        _metadata: HashMap<String, String>,
518    ) -> Result<()> {
519        let vector_obj = Vector::new(vector);
520        self.insert(id, vector_obj)
521    }
522
523    /// Search for nearest neighbors (for compatibility with tests)
524    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
525        let query_vector = Vector::new(query.to_vec());
526        let results = self.search_advanced(&query_vector, k, None, None)?;
527        Ok(results)
528    }
529}
530
531impl VectorIndex for AdvancedVectorIndex {
532    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
533        if let Some(dims) = self.dimensions {
534            if vector.dimensions != dims {
535                return Err(anyhow!(
536                    "Vector dimensions ({}) don't match index dimensions ({})",
537                    vector.dimensions,
538                    dims
539                ));
540            }
541        } else {
542            self.dimensions = Some(vector.dimensions);
543        }
544
545        let id = self.vectors.len();
546        self.uri_to_id.insert(uri.clone(), id);
547        self.vectors.push((uri, vector));
548
549        Ok(())
550    }
551
552    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
553        let results = self.search_advanced(query, k, None, None)?;
554        Ok(results.into_iter().map(|r| (r.uri, r.distance)).collect())
555    }
556
557    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
558        let mut results = Vec::new();
559
560        for (uri, vector) in &self.vectors {
561            let distance = self.config.distance_metric.distance_vectors(query, vector);
562            if distance <= threshold {
563                results.push((uri.clone(), distance));
564            }
565        }
566
567        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
568        Ok(results)
569    }
570
571    fn get_vector(&self, uri: &str) -> Option<&Vector> {
572        // For AdvancedVectorIndex, vectors are stored in the vectors field
573        // regardless of the index type being used
574        self.vectors.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
575    }
576}
577
578/// Index performance statistics
579#[derive(Debug, Clone)]
580pub struct IndexStats {
581    pub num_vectors: usize,
582    pub dimensions: usize,
583    pub index_type: IndexType,
584    pub memory_usage: usize,
585}
586
587/// Quantized vector index for memory efficiency
588pub struct QuantizedVectorIndex {
589    config: IndexConfig,
590    quantized_vectors: Vec<Vec<u8>>,
591    centroids: Vec<Vector>,
592    uri_to_id: HashMap<String, usize>,
593    dimensions: Option<usize>,
594}
595
596impl QuantizedVectorIndex {
597    pub fn new(config: IndexConfig, num_centroids: usize) -> Self {
598        Self {
599            config,
600            quantized_vectors: Vec::new(),
601            centroids: Vec::with_capacity(num_centroids),
602            uri_to_id: HashMap::new(),
603            dimensions: None,
604        }
605    }
606
607    /// Train quantization centroids using k-means
608    pub fn train_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
609        if training_vectors.is_empty() {
610            return Err(anyhow!("No training vectors provided"));
611        }
612
613        let dimensions = training_vectors[0].dimensions;
614        self.dimensions = Some(dimensions);
615
616        // Simple k-means clustering for centroids
617        self.centroids = kmeans_clustering(training_vectors, self.centroids.capacity())?;
618
619        Ok(())
620    }
621
622    fn quantize_vector(&self, vector: &Vector) -> Vec<u8> {
623        let mut quantized = Vec::new();
624
625        // Find nearest centroid for each dimension chunk
626        let chunk_size = vector.dimensions / self.centroids.len().max(1);
627
628        let vector_f32 = vector.as_f32();
629        for chunk in vector_f32.chunks(chunk_size) {
630            let mut best_centroid = 0u8;
631            let mut best_distance = f32::INFINITY;
632
633            for (i, centroid) in self.centroids.iter().enumerate() {
634                let centroid_f32 = centroid.as_f32();
635                let centroid_chunk = &centroid_f32[0..chunk.len().min(centroid.dimensions)];
636                use oxirs_core::simd::SimdOps;
637                let distance = f32::euclidean_distance(chunk, centroid_chunk);
638                if distance < best_distance {
639                    best_distance = distance;
640                    best_centroid = i as u8;
641                }
642            }
643
644            quantized.push(best_centroid);
645        }
646
647        quantized
648    }
649}
650
651impl VectorIndex for QuantizedVectorIndex {
652    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
653        if self.centroids.is_empty() {
654            return Err(anyhow!(
655                "Quantization not trained. Call train_quantization first."
656            ));
657        }
658
659        let id = self.quantized_vectors.len();
660        self.uri_to_id.insert(uri.clone(), id);
661
662        let quantized = self.quantize_vector(&vector);
663        self.quantized_vectors.push(quantized);
664
665        Ok(())
666    }
667
668    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
669        let query_quantized = self.quantize_vector(query);
670        let mut results = Vec::new();
671
672        for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
673            let distance = hamming_distance(&query_quantized, quantized);
674            results.push((uri.clone(), distance));
675        }
676
677        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
678        results.truncate(k);
679
680        Ok(results)
681    }
682
683    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
684        let query_quantized = self.quantize_vector(query);
685        let mut results = Vec::new();
686
687        for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
688            let distance = hamming_distance(&query_quantized, quantized);
689            if distance <= threshold {
690                results.push((uri.clone(), distance));
691            }
692        }
693
694        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
695        Ok(results)
696    }
697
698    fn get_vector(&self, _uri: &str) -> Option<&Vector> {
699        // Quantized index doesn't store original vectors
700        // Return None as we only have quantized representations
701        None
702    }
703}
704
705// Helper functions that don't have SIMD equivalents
706
707fn hamming_distance(a: &[u8], b: &[u8]) -> f32 {
708    a.iter().zip(b).filter(|(x, y)| x != y).count() as f32
709}
710
711// K-means clustering for quantization
712fn kmeans_clustering(vectors: &[Vector], k: usize) -> Result<Vec<Vector>> {
713    if vectors.is_empty() || k == 0 {
714        return Ok(Vec::new());
715    }
716
717    let dimensions = vectors[0].dimensions;
718    let mut centroids = Vec::with_capacity(k);
719
720    // Initialize centroids randomly
721    for i in 0..k {
722        let idx = i % vectors.len();
723        centroids.push(vectors[idx].clone());
724    }
725
726    // Simple k-means iterations
727    for _ in 0..10 {
728        let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); k];
729
730        // Assign vectors to nearest centroid
731        for vector in vectors {
732            let mut best_centroid = 0;
733            let mut best_distance = f32::INFINITY;
734
735            for (i, centroid) in centroids.iter().enumerate() {
736                let vector_f32 = vector.as_f32();
737                let centroid_f32 = centroid.as_f32();
738                use oxirs_core::simd::SimdOps;
739                let distance = f32::euclidean_distance(&vector_f32, &centroid_f32);
740                if distance < best_distance {
741                    best_distance = distance;
742                    best_centroid = i;
743                }
744            }
745
746            clusters[best_centroid].push(vector);
747        }
748
749        // Update centroids
750        for (i, cluster) in clusters.iter().enumerate() {
751            if !cluster.is_empty() {
752                let mut new_centroid = vec![0.0; dimensions];
753
754                for vector in cluster {
755                    let vector_f32 = vector.as_f32();
756                    for (j, &value) in vector_f32.iter().enumerate() {
757                        new_centroid[j] += value;
758                    }
759                }
760
761                for value in &mut new_centroid {
762                    *value /= cluster.len() as f32;
763                }
764
765                centroids[i] = Vector::new(new_centroid);
766            }
767        }
768    }
769
770    Ok(centroids)
771}
772
773/// Multi-index system that combines multiple index types
774pub struct MultiIndex {
775    indices: HashMap<String, Box<dyn VectorIndex>>,
776    default_index: String,
777}
778
779impl MultiIndex {
780    pub fn new() -> Self {
781        Self {
782            indices: HashMap::new(),
783            default_index: String::new(),
784        }
785    }
786
787    pub fn add_index(&mut self, name: String, index: Box<dyn VectorIndex>) {
788        if self.indices.is_empty() {
789            self.default_index = name.clone();
790        }
791        self.indices.insert(name, index);
792    }
793
794    pub fn set_default(&mut self, name: &str) -> Result<()> {
795        if self.indices.contains_key(name) {
796            self.default_index = name.to_string();
797            Ok(())
798        } else {
799            Err(anyhow!("Index '{}' not found", name))
800        }
801    }
802
803    pub fn search_index(
804        &self,
805        index_name: &str,
806        query: &Vector,
807        k: usize,
808    ) -> Result<Vec<(String, f32)>> {
809        if let Some(index) = self.indices.get(index_name) {
810            index.search_knn(query, k)
811        } else {
812            Err(anyhow!("Index '{}' not found", index_name))
813        }
814    }
815}
816
817impl Default for MultiIndex {
818    fn default() -> Self {
819        Self::new()
820    }
821}
822
823impl VectorIndex for MultiIndex {
824    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
825        if let Some(index) = self.indices.get_mut(&self.default_index) {
826            index.insert(uri, vector)
827        } else {
828            Err(anyhow!("No default index set"))
829        }
830    }
831
832    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
833        if let Some(index) = self.indices.get(&self.default_index) {
834            index.search_knn(query, k)
835        } else {
836            Err(anyhow!("No default index set"))
837        }
838    }
839
840    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
841        if let Some(index) = self.indices.get(&self.default_index) {
842            index.search_threshold(query, threshold)
843        } else {
844            Err(anyhow!("No default index set"))
845        }
846    }
847
848    fn get_vector(&self, uri: &str) -> Option<&Vector> {
849        if let Some(index) = self.indices.get(&self.default_index) {
850            index.get_vector(uri)
851        } else {
852            None
853        }
854    }
855}
856
857#[cfg(test)]
858mod tests {
859    use super::*;
860
861    fn sample_vectors() -> Vec<(&'static str, Vector)> {
862        vec![
863            (
864                "http://example.org/a",
865                Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
866            ),
867            (
868                "http://example.org/b",
869                Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
870            ),
871            (
872                "http://example.org/c",
873                Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
874            ),
875            (
876                "http://example.org/d",
877                Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
878            ),
879            (
880                "http://example.org/e",
881                Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
882            ),
883            (
884                "http://example.org/f",
885                Vector::new(vec![-0.5, 0.5, -0.5, 0.5]),
886            ),
887            (
888                "http://example.org/g",
889                Vector::new(vec![1.0, 1.0, 0.0, 0.0]),
890            ),
891            (
892                "http://example.org/h",
893                Vector::new(vec![0.0, 0.0, 1.0, 1.0]),
894            ),
895        ]
896    }
897
898    fn build_index(index_type: IndexType) -> Result<AdvancedVectorIndex> {
899        let config = IndexConfig {
900            index_type,
901            ..Default::default()
902        };
903        let mut idx = AdvancedVectorIndex::new(config);
904        for (uri, vec) in sample_vectors() {
905            idx.insert(uri.to_string(), vec)?;
906        }
907        idx.build()?;
908        Ok(idx)
909    }
910
911    #[test]
912    fn test_ivf_build_and_search() -> Result<()> {
913        let idx = build_index(IndexType::Ivf)?;
914        assert!(idx.ivf_index.is_some(), "IVF index should be built");
915
916        let query = Vector::new(vec![1.0, 0.0, 0.0, 1.0]);
917        let results = idx.search(&query.as_f32(), 3)?;
918        assert!(!results.is_empty(), "IVF search should return results");
919        Ok(())
920    }
921
922    #[test]
923    fn test_pq_build_and_search() -> Result<()> {
924        let idx = build_index(IndexType::PQ)?;
925        assert!(idx.pq_index.is_some(), "PQ index should be built");
926
927        let query = Vector::new(vec![0.0, 1.0, 1.0, 0.0]);
928        let results = idx.search(&query.as_f32(), 3)?;
929        assert!(!results.is_empty(), "PQ search should return results");
930        Ok(())
931    }
932
933    #[test]
934    fn test_flat_search_unchanged() -> Result<()> {
935        let idx = build_index(IndexType::Flat)?;
936        let query = Vector::new(vec![1.0, 0.0, 0.0, 1.0]);
937        let results = idx.search(&query.as_f32(), 2)?;
938        assert_eq!(
939            results.len(),
940            2,
941            "Flat search should return exactly k results"
942        );
943        Ok(())
944    }
945}