Skip to main content

oxirs_vec/
index.rs

1//! Advanced vector indexing with HNSW and other efficient algorithms
2
3use crate::Vector;
4
5// Re-export VectorIndex trait for use by other modules
6pub use crate::VectorIndex;
7use anyhow::{anyhow, Result};
8use oxirs_core::parallel::*;
9use oxirs_core::Triple;
10use serde::{Deserialize, Serialize};
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::sync::Arc;
14
15use crate::hnsw::{HnswConfig, HnswIndex};
16
17/// Type alias for filter functions
18pub type FilterFunction = Box<dyn Fn(&str) -> bool>;
19/// Type alias for filter functions with Send + Sync
20pub type FilterFunctionSync = Box<dyn Fn(&str) -> bool + Send + Sync>;
21
22/// Configuration for vector index
23#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
24pub struct IndexConfig {
25    /// Index type to use
26    pub index_type: IndexType,
27    /// Maximum number of connections for each node (for HNSW)
28    pub max_connections: usize,
29    /// Construction parameter (for HNSW)
30    pub ef_construction: usize,
31    /// Search parameter (for HNSW)
32    pub ef_search: usize,
33    /// Distance metric to use
34    pub distance_metric: DistanceMetric,
35    /// Whether to enable parallel operations
36    pub parallel: bool,
37}
38
39impl Default for IndexConfig {
40    fn default() -> Self {
41        Self {
42            index_type: IndexType::Hnsw,
43            max_connections: 16,
44            ef_construction: 200,
45            ef_search: 50,
46            distance_metric: DistanceMetric::Cosine,
47            parallel: true,
48        }
49    }
50}
51
52/// Available index types
53#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
54pub enum IndexType {
55    /// Hierarchical Navigable Small World
56    Hnsw,
57    /// Simple flat index (brute force)
58    Flat,
59    /// IVF (Inverted File) index
60    Ivf,
61    /// Product Quantization
62    PQ,
63}
64
65/// Distance metrics supported
66#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
67pub enum DistanceMetric {
68    /// Cosine distance (1 - cosine_similarity)
69    Cosine,
70    /// Euclidean (L2) distance
71    Euclidean,
72    /// Manhattan (L1) distance
73    Manhattan,
74    /// Dot product (negative for max-heap behavior)
75    DotProduct,
76}
77
78impl DistanceMetric {
79    /// Calculate distance between two vectors
80    pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
81        use oxirs_core::simd::SimdOps;
82
83        match self {
84            DistanceMetric::Cosine => f32::cosine_distance(a, b),
85            DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
86            DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
87            DistanceMetric::DotProduct => -f32::dot(a, b), // Negative for max-heap
88        }
89    }
90
91    /// Calculate distance between two Vector objects
92    pub fn distance_vectors(&self, a: &Vector, b: &Vector) -> f32 {
93        let a_f32 = a.as_f32();
94        let b_f32 = b.as_f32();
95        self.distance(&a_f32, &b_f32)
96    }
97}
98
99/// Search result with distance/score
100#[derive(Debug, Clone, PartialEq)]
101pub struct SearchResult {
102    pub uri: String,
103    pub distance: f32,
104    pub score: f32,
105    pub metadata: Option<HashMap<String, String>>,
106}
107
108impl Eq for SearchResult {}
109
110impl Ord for SearchResult {
111    fn cmp(&self, other: &Self) -> Ordering {
112        self.distance
113            .partial_cmp(&other.distance)
114            .unwrap_or(Ordering::Equal)
115    }
116}
117
118impl PartialOrd for SearchResult {
119    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
120        Some(self.cmp(other))
121    }
122}
123
124/// Advanced vector index with multiple implementations
125pub struct AdvancedVectorIndex {
126    config: IndexConfig,
127    vectors: Vec<(String, Vector)>,
128    uri_to_id: HashMap<String, usize>,
129    hnsw_index: Option<HnswIndex>,
130    dimensions: Option<usize>,
131}
132
133impl AdvancedVectorIndex {
134    pub fn new(config: IndexConfig) -> Self {
135        Self {
136            config,
137            vectors: Vec::new(),
138            uri_to_id: HashMap::new(),
139            hnsw_index: None,
140            dimensions: None,
141        }
142    }
143
144    /// Build the index after adding all vectors
145    pub fn build(&mut self) -> Result<()> {
146        if self.vectors.is_empty() {
147            return Ok(());
148        }
149
150        match self.config.index_type {
151            IndexType::Hnsw => {
152                self.build_hnsw_index()?;
153            }
154            IndexType::Flat => {
155                // No special building needed for flat index
156            }
157            IndexType::Ivf | IndexType::PQ => {
158                return Err(anyhow!("IVF and PQ indices not yet implemented"));
159            }
160        }
161
162        Ok(())
163    }
164
165    fn build_hnsw_index(&mut self) -> Result<()> {
166        if self.dimensions.is_some() {
167            let hnsw_config = HnswConfig {
168                m: self.config.max_connections,
169                m_l0: self.config.max_connections * 2,
170                ef_construction: self.config.ef_construction,
171                ef: self.config.ef_search,
172                ..HnswConfig::default()
173            };
174
175            let mut hnsw = HnswIndex::new_cpu_only(hnsw_config);
176
177            for (uri, vector) in &self.vectors {
178                hnsw.insert(uri.clone(), vector.clone())?;
179            }
180
181            self.hnsw_index = Some(hnsw);
182        }
183
184        Ok(())
185    }
186
187    /// Add metadata to a vector
188    pub fn add_metadata(&mut self, _uri: &str, _metadata: HashMap<String, String>) -> Result<()> {
189        // For now, we'll store metadata separately
190        // In a full implementation, this would be integrated with the index
191        Ok(())
192    }
193
194    /// Search with advanced parameters
195    pub fn search_advanced(
196        &self,
197        query: &Vector,
198        k: usize,
199        _ef: Option<usize>,
200        filter: Option<FilterFunction>,
201    ) -> Result<Vec<SearchResult>> {
202        match self.config.index_type {
203            IndexType::Hnsw => self.search_hnsw(query, k),
204            _ => self.search_flat(query, k, filter),
205        }
206    }
207
208    fn search_hnsw(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
209        if let Some(ref hnsw) = self.hnsw_index {
210            let results = hnsw.search_knn(query, k)?;
211
212            Ok(results
213                .into_iter()
214                .map(|(uri, distance)| SearchResult {
215                    uri,
216                    distance,
217                    score: 1.0 - distance,
218                    metadata: None,
219                })
220                .collect())
221        } else {
222            Err(anyhow!("HNSW index not built"))
223        }
224    }
225
226    fn search_flat(
227        &self,
228        query: &Vector,
229        k: usize,
230        filter: Option<FilterFunction>,
231    ) -> Result<Vec<SearchResult>> {
232        if self.config.parallel && self.vectors.len() > 1000 {
233            // For parallel search, we need Send + Sync filter
234            if filter.is_some() {
235                // Fall back to sequential if filter is present but not Send + Sync
236                self.search_flat_sequential(query, k, filter)
237            } else {
238                self.search_flat_parallel(query, k, None)
239            }
240        } else {
241            self.search_flat_sequential(query, k, filter)
242        }
243    }
244
245    fn search_flat_sequential(
246        &self,
247        query: &Vector,
248        k: usize,
249        filter: Option<FilterFunction>,
250    ) -> Result<Vec<SearchResult>> {
251        let mut heap = BinaryHeap::new();
252
253        for (uri, vector) in &self.vectors {
254            if let Some(ref filter_fn) = filter {
255                if !filter_fn(uri) {
256                    continue;
257                }
258            }
259
260            let distance = self.config.distance_metric.distance_vectors(query, vector);
261
262            if heap.len() < k {
263                heap.push(std::cmp::Reverse(SearchResult {
264                    uri: uri.clone(),
265                    distance,
266                    score: 1.0 - distance, // Convert distance to similarity score
267                    metadata: None,
268                }));
269            } else if let Some(std::cmp::Reverse(worst)) = heap.peek() {
270                if distance < worst.distance {
271                    heap.pop();
272                    heap.push(std::cmp::Reverse(SearchResult {
273                        uri: uri.clone(),
274                        distance,
275                        score: 1.0 - distance, // Convert distance to similarity score
276                        metadata: None,
277                    }));
278                }
279            }
280        }
281
282        let mut results: Vec<SearchResult> = heap.into_iter().map(|r| r.0).collect();
283        results.sort_by(|a, b| {
284            a.distance
285                .partial_cmp(&b.distance)
286                .unwrap_or(std::cmp::Ordering::Equal)
287        });
288
289        Ok(results)
290    }
291
292    fn search_flat_parallel(
293        &self,
294        query: &Vector,
295        k: usize,
296        filter: Option<FilterFunctionSync>,
297    ) -> Result<Vec<SearchResult>> {
298        // Split vectors into chunks for parallel processing
299        let chunk_size = (self.vectors.len() / num_threads()).max(100);
300
301        // Use Arc for thread-safe sharing of the filter
302        let filter_arc = filter.map(Arc::new);
303
304        // Process chunks in parallel and collect top-k from each
305        let partial_results: Vec<Vec<SearchResult>> = self
306            .vectors
307            .par_chunks(chunk_size)
308            .map(|chunk| {
309                let mut local_heap = BinaryHeap::new();
310                let filter_ref = filter_arc.as_ref();
311
312                for (uri, vector) in chunk {
313                    if let Some(filter_fn) = filter_ref {
314                        if !filter_fn(uri) {
315                            continue;
316                        }
317                    }
318
319                    let distance = self.config.distance_metric.distance_vectors(query, vector);
320
321                    if local_heap.len() < k {
322                        local_heap.push(std::cmp::Reverse(SearchResult {
323                            uri: uri.clone(),
324                            distance,
325                            score: 1.0 - distance, // Convert distance to similarity score
326                            metadata: None,
327                        }));
328                    } else if let Some(std::cmp::Reverse(worst)) = local_heap.peek() {
329                        if distance < worst.distance {
330                            local_heap.pop();
331                            local_heap.push(std::cmp::Reverse(SearchResult {
332                                uri: uri.clone(),
333                                distance,
334                                score: 1.0 - distance, // Convert distance to similarity score
335                                metadata: None,
336                            }));
337                        }
338                    }
339                }
340
341                local_heap
342                    .into_sorted_vec()
343                    .into_iter()
344                    .map(|r| r.0)
345                    .collect()
346            })
347            .collect();
348
349        // Merge results from all chunks
350        let mut final_heap = BinaryHeap::new();
351        for partial in partial_results {
352            for result in partial {
353                if final_heap.len() < k {
354                    final_heap.push(std::cmp::Reverse(result));
355                } else if let Some(std::cmp::Reverse(worst)) = final_heap.peek() {
356                    if result.distance < worst.distance {
357                        final_heap.pop();
358                        final_heap.push(std::cmp::Reverse(result));
359                    }
360                }
361            }
362        }
363
364        let mut results: Vec<SearchResult> = final_heap.into_iter().map(|r| r.0).collect();
365        results.sort_by(|a, b| {
366            a.distance
367                .partial_cmp(&b.distance)
368                .unwrap_or(std::cmp::Ordering::Equal)
369        });
370
371        Ok(results)
372    }
373
374    /// Get index statistics
375    pub fn stats(&self) -> IndexStats {
376        IndexStats {
377            num_vectors: self.vectors.len(),
378            dimensions: self.dimensions.unwrap_or(0),
379            index_type: self.config.index_type,
380            memory_usage: self.estimate_memory_usage(),
381        }
382    }
383
384    fn estimate_memory_usage(&self) -> usize {
385        let vector_memory = self.vectors.len()
386            * (std::mem::size_of::<String>()
387                + self.dimensions.unwrap_or(0) * std::mem::size_of::<f32>());
388
389        let uri_map_memory =
390            self.uri_to_id.len() * (std::mem::size_of::<String>() + std::mem::size_of::<usize>());
391
392        vector_memory + uri_map_memory
393    }
394
395    /// Get the number of vectors in the index
396    pub fn len(&self) -> usize {
397        self.vectors.len()
398    }
399
400    /// Check if the index is empty
401    pub fn is_empty(&self) -> bool {
402        self.vectors.is_empty()
403    }
404
405    /// Add a vector with RDF triple and metadata (for compatibility with tests)
406    pub fn add(
407        &mut self,
408        id: String,
409        vector: Vec<f32>,
410        _triple: Triple,
411        _metadata: HashMap<String, String>,
412    ) -> Result<()> {
413        let vector_obj = Vector::new(vector);
414        self.insert(id, vector_obj)
415    }
416
417    /// Search for nearest neighbors (for compatibility with tests)
418    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
419        let query_vector = Vector::new(query.to_vec());
420        let results = self.search_advanced(&query_vector, k, None, None)?;
421        Ok(results)
422    }
423}
424
425impl VectorIndex for AdvancedVectorIndex {
426    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
427        if let Some(dims) = self.dimensions {
428            if vector.dimensions != dims {
429                return Err(anyhow!(
430                    "Vector dimensions ({}) don't match index dimensions ({})",
431                    vector.dimensions,
432                    dims
433                ));
434            }
435        } else {
436            self.dimensions = Some(vector.dimensions);
437        }
438
439        let id = self.vectors.len();
440        self.uri_to_id.insert(uri.clone(), id);
441        self.vectors.push((uri, vector));
442
443        Ok(())
444    }
445
446    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
447        let results = self.search_advanced(query, k, None, None)?;
448        Ok(results.into_iter().map(|r| (r.uri, r.distance)).collect())
449    }
450
451    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
452        let mut results = Vec::new();
453
454        for (uri, vector) in &self.vectors {
455            let distance = self.config.distance_metric.distance_vectors(query, vector);
456            if distance <= threshold {
457                results.push((uri.clone(), distance));
458            }
459        }
460
461        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
462        Ok(results)
463    }
464
465    fn get_vector(&self, uri: &str) -> Option<&Vector> {
466        // For AdvancedVectorIndex, vectors are stored in the vectors field
467        // regardless of the index type being used
468        self.vectors.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
469    }
470}
471
472/// Index performance statistics
473#[derive(Debug, Clone)]
474pub struct IndexStats {
475    pub num_vectors: usize,
476    pub dimensions: usize,
477    pub index_type: IndexType,
478    pub memory_usage: usize,
479}
480
481/// Quantized vector index for memory efficiency
482pub struct QuantizedVectorIndex {
483    config: IndexConfig,
484    quantized_vectors: Vec<Vec<u8>>,
485    centroids: Vec<Vector>,
486    uri_to_id: HashMap<String, usize>,
487    dimensions: Option<usize>,
488}
489
490impl QuantizedVectorIndex {
491    pub fn new(config: IndexConfig, num_centroids: usize) -> Self {
492        Self {
493            config,
494            quantized_vectors: Vec::new(),
495            centroids: Vec::with_capacity(num_centroids),
496            uri_to_id: HashMap::new(),
497            dimensions: None,
498        }
499    }
500
501    /// Train quantization centroids using k-means
502    pub fn train_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
503        if training_vectors.is_empty() {
504            return Err(anyhow!("No training vectors provided"));
505        }
506
507        let dimensions = training_vectors[0].dimensions;
508        self.dimensions = Some(dimensions);
509
510        // Simple k-means clustering for centroids
511        self.centroids = kmeans_clustering(training_vectors, self.centroids.capacity())?;
512
513        Ok(())
514    }
515
516    fn quantize_vector(&self, vector: &Vector) -> Vec<u8> {
517        let mut quantized = Vec::new();
518
519        // Find nearest centroid for each dimension chunk
520        let chunk_size = vector.dimensions / self.centroids.len().max(1);
521
522        let vector_f32 = vector.as_f32();
523        for chunk in vector_f32.chunks(chunk_size) {
524            let mut best_centroid = 0u8;
525            let mut best_distance = f32::INFINITY;
526
527            for (i, centroid) in self.centroids.iter().enumerate() {
528                let centroid_f32 = centroid.as_f32();
529                let centroid_chunk = &centroid_f32[0..chunk.len().min(centroid.dimensions)];
530                use oxirs_core::simd::SimdOps;
531                let distance = f32::euclidean_distance(chunk, centroid_chunk);
532                if distance < best_distance {
533                    best_distance = distance;
534                    best_centroid = i as u8;
535                }
536            }
537
538            quantized.push(best_centroid);
539        }
540
541        quantized
542    }
543}
544
545impl VectorIndex for QuantizedVectorIndex {
546    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
547        if self.centroids.is_empty() {
548            return Err(anyhow!(
549                "Quantization not trained. Call train_quantization first."
550            ));
551        }
552
553        let id = self.quantized_vectors.len();
554        self.uri_to_id.insert(uri.clone(), id);
555
556        let quantized = self.quantize_vector(&vector);
557        self.quantized_vectors.push(quantized);
558
559        Ok(())
560    }
561
562    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
563        let query_quantized = self.quantize_vector(query);
564        let mut results = Vec::new();
565
566        for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
567            let distance = hamming_distance(&query_quantized, quantized);
568            results.push((uri.clone(), distance));
569        }
570
571        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
572        results.truncate(k);
573
574        Ok(results)
575    }
576
577    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
578        let query_quantized = self.quantize_vector(query);
579        let mut results = Vec::new();
580
581        for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
582            let distance = hamming_distance(&query_quantized, quantized);
583            if distance <= threshold {
584                results.push((uri.clone(), distance));
585            }
586        }
587
588        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
589        Ok(results)
590    }
591
592    fn get_vector(&self, _uri: &str) -> Option<&Vector> {
593        // Quantized index doesn't store original vectors
594        // Return None as we only have quantized representations
595        None
596    }
597}
598
599// Helper functions that don't have SIMD equivalents
600
601fn hamming_distance(a: &[u8], b: &[u8]) -> f32 {
602    a.iter().zip(b).filter(|(x, y)| x != y).count() as f32
603}
604
605// K-means clustering for quantization
606fn kmeans_clustering(vectors: &[Vector], k: usize) -> Result<Vec<Vector>> {
607    if vectors.is_empty() || k == 0 {
608        return Ok(Vec::new());
609    }
610
611    let dimensions = vectors[0].dimensions;
612    let mut centroids = Vec::with_capacity(k);
613
614    // Initialize centroids randomly
615    for i in 0..k {
616        let idx = i % vectors.len();
617        centroids.push(vectors[idx].clone());
618    }
619
620    // Simple k-means iterations
621    for _ in 0..10 {
622        let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); k];
623
624        // Assign vectors to nearest centroid
625        for vector in vectors {
626            let mut best_centroid = 0;
627            let mut best_distance = f32::INFINITY;
628
629            for (i, centroid) in centroids.iter().enumerate() {
630                let vector_f32 = vector.as_f32();
631                let centroid_f32 = centroid.as_f32();
632                use oxirs_core::simd::SimdOps;
633                let distance = f32::euclidean_distance(&vector_f32, &centroid_f32);
634                if distance < best_distance {
635                    best_distance = distance;
636                    best_centroid = i;
637                }
638            }
639
640            clusters[best_centroid].push(vector);
641        }
642
643        // Update centroids
644        for (i, cluster) in clusters.iter().enumerate() {
645            if !cluster.is_empty() {
646                let mut new_centroid = vec![0.0; dimensions];
647
648                for vector in cluster {
649                    let vector_f32 = vector.as_f32();
650                    for (j, &value) in vector_f32.iter().enumerate() {
651                        new_centroid[j] += value;
652                    }
653                }
654
655                for value in &mut new_centroid {
656                    *value /= cluster.len() as f32;
657                }
658
659                centroids[i] = Vector::new(new_centroid);
660            }
661        }
662    }
663
664    Ok(centroids)
665}
666
667/// Multi-index system that combines multiple index types
668pub struct MultiIndex {
669    indices: HashMap<String, Box<dyn VectorIndex>>,
670    default_index: String,
671}
672
673impl MultiIndex {
674    pub fn new() -> Self {
675        Self {
676            indices: HashMap::new(),
677            default_index: String::new(),
678        }
679    }
680
681    pub fn add_index(&mut self, name: String, index: Box<dyn VectorIndex>) {
682        if self.indices.is_empty() {
683            self.default_index = name.clone();
684        }
685        self.indices.insert(name, index);
686    }
687
688    pub fn set_default(&mut self, name: &str) -> Result<()> {
689        if self.indices.contains_key(name) {
690            self.default_index = name.to_string();
691            Ok(())
692        } else {
693            Err(anyhow!("Index '{}' not found", name))
694        }
695    }
696
697    pub fn search_index(
698        &self,
699        index_name: &str,
700        query: &Vector,
701        k: usize,
702    ) -> Result<Vec<(String, f32)>> {
703        if let Some(index) = self.indices.get(index_name) {
704            index.search_knn(query, k)
705        } else {
706            Err(anyhow!("Index '{}' not found", index_name))
707        }
708    }
709}
710
711impl Default for MultiIndex {
712    fn default() -> Self {
713        Self::new()
714    }
715}
716
717impl VectorIndex for MultiIndex {
718    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
719        if let Some(index) = self.indices.get_mut(&self.default_index) {
720            index.insert(uri, vector)
721        } else {
722            Err(anyhow!("No default index set"))
723        }
724    }
725
726    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
727        if let Some(index) = self.indices.get(&self.default_index) {
728            index.search_knn(query, k)
729        } else {
730            Err(anyhow!("No default index set"))
731        }
732    }
733
734    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
735        if let Some(index) = self.indices.get(&self.default_index) {
736            index.search_threshold(query, threshold)
737        } else {
738            Err(anyhow!("No default index set"))
739        }
740    }
741
742    fn get_vector(&self, uri: &str) -> Option<&Vector> {
743        if let Some(index) = self.indices.get(&self.default_index) {
744            index.get_vector(uri)
745        } else {
746            None
747        }
748    }
749}