1pub mod contrastive;
135pub mod crosslingual;
136pub mod fasttext;
137pub mod glove;
138pub mod sentence;
139pub mod sentence_encoder;
140pub mod universal;
141
142pub use fasttext::{FastText, FastTextConfig};
144pub use glove::{
145 cosine_similarity as glove_cosine_similarity, CooccurrenceMatrix, GloVe, GloVeTrainer,
146 GloVeTrainerConfig,
147};
148
149use crate::error::{Result, TextError};
150use crate::tokenize::{Tokenizer, WordTokenizer};
151use crate::vocabulary::Vocabulary;
152use scirs2_core::ndarray::{Array1, Array2};
153use scirs2_core::random::prelude::*;
154use std::collections::HashMap;
155use std::fmt::Debug;
156use std::fs::File;
157use std::io::{BufRead, BufReader, Write};
158use std::path::Path;
159
160pub trait WordEmbedding {
168 fn embedding(&self, word: &str) -> Result<Array1<f64>>;
170
171 fn dimension(&self) -> usize;
173
174 fn similarity(&self, word1: &str, word2: &str) -> Result<f64> {
176 let v1 = self.embedding(word1)?;
177 let v2 = self.embedding(word2)?;
178 Ok(embedding_cosine_similarity(&v1, &v2))
179 }
180
181 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
183
184 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
186
187 fn vocab_size(&self) -> usize;
189}
190
191pub fn embedding_cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
193 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
194 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
195 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
196
197 if norm_a > 0.0 && norm_b > 0.0 {
198 dot_product / (norm_a * norm_b)
199 } else {
200 0.0
201 }
202}
203
204pub fn pairwise_similarity(model: &dyn WordEmbedding, words: &[&str]) -> Result<Vec<Vec<f64>>> {
206 let vectors: Vec<Array1<f64>> = words
207 .iter()
208 .map(|&w| model.embedding(w))
209 .collect::<Result<Vec<_>>>()?;
210
211 let n = vectors.len();
212 let mut matrix = vec![vec![0.0; n]; n];
213
214 for i in 0..n {
215 for j in i..n {
216 let sim = embedding_cosine_similarity(&vectors[i], &vectors[j]);
217 matrix[i][j] = sim;
218 matrix[j][i] = sim;
219 }
220 }
221
222 Ok(matrix)
223}
224
225impl WordEmbedding for GloVe {
228 fn embedding(&self, word: &str) -> Result<Array1<f64>> {
229 self.get_word_vector(word)
230 }
231
232 fn dimension(&self) -> usize {
233 self.vector_size()
234 }
235
236 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
237 self.most_similar(word, top_n)
238 }
239
240 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
241 self.analogy(a, b, c, top_n)
242 }
243
244 fn vocab_size(&self) -> usize {
245 self.vocabulary_size()
246 }
247}
248
249impl WordEmbedding for FastText {
250 fn embedding(&self, word: &str) -> Result<Array1<f64>> {
251 self.get_word_vector(word)
252 }
253
254 fn dimension(&self) -> usize {
255 self.vector_size()
256 }
257
258 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
259 self.most_similar(word, top_n)
260 }
261
262 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
263 self.analogy(a, b, c, top_n)
264 }
265
266 fn vocab_size(&self) -> usize {
267 self.vocabulary_size()
268 }
269}
270
271#[derive(Debug, Clone)]
275struct HuffmanNode {
276 id: usize,
278 frequency: usize,
280 left: Option<usize>,
282 right: Option<usize>,
284 is_leaf: bool,
286}
287
288#[derive(Debug, Clone)]
290struct HuffmanTree {
291 codes: Vec<Vec<u8>>,
293 paths: Vec<Vec<usize>>,
295 num_internal: usize,
297}
298
299impl HuffmanTree {
300 fn build(frequencies: &[usize]) -> Result<Self> {
304 let vocab_size = frequencies.len();
305 if vocab_size == 0 {
306 return Err(TextError::EmbeddingError(
307 "Cannot build Huffman tree with empty vocabulary".into(),
308 ));
309 }
310 if vocab_size == 1 {
311 return Ok(Self {
313 codes: vec![vec![0]],
314 paths: vec![vec![0]],
315 num_internal: 1,
316 });
317 }
318
319 let mut nodes: Vec<HuffmanNode> = frequencies
321 .iter()
322 .enumerate()
323 .map(|(id, &freq)| HuffmanNode {
324 id,
325 frequency: freq.max(1), left: None,
327 right: None,
328 is_leaf: true,
329 })
330 .collect();
331
332 let mut queue: Vec<(usize, usize)> = nodes
335 .iter()
336 .enumerate()
337 .map(|(i, n)| (i, n.frequency))
338 .collect();
339 queue.sort_by_key(|item| std::cmp::Reverse(item.1)); while queue.len() > 1 {
343 let (idx1, freq1) = queue
345 .pop()
346 .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
347 let (idx2, freq2) = queue
348 .pop()
349 .ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
350
351 let new_id = nodes.len();
352 let new_node = HuffmanNode {
353 id: new_id,
354 frequency: freq1 + freq2,
355 left: Some(idx1),
356 right: Some(idx2),
357 is_leaf: false,
358 };
359 nodes.push(new_node);
360
361 let new_freq = freq1 + freq2;
363 let insert_pos = queue
364 .binary_search_by(|(_, f)| new_freq.cmp(f))
365 .unwrap_or_else(|pos| pos);
366 queue.insert(insert_pos, (new_id, new_freq));
367 }
368
369 let num_internal = nodes.len() - vocab_size;
371 let mut codes = vec![Vec::new(); vocab_size];
372 let mut paths = vec![Vec::new(); vocab_size];
373
374 let root_idx = nodes.len() - 1;
376 let mut stack: Vec<(usize, Vec<u8>, Vec<usize>)> = vec![(root_idx, Vec::new(), Vec::new())];
377
378 while let Some((node_idx, code, path)) = stack.pop() {
379 let node = &nodes[node_idx];
380
381 if node.is_leaf {
382 codes[node.id] = code;
383 paths[node.id] = path;
384 } else {
385 let internal_idx = node.id - vocab_size;
387
388 if let Some(left_idx) = node.left {
389 let mut left_code = code.clone();
390 left_code.push(0);
391 let mut left_path = path.clone();
392 left_path.push(internal_idx);
393 stack.push((left_idx, left_code, left_path));
394 }
395
396 if let Some(right_idx) = node.right {
397 let mut right_code = code.clone();
398 right_code.push(1);
399 let mut right_path = path.clone();
400 right_path.push(internal_idx);
401 stack.push((right_idx, right_code, right_path));
402 }
403 }
404 }
405
406 Ok(Self {
407 codes,
408 paths,
409 num_internal,
410 })
411 }
412}
413
414#[derive(Debug, Clone)]
416struct SamplingTable {
417 cdf: Vec<f64>,
419 weights: Vec<f64>,
421}
422
423impl SamplingTable {
424 fn new(weights: &[f64]) -> Result<Self> {
426 if weights.is_empty() {
427 return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
428 }
429
430 if weights.iter().any(|&w| w < 0.0) {
432 return Err(TextError::EmbeddingError("Weights must be positive".into()));
433 }
434
435 let sum: f64 = weights.iter().sum();
437 if sum <= 0.0 {
438 return Err(TextError::EmbeddingError(
439 "Sum of _weights must be positive".into(),
440 ));
441 }
442
443 let mut cdf = Vec::with_capacity(weights.len());
444 let mut total = 0.0;
445
446 for &w in weights {
447 total += w;
448 cdf.push(total / sum);
449 }
450
451 Ok(Self {
452 cdf,
453 weights: weights.to_vec(),
454 })
455 }
456
457 fn sample<R: Rng>(&self, rng: &mut R) -> usize {
459 let r = rng.random::<f64>();
460
461 match self.cdf.binary_search_by(|&cdf_val| {
463 cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
464 }) {
465 Ok(idx) => idx,
466 Err(idx) => idx,
467 }
468 }
469
470 fn weights(&self) -> &[f64] {
472 &self.weights
473 }
474}
475
476#[derive(Debug, Clone, Copy, PartialEq, Eq)]
478pub enum Word2VecAlgorithm {
479 CBOW,
481 SkipGram,
483}
484
485#[derive(Debug, Clone)]
487pub struct Word2VecConfig {
488 pub vector_size: usize,
490 pub window_size: usize,
492 pub min_count: usize,
494 pub epochs: usize,
496 pub learning_rate: f64,
498 pub algorithm: Word2VecAlgorithm,
500 pub negative_samples: usize,
502 pub subsample: f64,
504 pub batch_size: usize,
506 pub hierarchical_softmax: bool,
508}
509
510impl Default for Word2VecConfig {
511 fn default() -> Self {
512 Self {
513 vector_size: 100,
514 window_size: 5,
515 min_count: 5,
516 epochs: 5,
517 learning_rate: 0.025,
518 algorithm: Word2VecAlgorithm::SkipGram,
519 negative_samples: 5,
520 subsample: 1e-3,
521 batch_size: 128,
522 hierarchical_softmax: false,
523 }
524 }
525}
526
527pub struct Word2Vec {
537 config: Word2VecConfig,
539 vocabulary: Vocabulary,
541 input_embeddings: Option<Array2<f64>>,
543 output_embeddings: Option<Array2<f64>>,
545 tokenizer: Box<dyn Tokenizer + Send + Sync>,
547 sampling_table: Option<SamplingTable>,
549 huffman_tree: Option<HuffmanTree>,
551 hs_params: Option<Array2<f64>>,
553 current_learning_rate: f64,
555}
556
557impl Debug for Word2Vec {
558 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
559 f.debug_struct("Word2Vec")
560 .field("config", &self.config)
561 .field("vocabulary", &self.vocabulary)
562 .field("input_embeddings", &self.input_embeddings)
563 .field("output_embeddings", &self.output_embeddings)
564 .field("sampling_table", &self.sampling_table)
565 .field("huffman_tree", &self.huffman_tree)
566 .field("current_learning_rate", &self.current_learning_rate)
567 .finish()
568 }
569}
570
571impl Default for Word2Vec {
573 fn default() -> Self {
574 Self::new()
575 }
576}
577
578impl Clone for Word2Vec {
579 fn clone(&self) -> Self {
580 let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
581
582 Self {
583 config: self.config.clone(),
584 vocabulary: self.vocabulary.clone(),
585 input_embeddings: self.input_embeddings.clone(),
586 output_embeddings: self.output_embeddings.clone(),
587 tokenizer,
588 sampling_table: self.sampling_table.clone(),
589 huffman_tree: self.huffman_tree.clone(),
590 hs_params: self.hs_params.clone(),
591 current_learning_rate: self.current_learning_rate,
592 }
593 }
594}
595
596impl Word2Vec {
597 pub fn new() -> Self {
599 Self {
600 config: Word2VecConfig::default(),
601 vocabulary: Vocabulary::new(),
602 input_embeddings: None,
603 output_embeddings: None,
604 tokenizer: Box::new(WordTokenizer::default()),
605 sampling_table: None,
606 huffman_tree: None,
607 hs_params: None,
608 current_learning_rate: 0.025,
609 }
610 }
611
612 pub fn with_config(config: Word2VecConfig) -> Self {
614 let learning_rate = config.learning_rate;
615 Self {
616 config,
617 vocabulary: Vocabulary::new(),
618 input_embeddings: None,
619 output_embeddings: None,
620 tokenizer: Box::new(WordTokenizer::default()),
621 sampling_table: None,
622 huffman_tree: None,
623 hs_params: None,
624 current_learning_rate: learning_rate,
625 }
626 }
627
628 pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
630 self.tokenizer = tokenizer;
631 self
632 }
633
634 pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
636 self.config.vector_size = vectorsize;
637 self
638 }
639
640 pub fn with_window_size(mut self, windowsize: usize) -> Self {
642 self.config.window_size = windowsize;
643 self
644 }
645
646 pub fn with_min_count(mut self, mincount: usize) -> Self {
648 self.config.min_count = mincount;
649 self
650 }
651
652 pub fn with_epochs(mut self, epochs: usize) -> Self {
654 self.config.epochs = epochs;
655 self
656 }
657
658 pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
660 self.config.learning_rate = learningrate;
661 self.current_learning_rate = learningrate;
662 self
663 }
664
665 pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
667 self.config.algorithm = algorithm;
668 self
669 }
670
671 pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
673 self.config.negative_samples = negativesamples;
674 self
675 }
676
677 pub fn with_subsample(mut self, subsample: f64) -> Self {
679 self.config.subsample = subsample;
680 self
681 }
682
683 pub fn with_batch_size(mut self, batchsize: usize) -> Self {
685 self.config.batch_size = batchsize;
686 self
687 }
688
689 pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
691 if texts.is_empty() {
692 return Err(TextError::InvalidInput(
693 "No texts provided for building vocabulary".into(),
694 ));
695 }
696
697 let mut word_counts = HashMap::new();
699 let mut _total_words = 0;
700
701 for &text in texts {
702 let tokens = self.tokenizer.tokenize(text)?;
703 for token in tokens {
704 *word_counts.entry(token).or_insert(0) += 1;
705 _total_words += 1;
706 }
707 }
708
709 self.vocabulary = Vocabulary::new();
711 for (word, count) in &word_counts {
712 if *count >= self.config.min_count {
713 self.vocabulary.add_token(word);
714 }
715 }
716
717 if self.vocabulary.is_empty() {
718 return Err(TextError::VocabularyError(
719 "No words meet the minimum count threshold".into(),
720 ));
721 }
722
723 let vocab_size = self.vocabulary.len();
725 let vector_size = self.config.vector_size;
726
727 let mut rng = scirs2_core::random::rng();
729 let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
730 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
731 });
732 let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
733 (rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
734 });
735
736 self.input_embeddings = Some(input_embeddings);
737 self.output_embeddings = Some(output_embeddings);
738
739 self.create_sampling_table(&word_counts)?;
741
742 if self.config.hierarchical_softmax {
744 let frequencies: Vec<usize> = (0..vocab_size)
745 .map(|i| {
746 self.vocabulary
747 .get_token(i)
748 .and_then(|word| word_counts.get(word).copied())
749 .unwrap_or(1)
750 })
751 .collect();
752
753 let tree = HuffmanTree::build(&frequencies)?;
754 let num_internal = tree.num_internal;
755
756 let hs_params = Array2::zeros((num_internal, vector_size));
758 self.hs_params = Some(hs_params);
759 self.huffman_tree = Some(tree);
760 }
761
762 Ok(())
763 }
764
765 fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
767 let mut sampling_weights = vec![0.0; self.vocabulary.len()];
769
770 for (word, &count) in wordcounts.iter() {
771 if let Some(idx) = self.vocabulary.get_index(word) {
772 sampling_weights[idx] = (count as f64).powf(0.75);
774 }
775 }
776
777 match SamplingTable::new(&sampling_weights) {
778 Ok(table) => {
779 self.sampling_table = Some(table);
780 Ok(())
781 }
782 Err(e) => Err(e),
783 }
784 }
785
786 pub fn train(&mut self, texts: &[&str]) -> Result<()> {
788 if texts.is_empty() {
789 return Err(TextError::InvalidInput(
790 "No texts provided for training".into(),
791 ));
792 }
793
794 if self.vocabulary.is_empty() {
796 self.build_vocabulary(texts)?;
797 }
798
799 if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
800 return Err(TextError::EmbeddingError(
801 "Embeddings not initialized. Call build_vocabulary() first".into(),
802 ));
803 }
804
805 let mut _total_tokens = 0;
807 let mut sentences = Vec::new();
808 for &text in texts {
809 let tokens = self.tokenizer.tokenize(text)?;
810 let filtered_tokens: Vec<usize> = tokens
811 .iter()
812 .filter_map(|token| self.vocabulary.get_index(token))
813 .collect();
814 if !filtered_tokens.is_empty() {
815 _total_tokens += filtered_tokens.len();
816 sentences.push(filtered_tokens);
817 }
818 }
819
820 for epoch in 0..self.config.epochs {
822 self.current_learning_rate =
824 self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
825 self.current_learning_rate = self
826 .current_learning_rate
827 .max(self.config.learning_rate * 0.0001);
828
829 for sentence in &sentences {
831 let subsampled_sentence = if self.config.subsample > 0.0 {
833 self.subsample_sentence(sentence)?
834 } else {
835 sentence.clone()
836 };
837
838 if subsampled_sentence.is_empty() {
840 continue;
841 }
842
843 if self.config.hierarchical_softmax {
845 match self.config.algorithm {
847 Word2VecAlgorithm::SkipGram => {
848 self.train_skipgram_hs_sentence(&subsampled_sentence)?;
849 }
850 Word2VecAlgorithm::CBOW => {
851 self.train_cbow_hs_sentence(&subsampled_sentence)?;
852 }
853 }
854 } else {
855 match self.config.algorithm {
857 Word2VecAlgorithm::CBOW => {
858 self.train_cbow_sentence(&subsampled_sentence)?;
859 }
860 Word2VecAlgorithm::SkipGram => {
861 self.train_skipgram_sentence(&subsampled_sentence)?;
862 }
863 }
864 }
865 }
866 }
867
868 Ok(())
869 }
870
871 fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
873 let mut rng = scirs2_core::random::rng();
874 let total_words: f64 = self.vocabulary.len() as f64;
875 let threshold = self.config.subsample * total_words;
876
877 let subsampled: Vec<usize> = sentence
879 .iter()
880 .filter(|&&word_idx| {
881 let word_freq = self.get_word_frequency(word_idx);
882 if word_freq == 0.0 {
883 return true; }
885 let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
887 rng.random::<f64>() < keep_prob
888 })
889 .copied()
890 .collect();
891
892 Ok(subsampled)
893 }
894
895 fn get_word_frequency(&self, wordidx: usize) -> f64 {
897 if let Some(table) = &self.sampling_table {
900 table.weights()[wordidx]
901 } else {
902 1.0 }
904 }
905
906 fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
908 if sentence.len() < 2 {
909 return Ok(()); }
911
912 let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
913 let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
914 let vector_size = self.config.vector_size;
915 let window_size = self.config.window_size;
916 let negative_samples = self.config.negative_samples;
917
918 for pos in 0..sentence.len() {
920 let mut rng = scirs2_core::random::rng();
922 let window = 1 + rng.random_range(0..window_size);
923 let target_word = sentence[pos];
924
925 let mut context_words = Vec::new();
927 #[allow(clippy::needless_range_loop)]
928 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
929 if i != pos {
930 context_words.push(sentence[i]);
931 }
932 }
933
934 if context_words.is_empty() {
935 continue; }
937
938 let mut context_sum = Array1::zeros(vector_size);
940 for &context_idx in &context_words {
941 context_sum += &input_embeddings.row(context_idx);
942 }
943 let context_avg = &context_sum / context_words.len() as f64;
944
945 let mut target_output = output_embeddings.row_mut(target_word);
947 let dot_product = (&context_avg * &target_output).sum();
948 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
949 let error = (1.0 - sigmoid) * self.current_learning_rate;
950
951 let mut target_update = target_output.to_owned();
953 target_update.scaled_add(error, &context_avg);
954 target_output.assign(&target_update);
955
956 if let Some(sampler) = &self.sampling_table {
958 for _ in 0..negative_samples {
959 let negative_idx = sampler.sample(&mut rng);
960 if negative_idx == target_word {
961 continue; }
963
964 let mut negative_output = output_embeddings.row_mut(negative_idx);
965 let dot_product = (&context_avg * &negative_output).sum();
966 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
967 let error = -sigmoid * self.current_learning_rate;
968
969 let mut negative_update = negative_output.to_owned();
971 negative_update.scaled_add(error, &context_avg);
972 negative_output.assign(&negative_update);
973 }
974 }
975
976 for &context_idx in &context_words {
978 let mut input_vec = input_embeddings.row_mut(context_idx);
979
980 let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
982 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
983 let error =
984 (1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
985
986 let mut input_update = input_vec.to_owned();
988 input_update.scaled_add(error, &output_embeddings.row(target_word));
989
990 if let Some(sampler) = &self.sampling_table {
992 for _ in 0..negative_samples {
993 let negative_idx = sampler.sample(&mut rng);
994 if negative_idx == target_word {
995 continue;
996 }
997
998 let dot_product =
999 (&context_avg * &output_embeddings.row(negative_idx)).sum();
1000 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1001 let error =
1002 -sigmoid * self.current_learning_rate / context_words.len() as f64;
1003
1004 input_update.scaled_add(error, &output_embeddings.row(negative_idx));
1005 }
1006 }
1007
1008 input_vec.assign(&input_update);
1009 }
1010 }
1011
1012 Ok(())
1013 }
1014
1015 fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1017 if sentence.len() < 2 {
1018 return Ok(()); }
1020
1021 let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
1022 let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
1023 let vector_size = self.config.vector_size;
1024 let window_size = self.config.window_size;
1025 let negative_samples = self.config.negative_samples;
1026
1027 for pos in 0..sentence.len() {
1029 let mut rng = scirs2_core::random::rng();
1031 let window = 1 + rng.random_range(0..window_size);
1032 let target_word = sentence[pos];
1033
1034 #[allow(clippy::needless_range_loop)]
1036 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1037 if i == pos {
1038 continue; }
1040
1041 let context_word = sentence[i];
1042 let target_input = input_embeddings.row(target_word);
1043 let mut context_output = output_embeddings.row_mut(context_word);
1044
1045 let dot_product = (&target_input * &context_output).sum();
1047 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1048 let error = (1.0 - sigmoid) * self.current_learning_rate;
1049
1050 let mut context_update = context_output.to_owned();
1052 context_update.scaled_add(error, &target_input);
1053 context_output.assign(&context_update);
1054
1055 let mut input_update = Array1::zeros(vector_size);
1057 input_update.scaled_add(error, &context_output);
1058
1059 if let Some(sampler) = &self.sampling_table {
1061 for _ in 0..negative_samples {
1062 let negative_idx = sampler.sample(&mut rng);
1063 if negative_idx == context_word {
1064 continue; }
1066
1067 let mut negative_output = output_embeddings.row_mut(negative_idx);
1068 let dot_product = (&target_input * &negative_output).sum();
1069 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
1070 let error = -sigmoid * self.current_learning_rate;
1071
1072 let mut negative_update = negative_output.to_owned();
1074 negative_update.scaled_add(error, &target_input);
1075 negative_output.assign(&negative_update);
1076
1077 input_update.scaled_add(error, &negative_output);
1079 }
1080 }
1081
1082 let mut target_input_mut = input_embeddings.row_mut(target_word);
1084 target_input_mut += &input_update;
1085 }
1086 }
1087
1088 Ok(())
1089 }
1090
1091 pub fn vector_size(&self) -> usize {
1093 self.config.vector_size
1094 }
1095
1096 pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
1098 if self.input_embeddings.is_none() {
1099 return Err(TextError::EmbeddingError(
1100 "Model not trained. Call train() first".into(),
1101 ));
1102 }
1103
1104 match self.vocabulary.get_index(word) {
1105 Some(idx) => Ok(self
1106 .input_embeddings
1107 .as_ref()
1108 .expect("Operation failed")
1109 .row(idx)
1110 .to_owned()),
1111 None => Err(TextError::VocabularyError(format!(
1112 "Word '{word}' not in vocabulary"
1113 ))),
1114 }
1115 }
1116
1117 pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1119 let word_vec = self.get_word_vector(word)?;
1120 self.most_similar_by_vector(&word_vec, topn, &[word])
1121 }
1122
1123 pub fn most_similar_by_vector(
1125 &self,
1126 vector: &Array1<f64>,
1127 top_n: usize,
1128 exclude_words: &[&str],
1129 ) -> Result<Vec<(String, f64)>> {
1130 if self.input_embeddings.is_none() {
1131 return Err(TextError::EmbeddingError(
1132 "Model not trained. Call train() first".into(),
1133 ));
1134 }
1135
1136 let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1137 let vocab_size = self.vocabulary.len();
1138
1139 let exclude_indices: Vec<usize> = exclude_words
1141 .iter()
1142 .filter_map(|&word| self.vocabulary.get_index(word))
1143 .collect();
1144
1145 let mut similarities = Vec::with_capacity(vocab_size);
1147
1148 for i in 0..vocab_size {
1149 if exclude_indices.contains(&i) {
1150 continue;
1151 }
1152
1153 let word_vec = input_embeddings.row(i);
1154 let similarity = cosine_similarity(vector, &word_vec.to_owned());
1155
1156 if let Some(word) = self.vocabulary.get_token(i) {
1157 similarities.push((word.to_string(), similarity));
1158 }
1159 }
1160
1161 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1163
1164 let result = similarities.into_iter().take(top_n).collect();
1166 Ok(result)
1167 }
1168
1169 pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
1171 if self.input_embeddings.is_none() {
1172 return Err(TextError::EmbeddingError(
1173 "Model not trained. Call train() first".into(),
1174 ));
1175 }
1176
1177 let a_vec = self.get_word_vector(a)?;
1179 let b_vec = self.get_word_vector(b)?;
1180 let c_vec = self.get_word_vector(c)?;
1181
1182 let mut d_vec = b_vec.clone();
1184 d_vec -= &a_vec;
1185 d_vec += &c_vec;
1186
1187 let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1189 d_vec.mapv_inplace(|val| val / norm);
1190
1191 self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
1193 }
1194
1195 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
1197 if self.input_embeddings.is_none() {
1198 return Err(TextError::EmbeddingError(
1199 "Model not trained. Call train() first".into(),
1200 ));
1201 }
1202
1203 let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
1204
1205 writeln!(
1207 &mut file,
1208 "{} {}",
1209 self.vocabulary.len(),
1210 self.config.vector_size
1211 )
1212 .map_err(|e| TextError::IoError(e.to_string()))?;
1213
1214 let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
1216
1217 for i in 0..self.vocabulary.len() {
1218 if let Some(word) = self.vocabulary.get_token(i) {
1219 write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
1221
1222 let vector = input_embeddings.row(i);
1224 for j in 0..self.config.vector_size {
1225 write!(&mut file, "{:.6} ", vector[j])
1226 .map_err(|e| TextError::IoError(e.to_string()))?;
1227 }
1228
1229 writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
1230 }
1231 }
1232
1233 Ok(())
1234 }
1235
1236 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
1238 let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
1239 let mut reader = BufReader::new(file);
1240
1241 let mut header = String::new();
1243 reader
1244 .read_line(&mut header)
1245 .map_err(|e| TextError::IoError(e.to_string()))?;
1246
1247 let parts: Vec<&str> = header.split_whitespace().collect();
1248 if parts.len() != 2 {
1249 return Err(TextError::EmbeddingError(
1250 "Invalid model file format".into(),
1251 ));
1252 }
1253
1254 let vocab_size = parts[0].parse::<usize>().map_err(|_| {
1255 TextError::EmbeddingError("Invalid vocabulary size in model file".into())
1256 })?;
1257
1258 let vector_size = parts[1]
1259 .parse::<usize>()
1260 .map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
1261
1262 let mut model = Self::new().with_vector_size(vector_size);
1264 let mut vocabulary = Vocabulary::new();
1265 let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
1266
1267 let mut i = 0;
1269 for line in reader.lines() {
1270 let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
1271 let parts: Vec<&str> = line.split_whitespace().collect();
1272
1273 if parts.len() != vector_size + 1 {
1274 let line_num = i + 2;
1275 return Err(TextError::EmbeddingError(format!(
1276 "Invalid vector format at line {line_num}"
1277 )));
1278 }
1279
1280 let word = parts[0];
1281 vocabulary.add_token(word);
1282
1283 for j in 0..vector_size {
1284 input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
1285 TextError::EmbeddingError(format!(
1286 "Invalid vector component at line {}, position {}",
1287 i + 2,
1288 j + 1
1289 ))
1290 })?;
1291 }
1292
1293 i += 1;
1294 }
1295
1296 if i != vocab_size {
1297 return Err(TextError::EmbeddingError(format!(
1298 "Expected {vocab_size} words but found {i}"
1299 )));
1300 }
1301
1302 model.vocabulary = vocabulary;
1303 model.input_embeddings = Some(input_embeddings);
1304 model.output_embeddings = None; Ok(model)
1307 }
1308
1309 pub fn get_vocabulary(&self) -> Vec<String> {
1313 let mut vocab = Vec::new();
1314 for i in 0..self.vocabulary.len() {
1315 if let Some(token) = self.vocabulary.get_token(i) {
1316 vocab.push(token.to_string());
1317 }
1318 }
1319 vocab
1320 }
1321
1322 pub fn get_vector_size(&self) -> usize {
1324 self.config.vector_size
1325 }
1326
1327 pub fn get_algorithm(&self) -> Word2VecAlgorithm {
1329 self.config.algorithm
1330 }
1331
1332 pub fn get_window_size(&self) -> usize {
1334 self.config.window_size
1335 }
1336
1337 pub fn get_min_count(&self) -> usize {
1339 self.config.min_count
1340 }
1341
1342 pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
1344 self.input_embeddings.clone()
1345 }
1346
1347 pub fn get_negative_samples(&self) -> usize {
1349 self.config.negative_samples
1350 }
1351
1352 pub fn get_learning_rate(&self) -> f64 {
1354 self.config.learning_rate
1355 }
1356
1357 pub fn get_epochs(&self) -> usize {
1359 self.config.epochs
1360 }
1361
1362 pub fn get_subsampling_threshold(&self) -> f64 {
1364 self.config.subsample
1365 }
1366
1367 pub fn uses_hierarchical_softmax(&self) -> bool {
1369 self.config.hierarchical_softmax
1370 }
1371
1372 pub fn restore_weights(
1380 &mut self,
1381 vocabulary: Vec<String>,
1382 embeddings: Array2<f64>,
1383 ) -> Result<()> {
1384 let embed_shape = embeddings.shape();
1385 let n_words = vocabulary.len();
1386
1387 if embed_shape[0] != n_words {
1388 return Err(TextError::EmbeddingError(format!(
1389 "Embedding row count {} does not match vocabulary size {}",
1390 embed_shape[0], n_words
1391 )));
1392 }
1393
1394 if embed_shape[1] != self.config.vector_size {
1395 return Err(TextError::EmbeddingError(format!(
1396 "Embedding dimension {} does not match configured vector_size {}",
1397 embed_shape[1], self.config.vector_size
1398 )));
1399 }
1400
1401 self.vocabulary = Vocabulary::new();
1403 for word in &vocabulary {
1404 self.vocabulary.add_token(word);
1405 }
1406
1407 self.input_embeddings = Some(embeddings);
1408 Ok(())
1409 }
1410
1411 fn train_skipgram_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1415 if sentence.len() < 2 {
1416 return Ok(());
1417 }
1418
1419 let input_embeddings = self
1420 .input_embeddings
1421 .as_mut()
1422 .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1423 let hs_params = self
1424 .hs_params
1425 .as_mut()
1426 .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1427 let tree = self
1428 .huffman_tree
1429 .as_ref()
1430 .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1431
1432 let vector_size = self.config.vector_size;
1433 let window_size = self.config.window_size;
1434 let lr = self.current_learning_rate;
1435
1436 let codes = tree.codes.clone();
1437 let paths = tree.paths.clone();
1438
1439 let mut rng = scirs2_core::random::rng();
1440
1441 for pos in 0..sentence.len() {
1442 let window = 1 + rng.random_range(0..window_size);
1443 let target_word = sentence[pos];
1444
1445 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1446 if i == pos {
1447 continue;
1448 }
1449
1450 let context_word = sentence[i];
1451 let code = &codes[context_word];
1452 let path = &paths[context_word];
1453
1454 let mut grad_input = Array1::zeros(vector_size);
1455
1456 for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1458 if node_idx >= hs_params.nrows() {
1459 continue;
1460 }
1461
1462 let input_vec = input_embeddings.row(target_word);
1464 let param_vec = hs_params.row(node_idx);
1465
1466 let dot: f64 = input_vec
1467 .iter()
1468 .zip(param_vec.iter())
1469 .map(|(a, b)| a * b)
1470 .sum();
1471 let sigmoid = 1.0 / (1.0 + (-dot).exp());
1472
1473 let target = if label == 0 { 1.0 } else { 0.0 };
1475 let gradient = (target - sigmoid) * lr;
1476
1477 grad_input.scaled_add(gradient, ¶m_vec.to_owned());
1479
1480 let input_owned = input_vec.to_owned();
1482 let mut param_mut = hs_params.row_mut(node_idx);
1483 param_mut.scaled_add(gradient, &input_owned);
1484 }
1485
1486 let mut input_mut = input_embeddings.row_mut(target_word);
1488 input_mut += &grad_input;
1489 }
1490 }
1491
1492 Ok(())
1493 }
1494
1495 fn train_cbow_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
1497 if sentence.len() < 2 {
1498 return Ok(());
1499 }
1500
1501 let input_embeddings = self
1502 .input_embeddings
1503 .as_mut()
1504 .ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
1505 let hs_params = self
1506 .hs_params
1507 .as_mut()
1508 .ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
1509 let tree = self
1510 .huffman_tree
1511 .as_ref()
1512 .ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
1513
1514 let vector_size = self.config.vector_size;
1515 let window_size = self.config.window_size;
1516 let lr = self.current_learning_rate;
1517
1518 let codes = tree.codes.clone();
1519 let paths = tree.paths.clone();
1520
1521 let mut rng = scirs2_core::random::rng();
1522
1523 for pos in 0..sentence.len() {
1524 let window = 1 + rng.random_range(0..window_size);
1525 let target_word = sentence[pos];
1526
1527 let mut context_words = Vec::new();
1529 for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
1530 if i != pos {
1531 context_words.push(sentence[i]);
1532 }
1533 }
1534
1535 if context_words.is_empty() {
1536 continue;
1537 }
1538
1539 let mut context_avg = Array1::zeros(vector_size);
1541 for &ctx_idx in &context_words {
1542 context_avg += &input_embeddings.row(ctx_idx);
1543 }
1544 context_avg /= context_words.len() as f64;
1545
1546 let code = &codes[target_word];
1548 let path = &paths[target_word];
1549
1550 let mut grad_context = Array1::zeros(vector_size);
1551
1552 for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
1553 if node_idx >= hs_params.nrows() {
1554 continue;
1555 }
1556
1557 let param_vec = hs_params.row(node_idx);
1558
1559 let dot: f64 = context_avg
1560 .iter()
1561 .zip(param_vec.iter())
1562 .map(|(a, b)| a * b)
1563 .sum();
1564 let sigmoid = 1.0 / (1.0 + (-dot).exp());
1565
1566 let target = if label == 0 { 1.0 } else { 0.0 };
1567 let gradient = (target - sigmoid) * lr;
1568
1569 grad_context.scaled_add(gradient, ¶m_vec.to_owned());
1570
1571 let ctx_owned = context_avg.clone();
1573 let mut param_mut = hs_params.row_mut(node_idx);
1574 param_mut.scaled_add(gradient, &ctx_owned);
1575 }
1576
1577 let grad_per_word = &grad_context / context_words.len() as f64;
1579 for &ctx_idx in &context_words {
1580 let mut input_mut = input_embeddings.row_mut(ctx_idx);
1581 input_mut += &grad_per_word;
1582 }
1583 }
1584
1585 Ok(())
1586 }
1587}
1588
1589impl WordEmbedding for Word2Vec {
1592 fn embedding(&self, word: &str) -> Result<Array1<f64>> {
1593 self.get_word_vector(word)
1594 }
1595
1596 fn dimension(&self) -> usize {
1597 self.vector_size()
1598 }
1599
1600 fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1601 self.most_similar(word, top_n)
1602 }
1603
1604 fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
1605 self.analogy(a, b, c, top_n)
1606 }
1607
1608 fn vocab_size(&self) -> usize {
1609 self.vocabulary.len()
1610 }
1611}
1612
1613#[allow(dead_code)]
1615pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
1616 let dot_product = (a * b).sum();
1617 let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1618 let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
1619
1620 if norm_a > 0.0 && norm_b > 0.0 {
1621 dot_product / (norm_a * norm_b)
1622 } else {
1623 0.0
1624 }
1625}
1626
1627#[cfg(test)]
1628mod tests {
1629 use super::*;
1630 use approx::assert_relative_eq;
1631
1632 #[test]
1633 fn test_cosine_similarity() {
1634 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1635 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1636
1637 let similarity = cosine_similarity(&a, &b);
1638 let expected = 0.9746318461970762;
1639 assert_relative_eq!(similarity, expected, max_relative = 1e-10);
1640 }
1641
1642 #[test]
1643 fn test_word2vec_config() {
1644 let config = Word2VecConfig::default();
1645 assert_eq!(config.vector_size, 100);
1646 assert_eq!(config.window_size, 5);
1647 assert_eq!(config.min_count, 5);
1648 assert_eq!(config.epochs, 5);
1649 assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
1650 }
1651
1652 #[test]
1653 fn test_word2vec_builder() {
1654 let model = Word2Vec::new()
1655 .with_vector_size(200)
1656 .with_window_size(10)
1657 .with_learning_rate(0.05)
1658 .with_algorithm(Word2VecAlgorithm::CBOW);
1659
1660 assert_eq!(model.config.vector_size, 200);
1661 assert_eq!(model.config.window_size, 10);
1662 assert_eq!(model.config.learning_rate, 0.05);
1663 assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
1664 }
1665
1666 #[test]
1667 fn test_build_vocabulary() {
1668 let texts = [
1669 "the quick brown fox jumps over the lazy dog",
1670 "a quick brown fox jumps over a lazy dog",
1671 ];
1672
1673 let mut model = Word2Vec::new().with_min_count(1);
1674 let result = model.build_vocabulary(&texts);
1675 assert!(result.is_ok());
1676
1677 assert_eq!(model.vocabulary.len(), 9);
1679
1680 assert!(model.input_embeddings.is_some());
1682 assert!(model.output_embeddings.is_some());
1683 assert_eq!(
1684 model
1685 .input_embeddings
1686 .as_ref()
1687 .expect("Operation failed")
1688 .shape(),
1689 &[9, 100]
1690 );
1691 }
1692
1693 #[test]
1694 fn test_skipgram_training_small() {
1695 let texts = [
1696 "the quick brown fox jumps over the lazy dog",
1697 "a quick brown fox jumps over a lazy dog",
1698 ];
1699
1700 let mut model = Word2Vec::new()
1701 .with_vector_size(10)
1702 .with_window_size(2)
1703 .with_min_count(1)
1704 .with_epochs(1)
1705 .with_algorithm(Word2VecAlgorithm::SkipGram);
1706
1707 let result = model.train(&texts);
1708 assert!(result.is_ok());
1709
1710 let result = model.get_word_vector("fox");
1712 assert!(result.is_ok());
1713 let vec = result.expect("Operation failed");
1714 assert_eq!(vec.len(), 10);
1715 }
1716
1717 #[test]
1720 fn test_huffman_tree_build() {
1721 let frequencies = vec![5, 3, 8, 1, 2];
1722 let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1723
1724 assert_eq!(tree.codes.len(), 5);
1726 assert_eq!(tree.paths.len(), 5);
1727
1728 for code in &tree.codes {
1730 assert!(!code.is_empty());
1731 }
1732
1733 assert_eq!(tree.num_internal, 4);
1735 }
1736
1737 #[test]
1738 fn test_huffman_tree_single_word() {
1739 let frequencies = vec![10];
1740 let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
1741 assert_eq!(tree.codes.len(), 1);
1742 assert_eq!(tree.paths.len(), 1);
1743 }
1744
1745 #[test]
1746 fn test_skipgram_hierarchical_softmax() {
1747 let texts = [
1748 "the quick brown fox jumps over the lazy dog",
1749 "a quick brown fox jumps over a lazy dog",
1750 ];
1751
1752 let config = Word2VecConfig {
1753 vector_size: 10,
1754 window_size: 2,
1755 min_count: 1,
1756 epochs: 3,
1757 learning_rate: 0.025,
1758 algorithm: Word2VecAlgorithm::SkipGram,
1759 hierarchical_softmax: true,
1760 ..Default::default()
1761 };
1762
1763 let mut model = Word2Vec::with_config(config);
1764 let result = model.train(&texts);
1765 assert!(
1766 result.is_ok(),
1767 "HS skipgram training failed: {:?}",
1768 result.err()
1769 );
1770
1771 assert!(model.uses_hierarchical_softmax());
1772
1773 let vec = model.get_word_vector("fox");
1775 assert!(vec.is_ok());
1776 assert_eq!(vec.expect("get vec").len(), 10);
1777 }
1778
1779 #[test]
1780 fn test_cbow_hierarchical_softmax() {
1781 let texts = [
1782 "the quick brown fox jumps over the lazy dog",
1783 "a quick brown fox jumps over a lazy dog",
1784 ];
1785
1786 let config = Word2VecConfig {
1787 vector_size: 10,
1788 window_size: 2,
1789 min_count: 1,
1790 epochs: 3,
1791 learning_rate: 0.025,
1792 algorithm: Word2VecAlgorithm::CBOW,
1793 hierarchical_softmax: true,
1794 ..Default::default()
1795 };
1796
1797 let mut model = Word2Vec::with_config(config);
1798 let result = model.train(&texts);
1799 assert!(
1800 result.is_ok(),
1801 "HS CBOW training failed: {:?}",
1802 result.err()
1803 );
1804
1805 let vec = model.get_word_vector("dog");
1806 assert!(vec.is_ok());
1807 }
1808
1809 #[test]
1812 fn test_word_embedding_trait_word2vec() {
1813 let texts = [
1814 "the quick brown fox jumps over the lazy dog",
1815 "a quick brown fox jumps over a lazy dog",
1816 ];
1817
1818 let mut model = Word2Vec::new()
1819 .with_vector_size(10)
1820 .with_min_count(1)
1821 .with_epochs(1);
1822
1823 model.train(&texts).expect("Training failed");
1824
1825 let emb: &dyn WordEmbedding = &model;
1827 assert_eq!(emb.dimension(), 10);
1828 assert!(emb.vocab_size() > 0);
1829
1830 let vec = emb.embedding("fox");
1831 assert!(vec.is_ok());
1832
1833 let sim = emb.similarity("fox", "dog");
1834 assert!(sim.is_ok());
1835 assert!(sim.expect("sim").is_finite());
1836
1837 let similar = emb.find_similar("fox", 2);
1838 assert!(similar.is_ok());
1839
1840 let analogy = emb.solve_analogy("the", "fox", "dog", 2);
1841 assert!(analogy.is_ok());
1842 }
1843
1844 #[test]
1845 fn test_embedding_cosine_similarity_fn() {
1846 let a = Array1::from_vec(vec![1.0, 0.0]);
1847 let b = Array1::from_vec(vec![0.0, 1.0]);
1848 assert!((embedding_cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
1849
1850 let c = Array1::from_vec(vec![1.0, 1.0]);
1851 let d = Array1::from_vec(vec![1.0, 1.0]);
1852 assert!((embedding_cosine_similarity(&c, &d) - 1.0).abs() < 1e-6);
1853 }
1854
1855 #[test]
1856 fn test_pairwise_similarity_fn() {
1857 let texts = ["the quick brown fox", "the lazy brown dog"];
1858
1859 let mut model = Word2Vec::new()
1860 .with_vector_size(10)
1861 .with_min_count(1)
1862 .with_epochs(1);
1863 model.train(&texts).expect("Training failed");
1864
1865 let words = vec!["the", "fox", "dog"];
1866 let matrix = pairwise_similarity(&model, &words).expect("pairwise failed");
1867
1868 assert_eq!(matrix.len(), 3);
1869 assert_eq!(matrix[0].len(), 3);
1870
1871 for i in 0..3 {
1873 assert!((matrix[i][i] - 1.0).abs() < 1e-6);
1874 }
1875
1876 for i in 0..3 {
1878 for j in 0..3 {
1879 assert!((matrix[i][j] - matrix[j][i]).abs() < 1e-10);
1880 }
1881 }
1882 }
1883}