1use crate::{GraphRAGError, Result};
2#[cfg(feature = "parallel-processing")]
3use crate::parallel::ParallelProcessor;
4use std::collections::hash_map::DefaultHasher;
5use std::collections::HashMap;
6use std::hash::{Hash, Hasher};
7
8#[cfg(feature = "vector-hnsw")]
9use instant_distance::{Builder, Point, Search};
10
11#[derive(Debug, Clone, PartialEq)]
21pub struct Vector(Vec<f32>);
22
23impl Vector {
24 pub fn new(vector_data: Vec<f32>) -> Self {
26 Self(vector_data)
27 }
28
29 pub fn as_slice(&self) -> &[f32] {
31 &self.0
32 }
33}
34
35#[cfg(feature = "vector-hnsw")]
36impl Point for Vector {
37 fn distance(&self, other: &Self) -> f32 {
38 if self.0.len() != other.0.len() {
40 return f32::INFINITY;
41 }
42
43 self.0
44 .iter()
45 .zip(other.0.iter())
46 .map(|(a, b)| (a - b).powi(2))
47 .sum::<f32>()
48 .sqrt()
49 }
50}
51
52pub struct VectorIndex {
54 #[cfg(feature = "vector-hnsw")]
55 index: Option<instant_distance::HnswMap<Vector, String>>,
56 #[cfg(not(feature = "vector-hnsw"))]
57 index: Option<()>, embeddings: HashMap<String, Vec<f32>>,
59 #[cfg(feature = "parallel-processing")]
60 parallel_processor: Option<ParallelProcessor>,
61}
62
63impl VectorIndex {
64 pub fn new() -> Self {
66 Self {
67 index: None,
68 embeddings: HashMap::new(),
69 #[cfg(feature = "parallel-processing")]
70 parallel_processor: None,
71 }
72 }
73
74 #[cfg(feature = "parallel-processing")]
76 pub fn with_parallel_processing(parallel_processor: ParallelProcessor) -> Self {
77 Self {
78 index: None,
79 embeddings: HashMap::new(),
80 parallel_processor: Some(parallel_processor),
81 }
82 }
83
84 pub fn add_vector(&mut self, id: String, embedding: Vec<f32>) -> Result<()> {
86 if embedding.is_empty() {
87 return Err(GraphRAGError::VectorSearch {
88 message: "Empty embedding vector".to_string(),
89 });
90 }
91
92 self.embeddings.insert(id, embedding);
93 Ok(())
94 }
95
96 pub fn build_index(&mut self) -> Result<()> {
98 if self.embeddings.is_empty() {
99 return Err(GraphRAGError::VectorSearch {
100 message: "No embeddings to build index from".to_string(),
101 });
102 }
103
104 #[cfg(feature = "vector-hnsw")]
105 {
106 let points: Vec<Vector> = self
107 .embeddings
108 .values()
109 .map(|v| Vector::new(v.clone()))
110 .collect();
111
112 let values: Vec<String> = self.embeddings.keys().cloned().collect();
113
114 let builder = Builder::default();
115 let index = builder.build(points, values);
116
117 self.index = Some(index);
118 }
119
120 #[cfg(not(feature = "vector-hnsw"))]
121 {
122 println!(
123 "Warning: HNSW vector indexing not available. Install with --features vector-hnsw"
124 );
125 self.index = Some(());
126 }
127
128 Ok(())
129 }
130
131 pub fn search(&self, query_embedding: &[f32], top_k: usize) -> Result<Vec<(String, f32)>> {
133 let _index = self
134 .index
135 .as_ref()
136 .ok_or_else(|| GraphRAGError::VectorSearch {
137 message: "Index not built. Call build_index() first.".to_string(),
138 })?;
139
140 #[cfg(feature = "vector-hnsw")]
141 {
142 let query_point = Vector::new(query_embedding.to_vec());
143 let mut search = Search::default();
144
145 let results = _index.search(&query_point, &mut search);
146
147 let mut scored_results = Vec::new();
148 for item in results.into_iter().take(top_k) {
149 let distance = item.distance;
150 let similarity = (-distance).exp().clamp(0.0, 1.0);
152 scored_results.push((item.value.clone(), similarity));
153 }
154
155 Ok(scored_results)
156 }
157
158 #[cfg(not(feature = "vector-hnsw"))]
159 {
160 let query_vec = query_embedding;
162 let mut scored_results = Vec::new();
163
164 for (id, embedding) in &self.embeddings {
165 let similarity = self.cosine_similarity(query_vec, embedding);
166 scored_results.push((id.clone(), similarity));
167 }
168
169 scored_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
171 scored_results.truncate(top_k);
172
173 Ok(scored_results)
174 }
175 }
176
177 #[cfg(not(feature = "vector-hnsw"))]
179 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
180 if a.len() != b.len() {
181 return 0.0;
182 }
183
184 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
185 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
186 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
187
188 if norm_a == 0.0 || norm_b == 0.0 {
189 0.0
190 } else {
191 dot_product / (norm_a * norm_b)
192 }
193 }
194
195 pub fn len(&self) -> usize {
197 self.embeddings.len()
198 }
199
200 pub fn is_empty(&self) -> bool {
202 self.embeddings.is_empty()
203 }
204
205 pub fn dimension(&self) -> Option<usize> {
207 self.embeddings.values().next().map(|v| v.len())
208 }
209
210 pub fn remove_vector(&mut self, id: &str) -> Result<()> {
212 self.embeddings.remove(id);
213 if !self.embeddings.is_empty() {
215 self.build_index()?;
216 } else {
217 self.index = None;
218 }
219 Ok(())
220 }
221
222 pub fn get_ids(&self) -> Vec<String> {
224 self.embeddings.keys().cloned().collect()
225 }
226
227 pub fn contains(&self, id: &str) -> bool {
229 self.embeddings.contains_key(id)
230 }
231
232 pub fn get_embedding(&self, id: &str) -> Option<&Vec<f32>> {
234 self.embeddings.get(id)
235 }
236
237 pub fn batch_add_vectors(&mut self, vectors: Vec<(String, Vec<f32>)>) -> Result<()> {
239 #[cfg(feature = "parallel-processing")]
240 if let Some(processor) = self.parallel_processor.clone() {
241 return self.batch_add_vectors_parallel(vectors, &processor);
242 }
243
244 for (id, embedding) in vectors {
246 self.add_vector(id, embedding)?;
247 }
248 Ok(())
249 }
250
251 #[cfg(feature = "parallel-processing")]
253 fn batch_add_vectors_parallel(
254 &mut self,
255 vectors: Vec<(String, Vec<f32>)>,
256 processor: &ParallelProcessor,
257 ) -> Result<()> {
258 if !processor.should_use_parallel(vectors.len()) {
259 for (id, embedding) in vectors {
261 self.add_vector(id, embedding)?;
262 }
263 return Ok(());
264 }
265
266 #[cfg(feature = "parallel-processing")]
267 {
268 use rayon::prelude::*;
269 use std::collections::HashMap;
270
271 let validation_results: std::result::Result<Vec<_>, crate::GraphRAGError> = vectors
273 .par_iter()
274 .map(|(id, embedding)| {
275 if embedding.is_empty() {
276 Err(crate::GraphRAGError::VectorSearch {
277 message: format!("Empty embedding vector for ID: {id}"),
278 })
279 } else {
280 Ok((id.clone(), embedding.clone()))
281 }
282 })
283 .collect();
284
285 let validated_vectors = match validation_results {
286 Ok(vectors) => vectors,
287 Err(e) => {
288 eprintln!("Vector validation failed: {e}");
289 for (id, embedding) in vectors {
291 self.add_vector(id, embedding)?;
292 }
293 return Ok(());
294 }
295 };
296
297 let mut unique_vectors = HashMap::new();
299 for (id, embedding) in validated_vectors {
300 if unique_vectors.contains_key(&id) {
301 eprintln!("Warning: Duplicate vector ID '{id}' - using latest");
302 }
303 unique_vectors.insert(id, embedding);
304 }
305
306 let vector_pairs: Vec<_> = unique_vectors.into_iter().collect();
308
309 for (id, embedding) in vector_pairs {
313 self.embeddings.insert(id, embedding);
314 }
315
316 println!("Added {} vectors in parallel batch", vectors.len());
317 }
318
319 #[cfg(not(feature = "parallel-processing"))]
320 {
321 for (id, embedding) in vectors {
323 self.add_vector(id, embedding)?;
324 }
325 }
326
327 Ok(())
328 }
329
330 pub fn batch_search(
332 &self,
333 queries: &[Vec<f32>],
334 top_k: usize,
335 ) -> Result<Vec<Vec<(String, f32)>>> {
336 #[cfg(feature = "parallel-processing")]
337 {
338 if let Some(processor) = &self.parallel_processor {
339 if processor.should_use_parallel(queries.len()) {
340 use rayon::prelude::*;
341 return queries
342 .par_iter()
343 .map(|query| self.search(query, top_k))
344 .collect();
345 }
346 }
347 }
348
349 queries
351 .iter()
352 .map(|query| self.search(query, top_k))
353 .collect()
354 }
355
356 pub fn compute_all_similarities(&self) -> HashMap<(String, String), f32> {
358 #[cfg(feature = "parallel-processing")]
359 if let Some(processor) = &self.parallel_processor {
360 return self.compute_similarities_parallel(processor);
361 }
362
363 self.compute_similarities_sequential()
365 }
366
367 #[cfg(feature = "parallel-processing")]
369 fn compute_similarities_parallel(
370 &self,
371 processor: &ParallelProcessor,
372 ) -> HashMap<(String, String), f32> {
373 let ids: Vec<String> = self.embeddings.keys().cloned().collect();
374 let total_pairs = (ids.len() * (ids.len() - 1)) / 2;
375
376 if !processor.should_use_parallel(total_pairs) {
377 return self.compute_similarities_sequential();
378 }
379
380 #[cfg(feature = "parallel-processing")]
381 {
382 use rayon::prelude::*;
383
384 let embedding_vec: Vec<(String, Vec<f32>)> = ids
386 .iter()
387 .filter_map(|id| {
388 self.embeddings.get(id).map(|emb| (id.clone(), emb.clone()))
389 })
390 .collect();
391
392 if embedding_vec.len() < 2 {
393 return HashMap::new();
394 }
395
396 let mut pairs = Vec::new();
398 for i in 0..embedding_vec.len() {
399 for j in (i + 1)..embedding_vec.len() {
400 pairs.push((i, j));
401 }
402 }
403
404 let chunk_size = processor.config().chunk_batch_size.min(pairs.len());
406 let similarities: HashMap<(String, String), f32> = pairs
407 .par_chunks(chunk_size)
408 .map(|chunk| {
409 let mut local_similarities = HashMap::new();
410
411 for &(i, j) in chunk {
412 let (first_id, first_emb) = &embedding_vec[i];
413 let (second_id, second_emb) = &embedding_vec[j];
414
415 let similarity = VectorUtils::cosine_similarity(first_emb, second_emb);
416
417 if similarity > 0.1 {
419 local_similarities.insert((first_id.clone(), second_id.clone()), similarity);
420 }
421 }
422
423 local_similarities
424 })
425 .reduce(
426 HashMap::new,
427 |mut acc, chunk_similarities| {
428 acc.extend(chunk_similarities);
429 acc
430 }
431 );
432
433 println!(
434 "Computed {} similarities from {} vectors in parallel",
435 similarities.len(),
436 embedding_vec.len()
437 );
438
439 similarities
440 }
441
442 #[cfg(not(feature = "parallel-processing"))]
443 {
444 self.compute_similarities_sequential()
445 }
446 }
447
448 fn compute_similarities_sequential(&self) -> HashMap<(String, String), f32> {
450 let ids: Vec<String> = self.embeddings.keys().cloned().collect();
451 let mut similarities = HashMap::new();
452
453 for (i, id1) in ids.iter().enumerate() {
454 if let Some(emb1) = self.embeddings.get(id1) {
455 for id2 in ids.iter().skip(i + 1) {
456 if let Some(emb2) = self.embeddings.get(id2) {
457 let sim = VectorUtils::cosine_similarity(emb1, emb2);
458 if sim > 0.1 {
460 similarities.insert((id1.clone(), id2.clone()), sim);
461 }
462 }
463 }
464 }
465 }
466
467 similarities
468 }
469
470 pub fn find_similar(
472 &self,
473 query_embedding: &[f32],
474 threshold: f32,
475 ) -> Result<Vec<(String, f32)>> {
476 let results = self.search(query_embedding, self.len())?;
477 Ok(results
478 .into_iter()
479 .filter(|(_, similarity)| *similarity >= threshold)
480 .collect())
481 }
482
483 pub fn statistics(&self) -> VectorIndexStatistics {
485 let dimension = self.dimension().unwrap_or(0);
486 let vector_count = self.len();
487
488 let mut min_norm = f32::INFINITY;
490 let mut max_norm: f32 = 0.0;
491 let mut sum_norm = 0.0;
492
493 for embedding in self.embeddings.values() {
494 let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
495 min_norm = min_norm.min(norm);
496 max_norm = max_norm.max(norm);
497 sum_norm += norm;
498 }
499
500 let avg_norm = if vector_count > 0 {
501 sum_norm / vector_count as f32
502 } else {
503 0.0
504 };
505
506 VectorIndexStatistics {
507 vector_count,
508 dimension,
509 min_norm,
510 max_norm,
511 avg_norm,
512 index_built: self.index.is_some(),
513 }
514 }
515}
516
517impl Default for VectorIndex {
518 fn default() -> Self {
519 Self::new()
520 }
521}
522
523#[derive(Debug)]
525pub struct VectorIndexStatistics {
526 pub vector_count: usize,
528 pub dimension: usize,
530 pub min_norm: f32,
532 pub max_norm: f32,
534 pub avg_norm: f32,
536 pub index_built: bool,
538}
539
540impl VectorIndexStatistics {
541 pub fn print(&self) {
543 println!("Vector Index Statistics:");
544 println!(" Vector count: {}", self.vector_count);
545 println!(" Dimension: {}", self.dimension);
546 println!(" Index built: {}", self.index_built);
547 if self.vector_count > 0 {
548 println!(" Vector norms:");
549 println!(" Min: {:.4}", self.min_norm);
550 println!(" Max: {:.4}", self.max_norm);
551 println!(" Average: {:.4}", self.avg_norm);
552 }
553 }
554}
555
556pub struct VectorUtils;
558
559pub struct EmbeddingGenerator {
561 dimension: usize,
562 word_vectors: HashMap<String, Vec<f32>>,
563}
564
565impl EmbeddingGenerator {
566 pub fn new(dimension: usize) -> Self {
568 Self {
569 dimension,
570 word_vectors: HashMap::new(),
571 }
572 }
573
574 #[cfg(feature = "parallel-processing")]
576 pub fn with_parallel_processing(
577 dimension: usize,
578 _parallel_processor: ParallelProcessor,
579 ) -> Self {
580 Self {
581 dimension,
582 word_vectors: HashMap::new(),
583 }
584 }
585
586 pub fn generate_embedding(&mut self, text: &str) -> Vec<f32> {
588 let words: Vec<&str> = text.split_whitespace().collect();
589 if words.is_empty() {
590 return vec![0.0; self.dimension];
591 }
592
593 let mut word_embeddings = Vec::new();
595 for word in &words {
596 let normalized_word = word.to_lowercase();
597 if !self.word_vectors.contains_key(&normalized_word) {
598 self.word_vectors.insert(
599 normalized_word.clone(),
600 self.generate_word_vector(&normalized_word),
601 );
602 }
603 word_embeddings.push(self.word_vectors[&normalized_word].clone());
604 }
605
606 let mut result = vec![0.0; self.dimension];
608 for word_vec in word_embeddings {
609 for (i, value) in word_vec.iter().enumerate() {
610 result[i] += value;
611 }
612 }
613
614 let word_count = words.len() as f32;
616 for value in &mut result {
617 *value /= word_count;
618 }
619
620 VectorUtils::normalize(&mut result);
622 result
623 }
624
625 fn generate_word_vector(&self, word: &str) -> Vec<f32> {
627 let mut vector = Vec::with_capacity(self.dimension);
628
629 for i in 0..self.dimension {
631 let mut hasher = DefaultHasher::new();
632 word.hash(&mut hasher);
633 i.hash(&mut hasher);
634
635 let hash = hasher.finish();
636 let value = ((hash % 2000) as f32 - 1000.0) / 1000.0;
638 vector.push(value);
639 }
640
641 VectorUtils::normalize(&mut vector);
643 vector
644 }
645
646 pub fn batch_generate(&mut self, texts: &[&str]) -> Vec<Vec<f32>> {
648 let mut results = Vec::with_capacity(texts.len());
650 for text in texts {
651 results.push(self.generate_embedding(text));
652 }
653 results
654 }
655
656 pub fn batch_generate_chunked(&mut self, texts: &[&str], chunk_size: usize) -> Vec<Vec<f32>> {
658 if texts.len() <= chunk_size {
659 return self.batch_generate(texts);
660 }
661
662 #[cfg(feature = "parallel-processing")]
663 {
664 use rayon::prelude::*;
665
666 let results: Vec<Vec<f32>> = texts
668 .par_chunks(chunk_size)
669 .map(|chunk| {
670 let mut local_generator = EmbeddingGenerator::new(self.dimension);
672 local_generator.word_vectors = self.word_vectors.clone(); chunk.iter().map(|&text| {
675 local_generator.generate_embedding(text)
676 }).collect::<Vec<_>>()
677 })
678 .flatten()
679 .collect();
680
681 println!(
686 "Generated {} embeddings in parallel chunks of size {}",
687 texts.len(),
688 chunk_size
689 );
690
691 results
692 }
693
694 #[cfg(not(feature = "parallel-processing"))]
695 {
696 let mut results = Vec::with_capacity(texts.len());
698
699 for chunk in texts.chunks(chunk_size) {
700 for &text in chunk {
701 results.push(self.generate_embedding(text));
702 }
703 }
704
705 results
706 }
707 }
708
709 pub fn dimension(&self) -> usize {
711 self.dimension
712 }
713
714 pub fn cached_words(&self) -> usize {
716 self.word_vectors.len()
717 }
718
719 pub fn clear_cache(&mut self) {
721 self.word_vectors.clear();
722 }
723}
724
725impl Default for EmbeddingGenerator {
726 fn default() -> Self {
727 Self::new(128) }
729}
730
731impl VectorUtils {
732 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
734 if a.len() != b.len() {
735 return 0.0;
736 }
737
738 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
739 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
740 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
741
742 if norm_a == 0.0 || norm_b == 0.0 {
743 0.0
744 } else {
745 dot_product / (norm_a * norm_b)
746 }
747 }
748
749 pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
751 if a.len() != b.len() {
752 return f32::INFINITY;
753 }
754
755 a.iter()
756 .zip(b.iter())
757 .map(|(x, y)| (x - y).powi(2))
758 .sum::<f32>()
759 .sqrt()
760 }
761
762 pub fn normalize(vector: &mut [f32]) {
764 let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
765 if norm > 0.0 {
766 for x in vector {
767 *x /= norm;
768 }
769 }
770 }
771
772 pub fn random_vector(dimension: usize) -> Vec<f32> {
774 use std::collections::hash_map::DefaultHasher;
775 use std::hash::{Hash, Hasher};
776
777 let mut vector = Vec::with_capacity(dimension);
778 let mut hasher = DefaultHasher::new();
779
780 for i in 0..dimension {
781 i.hash(&mut hasher);
782 let hash = hasher.finish();
783 let value = ((hash % 1000) as f32 - 500.0) / 1000.0; vector.push(value);
785 }
786
787 vector
788 }
789
790 pub fn centroid(vectors: &[Vec<f32>]) -> Option<Vec<f32>> {
792 if vectors.is_empty() {
793 return None;
794 }
795
796 let dimension = vectors[0].len();
797 if !vectors.iter().all(|v| v.len() == dimension) {
798 return None; }
800
801 let mut centroid = vec![0.0; dimension];
802 for vector in vectors {
803 for (i, &value) in vector.iter().enumerate() {
804 centroid[i] += value;
805 }
806 }
807
808 let count = vectors.len() as f32;
809 for value in &mut centroid {
810 *value /= count;
811 }
812
813 Some(centroid)
814 }
815}
816
817#[cfg(test)]
818mod tests {
819 use super::*;
820
821 #[test]
822 fn test_vector_index_creation() {
823 let mut index = VectorIndex::new();
824 assert!(index.is_empty());
825
826 let embedding = vec![0.1, 0.2, 0.3];
827 index.add_vector("test".to_string(), embedding).unwrap();
828
829 assert!(!index.is_empty());
830 assert_eq!(index.len(), 1);
831 assert_eq!(index.dimension(), Some(3));
832 }
833
834 #[test]
835 fn test_vector_search() {
836 let mut index = VectorIndex::new();
837
838 index
840 .add_vector("doc1".to_string(), vec![1.0, 0.0, 0.0])
841 .unwrap();
842 index
843 .add_vector("doc2".to_string(), vec![0.0, 1.0, 0.0])
844 .unwrap();
845 index
846 .add_vector("doc3".to_string(), vec![0.8, 0.2, 0.0])
847 .unwrap();
848
849 index.build_index().unwrap();
850
851 let query = vec![1.0, 0.0, 0.0];
853 let results = index.search(&query, 2).unwrap();
854
855 assert!(!results.is_empty());
856 assert!(results.len() <= 2);
857
858 assert_eq!(results[0].0, "doc1");
860 }
861
862 #[test]
863 fn test_cosine_similarity() {
864 let vec1 = vec![1.0, 0.0, 0.0];
865 let vec2 = vec![1.0, 0.0, 0.0];
866 let vec3 = vec![0.0, 1.0, 0.0];
867
868 assert!((VectorUtils::cosine_similarity(&vec1, &vec2) - 1.0).abs() < 0.001);
869 assert!((VectorUtils::cosine_similarity(&vec1, &vec3) - 0.0).abs() < 0.001);
870 }
871
872 #[test]
873 fn test_vector_normalization() {
874 let mut vector = vec![3.0, 4.0];
875 VectorUtils::normalize(&mut vector);
876
877 let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
878 assert!((norm - 1.0).abs() < 0.001);
879 }
880
881 #[test]
882 fn test_centroid_calculation() {
883 let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
884
885 let centroid = VectorUtils::centroid(&vectors).unwrap();
886 assert!((centroid[0] - 2.0 / 3.0).abs() < 0.001);
887 assert!((centroid[1] - 2.0 / 3.0).abs() < 0.001);
888 }
889
890 #[test]
891 fn test_embedding_generator() {
892 let mut generator = EmbeddingGenerator::new(64);
893
894 let text1 = "hello world";
895 let text2 = "hello world";
896 let text3 = "goodbye world";
897
898 let embedding1 = generator.generate_embedding(text1);
899 let embedding2 = generator.generate_embedding(text2);
900 let embedding3 = generator.generate_embedding(text3);
901
902 assert_eq!(embedding1, embedding2);
904
905 assert_ne!(embedding1, embedding3);
907
908 assert_eq!(embedding1.len(), 64);
910
911 let norm1 = embedding1.iter().map(|x| x * x).sum::<f32>().sqrt();
913 assert!((norm1 - 1.0).abs() < 0.001);
914 }
915
916 #[test]
917 fn test_batch_embedding_generation() {
918 let mut generator = EmbeddingGenerator::new(32);
919
920 let texts = vec!["first text", "second text", "third text"];
921 let embeddings = generator.batch_generate(&texts);
922
923 assert_eq!(embeddings.len(), 3);
924 assert!(embeddings.iter().all(|e| e.len() == 32));
925
926 assert_ne!(embeddings[0], embeddings[1]);
928 assert_ne!(embeddings[1], embeddings[2]);
929 }
930
931 #[test]
932 fn test_embedding_similarity() {
933 let mut generator = EmbeddingGenerator::new(64);
934
935 let similar1 = generator.generate_embedding("machine learning artificial intelligence");
936 let similar2 = generator.generate_embedding("artificial intelligence machine learning");
937 let different = generator.generate_embedding("cooking recipes kitchen");
938
939 let sim1 = VectorUtils::cosine_similarity(&similar1, &similar2);
940 let sim2 = VectorUtils::cosine_similarity(&similar1, &different);
941
942 assert!(sim1 > sim2);
944 }
945}