1#[cfg(feature = "parallel-processing")]
2use crate::parallel::ParallelProcessor;
3use crate::{GraphRAGError, Result};
4use std::collections::hash_map::DefaultHasher;
5use std::collections::HashMap;
6use std::hash::{Hash, Hasher};
7
8#[cfg(feature = "vector-hnsw")]
9use instant_distance::{Builder, Point, Search};
10
11pub mod memory_store;
15
16pub mod store;
18
19#[cfg(feature = "lancedb")]
20pub mod lancedb;
22
23#[cfg(feature = "qdrant")]
24pub mod qdrant;
26
27#[derive(Debug, Clone, PartialEq)]
33pub struct Vector(Vec<f32>);
34
35impl Vector {
36 pub fn new(vector_data: Vec<f32>) -> Self {
38 Self(vector_data)
39 }
40
41 pub fn as_slice(&self) -> &[f32] {
43 &self.0
44 }
45}
46
47#[cfg(feature = "vector-hnsw")]
48impl Point for Vector {
49 fn distance(&self, other: &Self) -> f32 {
50 if self.0.len() != other.0.len() {
52 return f32::INFINITY;
53 }
54
55 self.0
56 .iter()
57 .zip(other.0.iter())
58 .map(|(a, b)| (a - b).powi(2))
59 .sum::<f32>()
60 .sqrt()
61 }
62}
63
64pub struct VectorIndex {
66 #[cfg(feature = "vector-hnsw")]
67 index: Option<instant_distance::HnswMap<Vector, String>>,
68 #[cfg(not(feature = "vector-hnsw"))]
69 index: Option<()>, embeddings: HashMap<String, Vec<f32>>,
71 #[cfg(feature = "parallel-processing")]
72 parallel_processor: Option<ParallelProcessor>,
73}
74
75impl VectorIndex {
76 pub fn new() -> Self {
78 Self {
79 index: None,
80 embeddings: HashMap::new(),
81 #[cfg(feature = "parallel-processing")]
82 parallel_processor: None,
83 }
84 }
85
86 #[cfg(feature = "parallel-processing")]
88 pub fn with_parallel_processing(parallel_processor: ParallelProcessor) -> Self {
89 Self {
90 index: None,
91 embeddings: HashMap::new(),
92 parallel_processor: Some(parallel_processor),
93 }
94 }
95
96 pub fn add_vector(&mut self, id: String, embedding: Vec<f32>) -> Result<()> {
98 if embedding.is_empty() {
99 return Err(GraphRAGError::VectorSearch {
100 message: "Empty embedding vector".to_string(),
101 });
102 }
103
104 self.embeddings.insert(id, embedding);
105 Ok(())
106 }
107
108 pub fn build_index(&mut self) -> Result<()> {
110 if self.embeddings.is_empty() {
111 return Err(GraphRAGError::VectorSearch {
112 message: "No embeddings to build index from".to_string(),
113 });
114 }
115
116 #[cfg(feature = "vector-hnsw")]
117 {
118 let points: Vec<Vector> = self
119 .embeddings
120 .values()
121 .map(|v| Vector::new(v.clone()))
122 .collect();
123
124 let values: Vec<String> = self.embeddings.keys().cloned().collect();
125
126 let builder = Builder::default();
127 let index = builder.build(points, values);
128
129 self.index = Some(index);
130 }
131
132 #[cfg(not(feature = "vector-hnsw"))]
133 {
134 println!(
135 "Warning: HNSW vector indexing not available. Install with --features vector-hnsw"
136 );
137 self.index = Some(());
138 }
139
140 Ok(())
141 }
142
143 pub fn search(&self, query_embedding: &[f32], top_k: usize) -> Result<Vec<(String, f32)>> {
145 let _index = self
146 .index
147 .as_ref()
148 .ok_or_else(|| GraphRAGError::VectorSearch {
149 message: "Index not built. Call build_index() first.".to_string(),
150 })?;
151
152 #[cfg(feature = "vector-hnsw")]
153 {
154 let query_point = Vector::new(query_embedding.to_vec());
155 let mut search = Search::default();
156
157 let results = _index.search(&query_point, &mut search);
158
159 let mut scored_results = Vec::new();
160 for item in results.into_iter().take(top_k) {
161 let distance = item.distance;
162 let similarity = (-distance).exp().clamp(0.0, 1.0);
164 scored_results.push((item.value.clone(), similarity));
165 }
166
167 Ok(scored_results)
168 }
169
170 #[cfg(not(feature = "vector-hnsw"))]
171 {
172 let query_vec = query_embedding;
174 let mut scored_results = Vec::new();
175
176 for (id, embedding) in &self.embeddings {
177 let similarity = self.cosine_similarity(query_vec, embedding);
178 scored_results.push((id.clone(), similarity));
179 }
180
181 scored_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
183 scored_results.truncate(top_k);
184
185 Ok(scored_results)
186 }
187 }
188
189 #[cfg(not(feature = "vector-hnsw"))]
191 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
192 if a.len() != b.len() {
193 return 0.0;
194 }
195
196 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
197 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
198 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
199
200 if norm_a == 0.0 || norm_b == 0.0 {
201 0.0
202 } else {
203 dot_product / (norm_a * norm_b)
204 }
205 }
206
207 pub fn len(&self) -> usize {
209 self.embeddings.len()
210 }
211
212 pub fn is_empty(&self) -> bool {
214 self.embeddings.is_empty()
215 }
216
217 pub fn dimension(&self) -> Option<usize> {
219 self.embeddings.values().next().map(|v| v.len())
220 }
221
222 pub fn remove_vector(&mut self, id: &str) -> Result<()> {
224 self.embeddings.remove(id);
225 if !self.embeddings.is_empty() {
227 self.build_index()?;
228 } else {
229 self.index = None;
230 }
231 Ok(())
232 }
233
234 pub fn get_ids(&self) -> Vec<String> {
236 self.embeddings.keys().cloned().collect()
237 }
238
239 pub fn contains(&self, id: &str) -> bool {
241 self.embeddings.contains_key(id)
242 }
243
244 pub fn get_embedding(&self, id: &str) -> Option<&Vec<f32>> {
246 self.embeddings.get(id)
247 }
248
249 pub fn batch_add_vectors(&mut self, vectors: Vec<(String, Vec<f32>)>) -> Result<()> {
251 #[cfg(feature = "parallel-processing")]
252 if let Some(processor) = self.parallel_processor.clone() {
253 return self.batch_add_vectors_parallel(vectors, &processor);
254 }
255
256 for (id, embedding) in vectors {
258 self.add_vector(id, embedding)?;
259 }
260 Ok(())
261 }
262
263 #[cfg(feature = "parallel-processing")]
265 fn batch_add_vectors_parallel(
266 &mut self,
267 vectors: Vec<(String, Vec<f32>)>,
268 processor: &ParallelProcessor,
269 ) -> Result<()> {
270 if !processor.should_use_parallel(vectors.len()) {
271 for (id, embedding) in vectors {
273 self.add_vector(id, embedding)?;
274 }
275 return Ok(());
276 }
277
278 #[cfg(feature = "parallel-processing")]
279 {
280 use rayon::prelude::*;
281 use std::collections::HashMap;
282
283 let validation_results: std::result::Result<Vec<_>, crate::GraphRAGError> = vectors
285 .par_iter()
286 .map(|(id, embedding)| {
287 if embedding.is_empty() {
288 Err(crate::GraphRAGError::VectorSearch {
289 message: format!("Empty embedding vector for ID: {id}"),
290 })
291 } else {
292 Ok((id.clone(), embedding.clone()))
293 }
294 })
295 .collect();
296
297 let validated_vectors = match validation_results {
298 Ok(vectors) => vectors,
299 Err(e) => {
300 eprintln!("Vector validation failed: {e}");
301 for (id, embedding) in vectors {
303 self.add_vector(id, embedding)?;
304 }
305 return Ok(());
306 },
307 };
308
309 let mut unique_vectors = HashMap::new();
311 for (id, embedding) in validated_vectors {
312 if unique_vectors.contains_key(&id) {
313 eprintln!("Warning: Duplicate vector ID '{id}' - using latest");
314 }
315 unique_vectors.insert(id, embedding);
316 }
317
318 let vector_pairs: Vec<_> = unique_vectors.into_iter().collect();
320
321 for (id, embedding) in vector_pairs {
325 self.embeddings.insert(id, embedding);
326 }
327
328 println!("Added {} vectors in parallel batch", vectors.len());
329 }
330
331 #[cfg(not(feature = "parallel-processing"))]
332 {
333 for (id, embedding) in vectors {
335 self.add_vector(id, embedding)?;
336 }
337 }
338
339 Ok(())
340 }
341
342 pub fn batch_search(
344 &self,
345 queries: &[Vec<f32>],
346 top_k: usize,
347 ) -> Result<Vec<Vec<(String, f32)>>> {
348 #[cfg(feature = "parallel-processing")]
349 {
350 if let Some(processor) = &self.parallel_processor {
351 if processor.should_use_parallel(queries.len()) {
352 use rayon::prelude::*;
353 return queries
354 .par_iter()
355 .map(|query| self.search(query, top_k))
356 .collect();
357 }
358 }
359 }
360
361 queries
363 .iter()
364 .map(|query| self.search(query, top_k))
365 .collect()
366 }
367
368 pub fn compute_all_similarities(&self) -> HashMap<(String, String), f32> {
370 #[cfg(feature = "parallel-processing")]
371 if let Some(processor) = &self.parallel_processor {
372 return self.compute_similarities_parallel(processor);
373 }
374
375 self.compute_similarities_sequential()
377 }
378
379 #[cfg(feature = "parallel-processing")]
381 fn compute_similarities_parallel(
382 &self,
383 processor: &ParallelProcessor,
384 ) -> HashMap<(String, String), f32> {
385 let ids: Vec<String> = self.embeddings.keys().cloned().collect();
386 let total_pairs = (ids.len() * (ids.len() - 1)) / 2;
387
388 if !processor.should_use_parallel(total_pairs) {
389 return self.compute_similarities_sequential();
390 }
391
392 #[cfg(feature = "parallel-processing")]
393 {
394 use rayon::prelude::*;
395
396 let embedding_vec: Vec<(String, Vec<f32>)> = ids
398 .iter()
399 .filter_map(|id| self.embeddings.get(id).map(|emb| (id.clone(), emb.clone())))
400 .collect();
401
402 if embedding_vec.len() < 2 {
403 return HashMap::new();
404 }
405
406 let mut pairs = Vec::new();
408 for i in 0..embedding_vec.len() {
409 for j in (i + 1)..embedding_vec.len() {
410 pairs.push((i, j));
411 }
412 }
413
414 let chunk_size = processor.config().chunk_batch_size.min(pairs.len());
416 let similarities: HashMap<(String, String), f32> = pairs
417 .par_chunks(chunk_size)
418 .map(|chunk| {
419 let mut local_similarities = HashMap::new();
420
421 for &(i, j) in chunk {
422 let (first_id, first_emb) = &embedding_vec[i];
423 let (second_id, second_emb) = &embedding_vec[j];
424
425 let similarity = VectorUtils::cosine_similarity(first_emb, second_emb);
426
427 if similarity > 0.1 {
429 local_similarities
430 .insert((first_id.clone(), second_id.clone()), similarity);
431 }
432 }
433
434 local_similarities
435 })
436 .reduce(HashMap::new, |mut acc, chunk_similarities| {
437 acc.extend(chunk_similarities);
438 acc
439 });
440
441 println!(
442 "Computed {} similarities from {} vectors in parallel",
443 similarities.len(),
444 embedding_vec.len()
445 );
446
447 similarities
448 }
449
450 #[cfg(not(feature = "parallel-processing"))]
451 {
452 self.compute_similarities_sequential()
453 }
454 }
455
456 fn compute_similarities_sequential(&self) -> HashMap<(String, String), f32> {
458 let ids: Vec<String> = self.embeddings.keys().cloned().collect();
459 let mut similarities = HashMap::new();
460
461 for (i, id1) in ids.iter().enumerate() {
462 if let Some(emb1) = self.embeddings.get(id1) {
463 for id2 in ids.iter().skip(i + 1) {
464 if let Some(emb2) = self.embeddings.get(id2) {
465 let sim = VectorUtils::cosine_similarity(emb1, emb2);
466 if sim > 0.1 {
468 similarities.insert((id1.clone(), id2.clone()), sim);
469 }
470 }
471 }
472 }
473 }
474
475 similarities
476 }
477
478 pub fn find_similar(
480 &self,
481 query_embedding: &[f32],
482 threshold: f32,
483 ) -> Result<Vec<(String, f32)>> {
484 let results = self.search(query_embedding, self.len())?;
485 Ok(results
486 .into_iter()
487 .filter(|(_, similarity)| *similarity >= threshold)
488 .collect())
489 }
490
491 pub fn statistics(&self) -> VectorIndexStatistics {
493 let dimension = self.dimension().unwrap_or(0);
494 let vector_count = self.len();
495
496 let mut min_norm = f32::INFINITY;
498 let mut max_norm: f32 = 0.0;
499 let mut sum_norm = 0.0;
500
501 for embedding in self.embeddings.values() {
502 let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
503 min_norm = min_norm.min(norm);
504 max_norm = max_norm.max(norm);
505 sum_norm += norm;
506 }
507
508 let avg_norm = if vector_count > 0 {
509 sum_norm / vector_count as f32
510 } else {
511 0.0
512 };
513
514 VectorIndexStatistics {
515 vector_count,
516 dimension,
517 min_norm,
518 max_norm,
519 avg_norm,
520 index_built: self.index.is_some(),
521 }
522 }
523}
524
525impl Default for VectorIndex {
526 fn default() -> Self {
527 Self::new()
528 }
529}
530
531#[derive(Debug)]
533pub struct VectorIndexStatistics {
534 pub vector_count: usize,
536 pub dimension: usize,
538 pub min_norm: f32,
540 pub max_norm: f32,
542 pub avg_norm: f32,
544 pub index_built: bool,
546}
547
548impl VectorIndexStatistics {
549 pub fn print(&self) {
551 println!("Vector Index Statistics:");
552 println!(" Vector count: {}", self.vector_count);
553 println!(" Dimension: {}", self.dimension);
554 println!(" Index built: {}", self.index_built);
555 if self.vector_count > 0 {
556 println!(" Vector norms:");
557 println!(" Min: {:.4}", self.min_norm);
558 println!(" Max: {:.4}", self.max_norm);
559 println!(" Average: {:.4}", self.avg_norm);
560 }
561 }
562}
563
564pub struct VectorUtils;
566
567pub struct EmbeddingGenerator {
569 dimension: usize,
570 word_vectors: HashMap<String, Vec<f32>>,
571}
572
573impl EmbeddingGenerator {
574 pub fn new(dimension: usize) -> Self {
576 Self {
577 dimension,
578 word_vectors: HashMap::new(),
579 }
580 }
581
582 #[cfg(feature = "parallel-processing")]
584 pub fn with_parallel_processing(
585 dimension: usize,
586 _parallel_processor: ParallelProcessor,
587 ) -> Self {
588 Self {
589 dimension,
590 word_vectors: HashMap::new(),
591 }
592 }
593
594 pub fn generate_embedding(&mut self, text: &str) -> Vec<f32> {
596 let words: Vec<&str> = text.split_whitespace().collect();
597 if words.is_empty() {
598 return vec![0.0; self.dimension];
599 }
600
601 let mut word_embeddings = Vec::new();
603 for word in &words {
604 let normalized_word = word.to_lowercase();
605 if !self.word_vectors.contains_key(&normalized_word) {
606 self.word_vectors.insert(
607 normalized_word.clone(),
608 self.generate_word_vector(&normalized_word),
609 );
610 }
611 word_embeddings.push(self.word_vectors[&normalized_word].clone());
612 }
613
614 let mut result = vec![0.0; self.dimension];
616 for word_vec in word_embeddings {
617 for (i, value) in word_vec.iter().enumerate() {
618 result[i] += value;
619 }
620 }
621
622 let word_count = words.len() as f32;
624 for value in &mut result {
625 *value /= word_count;
626 }
627
628 VectorUtils::normalize(&mut result);
630 result
631 }
632
633 fn generate_word_vector(&self, word: &str) -> Vec<f32> {
635 let mut vector = Vec::with_capacity(self.dimension);
636
637 for i in 0..self.dimension {
639 let mut hasher = DefaultHasher::new();
640 word.hash(&mut hasher);
641 i.hash(&mut hasher);
642
643 let hash = hasher.finish();
644 let value = ((hash % 2000) as f32 - 1000.0) / 1000.0;
646 vector.push(value);
647 }
648
649 VectorUtils::normalize(&mut vector);
651 vector
652 }
653
654 pub fn batch_generate(&mut self, texts: &[&str]) -> Vec<Vec<f32>> {
656 let mut results = Vec::with_capacity(texts.len());
658 for text in texts {
659 results.push(self.generate_embedding(text));
660 }
661 results
662 }
663
664 pub fn batch_generate_chunked(&mut self, texts: &[&str], chunk_size: usize) -> Vec<Vec<f32>> {
666 if texts.len() <= chunk_size {
667 return self.batch_generate(texts);
668 }
669
670 #[cfg(feature = "parallel-processing")]
671 {
672 use rayon::prelude::*;
673
674 let results: Vec<Vec<f32>> = texts
676 .par_chunks(chunk_size)
677 .map(|chunk| {
678 let mut local_generator = EmbeddingGenerator::new(self.dimension);
680 local_generator.word_vectors = self.word_vectors.clone(); chunk
683 .iter()
684 .map(|&text| local_generator.generate_embedding(text))
685 .collect::<Vec<_>>()
686 })
687 .flatten()
688 .collect();
689
690 println!(
695 "Generated {} embeddings in parallel chunks of size {}",
696 texts.len(),
697 chunk_size
698 );
699
700 results
701 }
702
703 #[cfg(not(feature = "parallel-processing"))]
704 {
705 let mut results = Vec::with_capacity(texts.len());
707
708 for chunk in texts.chunks(chunk_size) {
709 for &text in chunk {
710 results.push(self.generate_embedding(text));
711 }
712 }
713
714 results
715 }
716 }
717
718 pub fn dimension(&self) -> usize {
720 self.dimension
721 }
722
723 pub fn cached_words(&self) -> usize {
725 self.word_vectors.len()
726 }
727
728 pub fn clear_cache(&mut self) {
730 self.word_vectors.clear();
731 }
732}
733
734impl Default for EmbeddingGenerator {
735 fn default() -> Self {
736 Self::new(128) }
738}
739
740impl VectorUtils {
741 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
743 if a.len() != b.len() {
744 return 0.0;
745 }
746
747 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
748 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
749 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
750
751 if norm_a == 0.0 || norm_b == 0.0 {
752 0.0
753 } else {
754 dot_product / (norm_a * norm_b)
755 }
756 }
757
758 pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
760 if a.len() != b.len() {
761 return f32::INFINITY;
762 }
763
764 a.iter()
765 .zip(b.iter())
766 .map(|(x, y)| (x - y).powi(2))
767 .sum::<f32>()
768 .sqrt()
769 }
770
771 pub fn normalize(vector: &mut [f32]) {
773 let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
774 if norm > 0.0 {
775 for x in vector {
776 *x /= norm;
777 }
778 }
779 }
780
781 pub fn random_vector(dimension: usize) -> Vec<f32> {
783 use std::collections::hash_map::DefaultHasher;
784 use std::hash::{Hash, Hasher};
785
786 let mut vector = Vec::with_capacity(dimension);
787 let mut hasher = DefaultHasher::new();
788
789 for i in 0..dimension {
790 i.hash(&mut hasher);
791 let hash = hasher.finish();
792 let value = ((hash % 1000) as f32 - 500.0) / 1000.0; vector.push(value);
794 }
795
796 vector
797 }
798
799 pub fn centroid(vectors: &[Vec<f32>]) -> Option<Vec<f32>> {
801 if vectors.is_empty() {
802 return None;
803 }
804
805 let dimension = vectors[0].len();
806 if !vectors.iter().all(|v| v.len() == dimension) {
807 return None; }
809
810 let mut centroid = vec![0.0; dimension];
811 for vector in vectors {
812 for (i, &value) in vector.iter().enumerate() {
813 centroid[i] += value;
814 }
815 }
816
817 let count = vectors.len() as f32;
818 for value in &mut centroid {
819 *value /= count;
820 }
821
822 Some(centroid)
823 }
824}
825
826#[cfg(test)]
827mod tests {
828 use super::*;
829
830 #[test]
831 fn test_vector_index_creation() {
832 let mut index = VectorIndex::new();
833 assert!(index.is_empty());
834
835 let embedding = vec![0.1, 0.2, 0.3];
836 index.add_vector("test".to_string(), embedding).unwrap();
837
838 assert!(!index.is_empty());
839 assert_eq!(index.len(), 1);
840 assert_eq!(index.dimension(), Some(3));
841 }
842
843 #[test]
844 fn test_vector_search() {
845 let mut index = VectorIndex::new();
846
847 index
849 .add_vector("doc1".to_string(), vec![1.0, 0.0, 0.0])
850 .unwrap();
851 index
852 .add_vector("doc2".to_string(), vec![0.0, 1.0, 0.0])
853 .unwrap();
854 index
855 .add_vector("doc3".to_string(), vec![0.8, 0.2, 0.0])
856 .unwrap();
857
858 index.build_index().unwrap();
859
860 let query = vec![1.0, 0.0, 0.0];
862 let results = index.search(&query, 2).unwrap();
863
864 assert!(!results.is_empty());
865 assert!(results.len() <= 2);
866
867 assert_eq!(results[0].0, "doc1");
869 }
870
871 #[test]
872 fn test_cosine_similarity() {
873 let vec1 = vec![1.0, 0.0, 0.0];
874 let vec2 = vec![1.0, 0.0, 0.0];
875 let vec3 = vec![0.0, 1.0, 0.0];
876
877 assert!((VectorUtils::cosine_similarity(&vec1, &vec2) - 1.0).abs() < 0.001);
878 assert!((VectorUtils::cosine_similarity(&vec1, &vec3) - 0.0).abs() < 0.001);
879 }
880
881 #[test]
882 fn test_vector_normalization() {
883 let mut vector = vec![3.0, 4.0];
884 VectorUtils::normalize(&mut vector);
885
886 let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
887 assert!((norm - 1.0).abs() < 0.001);
888 }
889
890 #[test]
891 fn test_centroid_calculation() {
892 let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
893
894 let centroid = VectorUtils::centroid(&vectors).unwrap();
895 assert!((centroid[0] - 2.0 / 3.0).abs() < 0.001);
896 assert!((centroid[1] - 2.0 / 3.0).abs() < 0.001);
897 }
898
899 #[test]
900 fn test_embedding_generator() {
901 let mut generator = EmbeddingGenerator::new(64);
902
903 let text1 = "hello world";
904 let text2 = "hello world";
905 let text3 = "goodbye world";
906
907 let embedding1 = generator.generate_embedding(text1);
908 let embedding2 = generator.generate_embedding(text2);
909 let embedding3 = generator.generate_embedding(text3);
910
911 assert_eq!(embedding1, embedding2);
913
914 assert_ne!(embedding1, embedding3);
916
917 assert_eq!(embedding1.len(), 64);
919
920 let norm1 = embedding1.iter().map(|x| x * x).sum::<f32>().sqrt();
922 assert!((norm1 - 1.0).abs() < 0.001);
923 }
924
925 #[test]
926 fn test_batch_embedding_generation() {
927 let mut generator = EmbeddingGenerator::new(32);
928
929 let texts = vec!["first text", "second text", "third text"];
930 let embeddings = generator.batch_generate(&texts);
931
932 assert_eq!(embeddings.len(), 3);
933 assert!(embeddings.iter().all(|e| e.len() == 32));
934
935 assert_ne!(embeddings[0], embeddings[1]);
937 assert_ne!(embeddings[1], embeddings[2]);
938 }
939
940 #[test]
941 fn test_embedding_similarity() {
942 let mut generator = EmbeddingGenerator::new(64);
943
944 let similar1 = generator.generate_embedding("machine learning artificial intelligence");
945 let similar2 = generator.generate_embedding("artificial intelligence machine learning");
946 let different = generator.generate_embedding("cooking recipes kitchen");
947
948 let sim1 = VectorUtils::cosine_similarity(&similar1, &similar2);
949 let sim2 = VectorUtils::cosine_similarity(&similar1, &different);
950
951 assert!(sim1 > sim2);
953 }
954}