oxirs_vec/
index.rs

1//! Advanced vector indexing with HNSW and other efficient algorithms
2
3use crate::Vector;
4
5// Re-export VectorIndex trait for use by other modules
6pub use crate::VectorIndex;
7use anyhow::{anyhow, Result};
8use oxirs_core::parallel::*;
9use oxirs_core::Triple;
10use serde::{Deserialize, Serialize};
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::sync::Arc;
14
15#[cfg(feature = "hnsw")]
16use hnsw_rs::prelude::*;
17
18/// Type alias for filter functions
19pub type FilterFunction = Box<dyn Fn(&str) -> bool>;
20/// Type alias for filter functions with Send + Sync
21pub type FilterFunctionSync = Box<dyn Fn(&str) -> bool + Send + Sync>;
22
23/// Configuration for vector index
24#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
25pub struct IndexConfig {
26    /// Index type to use
27    pub index_type: IndexType,
28    /// Maximum number of connections for each node (for HNSW)
29    pub max_connections: usize,
30    /// Construction parameter (for HNSW)
31    pub ef_construction: usize,
32    /// Search parameter (for HNSW)
33    pub ef_search: usize,
34    /// Distance metric to use
35    pub distance_metric: DistanceMetric,
36    /// Whether to enable parallel operations
37    pub parallel: bool,
38}
39
40impl Default for IndexConfig {
41    fn default() -> Self {
42        Self {
43            index_type: IndexType::Hnsw,
44            max_connections: 16,
45            ef_construction: 200,
46            ef_search: 50,
47            distance_metric: DistanceMetric::Cosine,
48            parallel: true,
49        }
50    }
51}
52
53/// Available index types
54#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
55pub enum IndexType {
56    /// Hierarchical Navigable Small World
57    Hnsw,
58    /// Simple flat index (brute force)
59    Flat,
60    /// IVF (Inverted File) index
61    Ivf,
62    /// Product Quantization
63    PQ,
64}
65
66/// Distance metrics supported
67#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
68pub enum DistanceMetric {
69    /// Cosine distance (1 - cosine_similarity)
70    Cosine,
71    /// Euclidean (L2) distance
72    Euclidean,
73    /// Manhattan (L1) distance
74    Manhattan,
75    /// Dot product (negative for max-heap behavior)
76    DotProduct,
77}
78
79impl DistanceMetric {
80    /// Calculate distance between two vectors
81    pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
82        use oxirs_core::simd::SimdOps;
83
84        match self {
85            DistanceMetric::Cosine => f32::cosine_distance(a, b),
86            DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
87            DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
88            DistanceMetric::DotProduct => -f32::dot(a, b), // Negative for max-heap
89        }
90    }
91
92    /// Calculate distance between two Vector objects
93    pub fn distance_vectors(&self, a: &Vector, b: &Vector) -> f32 {
94        let a_f32 = a.as_f32();
95        let b_f32 = b.as_f32();
96        self.distance(&a_f32, &b_f32)
97    }
98}
99
100/// Search result with distance/score
101#[derive(Debug, Clone, PartialEq)]
102pub struct SearchResult {
103    pub uri: String,
104    pub distance: f32,
105    pub score: f32,
106    pub metadata: Option<HashMap<String, String>>,
107}
108
109impl Eq for SearchResult {}
110
111impl Ord for SearchResult {
112    fn cmp(&self, other: &Self) -> Ordering {
113        self.distance
114            .partial_cmp(&other.distance)
115            .unwrap_or(Ordering::Equal)
116    }
117}
118
119impl PartialOrd for SearchResult {
120    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
121        Some(self.cmp(other))
122    }
123}
124
125/// Advanced vector index with multiple implementations
126pub struct AdvancedVectorIndex {
127    config: IndexConfig,
128    vectors: Vec<(String, Vector)>,
129    uri_to_id: HashMap<String, usize>,
130    #[cfg(feature = "hnsw")]
131    hnsw_index: Option<Hnsw<'static, f32, DistCosine>>,
132    dimensions: Option<usize>,
133}
134
135impl AdvancedVectorIndex {
136    pub fn new(config: IndexConfig) -> Self {
137        Self {
138            config,
139            vectors: Vec::new(),
140            uri_to_id: HashMap::new(),
141            #[cfg(feature = "hnsw")]
142            hnsw_index: None,
143            dimensions: None,
144        }
145    }
146
147    /// Build the index after adding all vectors
148    pub fn build(&mut self) -> Result<()> {
149        if self.vectors.is_empty() {
150            return Ok(());
151        }
152
153        match self.config.index_type {
154            IndexType::Hnsw => {
155                #[cfg(feature = "hnsw")]
156                {
157                    self.build_hnsw_index()?;
158                }
159                #[cfg(not(feature = "hnsw"))]
160                {
161                    return Err(anyhow!("HNSW feature not enabled"));
162                }
163            }
164            IndexType::Flat => {
165                // No special building needed for flat index
166            }
167            IndexType::Ivf | IndexType::PQ => {
168                return Err(anyhow!("IVF and PQ indices not yet implemented"));
169            }
170        }
171
172        Ok(())
173    }
174
175    #[cfg(feature = "hnsw")]
176    fn build_hnsw_index(&mut self) -> Result<()> {
177        if let Some(_dimensions) = self.dimensions {
178            let hnsw = Hnsw::<f32, DistCosine>::new(
179                self.config.max_connections,
180                self.vectors.len(),
181                16, // layer factor
182                self.config.ef_construction,
183                DistCosine,
184            );
185
186            for (id, (_, vector)) in self.vectors.iter().enumerate() {
187                let vector_f32 = vector.as_f32();
188                hnsw.insert((&vector_f32, id));
189            }
190
191            self.hnsw_index = Some(hnsw);
192        }
193
194        Ok(())
195    }
196
197    /// Add metadata to a vector
198    pub fn add_metadata(&mut self, _uri: &str, _metadata: HashMap<String, String>) -> Result<()> {
199        // For now, we'll store metadata separately
200        // In a full implementation, this would be integrated with the index
201        Ok(())
202    }
203
204    /// Search with advanced parameters
205    pub fn search_advanced(
206        &self,
207        query: &Vector,
208        k: usize,
209        ef: Option<usize>,
210        filter: Option<FilterFunction>,
211    ) -> Result<Vec<SearchResult>> {
212        match self.config.index_type {
213            IndexType::Hnsw => {
214                #[cfg(feature = "hnsw")]
215                {
216                    self.search_hnsw(query, k, ef)
217                }
218                #[cfg(not(feature = "hnsw"))]
219                {
220                    let _ = ef;
221                    self.search_flat(query, k, filter)
222                }
223            }
224            _ => self.search_flat(query, k, filter),
225        }
226    }
227
228    #[cfg(feature = "hnsw")]
229    fn search_hnsw(
230        &self,
231        query: &Vector,
232        k: usize,
233        ef: Option<usize>,
234    ) -> Result<Vec<SearchResult>> {
235        if let Some(ref hnsw) = self.hnsw_index {
236            let search_ef = ef.unwrap_or(self.config.ef_search);
237            let query_f32 = query.as_f32();
238            let results = hnsw.search(&query_f32, k, search_ef);
239
240            Ok(results
241                .into_iter()
242                .map(|result| SearchResult {
243                    uri: self.vectors[result.d_id].0.clone(),
244                    distance: result.distance,
245                    score: 1.0 - result.distance, // Convert distance to similarity score
246                    metadata: None,
247                })
248                .collect())
249        } else {
250            Err(anyhow!("HNSW index not built"))
251        }
252    }
253
254    fn search_flat(
255        &self,
256        query: &Vector,
257        k: usize,
258        filter: Option<FilterFunction>,
259    ) -> Result<Vec<SearchResult>> {
260        if self.config.parallel && self.vectors.len() > 1000 {
261            // For parallel search, we need Send + Sync filter
262            if filter.is_some() {
263                // Fall back to sequential if filter is present but not Send + Sync
264                self.search_flat_sequential(query, k, filter)
265            } else {
266                self.search_flat_parallel(query, k, None)
267            }
268        } else {
269            self.search_flat_sequential(query, k, filter)
270        }
271    }
272
273    fn search_flat_sequential(
274        &self,
275        query: &Vector,
276        k: usize,
277        filter: Option<FilterFunction>,
278    ) -> Result<Vec<SearchResult>> {
279        let mut heap = BinaryHeap::new();
280
281        for (uri, vector) in &self.vectors {
282            if let Some(ref filter_fn) = filter {
283                if !filter_fn(uri) {
284                    continue;
285                }
286            }
287
288            let distance = self.config.distance_metric.distance_vectors(query, vector);
289
290            if heap.len() < k {
291                heap.push(std::cmp::Reverse(SearchResult {
292                    uri: uri.clone(),
293                    distance,
294                    score: 1.0 - distance, // Convert distance to similarity score
295                    metadata: None,
296                }));
297            } else if let Some(std::cmp::Reverse(worst)) = heap.peek() {
298                if distance < worst.distance {
299                    heap.pop();
300                    heap.push(std::cmp::Reverse(SearchResult {
301                        uri: uri.clone(),
302                        distance,
303                        score: 1.0 - distance, // Convert distance to similarity score
304                        metadata: None,
305                    }));
306                }
307            }
308        }
309
310        let mut results: Vec<SearchResult> = heap.into_iter().map(|r| r.0).collect();
311        results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
312
313        Ok(results)
314    }
315
316    fn search_flat_parallel(
317        &self,
318        query: &Vector,
319        k: usize,
320        filter: Option<FilterFunctionSync>,
321    ) -> Result<Vec<SearchResult>> {
322        // Split vectors into chunks for parallel processing
323        let chunk_size = (self.vectors.len() / num_threads()).max(100);
324
325        // Use Arc for thread-safe sharing of the filter
326        let filter_arc = filter.map(Arc::new);
327
328        // Process chunks in parallel and collect top-k from each
329        let partial_results: Vec<Vec<SearchResult>> = self
330            .vectors
331            .par_chunks(chunk_size)
332            .map(|chunk| {
333                let mut local_heap = BinaryHeap::new();
334                let filter_ref = filter_arc.as_ref();
335
336                for (uri, vector) in chunk {
337                    if let Some(filter_fn) = filter_ref {
338                        if !filter_fn(uri) {
339                            continue;
340                        }
341                    }
342
343                    let distance = self.config.distance_metric.distance_vectors(query, vector);
344
345                    if local_heap.len() < k {
346                        local_heap.push(std::cmp::Reverse(SearchResult {
347                            uri: uri.clone(),
348                            distance,
349                            score: 1.0 - distance, // Convert distance to similarity score
350                            metadata: None,
351                        }));
352                    } else if let Some(std::cmp::Reverse(worst)) = local_heap.peek() {
353                        if distance < worst.distance {
354                            local_heap.pop();
355                            local_heap.push(std::cmp::Reverse(SearchResult {
356                                uri: uri.clone(),
357                                distance,
358                                score: 1.0 - distance, // Convert distance to similarity score
359                                metadata: None,
360                            }));
361                        }
362                    }
363                }
364
365                local_heap
366                    .into_sorted_vec()
367                    .into_iter()
368                    .map(|r| r.0)
369                    .collect()
370            })
371            .collect();
372
373        // Merge results from all chunks
374        let mut final_heap = BinaryHeap::new();
375        for partial in partial_results {
376            for result in partial {
377                if final_heap.len() < k {
378                    final_heap.push(std::cmp::Reverse(result));
379                } else if let Some(std::cmp::Reverse(worst)) = final_heap.peek() {
380                    if result.distance < worst.distance {
381                        final_heap.pop();
382                        final_heap.push(std::cmp::Reverse(result));
383                    }
384                }
385            }
386        }
387
388        let mut results: Vec<SearchResult> = final_heap.into_iter().map(|r| r.0).collect();
389        results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
390
391        Ok(results)
392    }
393
394    /// Get index statistics
395    pub fn stats(&self) -> IndexStats {
396        IndexStats {
397            num_vectors: self.vectors.len(),
398            dimensions: self.dimensions.unwrap_or(0),
399            index_type: self.config.index_type,
400            memory_usage: self.estimate_memory_usage(),
401        }
402    }
403
404    fn estimate_memory_usage(&self) -> usize {
405        let vector_memory = self.vectors.len()
406            * (std::mem::size_of::<String>()
407                + self.dimensions.unwrap_or(0) * std::mem::size_of::<f32>());
408
409        let uri_map_memory =
410            self.uri_to_id.len() * (std::mem::size_of::<String>() + std::mem::size_of::<usize>());
411
412        vector_memory + uri_map_memory
413    }
414
415    /// Get the number of vectors in the index
416    pub fn len(&self) -> usize {
417        self.vectors.len()
418    }
419
420    /// Check if the index is empty
421    pub fn is_empty(&self) -> bool {
422        self.vectors.is_empty()
423    }
424
425    /// Add a vector with RDF triple and metadata (for compatibility with tests)
426    pub fn add(
427        &mut self,
428        id: String,
429        vector: Vec<f32>,
430        _triple: Triple,
431        _metadata: HashMap<String, String>,
432    ) -> Result<()> {
433        let vector_obj = Vector::new(vector);
434        self.insert(id, vector_obj)
435    }
436
437    /// Search for nearest neighbors (for compatibility with tests)
438    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
439        let query_vector = Vector::new(query.to_vec());
440        let results = self.search_advanced(&query_vector, k, None, None)?;
441        Ok(results)
442    }
443}
444
445impl VectorIndex for AdvancedVectorIndex {
446    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
447        if let Some(dims) = self.dimensions {
448            if vector.dimensions != dims {
449                return Err(anyhow!(
450                    "Vector dimensions ({}) don't match index dimensions ({})",
451                    vector.dimensions,
452                    dims
453                ));
454            }
455        } else {
456            self.dimensions = Some(vector.dimensions);
457        }
458
459        let id = self.vectors.len();
460        self.uri_to_id.insert(uri.clone(), id);
461        self.vectors.push((uri, vector));
462
463        Ok(())
464    }
465
466    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
467        let results = self.search_advanced(query, k, None, None)?;
468        Ok(results.into_iter().map(|r| (r.uri, r.distance)).collect())
469    }
470
471    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
472        let mut results = Vec::new();
473
474        for (uri, vector) in &self.vectors {
475            let distance = self.config.distance_metric.distance_vectors(query, vector);
476            if distance <= threshold {
477                results.push((uri.clone(), distance));
478            }
479        }
480
481        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
482        Ok(results)
483    }
484
485    fn get_vector(&self, uri: &str) -> Option<&Vector> {
486        // For AdvancedVectorIndex, vectors are stored in the vectors field
487        // regardless of the index type being used
488        self.vectors.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
489    }
490}
491
492/// Index performance statistics
493#[derive(Debug, Clone)]
494pub struct IndexStats {
495    pub num_vectors: usize,
496    pub dimensions: usize,
497    pub index_type: IndexType,
498    pub memory_usage: usize,
499}
500
501/// Quantized vector index for memory efficiency
502pub struct QuantizedVectorIndex {
503    config: IndexConfig,
504    quantized_vectors: Vec<Vec<u8>>,
505    centroids: Vec<Vector>,
506    uri_to_id: HashMap<String, usize>,
507    dimensions: Option<usize>,
508}
509
510impl QuantizedVectorIndex {
511    pub fn new(config: IndexConfig, num_centroids: usize) -> Self {
512        Self {
513            config,
514            quantized_vectors: Vec::new(),
515            centroids: Vec::with_capacity(num_centroids),
516            uri_to_id: HashMap::new(),
517            dimensions: None,
518        }
519    }
520
521    /// Train quantization centroids using k-means
522    pub fn train_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
523        if training_vectors.is_empty() {
524            return Err(anyhow!("No training vectors provided"));
525        }
526
527        let dimensions = training_vectors[0].dimensions;
528        self.dimensions = Some(dimensions);
529
530        // Simple k-means clustering for centroids
531        self.centroids = kmeans_clustering(training_vectors, self.centroids.capacity())?;
532
533        Ok(())
534    }
535
536    fn quantize_vector(&self, vector: &Vector) -> Vec<u8> {
537        let mut quantized = Vec::new();
538
539        // Find nearest centroid for each dimension chunk
540        let chunk_size = vector.dimensions / self.centroids.len().max(1);
541
542        let vector_f32 = vector.as_f32();
543        for chunk in vector_f32.chunks(chunk_size) {
544            let mut best_centroid = 0u8;
545            let mut best_distance = f32::INFINITY;
546
547            for (i, centroid) in self.centroids.iter().enumerate() {
548                let centroid_f32 = centroid.as_f32();
549                let centroid_chunk = &centroid_f32[0..chunk.len().min(centroid.dimensions)];
550                use oxirs_core::simd::SimdOps;
551                let distance = f32::euclidean_distance(chunk, centroid_chunk);
552                if distance < best_distance {
553                    best_distance = distance;
554                    best_centroid = i as u8;
555                }
556            }
557
558            quantized.push(best_centroid);
559        }
560
561        quantized
562    }
563}
564
565impl VectorIndex for QuantizedVectorIndex {
566    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
567        if self.centroids.is_empty() {
568            return Err(anyhow!(
569                "Quantization not trained. Call train_quantization first."
570            ));
571        }
572
573        let id = self.quantized_vectors.len();
574        self.uri_to_id.insert(uri.clone(), id);
575
576        let quantized = self.quantize_vector(&vector);
577        self.quantized_vectors.push(quantized);
578
579        Ok(())
580    }
581
582    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
583        let query_quantized = self.quantize_vector(query);
584        let mut results = Vec::new();
585
586        for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
587            let distance = hamming_distance(&query_quantized, quantized);
588            results.push((uri.clone(), distance));
589        }
590
591        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
592        results.truncate(k);
593
594        Ok(results)
595    }
596
597    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
598        let query_quantized = self.quantize_vector(query);
599        let mut results = Vec::new();
600
601        for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
602            let distance = hamming_distance(&query_quantized, quantized);
603            if distance <= threshold {
604                results.push((uri.clone(), distance));
605            }
606        }
607
608        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
609        Ok(results)
610    }
611
612    fn get_vector(&self, _uri: &str) -> Option<&Vector> {
613        // Quantized index doesn't store original vectors
614        // Return None as we only have quantized representations
615        None
616    }
617}
618
619// Helper functions that don't have SIMD equivalents
620
621fn hamming_distance(a: &[u8], b: &[u8]) -> f32 {
622    a.iter().zip(b).filter(|(x, y)| x != y).count() as f32
623}
624
625// K-means clustering for quantization
626fn kmeans_clustering(vectors: &[Vector], k: usize) -> Result<Vec<Vector>> {
627    if vectors.is_empty() || k == 0 {
628        return Ok(Vec::new());
629    }
630
631    let dimensions = vectors[0].dimensions;
632    let mut centroids = Vec::with_capacity(k);
633
634    // Initialize centroids randomly
635    for i in 0..k {
636        let idx = i % vectors.len();
637        centroids.push(vectors[idx].clone());
638    }
639
640    // Simple k-means iterations
641    for _ in 0..10 {
642        let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); k];
643
644        // Assign vectors to nearest centroid
645        for vector in vectors {
646            let mut best_centroid = 0;
647            let mut best_distance = f32::INFINITY;
648
649            for (i, centroid) in centroids.iter().enumerate() {
650                let vector_f32 = vector.as_f32();
651                let centroid_f32 = centroid.as_f32();
652                use oxirs_core::simd::SimdOps;
653                let distance = f32::euclidean_distance(&vector_f32, &centroid_f32);
654                if distance < best_distance {
655                    best_distance = distance;
656                    best_centroid = i;
657                }
658            }
659
660            clusters[best_centroid].push(vector);
661        }
662
663        // Update centroids
664        for (i, cluster) in clusters.iter().enumerate() {
665            if !cluster.is_empty() {
666                let mut new_centroid = vec![0.0; dimensions];
667
668                for vector in cluster {
669                    let vector_f32 = vector.as_f32();
670                    for (j, &value) in vector_f32.iter().enumerate() {
671                        new_centroid[j] += value;
672                    }
673                }
674
675                for value in &mut new_centroid {
676                    *value /= cluster.len() as f32;
677                }
678
679                centroids[i] = Vector::new(new_centroid);
680            }
681        }
682    }
683
684    Ok(centroids)
685}
686
687/// Multi-index system that combines multiple index types
688pub struct MultiIndex {
689    indices: HashMap<String, Box<dyn VectorIndex>>,
690    default_index: String,
691}
692
693impl MultiIndex {
694    pub fn new() -> Self {
695        Self {
696            indices: HashMap::new(),
697            default_index: String::new(),
698        }
699    }
700
701    pub fn add_index(&mut self, name: String, index: Box<dyn VectorIndex>) {
702        if self.indices.is_empty() {
703            self.default_index = name.clone();
704        }
705        self.indices.insert(name, index);
706    }
707
708    pub fn set_default(&mut self, name: &str) -> Result<()> {
709        if self.indices.contains_key(name) {
710            self.default_index = name.to_string();
711            Ok(())
712        } else {
713            Err(anyhow!("Index '{}' not found", name))
714        }
715    }
716
717    pub fn search_index(
718        &self,
719        index_name: &str,
720        query: &Vector,
721        k: usize,
722    ) -> Result<Vec<(String, f32)>> {
723        if let Some(index) = self.indices.get(index_name) {
724            index.search_knn(query, k)
725        } else {
726            Err(anyhow!("Index '{}' not found", index_name))
727        }
728    }
729}
730
731impl Default for MultiIndex {
732    fn default() -> Self {
733        Self::new()
734    }
735}
736
737impl VectorIndex for MultiIndex {
738    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
739        if let Some(index) = self.indices.get_mut(&self.default_index) {
740            index.insert(uri, vector)
741        } else {
742            Err(anyhow!("No default index set"))
743        }
744    }
745
746    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
747        if let Some(index) = self.indices.get(&self.default_index) {
748            index.search_knn(query, k)
749        } else {
750            Err(anyhow!("No default index set"))
751        }
752    }
753
754    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
755        if let Some(index) = self.indices.get(&self.default_index) {
756            index.search_threshold(query, threshold)
757        } else {
758            Err(anyhow!("No default index set"))
759        }
760    }
761
762    fn get_vector(&self, uri: &str) -> Option<&Vector> {
763        if let Some(index) = self.indices.get(&self.default_index) {
764            index.get_vector(uri)
765        } else {
766            None
767        }
768    }
769}