Skip to main content

graphrag_core/vector/
mod.rs

1use crate::{GraphRAGError, Result};
2#[cfg(feature = "parallel-processing")]
3use crate::parallel::ParallelProcessor;
4use std::collections::hash_map::DefaultHasher;
5use std::collections::HashMap;
6use std::hash::{Hash, Hasher};
7
8#[cfg(feature = "vector-hnsw")]
9use instant_distance::{Builder, Point, Search};
10
11// Voy vector store module (WASM-optimized)
12// TODO: Re-enable when voy crate is properly configured
13// #[cfg(feature = "wasm")]
14// pub mod voy_store;
15
16// #[cfg(feature = "wasm")]
17// pub use voy_store::{VoyStore, VoyStoreStatistics};
18
19/// Wrapper for Vec<f32> to implement Point trait for vector operations
20#[derive(Debug, Clone, PartialEq)]
21pub struct Vector(Vec<f32>);
22
23impl Vector {
24    /// Create a new vector from raw data
25    pub fn new(vector_data: Vec<f32>) -> Self {
26        Self(vector_data)
27    }
28
29    /// Get the vector data as a slice
30    pub fn as_slice(&self) -> &[f32] {
31        &self.0
32    }
33}
34
35#[cfg(feature = "vector-hnsw")]
36impl Point for Vector {
37    fn distance(&self, other: &Self) -> f32 {
38        // Euclidean distance
39        if self.0.len() != other.0.len() {
40            return f32::INFINITY;
41        }
42
43        self.0
44            .iter()
45            .zip(other.0.iter())
46            .map(|(a, b)| (a - b).powi(2))
47            .sum::<f32>()
48            .sqrt()
49    }
50}
51
52/// Vector index for semantic search
53pub struct VectorIndex {
54    #[cfg(feature = "vector-hnsw")]
55    index: Option<instant_distance::HnswMap<Vector, String>>,
56    #[cfg(not(feature = "vector-hnsw"))]
57    index: Option<()>, // Placeholder when HNSW is not available
58    embeddings: HashMap<String, Vec<f32>>,
59    #[cfg(feature = "parallel-processing")]
60    parallel_processor: Option<ParallelProcessor>,
61}
62
63impl VectorIndex {
64    /// Create a new vector index
65    pub fn new() -> Self {
66        Self {
67            index: None,
68            embeddings: HashMap::new(),
69            #[cfg(feature = "parallel-processing")]
70            parallel_processor: None,
71        }
72    }
73
74    /// Create a new vector index with parallel processing support
75    #[cfg(feature = "parallel-processing")]
76    pub fn with_parallel_processing(parallel_processor: ParallelProcessor) -> Self {
77        Self {
78            index: None,
79            embeddings: HashMap::new(),
80            parallel_processor: Some(parallel_processor),
81        }
82    }
83
84    /// Add a vector to the index
85    pub fn add_vector(&mut self, id: String, embedding: Vec<f32>) -> Result<()> {
86        if embedding.is_empty() {
87            return Err(GraphRAGError::VectorSearch {
88                message: "Empty embedding vector".to_string(),
89            });
90        }
91
92        self.embeddings.insert(id, embedding);
93        Ok(())
94    }
95
96    /// Build the index from all added vectors
97    pub fn build_index(&mut self) -> Result<()> {
98        if self.embeddings.is_empty() {
99            return Err(GraphRAGError::VectorSearch {
100                message: "No embeddings to build index from".to_string(),
101            });
102        }
103
104        #[cfg(feature = "vector-hnsw")]
105        {
106            let points: Vec<Vector> = self
107                .embeddings
108                .values()
109                .map(|v| Vector::new(v.clone()))
110                .collect();
111
112            let values: Vec<String> = self.embeddings.keys().cloned().collect();
113
114            let builder = Builder::default();
115            let index = builder.build(points, values);
116
117            self.index = Some(index);
118        }
119
120        #[cfg(not(feature = "vector-hnsw"))]
121        {
122            println!(
123                "Warning: HNSW vector indexing not available. Install with --features vector-hnsw"
124            );
125            self.index = Some(());
126        }
127
128        Ok(())
129    }
130
131    /// Search for similar vectors
132    pub fn search(&self, query_embedding: &[f32], top_k: usize) -> Result<Vec<(String, f32)>> {
133        let _index = self
134            .index
135            .as_ref()
136            .ok_or_else(|| GraphRAGError::VectorSearch {
137                message: "Index not built. Call build_index() first.".to_string(),
138            })?;
139
140        #[cfg(feature = "vector-hnsw")]
141        {
142            let query_point = Vector::new(query_embedding.to_vec());
143            let mut search = Search::default();
144
145            let results = _index.search(&query_point, &mut search);
146
147            let mut scored_results = Vec::new();
148            for item in results.into_iter().take(top_k) {
149                let distance = item.distance;
150                // Convert distance to similarity using exponential decay for better score distribution
151                let similarity = (-distance).exp().clamp(0.0, 1.0);
152                scored_results.push((item.value.clone(), similarity));
153            }
154
155            Ok(scored_results)
156        }
157
158        #[cfg(not(feature = "vector-hnsw"))]
159        {
160            // Fallback to brute force similarity search
161            let query_vec = query_embedding;
162            let mut scored_results = Vec::new();
163
164            for (id, embedding) in &self.embeddings {
165                let similarity = self.cosine_similarity(query_vec, embedding);
166                scored_results.push((id.clone(), similarity));
167            }
168
169            // Sort by similarity (highest first) and take top_k
170            scored_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
171            scored_results.truncate(top_k);
172
173            Ok(scored_results)
174        }
175    }
176
177    /// Calculate cosine similarity between two vectors (fallback when HNSW is not available)
178    #[cfg(not(feature = "vector-hnsw"))]
179    fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
180        if a.len() != b.len() {
181            return 0.0;
182        }
183
184        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
185        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
186        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
187
188        if norm_a == 0.0 || norm_b == 0.0 {
189            0.0
190        } else {
191            dot_product / (norm_a * norm_b)
192        }
193    }
194
195    /// Get the number of vectors in the index
196    pub fn len(&self) -> usize {
197        self.embeddings.len()
198    }
199
200    /// Check if the index is empty
201    pub fn is_empty(&self) -> bool {
202        self.embeddings.is_empty()
203    }
204
205    /// Get embedding dimension (assuming all embeddings have the same dimension)
206    pub fn dimension(&self) -> Option<usize> {
207        self.embeddings.values().next().map(|v| v.len())
208    }
209
210    /// Remove a vector from the index
211    pub fn remove_vector(&mut self, id: &str) -> Result<()> {
212        self.embeddings.remove(id);
213        // Note: instant-distance doesn't support removal, so we need to rebuild
214        if !self.embeddings.is_empty() {
215            self.build_index()?;
216        } else {
217            self.index = None;
218        }
219        Ok(())
220    }
221
222    /// Get all vector IDs
223    pub fn get_ids(&self) -> Vec<String> {
224        self.embeddings.keys().cloned().collect()
225    }
226
227    /// Check if a vector exists
228    pub fn contains(&self, id: &str) -> bool {
229        self.embeddings.contains_key(id)
230    }
231
232    /// Get embedding by ID
233    pub fn get_embedding(&self, id: &str) -> Option<&Vec<f32>> {
234        self.embeddings.get(id)
235    }
236
237    /// Batch add multiple vectors in parallel with proper synchronization
238    pub fn batch_add_vectors(&mut self, vectors: Vec<(String, Vec<f32>)>) -> Result<()> {
239        #[cfg(feature = "parallel-processing")]
240        if let Some(processor) = self.parallel_processor.clone() {
241            return self.batch_add_vectors_parallel(vectors, &processor);
242        }
243
244        // Sequential fallback
245        for (id, embedding) in vectors {
246            self.add_vector(id, embedding)?;
247        }
248        Ok(())
249    }
250
251    /// Parallel batch vector addition with conflict detection and chunked processing
252    #[cfg(feature = "parallel-processing")]
253    fn batch_add_vectors_parallel(
254        &mut self,
255        vectors: Vec<(String, Vec<f32>)>,
256        processor: &ParallelProcessor,
257    ) -> Result<()> {
258        if !processor.should_use_parallel(vectors.len()) {
259            // Use sequential processing for small batches
260            for (id, embedding) in vectors {
261                self.add_vector(id, embedding)?;
262            }
263            return Ok(());
264        }
265
266        #[cfg(feature = "parallel-processing")]
267        {
268            use rayon::prelude::*;
269            use std::collections::HashMap;
270
271            // Pre-validate all vectors in parallel
272            let validation_results: std::result::Result<Vec<_>, crate::GraphRAGError> = vectors
273                .par_iter()
274                .map(|(id, embedding)| {
275                    if embedding.is_empty() {
276                        Err(crate::GraphRAGError::VectorSearch {
277                            message: format!("Empty embedding vector for ID: {id}"),
278                        })
279                    } else {
280                        Ok((id.clone(), embedding.clone()))
281                    }
282                })
283                .collect();
284
285            let validated_vectors = match validation_results {
286                Ok(vectors) => vectors,
287                Err(e) => {
288                    eprintln!("Vector validation failed: {e}");
289                    // Fall back to sequential processing with validation
290                    for (id, embedding) in vectors {
291                        self.add_vector(id, embedding)?;
292                    }
293                    return Ok(());
294                }
295            };
296
297            // Check for duplicate IDs and resolve conflicts
298            let mut unique_vectors = HashMap::new();
299            for (id, embedding) in validated_vectors {
300                if unique_vectors.contains_key(&id) {
301                    eprintln!("Warning: Duplicate vector ID '{id}' - using latest");
302                }
303                unique_vectors.insert(id, embedding);
304            }
305
306            // Convert to vector pairs for sequential insertion
307            let vector_pairs: Vec<_> = unique_vectors.into_iter().collect();
308
309            // Vector pairs are already validated and deduplicated
310
311            // Apply the validated vectors to the embeddings map sequentially
312            for (id, embedding) in vector_pairs {
313                self.embeddings.insert(id, embedding);
314            }
315
316            println!("Added {} vectors in parallel batch", vectors.len());
317        }
318
319        #[cfg(not(feature = "parallel-processing"))]
320        {
321            // Sequential fallback when parallel processing is not available
322            for (id, embedding) in vectors {
323                self.add_vector(id, embedding)?;
324            }
325        }
326
327        Ok(())
328    }
329
330    /// Batch search for multiple queries in parallel
331    pub fn batch_search(
332        &self,
333        queries: &[Vec<f32>],
334        top_k: usize,
335    ) -> Result<Vec<Vec<(String, f32)>>> {
336        #[cfg(feature = "parallel-processing")]
337        {
338            if let Some(processor) = &self.parallel_processor {
339                if processor.should_use_parallel(queries.len()) {
340                    use rayon::prelude::*;
341                    return queries
342                        .par_iter()
343                        .map(|query| self.search(query, top_k))
344                        .collect();
345                }
346            }
347        }
348
349        // Sequential fallback
350        queries
351            .iter()
352            .map(|query| self.search(query, top_k))
353            .collect()
354    }
355
356    /// Parallel similarity computation between all vectors with optimized chunking
357    pub fn compute_all_similarities(&self) -> HashMap<(String, String), f32> {
358        #[cfg(feature = "parallel-processing")]
359        if let Some(processor) = &self.parallel_processor {
360            return self.compute_similarities_parallel(processor);
361        }
362
363        // Sequential fallback
364        self.compute_similarities_sequential()
365    }
366
367    /// Parallel similarity computation with work-stealing and memory optimization
368    #[cfg(feature = "parallel-processing")]
369    fn compute_similarities_parallel(
370        &self,
371        processor: &ParallelProcessor,
372    ) -> HashMap<(String, String), f32> {
373        let ids: Vec<String> = self.embeddings.keys().cloned().collect();
374        let total_pairs = (ids.len() * (ids.len() - 1)) / 2;
375
376        if !processor.should_use_parallel(total_pairs) {
377            return self.compute_similarities_sequential();
378        }
379
380        #[cfg(feature = "parallel-processing")]
381        {
382            use rayon::prelude::*;
383
384            // Pre-collect embeddings for efficient parallel access
385            let embedding_vec: Vec<(String, Vec<f32>)> = ids
386                .iter()
387                .filter_map(|id| {
388                    self.embeddings.get(id).map(|emb| (id.clone(), emb.clone()))
389                })
390                .collect();
391
392            if embedding_vec.len() < 2 {
393                return HashMap::new();
394            }
395
396            // Generate pairs for parallel processing
397            let mut pairs = Vec::new();
398            for i in 0..embedding_vec.len() {
399                for j in (i + 1)..embedding_vec.len() {
400                    pairs.push((i, j));
401                }
402            }
403
404            // Parallel similarity computation with chunked processing
405            let chunk_size = processor.config().chunk_batch_size.min(pairs.len());
406            let similarities: HashMap<(String, String), f32> = pairs
407                .par_chunks(chunk_size)
408                .map(|chunk| {
409                    let mut local_similarities = HashMap::new();
410
411                    for &(i, j) in chunk {
412                        let (first_id, first_emb) = &embedding_vec[i];
413                        let (second_id, second_emb) = &embedding_vec[j];
414
415                        let similarity = VectorUtils::cosine_similarity(first_emb, second_emb);
416
417                        // Only store similarities above a threshold to save memory
418                        if similarity > 0.1 {
419                            local_similarities.insert((first_id.clone(), second_id.clone()), similarity);
420                        }
421                    }
422
423                    local_similarities
424                })
425                .reduce(
426                    HashMap::new,
427                    |mut acc, chunk_similarities| {
428                        acc.extend(chunk_similarities);
429                        acc
430                    }
431                );
432
433            println!(
434                "Computed {} similarities from {} vectors in parallel",
435                similarities.len(),
436                embedding_vec.len()
437            );
438
439            similarities
440        }
441
442        #[cfg(not(feature = "parallel-processing"))]
443        {
444            self.compute_similarities_sequential()
445        }
446    }
447
448    /// Sequential similarity computation (fallback)
449    fn compute_similarities_sequential(&self) -> HashMap<(String, String), f32> {
450        let ids: Vec<String> = self.embeddings.keys().cloned().collect();
451        let mut similarities = HashMap::new();
452
453        for (i, id1) in ids.iter().enumerate() {
454            if let Some(emb1) = self.embeddings.get(id1) {
455                for id2 in ids.iter().skip(i + 1) {
456                    if let Some(emb2) = self.embeddings.get(id2) {
457                        let sim = VectorUtils::cosine_similarity(emb1, emb2);
458                        // Only store similarities above a threshold to save memory
459                        if sim > 0.1 {
460                            similarities.insert((id1.clone(), id2.clone()), sim);
461                        }
462                    }
463                }
464            }
465        }
466
467        similarities
468    }
469
470    /// Find vectors within a similarity threshold
471    pub fn find_similar(
472        &self,
473        query_embedding: &[f32],
474        threshold: f32,
475    ) -> Result<Vec<(String, f32)>> {
476        let results = self.search(query_embedding, self.len())?;
477        Ok(results
478            .into_iter()
479            .filter(|(_, similarity)| *similarity >= threshold)
480            .collect())
481    }
482
483    /// Calculate statistics about the index
484    pub fn statistics(&self) -> VectorIndexStatistics {
485        let dimension = self.dimension().unwrap_or(0);
486        let vector_count = self.len();
487
488        // Calculate basic statistics
489        let mut min_norm = f32::INFINITY;
490        let mut max_norm: f32 = 0.0;
491        let mut sum_norm = 0.0;
492
493        for embedding in self.embeddings.values() {
494            let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
495            min_norm = min_norm.min(norm);
496            max_norm = max_norm.max(norm);
497            sum_norm += norm;
498        }
499
500        let avg_norm = if vector_count > 0 {
501            sum_norm / vector_count as f32
502        } else {
503            0.0
504        };
505
506        VectorIndexStatistics {
507            vector_count,
508            dimension,
509            min_norm,
510            max_norm,
511            avg_norm,
512            index_built: self.index.is_some(),
513        }
514    }
515}
516
517impl Default for VectorIndex {
518    fn default() -> Self {
519        Self::new()
520    }
521}
522
523/// Statistics about the vector index
524#[derive(Debug)]
525pub struct VectorIndexStatistics {
526    /// Total number of vectors in the index
527    pub vector_count: usize,
528    /// Dimensionality of vectors
529    pub dimension: usize,
530    /// Minimum vector norm
531    pub min_norm: f32,
532    /// Maximum vector norm
533    pub max_norm: f32,
534    /// Average vector norm
535    pub avg_norm: f32,
536    /// Whether the index has been built
537    pub index_built: bool,
538}
539
540impl VectorIndexStatistics {
541    /// Print statistics
542    pub fn print(&self) {
543        println!("Vector Index Statistics:");
544        println!("  Vector count: {}", self.vector_count);
545        println!("  Dimension: {}", self.dimension);
546        println!("  Index built: {}", self.index_built);
547        if self.vector_count > 0 {
548            println!("  Vector norms:");
549            println!("    Min: {:.4}", self.min_norm);
550            println!("    Max: {:.4}", self.max_norm);
551            println!("    Average: {:.4}", self.avg_norm);
552        }
553    }
554}
555
556/// Utility functions for vector operations
557pub struct VectorUtils;
558
559/// Simple embedding generator using hash-based approach for consistent vectors
560pub struct EmbeddingGenerator {
561    dimension: usize,
562    word_vectors: HashMap<String, Vec<f32>>,
563}
564
565impl EmbeddingGenerator {
566    /// Create a new embedding generator with specified dimension
567    pub fn new(dimension: usize) -> Self {
568        Self {
569            dimension,
570            word_vectors: HashMap::new(),
571        }
572    }
573
574    /// Create a new embedding generator with parallel processing support
575    #[cfg(feature = "parallel-processing")]
576    pub fn with_parallel_processing(
577        dimension: usize,
578        _parallel_processor: ParallelProcessor,
579    ) -> Self {
580        Self {
581            dimension,
582            word_vectors: HashMap::new(),
583        }
584    }
585
586    /// Generate embedding for a text string
587    pub fn generate_embedding(&mut self, text: &str) -> Vec<f32> {
588        let words: Vec<&str> = text.split_whitespace().collect();
589        if words.is_empty() {
590            return vec![0.0; self.dimension];
591        }
592
593        // Get or create word vectors
594        let mut word_embeddings = Vec::new();
595        for word in &words {
596            let normalized_word = word.to_lowercase();
597            if !self.word_vectors.contains_key(&normalized_word) {
598                self.word_vectors.insert(
599                    normalized_word.clone(),
600                    self.generate_word_vector(&normalized_word),
601                );
602            }
603            word_embeddings.push(self.word_vectors[&normalized_word].clone());
604        }
605
606        // Average the word vectors
607        let mut result = vec![0.0; self.dimension];
608        for word_vec in word_embeddings {
609            for (i, value) in word_vec.iter().enumerate() {
610                result[i] += value;
611            }
612        }
613
614        // Normalize by number of words
615        let word_count = words.len() as f32;
616        for value in &mut result {
617            *value /= word_count;
618        }
619
620        // Normalize to unit vector
621        VectorUtils::normalize(&mut result);
622        result
623    }
624
625    /// Generate a consistent vector for a word using hash-based approach
626    fn generate_word_vector(&self, word: &str) -> Vec<f32> {
627        let mut vector = Vec::with_capacity(self.dimension);
628
629        // Use multiple hash seeds for better distribution
630        for i in 0..self.dimension {
631            let mut hasher = DefaultHasher::new();
632            word.hash(&mut hasher);
633            i.hash(&mut hasher);
634
635            let hash = hasher.finish();
636            // Convert hash to float in range [-1, 1]
637            let value = ((hash % 2000) as f32 - 1000.0) / 1000.0;
638            vector.push(value);
639        }
640
641        // Normalize to unit vector for better similarity properties
642        VectorUtils::normalize(&mut vector);
643        vector
644    }
645
646    /// Generate embeddings for multiple texts in batch with parallel processing
647    pub fn batch_generate(&mut self, texts: &[&str]) -> Vec<Vec<f32>> {
648        // Use sequential approach to avoid borrowing issues
649        let mut results = Vec::with_capacity(texts.len());
650        for text in texts {
651            results.push(self.generate_embedding(text));
652        }
653        results
654    }
655
656    /// Parallel batch generation with chunking for very large datasets
657    pub fn batch_generate_chunked(&mut self, texts: &[&str], chunk_size: usize) -> Vec<Vec<f32>> {
658        if texts.len() <= chunk_size {
659            return self.batch_generate(texts);
660        }
661
662        #[cfg(feature = "parallel-processing")]
663        {
664            use rayon::prelude::*;
665
666            // Process in chunks to manage memory usage
667            let results: Vec<Vec<f32>> = texts
668                .par_chunks(chunk_size)
669                .map(|chunk| {
670                    // Each chunk is processed with its own generator state
671                    let mut local_generator = EmbeddingGenerator::new(self.dimension);
672                    local_generator.word_vectors = self.word_vectors.clone(); // Share cached words
673
674                    chunk.iter().map(|&text| {
675                        local_generator.generate_embedding(text)
676                    }).collect::<Vec<_>>()
677                })
678                .flatten()
679                .collect();
680
681            // Update the main generator's word cache with new words from parallel processing
682            // Note: This is a simplified approach - in a more sophisticated implementation,
683            // we would merge the word caches from all parallel workers
684
685            println!(
686                "Generated {} embeddings in parallel chunks of size {}",
687                texts.len(),
688                chunk_size
689            );
690
691            results
692        }
693
694        #[cfg(not(feature = "parallel-processing"))]
695        {
696            // Sequential chunked processing when parallel is not available
697            let mut results = Vec::with_capacity(texts.len());
698
699            for chunk in texts.chunks(chunk_size) {
700                for &text in chunk {
701                    results.push(self.generate_embedding(text));
702                }
703            }
704
705            results
706        }
707    }
708
709    /// Get the embedding dimension
710    pub fn dimension(&self) -> usize {
711        self.dimension
712    }
713
714    /// Get the number of cached word vectors
715    pub fn cached_words(&self) -> usize {
716        self.word_vectors.len()
717    }
718
719    /// Clear the word vector cache
720    pub fn clear_cache(&mut self) {
721        self.word_vectors.clear();
722    }
723}
724
725impl Default for EmbeddingGenerator {
726    fn default() -> Self {
727        Self::new(128) // Default to 128-dimensional embeddings
728    }
729}
730
731impl VectorUtils {
732    /// Calculate cosine similarity between two vectors
733    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
734        if a.len() != b.len() {
735            return 0.0;
736        }
737
738        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
739        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
740        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
741
742        if norm_a == 0.0 || norm_b == 0.0 {
743            0.0
744        } else {
745            dot_product / (norm_a * norm_b)
746        }
747    }
748
749    /// Calculate Euclidean distance between two vectors
750    pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
751        if a.len() != b.len() {
752            return f32::INFINITY;
753        }
754
755        a.iter()
756            .zip(b.iter())
757            .map(|(x, y)| (x - y).powi(2))
758            .sum::<f32>()
759            .sqrt()
760    }
761
762    /// Normalize a vector to unit length
763    pub fn normalize(vector: &mut [f32]) {
764        let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
765        if norm > 0.0 {
766            for x in vector {
767                *x /= norm;
768            }
769        }
770    }
771
772    /// Generate a random vector (for testing)
773    pub fn random_vector(dimension: usize) -> Vec<f32> {
774        use std::collections::hash_map::DefaultHasher;
775        use std::hash::{Hash, Hasher};
776
777        let mut vector = Vec::with_capacity(dimension);
778        let mut hasher = DefaultHasher::new();
779
780        for i in 0..dimension {
781            i.hash(&mut hasher);
782            let hash = hasher.finish();
783            let value = ((hash % 1000) as f32 - 500.0) / 1000.0; // Range [-0.5, 0.5]
784            vector.push(value);
785        }
786
787        vector
788    }
789
790    /// Calculate the centroid of multiple vectors
791    pub fn centroid(vectors: &[Vec<f32>]) -> Option<Vec<f32>> {
792        if vectors.is_empty() {
793            return None;
794        }
795
796        let dimension = vectors[0].len();
797        if !vectors.iter().all(|v| v.len() == dimension) {
798            return None; // All vectors must have the same dimension
799        }
800
801        let mut centroid = vec![0.0; dimension];
802        for vector in vectors {
803            for (i, &value) in vector.iter().enumerate() {
804                centroid[i] += value;
805            }
806        }
807
808        let count = vectors.len() as f32;
809        for value in &mut centroid {
810            *value /= count;
811        }
812
813        Some(centroid)
814    }
815}
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820
821    #[test]
822    fn test_vector_index_creation() {
823        let mut index = VectorIndex::new();
824        assert!(index.is_empty());
825
826        let embedding = vec![0.1, 0.2, 0.3];
827        index.add_vector("test".to_string(), embedding).unwrap();
828
829        assert!(!index.is_empty());
830        assert_eq!(index.len(), 1);
831        assert_eq!(index.dimension(), Some(3));
832    }
833
834    #[test]
835    fn test_vector_search() {
836        let mut index = VectorIndex::new();
837
838        // Add some test vectors
839        index
840            .add_vector("doc1".to_string(), vec![1.0, 0.0, 0.0])
841            .unwrap();
842        index
843            .add_vector("doc2".to_string(), vec![0.0, 1.0, 0.0])
844            .unwrap();
845        index
846            .add_vector("doc3".to_string(), vec![0.8, 0.2, 0.0])
847            .unwrap();
848
849        index.build_index().unwrap();
850
851        // Search for similar vectors
852        let query = vec![1.0, 0.0, 0.0];
853        let results = index.search(&query, 2).unwrap();
854
855        assert!(!results.is_empty());
856        assert!(results.len() <= 2);
857
858        // First result should be most similar
859        assert_eq!(results[0].0, "doc1");
860    }
861
862    #[test]
863    fn test_cosine_similarity() {
864        let vec1 = vec![1.0, 0.0, 0.0];
865        let vec2 = vec![1.0, 0.0, 0.0];
866        let vec3 = vec![0.0, 1.0, 0.0];
867
868        assert!((VectorUtils::cosine_similarity(&vec1, &vec2) - 1.0).abs() < 0.001);
869        assert!((VectorUtils::cosine_similarity(&vec1, &vec3) - 0.0).abs() < 0.001);
870    }
871
872    #[test]
873    fn test_vector_normalization() {
874        let mut vector = vec![3.0, 4.0];
875        VectorUtils::normalize(&mut vector);
876
877        let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
878        assert!((norm - 1.0).abs() < 0.001);
879    }
880
881    #[test]
882    fn test_centroid_calculation() {
883        let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
884
885        let centroid = VectorUtils::centroid(&vectors).unwrap();
886        assert!((centroid[0] - 2.0 / 3.0).abs() < 0.001);
887        assert!((centroid[1] - 2.0 / 3.0).abs() < 0.001);
888    }
889
890    #[test]
891    fn test_embedding_generator() {
892        let mut generator = EmbeddingGenerator::new(64);
893
894        let text1 = "hello world";
895        let text2 = "hello world";
896        let text3 = "goodbye world";
897
898        let embedding1 = generator.generate_embedding(text1);
899        let embedding2 = generator.generate_embedding(text2);
900        let embedding3 = generator.generate_embedding(text3);
901
902        // Same text should produce identical embeddings
903        assert_eq!(embedding1, embedding2);
904
905        // Different text should produce different embeddings
906        assert_ne!(embedding1, embedding3);
907
908        // Check dimension
909        assert_eq!(embedding1.len(), 64);
910
911        // Check that embeddings are normalized
912        let norm1 = embedding1.iter().map(|x| x * x).sum::<f32>().sqrt();
913        assert!((norm1 - 1.0).abs() < 0.001);
914    }
915
916    #[test]
917    fn test_batch_embedding_generation() {
918        let mut generator = EmbeddingGenerator::new(32);
919
920        let texts = vec!["first text", "second text", "third text"];
921        let embeddings = generator.batch_generate(&texts);
922
923        assert_eq!(embeddings.len(), 3);
924        assert!(embeddings.iter().all(|e| e.len() == 32));
925
926        // Each embedding should be different
927        assert_ne!(embeddings[0], embeddings[1]);
928        assert_ne!(embeddings[1], embeddings[2]);
929    }
930
931    #[test]
932    fn test_embedding_similarity() {
933        let mut generator = EmbeddingGenerator::new(64);
934
935        let similar1 = generator.generate_embedding("machine learning artificial intelligence");
936        let similar2 = generator.generate_embedding("artificial intelligence machine learning");
937        let different = generator.generate_embedding("cooking recipes kitchen");
938
939        let sim1 = VectorUtils::cosine_similarity(&similar1, &similar2);
940        let sim2 = VectorUtils::cosine_similarity(&similar1, &different);
941
942        // Similar content should have higher similarity
943        assert!(sim1 > sim2);
944    }
945}