1use crate::trainer::Trainer;
4use crate::vocab::Vocabulary;
5use crate::{Result, Word2VecError};
6use std::path::Path;
7use std::sync::Arc;
8
9pub struct Word2Vec {
11 config: TrainingConfig,
13 vocab: Arc<Vocabulary>,
15 pub syn0: Vec<f32>,
18 pub syn1neg: Vec<f32>,
21}
22
23#[derive(Debug, Clone)]
25pub struct TrainingConfig {
26 pub vector_size: usize,
28 pub window_size: usize,
30 pub negative_samples: usize,
32 pub min_count: u64,
34 pub sample: f64,
36 pub alpha: f32,
38 pub min_alpha: f32,
40 pub epochs: usize,
42 pub threads: usize,
44}
45
46impl Default for TrainingConfig {
47 fn default() -> Self {
48 Self {
49 vector_size: 100,
50 window_size: 5,
51 negative_samples: 5,
52 min_count: 10,
53 sample: 1e-4,
54 alpha: 0.025,
55 min_alpha: 0.0001,
56 epochs: 3,
57 threads: 8,
58 }
59 }
60}
61
62impl Word2Vec {
63 pub fn new(config: TrainingConfig, vocab: Vocabulary) -> Self {
65 let vocab_size = vocab.len();
66 let max_word_id = vocab.max_word_id();
67 let vector_size = config.vector_size;
68
69 let array_size = vocab_size * vector_size;
72
73 let mut syn0 = vec![0.0f32; array_size];
75 let syn1neg = vec![0.0f32; array_size];
76
77 use rand::Rng;
78 let mut rng = rand::rng();
79
80 for remapped_id in 0..vocab_size {
82 let offset = remapped_id * vector_size;
83 for i in 0..vector_size {
84 syn0[offset + i] = (rng.random::<f32>() - 0.5) / vector_size as f32;
85 }
86 }
87
88 eprintln!("Model initialized:");
91 eprintln!(" Vocab size (trained): {}", vocab_size);
92 eprintln!(" Max word_id (MeCab): {}", max_word_id);
93 eprintln!(
94 " Array size: {} elements ({} MB)",
95 array_size,
96 array_size * 4 / 1024 / 1024
97 );
98 eprintln!(" Indexing: DENSE (remapped IDs 0-{})", vocab_size - 1);
99
100 Self {
101 config,
102 vocab: Arc::new(vocab),
103 syn0,
104 syn1neg,
105 }
106 }
107
108 pub fn train_from_file<P: AsRef<Path>>(&mut self, corpus_path: P) -> Result<()> {
110 let mut trainer = Trainer::new(corpus_path.as_ref(), self.vocab.clone(), &self.config);
111
112 trainer.train(&mut self.syn0, &mut self.syn1neg)?;
113 Ok(())
114 }
115
116 pub fn save_text<P: AsRef<Path>>(&self, path: P) -> Result<()> {
118 crate::io::save_word2vec_text(
119 path,
120 &self.syn0,
121 self.vocab.as_ref(),
122 self.config.vector_size,
123 )
124 }
125
126 pub fn save_mcv1<P: AsRef<Path>>(&self, path: P, max_word_id: u32) -> Result<()> {
128 crate::io::save_mcv1_format(
129 path,
130 &self.syn0,
131 self.vocab.as_ref(),
132 self.config.vector_size,
133 max_word_id,
134 )
135 }
136
137 pub fn vocab(&self) -> &Vocabulary {
139 &self.vocab
140 }
141
142 pub fn config(&self) -> &TrainingConfig {
144 &self.config
145 }
146}
147
148#[derive(Default)]
150pub struct Word2VecBuilder {
151 config: TrainingConfig,
152}
153
154impl Word2VecBuilder {
155 pub fn new() -> Self {
157 Self::default()
158 }
159
160 pub fn vector_size(mut self, size: usize) -> Self {
162 self.config.vector_size = size;
163 self
164 }
165
166 pub fn window_size(mut self, size: usize) -> Self {
168 self.config.window_size = size;
169 self
170 }
171
172 pub fn negative_samples(mut self, n: usize) -> Self {
174 self.config.negative_samples = n;
175 self
176 }
177
178 pub fn min_count(mut self, count: u64) -> Self {
180 self.config.min_count = count;
181 self
182 }
183
184 pub fn sample(mut self, threshold: f64) -> Self {
186 self.config.sample = threshold;
187 self
188 }
189
190 pub fn alpha(mut self, alpha: f32) -> Self {
192 self.config.alpha = alpha;
193 self
194 }
195
196 pub fn min_alpha(mut self, alpha: f32) -> Self {
198 self.config.min_alpha = alpha;
199 self
200 }
201
202 pub fn epochs(mut self, epochs: usize) -> Self {
204 self.config.epochs = epochs;
205 self
206 }
207
208 pub fn threads(mut self, threads: usize) -> Self {
210 self.config.threads = threads;
211 self
212 }
213
214 pub fn build_from_corpus<P: AsRef<Path>>(self, corpus_path: P) -> Result<Word2Vec> {
216 let mut vocab = Vocabulary::new(self.config.min_count, self.config.sample);
217 vocab.build_from_file(&corpus_path)?;
218
219 if vocab.is_empty() {
220 return Err(Word2VecError::Vocabulary(
221 "Vocabulary is empty after filtering".to_string(),
222 ));
223 }
224
225 Ok(Word2Vec::new(self.config, vocab))
226 }
227}