mecrab_word2vec/
model.rs

1//! Word2Vec model structure and builder
2
3use crate::trainer::Trainer;
4use crate::vocab::Vocabulary;
5use crate::{Result, Word2VecError};
6use std::path::Path;
7use std::sync::Arc;
8
9/// Word2Vec model
10pub struct Word2Vec {
11    /// Model configuration
12    config: TrainingConfig,
13    /// Vocabulary
14    vocab: Arc<Vocabulary>,
15    /// Input embeddings (word vectors)
16    /// Shape: [vocab_size, vector_size]
17    pub syn0: Vec<f32>,
18    /// Output embeddings (context vectors for negative sampling)
19    /// Shape: [vocab_size, vector_size]
20    pub syn1neg: Vec<f32>,
21}
22
23/// Training configuration
24#[derive(Debug, Clone)]
25pub struct TrainingConfig {
26    /// Embedding vector size
27    pub vector_size: usize,
28    /// Context window size
29    pub window_size: usize,
30    /// Number of negative samples
31    pub negative_samples: usize,
32    /// Minimum word frequency
33    pub min_count: u64,
34    /// Subsampling threshold
35    pub sample: f64,
36    /// Initial learning rate
37    pub alpha: f32,
38    /// Minimum learning rate
39    pub min_alpha: f32,
40    /// Number of training epochs
41    pub epochs: usize,
42    /// Number of threads
43    pub threads: usize,
44}
45
46impl Default for TrainingConfig {
47    fn default() -> Self {
48        Self {
49            vector_size: 100,
50            window_size: 5,
51            negative_samples: 5,
52            min_count: 10,
53            sample: 1e-4,
54            alpha: 0.025,
55            min_alpha: 0.0001,
56            epochs: 3,
57            threads: 8,
58        }
59    }
60}
61
62impl Word2Vec {
63    /// Create a new Word2Vec model with given configuration
64    pub fn new(config: TrainingConfig, vocab: Vocabulary) -> Self {
65        let vocab_size = vocab.len();
66        let max_word_id = vocab.max_word_id();
67        let vector_size = config.vector_size;
68
69        // Use dense indexing: remapped_ids are 0-based and contiguous
70        // This is MUCH more cache-friendly than sparse word_id indexing
71        let array_size = vocab_size * vector_size;
72
73        // Initialize embeddings with small random values
74        let mut syn0 = vec![0.0f32; array_size];
75        let syn1neg = vec![0.0f32; array_size];
76
77        use rand::Rng;
78        let mut rng = rand::rng();
79
80        // Initialize all vectors (remapped_ids are dense 0..vocab_size-1)
81        for remapped_id in 0..vocab_size {
82            let offset = remapped_id * vector_size;
83            for i in 0..vector_size {
84                syn0[offset + i] = (rng.random::<f32>() - 0.5) / vector_size as f32;
85            }
86        }
87
88        // syn1neg initialized to zeros (common practice)
89
90        eprintln!("Model initialized:");
91        eprintln!("  Vocab size (trained): {}", vocab_size);
92        eprintln!("  Max word_id (MeCab): {}", max_word_id);
93        eprintln!(
94            "  Array size: {} elements ({} MB)",
95            array_size,
96            array_size * 4 / 1024 / 1024
97        );
98        eprintln!("  Indexing: DENSE (remapped IDs 0-{})", vocab_size - 1);
99
100        Self {
101            config,
102            vocab: Arc::new(vocab),
103            syn0,
104            syn1neg,
105        }
106    }
107
108    /// Train model from corpus file
109    pub fn train_from_file<P: AsRef<Path>>(&mut self, corpus_path: P) -> Result<()> {
110        let mut trainer = Trainer::new(corpus_path.as_ref(), self.vocab.clone(), &self.config);
111
112        trainer.train(&mut self.syn0, &mut self.syn1neg)?;
113        Ok(())
114    }
115
116    /// Save embeddings in word2vec text format
117    pub fn save_text<P: AsRef<Path>>(&self, path: P) -> Result<()> {
118        crate::io::save_word2vec_text(
119            path,
120            &self.syn0,
121            self.vocab.as_ref(),
122            self.config.vector_size,
123        )
124    }
125
126    /// Save embeddings in MCV1 binary format
127    pub fn save_mcv1<P: AsRef<Path>>(&self, path: P, max_word_id: u32) -> Result<()> {
128        crate::io::save_mcv1_format(
129            path,
130            &self.syn0,
131            self.vocab.as_ref(),
132            self.config.vector_size,
133            max_word_id,
134        )
135    }
136
137    /// Get vocabulary
138    pub fn vocab(&self) -> &Vocabulary {
139        &self.vocab
140    }
141
142    /// Get configuration
143    pub fn config(&self) -> &TrainingConfig {
144        &self.config
145    }
146}
147
148/// Builder for Word2Vec model
149#[derive(Default)]
150pub struct Word2VecBuilder {
151    config: TrainingConfig,
152}
153
154impl Word2VecBuilder {
155    /// Create a new builder with default configuration
156    pub fn new() -> Self {
157        Self::default()
158    }
159
160    /// Set vector size (default: 100)
161    pub fn vector_size(mut self, size: usize) -> Self {
162        self.config.vector_size = size;
163        self
164    }
165
166    /// Set window size (default: 5)
167    pub fn window_size(mut self, size: usize) -> Self {
168        self.config.window_size = size;
169        self
170    }
171
172    /// Set number of negative samples (default: 5)
173    pub fn negative_samples(mut self, n: usize) -> Self {
174        self.config.negative_samples = n;
175        self
176    }
177
178    /// Set minimum word count (default: 10)
179    pub fn min_count(mut self, count: u64) -> Self {
180        self.config.min_count = count;
181        self
182    }
183
184    /// Set subsampling threshold (default: 1e-4)
185    pub fn sample(mut self, threshold: f64) -> Self {
186        self.config.sample = threshold;
187        self
188    }
189
190    /// Set initial learning rate (default: 0.025)
191    pub fn alpha(mut self, alpha: f32) -> Self {
192        self.config.alpha = alpha;
193        self
194    }
195
196    /// Set minimum learning rate (default: 0.0001)
197    pub fn min_alpha(mut self, alpha: f32) -> Self {
198        self.config.min_alpha = alpha;
199        self
200    }
201
202    /// Set number of epochs (default: 3)
203    pub fn epochs(mut self, epochs: usize) -> Self {
204        self.config.epochs = epochs;
205        self
206    }
207
208    /// Set number of threads (default: 8)
209    pub fn threads(mut self, threads: usize) -> Self {
210        self.config.threads = threads;
211        self
212    }
213
214    /// Build vocabulary from corpus and create model
215    pub fn build_from_corpus<P: AsRef<Path>>(self, corpus_path: P) -> Result<Word2Vec> {
216        let mut vocab = Vocabulary::new(self.config.min_count, self.config.sample);
217        vocab.build_from_file(&corpus_path)?;
218
219        if vocab.is_empty() {
220            return Err(Word2VecError::Vocabulary(
221                "Vocabulary is empty after filtering".to_string(),
222            ));
223        }
224
225        Ok(Word2Vec::new(self.config, vocab))
226    }
227}