1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
//! Multi-threaded word2vec trainer with Hogwild! algorithm
use crate::Result;
use crate::model::TrainingConfig;
use crate::skipgram::SkipGram;
use crate::vocab::Vocabulary;
use rand::Rng;
use rayon::prelude::*;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
/// Trainer for Word2Vec model
pub struct Trainer {
corpus_path: PathBuf,
vocab: Arc<Vocabulary>,
config: TrainingConfig,
}
impl Trainer {
/// Create a new trainer
pub fn new(corpus_path: &Path, vocab: Arc<Vocabulary>, config: &TrainingConfig) -> Self {
Self {
corpus_path: corpus_path.to_path_buf(),
vocab,
config: config.clone(),
}
}
/// Train the model using Hogwild! algorithm
///
/// Hogwild! is a lock-free parallel SGD algorithm where multiple threads
/// update shared parameters without locks. Small race conditions are acceptable
/// and don't affect convergence in practice.
///
/// Reference: "Hogwild!: A Lock-Free Approach to Parallelizing SGD" (NIPS 2011)
pub fn train(&mut self, syn0: &mut [f32], syn1neg: &mut [f32]) -> Result<()> {
let vocab_size = self.vocab.len();
let array_size = syn0.len(); // Actual array size: (max_word_id + 1) * vector_size
eprintln!("\nStarting training using file {:?}", self.corpus_path);
eprintln!("Vocab size: {}", vocab_size);
eprintln!("Words in train file: {}", self.vocab.total_words());
eprintln!("Parallelization: Hogwild! (lock-free)");
// Build negative sampling table (using remapped_ids for cache efficiency)
let mut skipgram = SkipGram::new();
let word_counts: Vec<(u32, u64)> = self
.vocab
.iter()
.map(|info| (info.remapped_id, info.count))
.collect();
skipgram.build_neg_table(&word_counts);
let skipgram = Arc::new(skipgram);
// Progress tracking - total across ALL epochs
let words_processed = Arc::new(AtomicU64::new(0));
let words_per_epoch = self.vocab.total_words();
let total_words_all_epochs = words_per_epoch * self.config.epochs as u64;
// Get raw pointers for Hogwild! updates
// SAFETY: We ensure memory is valid for the entire training duration
// Store as usize to make it Send (raw pointers are not Send)
let syn0_addr = syn0.as_mut_ptr() as usize;
let syn1neg_addr = syn1neg.as_mut_ptr() as usize;
let vector_size = self.config.vector_size;
// Load corpus once into memory (reuse across all epochs)
eprintln!("Loading corpus into memory...");
let sentences = self.load_corpus()?;
let total_sentences = sentences.len();
eprintln!("Loaded {} sentences", total_sentences);
// Train for multiple epochs
for epoch in 0..self.config.epochs {
eprintln!("\nEpoch {}/{}", epoch + 1, self.config.epochs);
// Calculate current alpha at start of this epoch
let epoch_start_words = epoch as u64 * words_per_epoch;
let current_alpha = self.config.alpha
- (self.config.alpha - self.config.min_alpha)
* (epoch_start_words as f32 / total_words_all_epochs as f32);
eprintln!(" Starting alpha: {:.6}", current_alpha);
eprintln!(" Processing {} sentences...", total_sentences);
// Process sentences in parallel (Hogwild!)
let chunk_size = (total_sentences / self.config.threads).max(1);
let sentence_chunks: Vec<_> = sentences.chunks(chunk_size).collect();
// SAFETY: Hogwild! algorithm
// Multiple threads write to syn0/syn1neg concurrently without locks.
// Race conditions create minor noise but don't affect convergence.
// This is the standard Word2Vec parallelization approach.
sentence_chunks.into_par_iter().for_each(|chunk| {
let mut rng = rand::rng();
// SAFETY: Reconstruct pointers from addresses in each thread
// The original memory is guaranteed to be valid for training duration
let syn0_ptr = syn0_addr as *mut f32;
let syn1neg_ptr = syn1neg_addr as *mut f32;
// Thread-local counter to reduce atomic operation frequency
let mut local_word_count = 0u64;
for sentence in chunk {
// Skip empty sentences
if sentence.is_empty() {
continue;
}
local_word_count += sentence.len() as u64;
// Process each word in sentence
for (pos, ¢er_id) in sentence.iter().enumerate() {
// Skip if not in vocab
if !self.vocab.contains(center_id) {
continue;
}
// Subsampling
if let Some(info) = self.vocab.get(center_id) {
if rng.random::<f32>() > info.sample_prob {
continue;
}
}
// Dynamic window size
let window = rng.random_range(1..=self.config.window_size);
// Train with context words - SAFETY: using Hogwild! algorithm
for offset in 1..=window {
// Left context
if pos >= offset {
let context_id = sentence[pos - offset];
if self.vocab.contains(context_id) {
unsafe {
self.train_word_pair_hogwild(
center_id,
context_id,
current_alpha,
syn0_ptr,
syn1neg_ptr,
vector_size,
array_size,
&skipgram,
&mut rng,
);
}
}
}
// Right context
if pos + offset < sentence.len() {
let context_id = sentence[pos + offset];
if self.vocab.contains(context_id) {
unsafe {
self.train_word_pair_hogwild(
center_id,
context_id,
current_alpha,
syn0_ptr,
syn1neg_ptr,
vector_size,
array_size,
&skipgram,
&mut rng,
);
}
}
}
}
}
}
// Batch update progress (once per thread chunk instead of per sentence)
if local_word_count > 0 {
let processed = words_processed.fetch_add(local_word_count, Ordering::Relaxed);
if processed % 100000 < local_word_count {
// Progress across all epochs
let progress = (processed as f32 / total_words_all_epochs as f32) * 100.0;
// Alpha decreases linearly across all epochs
let alpha = self.config.alpha
- (self.config.alpha - self.config.min_alpha)
* (processed as f32 / total_words_all_epochs as f32);
eprint!("\rAlpha: {:.6} Progress: {:.2}% ", alpha, progress);
}
}
});
eprintln!("\n Epoch {} complete", epoch + 1);
}
eprintln!("\nTraining complete!");
Ok(())
}
/// Train a single word pair using Hogwild! (lock-free)
///
/// Inlined skip-gram with direct pointer arithmetic - NO slice creation overhead.
/// This enables true lock-free parallelization.
///
/// SAFETY: This function is unsafe because it writes to shared memory
/// without synchronization. Caller must ensure:
/// 1. Pointers are valid
/// 2. Memory is large enough for all word_ids
/// 3. Concurrent access is acceptable (Hogwild! assumption)
#[inline]
#[allow(clippy::too_many_arguments, clippy::needless_range_loop)]
unsafe fn train_word_pair_hogwild(
&self,
center_id: u32,
context_id: u32,
alpha: f32,
syn0_ptr: *mut f32,
syn1neg_ptr: *mut f32,
vector_size: usize,
array_size: usize,
skipgram: &Arc<SkipGram>,
rng: &mut impl Rng,
) -> f32 {
// SAFETY: All pointer operations are wrapped in unsafe blocks
// Caller guarantees pointers are valid and memory is large enough
unsafe {
let mut loss = 0.0f32;
// Fast O(1) lookup: word_id → remapped_id
let center_remapped = match self.vocab.get_remapped_id(center_id) {
Some(id) => id,
None => return loss,
};
let context_remapped = match self.vocab.get_remapped_id(context_id) {
Some(id) => id,
None => return loss,
};
// Get center word vector pointer (dense indexing for cache efficiency!)
let l1 = center_remapped as usize * vector_size;
if l1 + vector_size > array_size {
return loss;
}
let center_vec = syn0_ptr.add(l1);
// Gradient accumulator
let mut neu1e = vec![0.0f32; vector_size];
// Positive sample (actual context word)
let label = 1.0f32;
let l2 = context_remapped as usize * vector_size;
if l2 + vector_size <= array_size {
let context_vec = syn1neg_ptr.add(l2);
// Dot product (direct pointer access)
let mut f = 0.0f32;
for i in 0..vector_size {
f += *center_vec.add(i) * *context_vec.add(i);
}
// Sigmoid function (inlined)
let sigmoid_f = if f > 6.0 {
1.0
} else if f < -6.0 {
0.0
} else {
1.0 / (1.0 + (-f).exp())
};
let g = (label - sigmoid_f) * alpha;
loss += if label > 0.5 {
-f.ln_1p()
} else {
-(1.0 - f).ln_1p()
};
// Update gradients (direct memory writes)
for i in 0..vector_size {
neu1e[i] += g * *context_vec.add(i);
*context_vec.add(i) += g * *center_vec.add(i);
}
}
// Negative samples
for _ in 0..self.config.negative_samples {
let neg_remapped = skipgram.sample_negative(rng);
// Skip if negative sample is same as context (compare remapped_ids)
if neg_remapped == context_remapped {
continue;
}
let label = 0.0f32;
let l2 = neg_remapped as usize * vector_size;
if l2 + vector_size > array_size {
continue;
}
let neg_vec = syn1neg_ptr.add(l2);
// Dot product
let mut f = 0.0f32;
for i in 0..vector_size {
f += *center_vec.add(i) * *neg_vec.add(i);
}
// Sigmoid
let sigmoid_f = if f > 6.0 {
1.0
} else if f < -6.0 {
0.0
} else {
1.0 / (1.0 + (-f).exp())
};
let g = (label - sigmoid_f) * alpha;
loss += if label > 0.5 {
-f.ln_1p()
} else {
-(1.0 - f).ln_1p()
};
// Update gradients
for i in 0..vector_size {
neu1e[i] += g * *neg_vec.add(i);
*neg_vec.add(i) += g * *center_vec.add(i);
}
}
// Update center word vector
for i in 0..vector_size {
*center_vec.add(i) += neu1e[i];
}
loss
}
}
/// Load corpus into memory
fn load_corpus(&self) -> Result<Vec<Vec<u32>>> {
let file = File::open(&self.corpus_path)?;
let reader = BufReader::new(file);
let mut sentences = Vec::new();
for line in reader.lines() {
let line = line?;
let sentence: Vec<u32> = line
.split_whitespace()
.filter_map(|token| token.parse::<u32>().ok())
.collect();
if !sentence.is_empty() {
sentences.push(sentence);
}
}
Ok(sentences)
}
}