Skip to main content

oxirs_vec/
index.rs

1//! Advanced vector indexing with HNSW and other efficient algorithms
2
3use crate::Vector;
4
5// Re-export VectorIndex trait for use by other modules
6pub use crate::VectorIndex;
7use anyhow::{anyhow, Result};
8use oxirs_core::parallel::*;
9use oxirs_core::Triple;
10use serde::{Deserialize, Serialize};
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::sync::Arc;
14
15#[cfg(feature = "hnsw")]
16use hnsw_rs::prelude::*;
17
18/// Type alias for filter functions
19pub type FilterFunction = Box<dyn Fn(&str) -> bool>;
20/// Type alias for filter functions with Send + Sync
21pub type FilterFunctionSync = Box<dyn Fn(&str) -> bool + Send + Sync>;
22
23/// Configuration for vector index
24#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
25pub struct IndexConfig {
26    /// Index type to use
27    pub index_type: IndexType,
28    /// Maximum number of connections for each node (for HNSW)
29    pub max_connections: usize,
30    /// Construction parameter (for HNSW)
31    pub ef_construction: usize,
32    /// Search parameter (for HNSW)
33    pub ef_search: usize,
34    /// Distance metric to use
35    pub distance_metric: DistanceMetric,
36    /// Whether to enable parallel operations
37    pub parallel: bool,
38}
39
40impl Default for IndexConfig {
41    fn default() -> Self {
42        Self {
43            index_type: IndexType::Hnsw,
44            max_connections: 16,
45            ef_construction: 200,
46            ef_search: 50,
47            distance_metric: DistanceMetric::Cosine,
48            parallel: true,
49        }
50    }
51}
52
53/// Available index types
54#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
55pub enum IndexType {
56    /// Hierarchical Navigable Small World
57    Hnsw,
58    /// Simple flat index (brute force)
59    Flat,
60    /// IVF (Inverted File) index
61    Ivf,
62    /// Product Quantization
63    PQ,
64}
65
66/// Distance metrics supported
67#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
68pub enum DistanceMetric {
69    /// Cosine distance (1 - cosine_similarity)
70    Cosine,
71    /// Euclidean (L2) distance
72    Euclidean,
73    /// Manhattan (L1) distance
74    Manhattan,
75    /// Dot product (negative for max-heap behavior)
76    DotProduct,
77}
78
79impl DistanceMetric {
80    /// Calculate distance between two vectors
81    pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
82        use oxirs_core::simd::SimdOps;
83
84        match self {
85            DistanceMetric::Cosine => f32::cosine_distance(a, b),
86            DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
87            DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
88            DistanceMetric::DotProduct => -f32::dot(a, b), // Negative for max-heap
89        }
90    }
91
92    /// Calculate distance between two Vector objects
93    pub fn distance_vectors(&self, a: &Vector, b: &Vector) -> f32 {
94        let a_f32 = a.as_f32();
95        let b_f32 = b.as_f32();
96        self.distance(&a_f32, &b_f32)
97    }
98}
99
100/// Search result with distance/score
101#[derive(Debug, Clone, PartialEq)]
102pub struct SearchResult {
103    pub uri: String,
104    pub distance: f32,
105    pub score: f32,
106    pub metadata: Option<HashMap<String, String>>,
107}
108
109impl Eq for SearchResult {}
110
111impl Ord for SearchResult {
112    fn cmp(&self, other: &Self) -> Ordering {
113        self.distance
114            .partial_cmp(&other.distance)
115            .unwrap_or(Ordering::Equal)
116    }
117}
118
119impl PartialOrd for SearchResult {
120    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
121        Some(self.cmp(other))
122    }
123}
124
125/// Advanced vector index with multiple implementations
126pub struct AdvancedVectorIndex {
127    config: IndexConfig,
128    vectors: Vec<(String, Vector)>,
129    uri_to_id: HashMap<String, usize>,
130    #[cfg(feature = "hnsw")]
131    hnsw_index: Option<Hnsw<'static, f32, DistCosine>>,
132    dimensions: Option<usize>,
133}
134
135impl AdvancedVectorIndex {
136    pub fn new(config: IndexConfig) -> Self {
137        Self {
138            config,
139            vectors: Vec::new(),
140            uri_to_id: HashMap::new(),
141            #[cfg(feature = "hnsw")]
142            hnsw_index: None,
143            dimensions: None,
144        }
145    }
146
147    /// Build the index after adding all vectors
148    pub fn build(&mut self) -> Result<()> {
149        if self.vectors.is_empty() {
150            return Ok(());
151        }
152
153        match self.config.index_type {
154            IndexType::Hnsw => {
155                #[cfg(feature = "hnsw")]
156                {
157                    self.build_hnsw_index()?;
158                }
159                #[cfg(not(feature = "hnsw"))]
160                {
161                    return Err(anyhow!("HNSW feature not enabled"));
162                }
163            }
164            IndexType::Flat => {
165                // No special building needed for flat index
166            }
167            IndexType::Ivf | IndexType::PQ => {
168                return Err(anyhow!("IVF and PQ indices not yet implemented"));
169            }
170        }
171
172        Ok(())
173    }
174
175    #[cfg(feature = "hnsw")]
176    fn build_hnsw_index(&mut self) -> Result<()> {
177        if let Some(_dimensions) = self.dimensions {
178            let hnsw = Hnsw::<f32, DistCosine>::new(
179                self.config.max_connections,
180                self.vectors.len(),
181                16, // layer factor
182                self.config.ef_construction,
183                DistCosine,
184            );
185
186            for (id, (_, vector)) in self.vectors.iter().enumerate() {
187                let vector_f32 = vector.as_f32();
188                hnsw.insert((&vector_f32, id));
189            }
190
191            self.hnsw_index = Some(hnsw);
192        }
193
194        Ok(())
195    }
196
197    /// Add metadata to a vector
198    pub fn add_metadata(&mut self, _uri: &str, _metadata: HashMap<String, String>) -> Result<()> {
199        // For now, we'll store metadata separately
200        // In a full implementation, this would be integrated with the index
201        Ok(())
202    }
203
204    /// Search with advanced parameters
205    pub fn search_advanced(
206        &self,
207        query: &Vector,
208        k: usize,
209        ef: Option<usize>,
210        filter: Option<FilterFunction>,
211    ) -> Result<Vec<SearchResult>> {
212        match self.config.index_type {
213            IndexType::Hnsw => {
214                #[cfg(feature = "hnsw")]
215                {
216                    self.search_hnsw(query, k, ef)
217                }
218                #[cfg(not(feature = "hnsw"))]
219                {
220                    let _ = ef;
221                    self.search_flat(query, k, filter)
222                }
223            }
224            _ => self.search_flat(query, k, filter),
225        }
226    }
227
228    #[cfg(feature = "hnsw")]
229    fn search_hnsw(
230        &self,
231        query: &Vector,
232        k: usize,
233        ef: Option<usize>,
234    ) -> Result<Vec<SearchResult>> {
235        if let Some(ref hnsw) = self.hnsw_index {
236            let search_ef = ef.unwrap_or(self.config.ef_search);
237            let query_f32 = query.as_f32();
238            let results = hnsw.search(&query_f32, k, search_ef);
239
240            Ok(results
241                .into_iter()
242                .map(|result| SearchResult {
243                    uri: self.vectors[result.d_id].0.clone(),
244                    distance: result.distance,
245                    score: 1.0 - result.distance, // Convert distance to similarity score
246                    metadata: None,
247                })
248                .collect())
249        } else {
250            Err(anyhow!("HNSW index not built"))
251        }
252    }
253
254    fn search_flat(
255        &self,
256        query: &Vector,
257        k: usize,
258        filter: Option<FilterFunction>,
259    ) -> Result<Vec<SearchResult>> {
260        if self.config.parallel && self.vectors.len() > 1000 {
261            // For parallel search, we need Send + Sync filter
262            if filter.is_some() {
263                // Fall back to sequential if filter is present but not Send + Sync
264                self.search_flat_sequential(query, k, filter)
265            } else {
266                self.search_flat_parallel(query, k, None)
267            }
268        } else {
269            self.search_flat_sequential(query, k, filter)
270        }
271    }
272
273    fn search_flat_sequential(
274        &self,
275        query: &Vector,
276        k: usize,
277        filter: Option<FilterFunction>,
278    ) -> Result<Vec<SearchResult>> {
279        let mut heap = BinaryHeap::new();
280
281        for (uri, vector) in &self.vectors {
282            if let Some(ref filter_fn) = filter {
283                if !filter_fn(uri) {
284                    continue;
285                }
286            }
287
288            let distance = self.config.distance_metric.distance_vectors(query, vector);
289
290            if heap.len() < k {
291                heap.push(std::cmp::Reverse(SearchResult {
292                    uri: uri.clone(),
293                    distance,
294                    score: 1.0 - distance, // Convert distance to similarity score
295                    metadata: None,
296                }));
297            } else if let Some(std::cmp::Reverse(worst)) = heap.peek() {
298                if distance < worst.distance {
299                    heap.pop();
300                    heap.push(std::cmp::Reverse(SearchResult {
301                        uri: uri.clone(),
302                        distance,
303                        score: 1.0 - distance, // Convert distance to similarity score
304                        metadata: None,
305                    }));
306                }
307            }
308        }
309
310        let mut results: Vec<SearchResult> = heap.into_iter().map(|r| r.0).collect();
311        results.sort_by(|a, b| {
312            a.distance
313                .partial_cmp(&b.distance)
314                .unwrap_or(std::cmp::Ordering::Equal)
315        });
316
317        Ok(results)
318    }
319
320    fn search_flat_parallel(
321        &self,
322        query: &Vector,
323        k: usize,
324        filter: Option<FilterFunctionSync>,
325    ) -> Result<Vec<SearchResult>> {
326        // Split vectors into chunks for parallel processing
327        let chunk_size = (self.vectors.len() / num_threads()).max(100);
328
329        // Use Arc for thread-safe sharing of the filter
330        let filter_arc = filter.map(Arc::new);
331
332        // Process chunks in parallel and collect top-k from each
333        let partial_results: Vec<Vec<SearchResult>> = self
334            .vectors
335            .par_chunks(chunk_size)
336            .map(|chunk| {
337                let mut local_heap = BinaryHeap::new();
338                let filter_ref = filter_arc.as_ref();
339
340                for (uri, vector) in chunk {
341                    if let Some(filter_fn) = filter_ref {
342                        if !filter_fn(uri) {
343                            continue;
344                        }
345                    }
346
347                    let distance = self.config.distance_metric.distance_vectors(query, vector);
348
349                    if local_heap.len() < k {
350                        local_heap.push(std::cmp::Reverse(SearchResult {
351                            uri: uri.clone(),
352                            distance,
353                            score: 1.0 - distance, // Convert distance to similarity score
354                            metadata: None,
355                        }));
356                    } else if let Some(std::cmp::Reverse(worst)) = local_heap.peek() {
357                        if distance < worst.distance {
358                            local_heap.pop();
359                            local_heap.push(std::cmp::Reverse(SearchResult {
360                                uri: uri.clone(),
361                                distance,
362                                score: 1.0 - distance, // Convert distance to similarity score
363                                metadata: None,
364                            }));
365                        }
366                    }
367                }
368
369                local_heap
370                    .into_sorted_vec()
371                    .into_iter()
372                    .map(|r| r.0)
373                    .collect()
374            })
375            .collect();
376
377        // Merge results from all chunks
378        let mut final_heap = BinaryHeap::new();
379        for partial in partial_results {
380            for result in partial {
381                if final_heap.len() < k {
382                    final_heap.push(std::cmp::Reverse(result));
383                } else if let Some(std::cmp::Reverse(worst)) = final_heap.peek() {
384                    if result.distance < worst.distance {
385                        final_heap.pop();
386                        final_heap.push(std::cmp::Reverse(result));
387                    }
388                }
389            }
390        }
391
392        let mut results: Vec<SearchResult> = final_heap.into_iter().map(|r| r.0).collect();
393        results.sort_by(|a, b| {
394            a.distance
395                .partial_cmp(&b.distance)
396                .unwrap_or(std::cmp::Ordering::Equal)
397        });
398
399        Ok(results)
400    }
401
402    /// Get index statistics
403    pub fn stats(&self) -> IndexStats {
404        IndexStats {
405            num_vectors: self.vectors.len(),
406            dimensions: self.dimensions.unwrap_or(0),
407            index_type: self.config.index_type,
408            memory_usage: self.estimate_memory_usage(),
409        }
410    }
411
412    fn estimate_memory_usage(&self) -> usize {
413        let vector_memory = self.vectors.len()
414            * (std::mem::size_of::<String>()
415                + self.dimensions.unwrap_or(0) * std::mem::size_of::<f32>());
416
417        let uri_map_memory =
418            self.uri_to_id.len() * (std::mem::size_of::<String>() + std::mem::size_of::<usize>());
419
420        vector_memory + uri_map_memory
421    }
422
423    /// Get the number of vectors in the index
424    pub fn len(&self) -> usize {
425        self.vectors.len()
426    }
427
428    /// Check if the index is empty
429    pub fn is_empty(&self) -> bool {
430        self.vectors.is_empty()
431    }
432
433    /// Add a vector with RDF triple and metadata (for compatibility with tests)
434    pub fn add(
435        &mut self,
436        id: String,
437        vector: Vec<f32>,
438        _triple: Triple,
439        _metadata: HashMap<String, String>,
440    ) -> Result<()> {
441        let vector_obj = Vector::new(vector);
442        self.insert(id, vector_obj)
443    }
444
445    /// Search for nearest neighbors (for compatibility with tests)
446    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
447        let query_vector = Vector::new(query.to_vec());
448        let results = self.search_advanced(&query_vector, k, None, None)?;
449        Ok(results)
450    }
451}
452
453impl VectorIndex for AdvancedVectorIndex {
454    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
455        if let Some(dims) = self.dimensions {
456            if vector.dimensions != dims {
457                return Err(anyhow!(
458                    "Vector dimensions ({}) don't match index dimensions ({})",
459                    vector.dimensions,
460                    dims
461                ));
462            }
463        } else {
464            self.dimensions = Some(vector.dimensions);
465        }
466
467        let id = self.vectors.len();
468        self.uri_to_id.insert(uri.clone(), id);
469        self.vectors.push((uri, vector));
470
471        Ok(())
472    }
473
474    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
475        let results = self.search_advanced(query, k, None, None)?;
476        Ok(results.into_iter().map(|r| (r.uri, r.distance)).collect())
477    }
478
479    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
480        let mut results = Vec::new();
481
482        for (uri, vector) in &self.vectors {
483            let distance = self.config.distance_metric.distance_vectors(query, vector);
484            if distance <= threshold {
485                results.push((uri.clone(), distance));
486            }
487        }
488
489        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
490        Ok(results)
491    }
492
493    fn get_vector(&self, uri: &str) -> Option<&Vector> {
494        // For AdvancedVectorIndex, vectors are stored in the vectors field
495        // regardless of the index type being used
496        self.vectors.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
497    }
498}
499
500/// Index performance statistics
501#[derive(Debug, Clone)]
502pub struct IndexStats {
503    pub num_vectors: usize,
504    pub dimensions: usize,
505    pub index_type: IndexType,
506    pub memory_usage: usize,
507}
508
509/// Quantized vector index for memory efficiency
510pub struct QuantizedVectorIndex {
511    config: IndexConfig,
512    quantized_vectors: Vec<Vec<u8>>,
513    centroids: Vec<Vector>,
514    uri_to_id: HashMap<String, usize>,
515    dimensions: Option<usize>,
516}
517
518impl QuantizedVectorIndex {
519    pub fn new(config: IndexConfig, num_centroids: usize) -> Self {
520        Self {
521            config,
522            quantized_vectors: Vec::new(),
523            centroids: Vec::with_capacity(num_centroids),
524            uri_to_id: HashMap::new(),
525            dimensions: None,
526        }
527    }
528
529    /// Train quantization centroids using k-means
530    pub fn train_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
531        if training_vectors.is_empty() {
532            return Err(anyhow!("No training vectors provided"));
533        }
534
535        let dimensions = training_vectors[0].dimensions;
536        self.dimensions = Some(dimensions);
537
538        // Simple k-means clustering for centroids
539        self.centroids = kmeans_clustering(training_vectors, self.centroids.capacity())?;
540
541        Ok(())
542    }
543
544    fn quantize_vector(&self, vector: &Vector) -> Vec<u8> {
545        let mut quantized = Vec::new();
546
547        // Find nearest centroid for each dimension chunk
548        let chunk_size = vector.dimensions / self.centroids.len().max(1);
549
550        let vector_f32 = vector.as_f32();
551        for chunk in vector_f32.chunks(chunk_size) {
552            let mut best_centroid = 0u8;
553            let mut best_distance = f32::INFINITY;
554
555            for (i, centroid) in self.centroids.iter().enumerate() {
556                let centroid_f32 = centroid.as_f32();
557                let centroid_chunk = &centroid_f32[0..chunk.len().min(centroid.dimensions)];
558                use oxirs_core::simd::SimdOps;
559                let distance = f32::euclidean_distance(chunk, centroid_chunk);
560                if distance < best_distance {
561                    best_distance = distance;
562                    best_centroid = i as u8;
563                }
564            }
565
566            quantized.push(best_centroid);
567        }
568
569        quantized
570    }
571}
572
573impl VectorIndex for QuantizedVectorIndex {
574    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
575        if self.centroids.is_empty() {
576            return Err(anyhow!(
577                "Quantization not trained. Call train_quantization first."
578            ));
579        }
580
581        let id = self.quantized_vectors.len();
582        self.uri_to_id.insert(uri.clone(), id);
583
584        let quantized = self.quantize_vector(&vector);
585        self.quantized_vectors.push(quantized);
586
587        Ok(())
588    }
589
590    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
591        let query_quantized = self.quantize_vector(query);
592        let mut results = Vec::new();
593
594        for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
595            let distance = hamming_distance(&query_quantized, quantized);
596            results.push((uri.clone(), distance));
597        }
598
599        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
600        results.truncate(k);
601
602        Ok(results)
603    }
604
605    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
606        let query_quantized = self.quantize_vector(query);
607        let mut results = Vec::new();
608
609        for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
610            let distance = hamming_distance(&query_quantized, quantized);
611            if distance <= threshold {
612                results.push((uri.clone(), distance));
613            }
614        }
615
616        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
617        Ok(results)
618    }
619
620    fn get_vector(&self, _uri: &str) -> Option<&Vector> {
621        // Quantized index doesn't store original vectors
622        // Return None as we only have quantized representations
623        None
624    }
625}
626
627// Helper functions that don't have SIMD equivalents
628
629fn hamming_distance(a: &[u8], b: &[u8]) -> f32 {
630    a.iter().zip(b).filter(|(x, y)| x != y).count() as f32
631}
632
633// K-means clustering for quantization
634fn kmeans_clustering(vectors: &[Vector], k: usize) -> Result<Vec<Vector>> {
635    if vectors.is_empty() || k == 0 {
636        return Ok(Vec::new());
637    }
638
639    let dimensions = vectors[0].dimensions;
640    let mut centroids = Vec::with_capacity(k);
641
642    // Initialize centroids randomly
643    for i in 0..k {
644        let idx = i % vectors.len();
645        centroids.push(vectors[idx].clone());
646    }
647
648    // Simple k-means iterations
649    for _ in 0..10 {
650        let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); k];
651
652        // Assign vectors to nearest centroid
653        for vector in vectors {
654            let mut best_centroid = 0;
655            let mut best_distance = f32::INFINITY;
656
657            for (i, centroid) in centroids.iter().enumerate() {
658                let vector_f32 = vector.as_f32();
659                let centroid_f32 = centroid.as_f32();
660                use oxirs_core::simd::SimdOps;
661                let distance = f32::euclidean_distance(&vector_f32, &centroid_f32);
662                if distance < best_distance {
663                    best_distance = distance;
664                    best_centroid = i;
665                }
666            }
667
668            clusters[best_centroid].push(vector);
669        }
670
671        // Update centroids
672        for (i, cluster) in clusters.iter().enumerate() {
673            if !cluster.is_empty() {
674                let mut new_centroid = vec![0.0; dimensions];
675
676                for vector in cluster {
677                    let vector_f32 = vector.as_f32();
678                    for (j, &value) in vector_f32.iter().enumerate() {
679                        new_centroid[j] += value;
680                    }
681                }
682
683                for value in &mut new_centroid {
684                    *value /= cluster.len() as f32;
685                }
686
687                centroids[i] = Vector::new(new_centroid);
688            }
689        }
690    }
691
692    Ok(centroids)
693}
694
695/// Multi-index system that combines multiple index types
696pub struct MultiIndex {
697    indices: HashMap<String, Box<dyn VectorIndex>>,
698    default_index: String,
699}
700
701impl MultiIndex {
702    pub fn new() -> Self {
703        Self {
704            indices: HashMap::new(),
705            default_index: String::new(),
706        }
707    }
708
709    pub fn add_index(&mut self, name: String, index: Box<dyn VectorIndex>) {
710        if self.indices.is_empty() {
711            self.default_index = name.clone();
712        }
713        self.indices.insert(name, index);
714    }
715
716    pub fn set_default(&mut self, name: &str) -> Result<()> {
717        if self.indices.contains_key(name) {
718            self.default_index = name.to_string();
719            Ok(())
720        } else {
721            Err(anyhow!("Index '{}' not found", name))
722        }
723    }
724
725    pub fn search_index(
726        &self,
727        index_name: &str,
728        query: &Vector,
729        k: usize,
730    ) -> Result<Vec<(String, f32)>> {
731        if let Some(index) = self.indices.get(index_name) {
732            index.search_knn(query, k)
733        } else {
734            Err(anyhow!("Index '{}' not found", index_name))
735        }
736    }
737}
738
739impl Default for MultiIndex {
740    fn default() -> Self {
741        Self::new()
742    }
743}
744
745impl VectorIndex for MultiIndex {
746    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
747        if let Some(index) = self.indices.get_mut(&self.default_index) {
748            index.insert(uri, vector)
749        } else {
750            Err(anyhow!("No default index set"))
751        }
752    }
753
754    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
755        if let Some(index) = self.indices.get(&self.default_index) {
756            index.search_knn(query, k)
757        } else {
758            Err(anyhow!("No default index set"))
759        }
760    }
761
762    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
763        if let Some(index) = self.indices.get(&self.default_index) {
764            index.search_threshold(query, threshold)
765        } else {
766            Err(anyhow!("No default index set"))
767        }
768    }
769
770    fn get_vector(&self, uri: &str) -> Option<&Vector> {
771        if let Some(index) = self.indices.get(&self.default_index) {
772            index.get_vector(uri)
773        } else {
774            None
775        }
776    }
777}