oxirs_vec/
index.rs

1//! Advanced vector indexing with HNSW and other efficient algorithms
2
3use crate::Vector;
4
5// Re-export VectorIndex trait for use by other modules
6pub use crate::VectorIndex;
7use anyhow::{anyhow, Result};
8use oxirs_core::parallel::*;
9use oxirs_core::Triple;
10use serde::{Deserialize, Serialize};
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::sync::Arc;
14
15#[cfg(feature = "hnsw")]
16use hnsw_rs::prelude::*;
17
18/// Type alias for filter functions
19pub type FilterFunction = Box<dyn Fn(&str) -> bool>;
20/// Type alias for filter functions with Send + Sync
21pub type FilterFunctionSync = Box<dyn Fn(&str) -> bool + Send + Sync>;
22
23/// Configuration for vector index
24#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
25pub struct IndexConfig {
26    /// Index type to use
27    pub index_type: IndexType,
28    /// Maximum number of connections for each node (for HNSW)
29    pub max_connections: usize,
30    /// Construction parameter (for HNSW)
31    pub ef_construction: usize,
32    /// Search parameter (for HNSW)
33    pub ef_search: usize,
34    /// Distance metric to use
35    pub distance_metric: DistanceMetric,
36    /// Whether to enable parallel operations
37    pub parallel: bool,
38}
39
40impl Default for IndexConfig {
41    fn default() -> Self {
42        Self {
43            index_type: IndexType::Hnsw,
44            max_connections: 16,
45            ef_construction: 200,
46            ef_search: 50,
47            distance_metric: DistanceMetric::Cosine,
48            parallel: true,
49        }
50    }
51}
52
53/// Available index types
54#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
55pub enum IndexType {
56    /// Hierarchical Navigable Small World
57    Hnsw,
58    /// Simple flat index (brute force)
59    Flat,
60    /// IVF (Inverted File) index
61    Ivf,
62    /// Product Quantization
63    PQ,
64}
65
66/// Distance metrics supported
67#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
68pub enum DistanceMetric {
69    /// Cosine distance (1 - cosine_similarity)
70    Cosine,
71    /// Euclidean (L2) distance
72    Euclidean,
73    /// Manhattan (L1) distance
74    Manhattan,
75    /// Dot product (negative for max-heap behavior)
76    DotProduct,
77}
78
79impl DistanceMetric {
80    /// Calculate distance between two vectors
81    pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
82        use oxirs_core::simd::SimdOps;
83
84        match self {
85            DistanceMetric::Cosine => f32::cosine_distance(a, b),
86            DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
87            DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
88            DistanceMetric::DotProduct => -f32::dot(a, b), // Negative for max-heap
89        }
90    }
91
92    /// Calculate distance between two Vector objects
93    pub fn distance_vectors(&self, a: &Vector, b: &Vector) -> f32 {
94        let a_f32 = a.as_f32();
95        let b_f32 = b.as_f32();
96        self.distance(&a_f32, &b_f32)
97    }
98}
99
100/// Search result with distance/score
101#[derive(Debug, Clone, PartialEq)]
102pub struct SearchResult {
103    pub uri: String,
104    pub distance: f32,
105    pub score: f32,
106    pub metadata: Option<HashMap<String, String>>,
107}
108
109impl Eq for SearchResult {}
110
111impl Ord for SearchResult {
112    fn cmp(&self, other: &Self) -> Ordering {
113        self.distance
114            .partial_cmp(&other.distance)
115            .unwrap_or(Ordering::Equal)
116    }
117}
118
119impl PartialOrd for SearchResult {
120    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
121        Some(self.cmp(other))
122    }
123}
124
125/// Advanced vector index with multiple implementations
126pub struct AdvancedVectorIndex {
127    config: IndexConfig,
128    vectors: Vec<(String, Vector)>,
129    uri_to_id: HashMap<String, usize>,
130    #[cfg(feature = "hnsw")]
131    hnsw_index: Option<Hnsw<'static, f32, DistCosine>>,
132    dimensions: Option<usize>,
133}
134
135impl AdvancedVectorIndex {
136    pub fn new(config: IndexConfig) -> Self {
137        Self {
138            config,
139            vectors: Vec::new(),
140            uri_to_id: HashMap::new(),
141            #[cfg(feature = "hnsw")]
142            hnsw_index: None,
143            dimensions: None,
144        }
145    }
146
147    /// Build the index after adding all vectors
148    pub fn build(&mut self) -> Result<()> {
149        if self.vectors.is_empty() {
150            return Ok(());
151        }
152
153        match self.config.index_type {
154            IndexType::Hnsw => {
155                #[cfg(feature = "hnsw")]
156                {
157                    self.build_hnsw_index()?;
158                }
159                #[cfg(not(feature = "hnsw"))]
160                {
161                    return Err(anyhow!("HNSW feature not enabled"));
162                }
163            }
164            IndexType::Flat => {
165                // No special building needed for flat index
166            }
167            IndexType::Ivf | IndexType::PQ => {
168                return Err(anyhow!("IVF and PQ indices not yet implemented"));
169            }
170        }
171
172        Ok(())
173    }
174
175    #[cfg(feature = "hnsw")]
176    fn build_hnsw_index(&mut self) -> Result<()> {
177        if let Some(dimensions) = self.dimensions {
178            let mut hnsw = Hnsw::<f32, DistCosine>::new(
179                self.config.max_connections,
180                self.vectors.len(),
181                16, // layer factor
182                self.config.ef_construction,
183                DistCosine,
184            );
185
186            for (id, (_, vector)) in self.vectors.iter().enumerate() {
187                let vector_f32 = vector.as_f32();
188                hnsw.insert((&vector_f32, id));
189            }
190
191            self.hnsw_index = Some(hnsw);
192        }
193
194        Ok(())
195    }
196
197    /// Add metadata to a vector
198    pub fn add_metadata(&mut self, _uri: &str, _metadata: HashMap<String, String>) -> Result<()> {
199        // For now, we'll store metadata separately
200        // In a full implementation, this would be integrated with the index
201        Ok(())
202    }
203
204    /// Search with advanced parameters
205    pub fn search_advanced(
206        &self,
207        query: &Vector,
208        k: usize,
209        _ef: Option<usize>,
210        filter: Option<FilterFunction>,
211    ) -> Result<Vec<SearchResult>> {
212        match self.config.index_type {
213            IndexType::Hnsw => {
214                #[cfg(feature = "hnsw")]
215                {
216                    self.search_hnsw(query, k, ef)
217                }
218                #[cfg(not(feature = "hnsw"))]
219                {
220                    self.search_flat(query, k, filter)
221                }
222            }
223            _ => self.search_flat(query, k, filter),
224        }
225    }
226
227    #[cfg(feature = "hnsw")]
228    fn search_hnsw(
229        &self,
230        query: &Vector,
231        k: usize,
232        ef: Option<usize>,
233    ) -> Result<Vec<SearchResult>> {
234        if let Some(ref hnsw) = self.hnsw_index {
235            let search_ef = ef.unwrap_or(self.config.ef_search);
236            let query_f32 = query.as_f32();
237            let results = hnsw.search(&query_f32, k, search_ef);
238
239            Ok(results
240                .into_iter()
241                .map(|result| SearchResult {
242                    uri: self.vectors[result.d_id].0.clone(),
243                    distance: result.distance,
244                    score: 1.0 - result.distance, // Convert distance to similarity score
245                    metadata: None,
246                })
247                .collect())
248        } else {
249            Err(anyhow!("HNSW index not built"))
250        }
251    }
252
253    fn search_flat(
254        &self,
255        query: &Vector,
256        k: usize,
257        filter: Option<FilterFunction>,
258    ) -> Result<Vec<SearchResult>> {
259        if self.config.parallel && self.vectors.len() > 1000 {
260            // For parallel search, we need Send + Sync filter
261            if filter.is_some() {
262                // Fall back to sequential if filter is present but not Send + Sync
263                self.search_flat_sequential(query, k, filter)
264            } else {
265                self.search_flat_parallel(query, k, None)
266            }
267        } else {
268            self.search_flat_sequential(query, k, filter)
269        }
270    }
271
272    fn search_flat_sequential(
273        &self,
274        query: &Vector,
275        k: usize,
276        filter: Option<FilterFunction>,
277    ) -> Result<Vec<SearchResult>> {
278        let mut heap = BinaryHeap::new();
279
280        for (uri, vector) in &self.vectors {
281            if let Some(ref filter_fn) = filter {
282                if !filter_fn(uri) {
283                    continue;
284                }
285            }
286
287            let distance = self.config.distance_metric.distance_vectors(query, vector);
288
289            if heap.len() < k {
290                heap.push(std::cmp::Reverse(SearchResult {
291                    uri: uri.clone(),
292                    distance,
293                    score: 1.0 - distance, // Convert distance to similarity score
294                    metadata: None,
295                }));
296            } else if let Some(std::cmp::Reverse(worst)) = heap.peek() {
297                if distance < worst.distance {
298                    heap.pop();
299                    heap.push(std::cmp::Reverse(SearchResult {
300                        uri: uri.clone(),
301                        distance,
302                        score: 1.0 - distance, // Convert distance to similarity score
303                        metadata: None,
304                    }));
305                }
306            }
307        }
308
309        let mut results: Vec<SearchResult> = heap.into_iter().map(|r| r.0).collect();
310        results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
311
312        Ok(results)
313    }
314
315    fn search_flat_parallel(
316        &self,
317        query: &Vector,
318        k: usize,
319        filter: Option<FilterFunctionSync>,
320    ) -> Result<Vec<SearchResult>> {
321        // Split vectors into chunks for parallel processing
322        let chunk_size = (self.vectors.len() / num_threads()).max(100);
323
324        // Use Arc for thread-safe sharing of the filter
325        let filter_arc = filter.map(Arc::new);
326
327        // Process chunks in parallel and collect top-k from each
328        let partial_results: Vec<Vec<SearchResult>> = self
329            .vectors
330            .par_chunks(chunk_size)
331            .map(|chunk| {
332                let mut local_heap = BinaryHeap::new();
333                let filter_ref = filter_arc.as_ref();
334
335                for (uri, vector) in chunk {
336                    if let Some(filter_fn) = filter_ref {
337                        if !filter_fn(uri) {
338                            continue;
339                        }
340                    }
341
342                    let distance = self.config.distance_metric.distance_vectors(query, vector);
343
344                    if local_heap.len() < k {
345                        local_heap.push(std::cmp::Reverse(SearchResult {
346                            uri: uri.clone(),
347                            distance,
348                            score: 1.0 - distance, // Convert distance to similarity score
349                            metadata: None,
350                        }));
351                    } else if let Some(std::cmp::Reverse(worst)) = local_heap.peek() {
352                        if distance < worst.distance {
353                            local_heap.pop();
354                            local_heap.push(std::cmp::Reverse(SearchResult {
355                                uri: uri.clone(),
356                                distance,
357                                score: 1.0 - distance, // Convert distance to similarity score
358                                metadata: None,
359                            }));
360                        }
361                    }
362                }
363
364                local_heap
365                    .into_sorted_vec()
366                    .into_iter()
367                    .map(|r| r.0)
368                    .collect()
369            })
370            .collect();
371
372        // Merge results from all chunks
373        let mut final_heap = BinaryHeap::new();
374        for partial in partial_results {
375            for result in partial {
376                if final_heap.len() < k {
377                    final_heap.push(std::cmp::Reverse(result));
378                } else if let Some(std::cmp::Reverse(worst)) = final_heap.peek() {
379                    if result.distance < worst.distance {
380                        final_heap.pop();
381                        final_heap.push(std::cmp::Reverse(result));
382                    }
383                }
384            }
385        }
386
387        let mut results: Vec<SearchResult> = final_heap.into_iter().map(|r| r.0).collect();
388        results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
389
390        Ok(results)
391    }
392
393    /// Get index statistics
394    pub fn stats(&self) -> IndexStats {
395        IndexStats {
396            num_vectors: self.vectors.len(),
397            dimensions: self.dimensions.unwrap_or(0),
398            index_type: self.config.index_type,
399            memory_usage: self.estimate_memory_usage(),
400        }
401    }
402
403    fn estimate_memory_usage(&self) -> usize {
404        let vector_memory = self.vectors.len()
405            * (std::mem::size_of::<String>()
406                + self.dimensions.unwrap_or(0) * std::mem::size_of::<f32>());
407
408        let uri_map_memory =
409            self.uri_to_id.len() * (std::mem::size_of::<String>() + std::mem::size_of::<usize>());
410
411        vector_memory + uri_map_memory
412    }
413
414    /// Get the number of vectors in the index
415    pub fn len(&self) -> usize {
416        self.vectors.len()
417    }
418
419    /// Check if the index is empty
420    pub fn is_empty(&self) -> bool {
421        self.vectors.is_empty()
422    }
423
424    /// Add a vector with RDF triple and metadata (for compatibility with tests)
425    pub fn add(
426        &mut self,
427        id: String,
428        vector: Vec<f32>,
429        _triple: Triple,
430        _metadata: HashMap<String, String>,
431    ) -> Result<()> {
432        let vector_obj = Vector::new(vector);
433        self.insert(id, vector_obj)
434    }
435
436    /// Search for nearest neighbors (for compatibility with tests)
437    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
438        let query_vector = Vector::new(query.to_vec());
439        let results = self.search_advanced(&query_vector, k, None, None)?;
440        Ok(results)
441    }
442}
443
444impl VectorIndex for AdvancedVectorIndex {
445    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
446        if let Some(dims) = self.dimensions {
447            if vector.dimensions != dims {
448                return Err(anyhow!(
449                    "Vector dimensions ({}) don't match index dimensions ({})",
450                    vector.dimensions,
451                    dims
452                ));
453            }
454        } else {
455            self.dimensions = Some(vector.dimensions);
456        }
457
458        let id = self.vectors.len();
459        self.uri_to_id.insert(uri.clone(), id);
460        self.vectors.push((uri, vector));
461
462        Ok(())
463    }
464
465    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
466        let results = self.search_advanced(query, k, None, None)?;
467        Ok(results.into_iter().map(|r| (r.uri, r.distance)).collect())
468    }
469
470    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
471        let mut results = Vec::new();
472
473        for (uri, vector) in &self.vectors {
474            let distance = self.config.distance_metric.distance_vectors(query, vector);
475            if distance <= threshold {
476                results.push((uri.clone(), distance));
477            }
478        }
479
480        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
481        Ok(results)
482    }
483
484    fn get_vector(&self, uri: &str) -> Option<&Vector> {
485        // For AdvancedVectorIndex, vectors are stored in the vectors field
486        // regardless of the index type being used
487        self.vectors.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
488    }
489}
490
491/// Index performance statistics
492#[derive(Debug, Clone)]
493pub struct IndexStats {
494    pub num_vectors: usize,
495    pub dimensions: usize,
496    pub index_type: IndexType,
497    pub memory_usage: usize,
498}
499
500/// Quantized vector index for memory efficiency
501pub struct QuantizedVectorIndex {
502    config: IndexConfig,
503    quantized_vectors: Vec<Vec<u8>>,
504    centroids: Vec<Vector>,
505    uri_to_id: HashMap<String, usize>,
506    dimensions: Option<usize>,
507}
508
509impl QuantizedVectorIndex {
510    pub fn new(config: IndexConfig, num_centroids: usize) -> Self {
511        Self {
512            config,
513            quantized_vectors: Vec::new(),
514            centroids: Vec::with_capacity(num_centroids),
515            uri_to_id: HashMap::new(),
516            dimensions: None,
517        }
518    }
519
520    /// Train quantization centroids using k-means
521    pub fn train_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
522        if training_vectors.is_empty() {
523            return Err(anyhow!("No training vectors provided"));
524        }
525
526        let dimensions = training_vectors[0].dimensions;
527        self.dimensions = Some(dimensions);
528
529        // Simple k-means clustering for centroids
530        self.centroids = kmeans_clustering(training_vectors, self.centroids.capacity())?;
531
532        Ok(())
533    }
534
535    fn quantize_vector(&self, vector: &Vector) -> Vec<u8> {
536        let mut quantized = Vec::new();
537
538        // Find nearest centroid for each dimension chunk
539        let chunk_size = vector.dimensions / self.centroids.len().max(1);
540
541        let vector_f32 = vector.as_f32();
542        for chunk in vector_f32.chunks(chunk_size) {
543            let mut best_centroid = 0u8;
544            let mut best_distance = f32::INFINITY;
545
546            for (i, centroid) in self.centroids.iter().enumerate() {
547                let centroid_f32 = centroid.as_f32();
548                let centroid_chunk = &centroid_f32[0..chunk.len().min(centroid.dimensions)];
549                use oxirs_core::simd::SimdOps;
550                let distance = f32::euclidean_distance(chunk, centroid_chunk);
551                if distance < best_distance {
552                    best_distance = distance;
553                    best_centroid = i as u8;
554                }
555            }
556
557            quantized.push(best_centroid);
558        }
559
560        quantized
561    }
562}
563
564impl VectorIndex for QuantizedVectorIndex {
565    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
566        if self.centroids.is_empty() {
567            return Err(anyhow!(
568                "Quantization not trained. Call train_quantization first."
569            ));
570        }
571
572        let id = self.quantized_vectors.len();
573        self.uri_to_id.insert(uri.clone(), id);
574
575        let quantized = self.quantize_vector(&vector);
576        self.quantized_vectors.push(quantized);
577
578        Ok(())
579    }
580
581    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
582        let query_quantized = self.quantize_vector(query);
583        let mut results = Vec::new();
584
585        for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
586            let distance = hamming_distance(&query_quantized, quantized);
587            results.push((uri.clone(), distance));
588        }
589
590        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
591        results.truncate(k);
592
593        Ok(results)
594    }
595
596    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
597        let query_quantized = self.quantize_vector(query);
598        let mut results = Vec::new();
599
600        for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
601            let distance = hamming_distance(&query_quantized, quantized);
602            if distance <= threshold {
603                results.push((uri.clone(), distance));
604            }
605        }
606
607        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
608        Ok(results)
609    }
610
611    fn get_vector(&self, _uri: &str) -> Option<&Vector> {
612        // Quantized index doesn't store original vectors
613        // Return None as we only have quantized representations
614        None
615    }
616}
617
618// Helper functions that don't have SIMD equivalents
619
620fn hamming_distance(a: &[u8], b: &[u8]) -> f32 {
621    a.iter().zip(b).filter(|(x, y)| x != y).count() as f32
622}
623
624// K-means clustering for quantization
625fn kmeans_clustering(vectors: &[Vector], k: usize) -> Result<Vec<Vector>> {
626    if vectors.is_empty() || k == 0 {
627        return Ok(Vec::new());
628    }
629
630    let dimensions = vectors[0].dimensions;
631    let mut centroids = Vec::with_capacity(k);
632
633    // Initialize centroids randomly
634    for i in 0..k {
635        let idx = i % vectors.len();
636        centroids.push(vectors[idx].clone());
637    }
638
639    // Simple k-means iterations
640    for _ in 0..10 {
641        let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); k];
642
643        // Assign vectors to nearest centroid
644        for vector in vectors {
645            let mut best_centroid = 0;
646            let mut best_distance = f32::INFINITY;
647
648            for (i, centroid) in centroids.iter().enumerate() {
649                let vector_f32 = vector.as_f32();
650                let centroid_f32 = centroid.as_f32();
651                use oxirs_core::simd::SimdOps;
652                let distance = f32::euclidean_distance(&vector_f32, &centroid_f32);
653                if distance < best_distance {
654                    best_distance = distance;
655                    best_centroid = i;
656                }
657            }
658
659            clusters[best_centroid].push(vector);
660        }
661
662        // Update centroids
663        for (i, cluster) in clusters.iter().enumerate() {
664            if !cluster.is_empty() {
665                let mut new_centroid = vec![0.0; dimensions];
666
667                for vector in cluster {
668                    let vector_f32 = vector.as_f32();
669                    for (j, &value) in vector_f32.iter().enumerate() {
670                        new_centroid[j] += value;
671                    }
672                }
673
674                for value in &mut new_centroid {
675                    *value /= cluster.len() as f32;
676                }
677
678                centroids[i] = Vector::new(new_centroid);
679            }
680        }
681    }
682
683    Ok(centroids)
684}
685
686/// Multi-index system that combines multiple index types
687pub struct MultiIndex {
688    indices: HashMap<String, Box<dyn VectorIndex>>,
689    default_index: String,
690}
691
692impl MultiIndex {
693    pub fn new() -> Self {
694        Self {
695            indices: HashMap::new(),
696            default_index: String::new(),
697        }
698    }
699
700    pub fn add_index(&mut self, name: String, index: Box<dyn VectorIndex>) {
701        if self.indices.is_empty() {
702            self.default_index = name.clone();
703        }
704        self.indices.insert(name, index);
705    }
706
707    pub fn set_default(&mut self, name: &str) -> Result<()> {
708        if self.indices.contains_key(name) {
709            self.default_index = name.to_string();
710            Ok(())
711        } else {
712            Err(anyhow!("Index '{}' not found", name))
713        }
714    }
715
716    pub fn search_index(
717        &self,
718        index_name: &str,
719        query: &Vector,
720        k: usize,
721    ) -> Result<Vec<(String, f32)>> {
722        if let Some(index) = self.indices.get(index_name) {
723            index.search_knn(query, k)
724        } else {
725            Err(anyhow!("Index '{}' not found", index_name))
726        }
727    }
728}
729
730impl Default for MultiIndex {
731    fn default() -> Self {
732        Self::new()
733    }
734}
735
736impl VectorIndex for MultiIndex {
737    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
738        if let Some(index) = self.indices.get_mut(&self.default_index) {
739            index.insert(uri, vector)
740        } else {
741            Err(anyhow!("No default index set"))
742        }
743    }
744
745    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
746        if let Some(index) = self.indices.get(&self.default_index) {
747            index.search_knn(query, k)
748        } else {
749            Err(anyhow!("No default index set"))
750        }
751    }
752
753    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
754        if let Some(index) = self.indices.get(&self.default_index) {
755            index.search_threshold(query, threshold)
756        } else {
757            Err(anyhow!("No default index set"))
758        }
759    }
760
761    fn get_vector(&self, uri: &str) -> Option<&Vector> {
762        if let Some(index) = self.indices.get(&self.default_index) {
763            index.get_vector(uri)
764        } else {
765            None
766        }
767    }
768}