pub mod contrastive;
pub mod crosslingual;
pub mod fasttext;
pub mod glove;
pub mod sentence;
pub mod sentence_encoder;
pub mod universal;
pub use fasttext::{FastText, FastTextConfig};
pub use glove::{
cosine_similarity as glove_cosine_similarity, CooccurrenceMatrix, GloVe, GloVeTrainer,
GloVeTrainerConfig,
};
use crate::error::{Result, TextError};
use crate::tokenize::{Tokenizer, WordTokenizer};
use crate::vocabulary::Vocabulary;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::prelude::*;
use std::collections::HashMap;
use std::fmt::Debug;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::path::Path;
pub trait WordEmbedding {
fn embedding(&self, word: &str) -> Result<Array1<f64>>;
fn dimension(&self) -> usize;
fn similarity(&self, word1: &str, word2: &str) -> Result<f64> {
let v1 = self.embedding(word1)?;
let v2 = self.embedding(word2)?;
Ok(embedding_cosine_similarity(&v1, &v2))
}
fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>>;
fn vocab_size(&self) -> usize;
}
pub fn embedding_cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot_product / (norm_a * norm_b)
} else {
0.0
}
}
pub fn pairwise_similarity(model: &dyn WordEmbedding, words: &[&str]) -> Result<Vec<Vec<f64>>> {
let vectors: Vec<Array1<f64>> = words
.iter()
.map(|&w| model.embedding(w))
.collect::<Result<Vec<_>>>()?;
let n = vectors.len();
let mut matrix = vec![vec![0.0; n]; n];
for i in 0..n {
for j in i..n {
let sim = embedding_cosine_similarity(&vectors[i], &vectors[j]);
matrix[i][j] = sim;
matrix[j][i] = sim;
}
}
Ok(matrix)
}
impl WordEmbedding for GloVe {
fn embedding(&self, word: &str) -> Result<Array1<f64>> {
self.get_word_vector(word)
}
fn dimension(&self) -> usize {
self.vector_size()
}
fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
self.most_similar(word, top_n)
}
fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
self.analogy(a, b, c, top_n)
}
fn vocab_size(&self) -> usize {
self.vocabulary_size()
}
}
impl WordEmbedding for FastText {
fn embedding(&self, word: &str) -> Result<Array1<f64>> {
self.get_word_vector(word)
}
fn dimension(&self) -> usize {
self.vector_size()
}
fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
self.most_similar(word, top_n)
}
fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
self.analogy(a, b, c, top_n)
}
fn vocab_size(&self) -> usize {
self.vocabulary_size()
}
}
#[derive(Debug, Clone)]
struct HuffmanNode {
id: usize,
frequency: usize,
left: Option<usize>,
right: Option<usize>,
is_leaf: bool,
}
#[derive(Debug, Clone)]
struct HuffmanTree {
codes: Vec<Vec<u8>>,
paths: Vec<Vec<usize>>,
num_internal: usize,
}
impl HuffmanTree {
fn build(frequencies: &[usize]) -> Result<Self> {
let vocab_size = frequencies.len();
if vocab_size == 0 {
return Err(TextError::EmbeddingError(
"Cannot build Huffman tree with empty vocabulary".into(),
));
}
if vocab_size == 1 {
return Ok(Self {
codes: vec![vec![0]],
paths: vec![vec![0]],
num_internal: 1,
});
}
let mut nodes: Vec<HuffmanNode> = frequencies
.iter()
.enumerate()
.map(|(id, &freq)| HuffmanNode {
id,
frequency: freq.max(1), left: None,
right: None,
is_leaf: true,
})
.collect();
let mut queue: Vec<(usize, usize)> = nodes
.iter()
.enumerate()
.map(|(i, n)| (i, n.frequency))
.collect();
queue.sort_by_key(|item| std::cmp::Reverse(item.1));
while queue.len() > 1 {
let (idx1, freq1) = queue
.pop()
.ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
let (idx2, freq2) = queue
.pop()
.ok_or_else(|| TextError::EmbeddingError("Queue empty".into()))?;
let new_id = nodes.len();
let new_node = HuffmanNode {
id: new_id,
frequency: freq1 + freq2,
left: Some(idx1),
right: Some(idx2),
is_leaf: false,
};
nodes.push(new_node);
let new_freq = freq1 + freq2;
let insert_pos = queue
.binary_search_by(|(_, f)| new_freq.cmp(f))
.unwrap_or_else(|pos| pos);
queue.insert(insert_pos, (new_id, new_freq));
}
let num_internal = nodes.len() - vocab_size;
let mut codes = vec![Vec::new(); vocab_size];
let mut paths = vec![Vec::new(); vocab_size];
let root_idx = nodes.len() - 1;
let mut stack: Vec<(usize, Vec<u8>, Vec<usize>)> = vec![(root_idx, Vec::new(), Vec::new())];
while let Some((node_idx, code, path)) = stack.pop() {
let node = &nodes[node_idx];
if node.is_leaf {
codes[node.id] = code;
paths[node.id] = path;
} else {
let internal_idx = node.id - vocab_size;
if let Some(left_idx) = node.left {
let mut left_code = code.clone();
left_code.push(0);
let mut left_path = path.clone();
left_path.push(internal_idx);
stack.push((left_idx, left_code, left_path));
}
if let Some(right_idx) = node.right {
let mut right_code = code.clone();
right_code.push(1);
let mut right_path = path.clone();
right_path.push(internal_idx);
stack.push((right_idx, right_code, right_path));
}
}
}
Ok(Self {
codes,
paths,
num_internal,
})
}
}
#[derive(Debug, Clone)]
struct SamplingTable {
cdf: Vec<f64>,
weights: Vec<f64>,
}
impl SamplingTable {
fn new(weights: &[f64]) -> Result<Self> {
if weights.is_empty() {
return Err(TextError::EmbeddingError("Weights cannot be empty".into()));
}
if weights.iter().any(|&w| w < 0.0) {
return Err(TextError::EmbeddingError("Weights must be positive".into()));
}
let sum: f64 = weights.iter().sum();
if sum <= 0.0 {
return Err(TextError::EmbeddingError(
"Sum of _weights must be positive".into(),
));
}
let mut cdf = Vec::with_capacity(weights.len());
let mut total = 0.0;
for &w in weights {
total += w;
cdf.push(total / sum);
}
Ok(Self {
cdf,
weights: weights.to_vec(),
})
}
fn sample<R: Rng>(&self, rng: &mut R) -> usize {
let r = rng.random::<f64>();
match self.cdf.binary_search_by(|&cdf_val| {
cdf_val.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)
}) {
Ok(idx) => idx,
Err(idx) => idx,
}
}
fn weights(&self) -> &[f64] {
&self.weights
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Word2VecAlgorithm {
CBOW,
SkipGram,
}
#[derive(Debug, Clone)]
pub struct Word2VecConfig {
pub vector_size: usize,
pub window_size: usize,
pub min_count: usize,
pub epochs: usize,
pub learning_rate: f64,
pub algorithm: Word2VecAlgorithm,
pub negative_samples: usize,
pub subsample: f64,
pub batch_size: usize,
pub hierarchical_softmax: bool,
}
impl Default for Word2VecConfig {
fn default() -> Self {
Self {
vector_size: 100,
window_size: 5,
min_count: 5,
epochs: 5,
learning_rate: 0.025,
algorithm: Word2VecAlgorithm::SkipGram,
negative_samples: 5,
subsample: 1e-3,
batch_size: 128,
hierarchical_softmax: false,
}
}
}
pub struct Word2Vec {
config: Word2VecConfig,
vocabulary: Vocabulary,
input_embeddings: Option<Array2<f64>>,
output_embeddings: Option<Array2<f64>>,
tokenizer: Box<dyn Tokenizer + Send + Sync>,
sampling_table: Option<SamplingTable>,
huffman_tree: Option<HuffmanTree>,
hs_params: Option<Array2<f64>>,
current_learning_rate: f64,
}
impl Debug for Word2Vec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Word2Vec")
.field("config", &self.config)
.field("vocabulary", &self.vocabulary)
.field("input_embeddings", &self.input_embeddings)
.field("output_embeddings", &self.output_embeddings)
.field("sampling_table", &self.sampling_table)
.field("huffman_tree", &self.huffman_tree)
.field("current_learning_rate", &self.current_learning_rate)
.finish()
}
}
impl Default for Word2Vec {
fn default() -> Self {
Self::new()
}
}
impl Clone for Word2Vec {
fn clone(&self) -> Self {
let tokenizer: Box<dyn Tokenizer + Send + Sync> = Box::new(WordTokenizer::default());
Self {
config: self.config.clone(),
vocabulary: self.vocabulary.clone(),
input_embeddings: self.input_embeddings.clone(),
output_embeddings: self.output_embeddings.clone(),
tokenizer,
sampling_table: self.sampling_table.clone(),
huffman_tree: self.huffman_tree.clone(),
hs_params: self.hs_params.clone(),
current_learning_rate: self.current_learning_rate,
}
}
}
impl Word2Vec {
pub fn new() -> Self {
Self {
config: Word2VecConfig::default(),
vocabulary: Vocabulary::new(),
input_embeddings: None,
output_embeddings: None,
tokenizer: Box::new(WordTokenizer::default()),
sampling_table: None,
huffman_tree: None,
hs_params: None,
current_learning_rate: 0.025,
}
}
pub fn with_config(config: Word2VecConfig) -> Self {
let learning_rate = config.learning_rate;
Self {
config,
vocabulary: Vocabulary::new(),
input_embeddings: None,
output_embeddings: None,
tokenizer: Box::new(WordTokenizer::default()),
sampling_table: None,
huffman_tree: None,
hs_params: None,
current_learning_rate: learning_rate,
}
}
pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
self.tokenizer = tokenizer;
self
}
pub fn with_vector_size(mut self, vectorsize: usize) -> Self {
self.config.vector_size = vectorsize;
self
}
pub fn with_window_size(mut self, windowsize: usize) -> Self {
self.config.window_size = windowsize;
self
}
pub fn with_min_count(mut self, mincount: usize) -> Self {
self.config.min_count = mincount;
self
}
pub fn with_epochs(mut self, epochs: usize) -> Self {
self.config.epochs = epochs;
self
}
pub fn with_learning_rate(mut self, learningrate: f64) -> Self {
self.config.learning_rate = learningrate;
self.current_learning_rate = learningrate;
self
}
pub fn with_algorithm(mut self, algorithm: Word2VecAlgorithm) -> Self {
self.config.algorithm = algorithm;
self
}
pub fn with_negative_samples(mut self, negativesamples: usize) -> Self {
self.config.negative_samples = negativesamples;
self
}
pub fn with_subsample(mut self, subsample: f64) -> Self {
self.config.subsample = subsample;
self
}
pub fn with_batch_size(mut self, batchsize: usize) -> Self {
self.config.batch_size = batchsize;
self
}
pub fn build_vocabulary(&mut self, texts: &[&str]) -> Result<()> {
if texts.is_empty() {
return Err(TextError::InvalidInput(
"No texts provided for building vocabulary".into(),
));
}
let mut word_counts = HashMap::new();
let mut _total_words = 0;
for &text in texts {
let tokens = self.tokenizer.tokenize(text)?;
for token in tokens {
*word_counts.entry(token).or_insert(0) += 1;
_total_words += 1;
}
}
self.vocabulary = Vocabulary::new();
for (word, count) in &word_counts {
if *count >= self.config.min_count {
self.vocabulary.add_token(word);
}
}
if self.vocabulary.is_empty() {
return Err(TextError::VocabularyError(
"No words meet the minimum count threshold".into(),
));
}
let vocab_size = self.vocabulary.len();
let vector_size = self.config.vector_size;
let mut rng = scirs2_core::random::rng();
let input_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
(rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
});
let output_embeddings = Array2::from_shape_fn((vocab_size, vector_size), |_| {
(rng.random::<f64>() * 2.0 - 1.0) / vector_size as f64
});
self.input_embeddings = Some(input_embeddings);
self.output_embeddings = Some(output_embeddings);
self.create_sampling_table(&word_counts)?;
if self.config.hierarchical_softmax {
let frequencies: Vec<usize> = (0..vocab_size)
.map(|i| {
self.vocabulary
.get_token(i)
.and_then(|word| word_counts.get(word).copied())
.unwrap_or(1)
})
.collect();
let tree = HuffmanTree::build(&frequencies)?;
let num_internal = tree.num_internal;
let hs_params = Array2::zeros((num_internal, vector_size));
self.hs_params = Some(hs_params);
self.huffman_tree = Some(tree);
}
Ok(())
}
fn create_sampling_table(&mut self, wordcounts: &HashMap<String, usize>) -> Result<()> {
let mut sampling_weights = vec![0.0; self.vocabulary.len()];
for (word, &count) in wordcounts.iter() {
if let Some(idx) = self.vocabulary.get_index(word) {
sampling_weights[idx] = (count as f64).powf(0.75);
}
}
match SamplingTable::new(&sampling_weights) {
Ok(table) => {
self.sampling_table = Some(table);
Ok(())
}
Err(e) => Err(e),
}
}
pub fn train(&mut self, texts: &[&str]) -> Result<()> {
if texts.is_empty() {
return Err(TextError::InvalidInput(
"No texts provided for training".into(),
));
}
if self.vocabulary.is_empty() {
self.build_vocabulary(texts)?;
}
if self.input_embeddings.is_none() || self.output_embeddings.is_none() {
return Err(TextError::EmbeddingError(
"Embeddings not initialized. Call build_vocabulary() first".into(),
));
}
let mut _total_tokens = 0;
let mut sentences = Vec::new();
for &text in texts {
let tokens = self.tokenizer.tokenize(text)?;
let filtered_tokens: Vec<usize> = tokens
.iter()
.filter_map(|token| self.vocabulary.get_index(token))
.collect();
if !filtered_tokens.is_empty() {
_total_tokens += filtered_tokens.len();
sentences.push(filtered_tokens);
}
}
for epoch in 0..self.config.epochs {
self.current_learning_rate =
self.config.learning_rate * (1.0 - (epoch as f64 / self.config.epochs as f64));
self.current_learning_rate = self
.current_learning_rate
.max(self.config.learning_rate * 0.0001);
for sentence in &sentences {
let subsampled_sentence = if self.config.subsample > 0.0 {
self.subsample_sentence(sentence)?
} else {
sentence.clone()
};
if subsampled_sentence.is_empty() {
continue;
}
if self.config.hierarchical_softmax {
match self.config.algorithm {
Word2VecAlgorithm::SkipGram => {
self.train_skipgram_hs_sentence(&subsampled_sentence)?;
}
Word2VecAlgorithm::CBOW => {
self.train_cbow_hs_sentence(&subsampled_sentence)?;
}
}
} else {
match self.config.algorithm {
Word2VecAlgorithm::CBOW => {
self.train_cbow_sentence(&subsampled_sentence)?;
}
Word2VecAlgorithm::SkipGram => {
self.train_skipgram_sentence(&subsampled_sentence)?;
}
}
}
}
}
Ok(())
}
fn subsample_sentence(&self, sentence: &[usize]) -> Result<Vec<usize>> {
let mut rng = scirs2_core::random::rng();
let total_words: f64 = self.vocabulary.len() as f64;
let threshold = self.config.subsample * total_words;
let subsampled: Vec<usize> = sentence
.iter()
.filter(|&&word_idx| {
let word_freq = self.get_word_frequency(word_idx);
if word_freq == 0.0 {
return true; }
let keep_prob = ((word_freq / threshold).sqrt() + 1.0) * (threshold / word_freq);
rng.random::<f64>() < keep_prob
})
.copied()
.collect();
Ok(subsampled)
}
fn get_word_frequency(&self, wordidx: usize) -> f64 {
if let Some(table) = &self.sampling_table {
table.weights()[wordidx]
} else {
1.0 }
}
fn train_cbow_sentence(&mut self, sentence: &[usize]) -> Result<()> {
if sentence.len() < 2 {
return Ok(()); }
let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
let vector_size = self.config.vector_size;
let window_size = self.config.window_size;
let negative_samples = self.config.negative_samples;
for pos in 0..sentence.len() {
let mut rng = scirs2_core::random::rng();
let window = 1 + rng.random_range(0..window_size);
let target_word = sentence[pos];
let mut context_words = Vec::new();
#[allow(clippy::needless_range_loop)]
for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
if i != pos {
context_words.push(sentence[i]);
}
}
if context_words.is_empty() {
continue; }
let mut context_sum = Array1::zeros(vector_size);
for &context_idx in &context_words {
context_sum += &input_embeddings.row(context_idx);
}
let context_avg = &context_sum / context_words.len() as f64;
let mut target_output = output_embeddings.row_mut(target_word);
let dot_product = (&context_avg * &target_output).sum();
let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
let error = (1.0 - sigmoid) * self.current_learning_rate;
let mut target_update = target_output.to_owned();
target_update.scaled_add(error, &context_avg);
target_output.assign(&target_update);
if let Some(sampler) = &self.sampling_table {
for _ in 0..negative_samples {
let negative_idx = sampler.sample(&mut rng);
if negative_idx == target_word {
continue; }
let mut negative_output = output_embeddings.row_mut(negative_idx);
let dot_product = (&context_avg * &negative_output).sum();
let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
let error = -sigmoid * self.current_learning_rate;
let mut negative_update = negative_output.to_owned();
negative_update.scaled_add(error, &context_avg);
negative_output.assign(&negative_update);
}
}
for &context_idx in &context_words {
let mut input_vec = input_embeddings.row_mut(context_idx);
let dot_product = (&context_avg * &output_embeddings.row(target_word)).sum();
let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
let error =
(1.0 - sigmoid) * self.current_learning_rate / context_words.len() as f64;
let mut input_update = input_vec.to_owned();
input_update.scaled_add(error, &output_embeddings.row(target_word));
if let Some(sampler) = &self.sampling_table {
for _ in 0..negative_samples {
let negative_idx = sampler.sample(&mut rng);
if negative_idx == target_word {
continue;
}
let dot_product =
(&context_avg * &output_embeddings.row(negative_idx)).sum();
let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
let error =
-sigmoid * self.current_learning_rate / context_words.len() as f64;
input_update.scaled_add(error, &output_embeddings.row(negative_idx));
}
}
input_vec.assign(&input_update);
}
}
Ok(())
}
fn train_skipgram_sentence(&mut self, sentence: &[usize]) -> Result<()> {
if sentence.len() < 2 {
return Ok(()); }
let input_embeddings = self.input_embeddings.as_mut().expect("Operation failed");
let output_embeddings = self.output_embeddings.as_mut().expect("Operation failed");
let vector_size = self.config.vector_size;
let window_size = self.config.window_size;
let negative_samples = self.config.negative_samples;
for pos in 0..sentence.len() {
let mut rng = scirs2_core::random::rng();
let window = 1 + rng.random_range(0..window_size);
let target_word = sentence[pos];
#[allow(clippy::needless_range_loop)]
for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
if i == pos {
continue; }
let context_word = sentence[i];
let target_input = input_embeddings.row(target_word);
let mut context_output = output_embeddings.row_mut(context_word);
let dot_product = (&target_input * &context_output).sum();
let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
let error = (1.0 - sigmoid) * self.current_learning_rate;
let mut context_update = context_output.to_owned();
context_update.scaled_add(error, &target_input);
context_output.assign(&context_update);
let mut input_update = Array1::zeros(vector_size);
input_update.scaled_add(error, &context_output);
if let Some(sampler) = &self.sampling_table {
for _ in 0..negative_samples {
let negative_idx = sampler.sample(&mut rng);
if negative_idx == context_word {
continue; }
let mut negative_output = output_embeddings.row_mut(negative_idx);
let dot_product = (&target_input * &negative_output).sum();
let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
let error = -sigmoid * self.current_learning_rate;
let mut negative_update = negative_output.to_owned();
negative_update.scaled_add(error, &target_input);
negative_output.assign(&negative_update);
input_update.scaled_add(error, &negative_output);
}
}
let mut target_input_mut = input_embeddings.row_mut(target_word);
target_input_mut += &input_update;
}
}
Ok(())
}
pub fn vector_size(&self) -> usize {
self.config.vector_size
}
pub fn get_word_vector(&self, word: &str) -> Result<Array1<f64>> {
if self.input_embeddings.is_none() {
return Err(TextError::EmbeddingError(
"Model not trained. Call train() first".into(),
));
}
match self.vocabulary.get_index(word) {
Some(idx) => Ok(self
.input_embeddings
.as_ref()
.expect("Operation failed")
.row(idx)
.to_owned()),
None => Err(TextError::VocabularyError(format!(
"Word '{word}' not in vocabulary"
))),
}
}
pub fn most_similar(&self, word: &str, topn: usize) -> Result<Vec<(String, f64)>> {
let word_vec = self.get_word_vector(word)?;
self.most_similar_by_vector(&word_vec, topn, &[word])
}
pub fn most_similar_by_vector(
&self,
vector: &Array1<f64>,
top_n: usize,
exclude_words: &[&str],
) -> Result<Vec<(String, f64)>> {
if self.input_embeddings.is_none() {
return Err(TextError::EmbeddingError(
"Model not trained. Call train() first".into(),
));
}
let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
let vocab_size = self.vocabulary.len();
let exclude_indices: Vec<usize> = exclude_words
.iter()
.filter_map(|&word| self.vocabulary.get_index(word))
.collect();
let mut similarities = Vec::with_capacity(vocab_size);
for i in 0..vocab_size {
if exclude_indices.contains(&i) {
continue;
}
let word_vec = input_embeddings.row(i);
let similarity = cosine_similarity(vector, &word_vec.to_owned());
if let Some(word) = self.vocabulary.get_token(i) {
similarities.push((word.to_string(), similarity));
}
}
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let result = similarities.into_iter().take(top_n).collect();
Ok(result)
}
pub fn analogy(&self, a: &str, b: &str, c: &str, topn: usize) -> Result<Vec<(String, f64)>> {
if self.input_embeddings.is_none() {
return Err(TextError::EmbeddingError(
"Model not trained. Call train() first".into(),
));
}
let a_vec = self.get_word_vector(a)?;
let b_vec = self.get_word_vector(b)?;
let c_vec = self.get_word_vector(c)?;
let mut d_vec = b_vec.clone();
d_vec -= &a_vec;
d_vec += &c_vec;
let norm = (d_vec.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
d_vec.mapv_inplace(|val| val / norm);
self.most_similar_by_vector(&d_vec, topn, &[a, b, c])
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
if self.input_embeddings.is_none() {
return Err(TextError::EmbeddingError(
"Model not trained. Call train() first".into(),
));
}
let mut file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
writeln!(
&mut file,
"{} {}",
self.vocabulary.len(),
self.config.vector_size
)
.map_err(|e| TextError::IoError(e.to_string()))?;
let input_embeddings = self.input_embeddings.as_ref().expect("Operation failed");
for i in 0..self.vocabulary.len() {
if let Some(word) = self.vocabulary.get_token(i) {
write!(&mut file, "{word} ").map_err(|e| TextError::IoError(e.to_string()))?;
let vector = input_embeddings.row(i);
for j in 0..self.config.vector_size {
write!(&mut file, "{:.6} ", vector[j])
.map_err(|e| TextError::IoError(e.to_string()))?;
}
writeln!(&mut file).map_err(|e| TextError::IoError(e.to_string()))?;
}
}
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
let mut reader = BufReader::new(file);
let mut header = String::new();
reader
.read_line(&mut header)
.map_err(|e| TextError::IoError(e.to_string()))?;
let parts: Vec<&str> = header.split_whitespace().collect();
if parts.len() != 2 {
return Err(TextError::EmbeddingError(
"Invalid model file format".into(),
));
}
let vocab_size = parts[0].parse::<usize>().map_err(|_| {
TextError::EmbeddingError("Invalid vocabulary size in model file".into())
})?;
let vector_size = parts[1]
.parse::<usize>()
.map_err(|_| TextError::EmbeddingError("Invalid vector size in model file".into()))?;
let mut model = Self::new().with_vector_size(vector_size);
let mut vocabulary = Vocabulary::new();
let mut input_embeddings = Array2::zeros((vocab_size, vector_size));
let mut i = 0;
for line in reader.lines() {
let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() != vector_size + 1 {
let line_num = i + 2;
return Err(TextError::EmbeddingError(format!(
"Invalid vector format at line {line_num}"
)));
}
let word = parts[0];
vocabulary.add_token(word);
for j in 0..vector_size {
input_embeddings[(i, j)] = parts[j + 1].parse::<f64>().map_err(|_| {
TextError::EmbeddingError(format!(
"Invalid vector component at line {}, position {}",
i + 2,
j + 1
))
})?;
}
i += 1;
}
if i != vocab_size {
return Err(TextError::EmbeddingError(format!(
"Expected {vocab_size} words but found {i}"
)));
}
model.vocabulary = vocabulary;
model.input_embeddings = Some(input_embeddings);
model.output_embeddings = None;
Ok(model)
}
pub fn get_vocabulary(&self) -> Vec<String> {
let mut vocab = Vec::new();
for i in 0..self.vocabulary.len() {
if let Some(token) = self.vocabulary.get_token(i) {
vocab.push(token.to_string());
}
}
vocab
}
pub fn get_vector_size(&self) -> usize {
self.config.vector_size
}
pub fn get_algorithm(&self) -> Word2VecAlgorithm {
self.config.algorithm
}
pub fn get_window_size(&self) -> usize {
self.config.window_size
}
pub fn get_min_count(&self) -> usize {
self.config.min_count
}
pub fn get_embeddings_matrix(&self) -> Option<Array2<f64>> {
self.input_embeddings.clone()
}
pub fn get_negative_samples(&self) -> usize {
self.config.negative_samples
}
pub fn get_learning_rate(&self) -> f64 {
self.config.learning_rate
}
pub fn get_epochs(&self) -> usize {
self.config.epochs
}
pub fn get_subsampling_threshold(&self) -> f64 {
self.config.subsample
}
pub fn uses_hierarchical_softmax(&self) -> bool {
self.config.hierarchical_softmax
}
pub fn restore_weights(
&mut self,
vocabulary: Vec<String>,
embeddings: Array2<f64>,
) -> Result<()> {
let embed_shape = embeddings.shape();
let n_words = vocabulary.len();
if embed_shape[0] != n_words {
return Err(TextError::EmbeddingError(format!(
"Embedding row count {} does not match vocabulary size {}",
embed_shape[0], n_words
)));
}
if embed_shape[1] != self.config.vector_size {
return Err(TextError::EmbeddingError(format!(
"Embedding dimension {} does not match configured vector_size {}",
embed_shape[1], self.config.vector_size
)));
}
self.vocabulary = Vocabulary::new();
for word in &vocabulary {
self.vocabulary.add_token(word);
}
self.input_embeddings = Some(embeddings);
Ok(())
}
fn train_skipgram_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
if sentence.len() < 2 {
return Ok(());
}
let input_embeddings = self
.input_embeddings
.as_mut()
.ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
let hs_params = self
.hs_params
.as_mut()
.ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
let tree = self
.huffman_tree
.as_ref()
.ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
let vector_size = self.config.vector_size;
let window_size = self.config.window_size;
let lr = self.current_learning_rate;
let codes = tree.codes.clone();
let paths = tree.paths.clone();
let mut rng = scirs2_core::random::rng();
for pos in 0..sentence.len() {
let window = 1 + rng.random_range(0..window_size);
let target_word = sentence[pos];
for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
if i == pos {
continue;
}
let context_word = sentence[i];
let code = &codes[context_word];
let path = &paths[context_word];
let mut grad_input = Array1::zeros(vector_size);
for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
if node_idx >= hs_params.nrows() {
continue;
}
let input_vec = input_embeddings.row(target_word);
let param_vec = hs_params.row(node_idx);
let dot: f64 = input_vec
.iter()
.zip(param_vec.iter())
.map(|(a, b)| a * b)
.sum();
let sigmoid = 1.0 / (1.0 + (-dot).exp());
let target = if label == 0 { 1.0 } else { 0.0 };
let gradient = (target - sigmoid) * lr;
grad_input.scaled_add(gradient, ¶m_vec.to_owned());
let input_owned = input_vec.to_owned();
let mut param_mut = hs_params.row_mut(node_idx);
param_mut.scaled_add(gradient, &input_owned);
}
let mut input_mut = input_embeddings.row_mut(target_word);
input_mut += &grad_input;
}
}
Ok(())
}
fn train_cbow_hs_sentence(&mut self, sentence: &[usize]) -> Result<()> {
if sentence.len() < 2 {
return Ok(());
}
let input_embeddings = self
.input_embeddings
.as_mut()
.ok_or_else(|| TextError::EmbeddingError("Input embeddings not initialized".into()))?;
let hs_params = self
.hs_params
.as_mut()
.ok_or_else(|| TextError::EmbeddingError("HS params not initialized".into()))?;
let tree = self
.huffman_tree
.as_ref()
.ok_or_else(|| TextError::EmbeddingError("Huffman tree not built".into()))?;
let vector_size = self.config.vector_size;
let window_size = self.config.window_size;
let lr = self.current_learning_rate;
let codes = tree.codes.clone();
let paths = tree.paths.clone();
let mut rng = scirs2_core::random::rng();
for pos in 0..sentence.len() {
let window = 1 + rng.random_range(0..window_size);
let target_word = sentence[pos];
let mut context_words = Vec::new();
for i in pos.saturating_sub(window)..=(pos + window).min(sentence.len() - 1) {
if i != pos {
context_words.push(sentence[i]);
}
}
if context_words.is_empty() {
continue;
}
let mut context_avg = Array1::zeros(vector_size);
for &ctx_idx in &context_words {
context_avg += &input_embeddings.row(ctx_idx);
}
context_avg /= context_words.len() as f64;
let code = &codes[target_word];
let path = &paths[target_word];
let mut grad_context = Array1::zeros(vector_size);
for (step, (&node_idx, &label)) in path.iter().zip(code.iter()).enumerate() {
if node_idx >= hs_params.nrows() {
continue;
}
let param_vec = hs_params.row(node_idx);
let dot: f64 = context_avg
.iter()
.zip(param_vec.iter())
.map(|(a, b)| a * b)
.sum();
let sigmoid = 1.0 / (1.0 + (-dot).exp());
let target = if label == 0 { 1.0 } else { 0.0 };
let gradient = (target - sigmoid) * lr;
grad_context.scaled_add(gradient, ¶m_vec.to_owned());
let ctx_owned = context_avg.clone();
let mut param_mut = hs_params.row_mut(node_idx);
param_mut.scaled_add(gradient, &ctx_owned);
}
let grad_per_word = &grad_context / context_words.len() as f64;
for &ctx_idx in &context_words {
let mut input_mut = input_embeddings.row_mut(ctx_idx);
input_mut += &grad_per_word;
}
}
Ok(())
}
}
impl WordEmbedding for Word2Vec {
fn embedding(&self, word: &str) -> Result<Array1<f64>> {
self.get_word_vector(word)
}
fn dimension(&self) -> usize {
self.vector_size()
}
fn find_similar(&self, word: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
self.most_similar(word, top_n)
}
fn solve_analogy(&self, a: &str, b: &str, c: &str, top_n: usize) -> Result<Vec<(String, f64)>> {
self.analogy(a, b, c, top_n)
}
fn vocab_size(&self) -> usize {
self.vocabulary.len()
}
}
#[allow(dead_code)]
pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let dot_product = (a * b).sum();
let norm_a = (a.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
let norm_b = (b.iter().fold(0.0, |sum, &val| sum + val * val)).sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot_product / (norm_a * norm_b)
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_cosine_similarity() {
let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
let similarity = cosine_similarity(&a, &b);
let expected = 0.9746318461970762;
assert_relative_eq!(similarity, expected, max_relative = 1e-10);
}
#[test]
fn test_word2vec_config() {
let config = Word2VecConfig::default();
assert_eq!(config.vector_size, 100);
assert_eq!(config.window_size, 5);
assert_eq!(config.min_count, 5);
assert_eq!(config.epochs, 5);
assert_eq!(config.algorithm, Word2VecAlgorithm::SkipGram);
}
#[test]
fn test_word2vec_builder() {
let model = Word2Vec::new()
.with_vector_size(200)
.with_window_size(10)
.with_learning_rate(0.05)
.with_algorithm(Word2VecAlgorithm::CBOW);
assert_eq!(model.config.vector_size, 200);
assert_eq!(model.config.window_size, 10);
assert_eq!(model.config.learning_rate, 0.05);
assert_eq!(model.config.algorithm, Word2VecAlgorithm::CBOW);
}
#[test]
fn test_build_vocabulary() {
let texts = [
"the quick brown fox jumps over the lazy dog",
"a quick brown fox jumps over a lazy dog",
];
let mut model = Word2Vec::new().with_min_count(1);
let result = model.build_vocabulary(&texts);
assert!(result.is_ok());
assert_eq!(model.vocabulary.len(), 9);
assert!(model.input_embeddings.is_some());
assert!(model.output_embeddings.is_some());
assert_eq!(
model
.input_embeddings
.as_ref()
.expect("Operation failed")
.shape(),
&[9, 100]
);
}
#[test]
fn test_skipgram_training_small() {
let texts = [
"the quick brown fox jumps over the lazy dog",
"a quick brown fox jumps over a lazy dog",
];
let mut model = Word2Vec::new()
.with_vector_size(10)
.with_window_size(2)
.with_min_count(1)
.with_epochs(1)
.with_algorithm(Word2VecAlgorithm::SkipGram);
let result = model.train(&texts);
assert!(result.is_ok());
let result = model.get_word_vector("fox");
assert!(result.is_ok());
let vec = result.expect("Operation failed");
assert_eq!(vec.len(), 10);
}
#[test]
fn test_huffman_tree_build() {
let frequencies = vec![5, 3, 8, 1, 2];
let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
assert_eq!(tree.codes.len(), 5);
assert_eq!(tree.paths.len(), 5);
for code in &tree.codes {
assert!(!code.is_empty());
}
assert_eq!(tree.num_internal, 4);
}
#[test]
fn test_huffman_tree_single_word() {
let frequencies = vec![10];
let tree = HuffmanTree::build(&frequencies).expect("Huffman build failed");
assert_eq!(tree.codes.len(), 1);
assert_eq!(tree.paths.len(), 1);
}
#[test]
fn test_skipgram_hierarchical_softmax() {
let texts = [
"the quick brown fox jumps over the lazy dog",
"a quick brown fox jumps over a lazy dog",
];
let config = Word2VecConfig {
vector_size: 10,
window_size: 2,
min_count: 1,
epochs: 3,
learning_rate: 0.025,
algorithm: Word2VecAlgorithm::SkipGram,
hierarchical_softmax: true,
..Default::default()
};
let mut model = Word2Vec::with_config(config);
let result = model.train(&texts);
assert!(
result.is_ok(),
"HS skipgram training failed: {:?}",
result.err()
);
assert!(model.uses_hierarchical_softmax());
let vec = model.get_word_vector("fox");
assert!(vec.is_ok());
assert_eq!(vec.expect("get vec").len(), 10);
}
#[test]
fn test_cbow_hierarchical_softmax() {
let texts = [
"the quick brown fox jumps over the lazy dog",
"a quick brown fox jumps over a lazy dog",
];
let config = Word2VecConfig {
vector_size: 10,
window_size: 2,
min_count: 1,
epochs: 3,
learning_rate: 0.025,
algorithm: Word2VecAlgorithm::CBOW,
hierarchical_softmax: true,
..Default::default()
};
let mut model = Word2Vec::with_config(config);
let result = model.train(&texts);
assert!(
result.is_ok(),
"HS CBOW training failed: {:?}",
result.err()
);
let vec = model.get_word_vector("dog");
assert!(vec.is_ok());
}
#[test]
fn test_word_embedding_trait_word2vec() {
let texts = [
"the quick brown fox jumps over the lazy dog",
"a quick brown fox jumps over a lazy dog",
];
let mut model = Word2Vec::new()
.with_vector_size(10)
.with_min_count(1)
.with_epochs(1);
model.train(&texts).expect("Training failed");
let emb: &dyn WordEmbedding = &model;
assert_eq!(emb.dimension(), 10);
assert!(emb.vocab_size() > 0);
let vec = emb.embedding("fox");
assert!(vec.is_ok());
let sim = emb.similarity("fox", "dog");
assert!(sim.is_ok());
assert!(sim.expect("sim").is_finite());
let similar = emb.find_similar("fox", 2);
assert!(similar.is_ok());
let analogy = emb.solve_analogy("the", "fox", "dog", 2);
assert!(analogy.is_ok());
}
#[test]
fn test_embedding_cosine_similarity_fn() {
let a = Array1::from_vec(vec![1.0, 0.0]);
let b = Array1::from_vec(vec![0.0, 1.0]);
assert!((embedding_cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
let c = Array1::from_vec(vec![1.0, 1.0]);
let d = Array1::from_vec(vec![1.0, 1.0]);
assert!((embedding_cosine_similarity(&c, &d) - 1.0).abs() < 1e-6);
}
#[test]
fn test_pairwise_similarity_fn() {
let texts = ["the quick brown fox", "the lazy brown dog"];
let mut model = Word2Vec::new()
.with_vector_size(10)
.with_min_count(1)
.with_epochs(1);
model.train(&texts).expect("Training failed");
let words = vec!["the", "fox", "dog"];
let matrix = pairwise_similarity(&model, &words).expect("pairwise failed");
assert_eq!(matrix.len(), 3);
assert_eq!(matrix[0].len(), 3);
for i in 0..3 {
assert!((matrix[i][i] - 1.0).abs() < 1e-6);
}
for i in 0..3 {
for j in 0..3 {
assert!((matrix[i][j] - matrix[j][i]).abs() < 1e-10);
}
}
}
}