use crate::trainer::Trainer;
use crate::vocab::Vocabulary;
use crate::{Result, Word2VecError};
use std::path::Path;
use std::sync::Arc;
pub struct Word2Vec {
config: TrainingConfig,
vocab: Arc<Vocabulary>,
pub syn0: Vec<f32>,
pub syn1neg: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub vector_size: usize,
pub window_size: usize,
pub negative_samples: usize,
pub min_count: u64,
pub sample: f64,
pub alpha: f32,
pub min_alpha: f32,
pub epochs: usize,
pub threads: usize,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
vector_size: 100,
window_size: 5,
negative_samples: 5,
min_count: 10,
sample: 1e-4,
alpha: 0.025,
min_alpha: 0.0001,
epochs: 3,
threads: 8,
}
}
}
impl Word2Vec {
pub fn new(config: TrainingConfig, vocab: Vocabulary) -> Self {
let vocab_size = vocab.len();
let max_word_id = vocab.max_word_id();
let vector_size = config.vector_size;
let array_size = vocab_size * vector_size;
let mut syn0 = vec![0.0f32; array_size];
let syn1neg = vec![0.0f32; array_size];
use rand::Rng;
let mut rng = rand::rng();
for remapped_id in 0..vocab_size {
let offset = remapped_id * vector_size;
for i in 0..vector_size {
syn0[offset + i] = (rng.random::<f32>() - 0.5) / vector_size as f32;
}
}
eprintln!("Model initialized:");
eprintln!(" Vocab size (trained): {}", vocab_size);
eprintln!(" Max word_id (MeCab): {}", max_word_id);
eprintln!(
" Array size: {} elements ({} MB)",
array_size,
array_size * 4 / 1024 / 1024
);
eprintln!(" Indexing: DENSE (remapped IDs 0-{})", vocab_size - 1);
Self {
config,
vocab: Arc::new(vocab),
syn0,
syn1neg,
}
}
pub fn train_from_file<P: AsRef<Path>>(&mut self, corpus_path: P) -> Result<()> {
let mut trainer = Trainer::new(corpus_path.as_ref(), self.vocab.clone(), &self.config);
trainer.train(&mut self.syn0, &mut self.syn1neg)?;
Ok(())
}
pub fn save_text<P: AsRef<Path>>(&self, path: P) -> Result<()> {
crate::io::save_word2vec_text(
path,
&self.syn0,
self.vocab.as_ref(),
self.config.vector_size,
)
}
pub fn save_mcv1<P: AsRef<Path>>(&self, path: P, max_word_id: u32) -> Result<()> {
crate::io::save_mcv1_format(
path,
&self.syn0,
self.vocab.as_ref(),
self.config.vector_size,
max_word_id,
)
}
pub fn vocab(&self) -> &Vocabulary {
&self.vocab
}
pub fn config(&self) -> &TrainingConfig {
&self.config
}
}
#[derive(Default)]
pub struct Word2VecBuilder {
config: TrainingConfig,
}
impl Word2VecBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn vector_size(mut self, size: usize) -> Self {
self.config.vector_size = size;
self
}
pub fn window_size(mut self, size: usize) -> Self {
self.config.window_size = size;
self
}
pub fn negative_samples(mut self, n: usize) -> Self {
self.config.negative_samples = n;
self
}
pub fn min_count(mut self, count: u64) -> Self {
self.config.min_count = count;
self
}
pub fn sample(mut self, threshold: f64) -> Self {
self.config.sample = threshold;
self
}
pub fn alpha(mut self, alpha: f32) -> Self {
self.config.alpha = alpha;
self
}
pub fn min_alpha(mut self, alpha: f32) -> Self {
self.config.min_alpha = alpha;
self
}
pub fn epochs(mut self, epochs: usize) -> Self {
self.config.epochs = epochs;
self
}
pub fn threads(mut self, threads: usize) -> Self {
self.config.threads = threads;
self
}
pub fn build_from_corpus<P: AsRef<Path>>(self, corpus_path: P) -> Result<Word2Vec> {
let mut vocab = Vocabulary::new(self.config.min_count, self.config.sample);
vocab.build_from_file(&corpus_path)?;
if vocab.is_empty() {
return Err(Word2VecError::Vocabulary(
"Vocabulary is empty after filtering".to_string(),
));
}
Ok(Word2Vec::new(self.config, vocab))
}
}