1use indicatif::{ProgressBar, ProgressStyle};
2use rand::prelude::SliceRandom;
3use rand::rngs::StdRng;
4use rand::{Rng, SeedableRng};
5use rand_distr::{Distribution, Uniform};
6use rayon::prelude::*;
7use std::io::IsTerminal;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::Instant;
11use thousands::*;
12
13use crate::cpu::matrix::{Matrix, MatrixWrapper};
14use crate::cpu::word2vec_model::Word2Vec;
15
16#[derive(Clone, Debug)]
35pub struct CpuTrainArgs {
36 pub dim: usize,
37 pub lr: f32,
38 pub epochs: usize,
39 pub neg: usize,
40 pub window: usize,
41 pub lr_update_rate: usize,
42 pub n_threads: usize,
43 pub verbose: bool,
44 pub sample: f32,
45}
46
47fn skipgram(
56 model: &mut Word2Vec,
57 walk: &[u32],
58 rng: &mut StdRng,
59 window_dist: &Uniform<usize>,
60 keep_probs: &[f32],
61) {
62 let mut active_walk = Vec::with_capacity(walk.len());
64 for &node in walk {
65 let prob = keep_probs[node as usize];
66 if prob >= 1.0 || rng.random::<f32>() < prob {
67 active_walk.push(node);
68 }
69 }
70
71 let length = active_walk.len();
73 for w in 0..length {
74 let bound = window_dist.sample(rng);
75 let start = w.saturating_sub(bound);
76 let end = (w + bound + 1).min(length);
77
78 for c in start..end {
79 if c != w {
80 model.update(active_walk[w] as usize, active_walk[c] as usize);
81 }
82 }
83 }
84}
85
86#[allow(clippy::too_many_arguments)]
102fn train_thread(
103 walks: &[Vec<u32>],
104 input: &Arc<MatrixWrapper>,
105 output: &Arc<MatrixWrapper>,
106 args: &CpuTrainArgs,
107 neg_table: &Arc<Vec<usize>>,
108 keep_probs: &Arc<Vec<f32>>,
109 processed_tokens: &Arc<AtomicUsize>,
110 total_tokens: usize,
111 thread_id: usize,
112 seed: usize,
113 progress: Option<&ProgressBar>,
114) {
115 let mut rng = StdRng::seed_from_u64(seed as u64);
116 let window_dist = Uniform::new(1, args.window + 1).unwrap();
117
118 let input_ptr = input.inner.get();
119 let output_ptr = output.inner.get();
120 let neg_start = seed.wrapping_add(thread_id);
121
122 let mut model = unsafe {
123 Word2Vec::new(
124 &mut *input_ptr,
125 &mut *output_ptr,
126 args.dim,
127 args.lr,
128 args.neg,
129 neg_table.clone(),
130 neg_start,
131 )
132 };
133
134 let mut local_token_count = 0;
135
136 for walk in walks {
137 skipgram(&mut model, walk, &mut rng, &window_dist, keep_probs);
138
139 local_token_count += walk.len();
141
142 if local_token_count >= args.lr_update_rate {
143 let global_tokens = processed_tokens.fetch_add(local_token_count, Ordering::SeqCst);
144 let progress_ratio = global_tokens as f32 / total_tokens as f32;
145 let new_lr = args.lr * (1.0 - progress_ratio).max(0.0001);
146 model.set_lr(new_lr);
147
148 if thread_id == 0 {
149 if let Some(pb) = progress {
150 pb.set_position(global_tokens as u64);
151 pb.set_message(format!("{:.6}", new_lr));
152 }
153 }
154
155 local_token_count = 0;
156 }
157 }
158
159 processed_tokens.fetch_add(local_token_count, Ordering::SeqCst);
160}
161
162pub fn create_negative_table(
175 vocab_size: usize,
176 walks: &[Vec<u32>],
177 neg_table_size: usize,
178 seed: usize,
179) -> Arc<Vec<usize>> {
180 const NEG_POW: f64 = 0.75;
181
182 let mut counts = vec![0u32; vocab_size];
184 for walk in walks {
185 for &node in walk {
186 counts[node as usize] += 1;
187 }
188 }
189
190 let mut negative_table = Vec::new();
192 let mut z = 0.0;
193 for &count in &counts {
194 z += (count as f64).powf(NEG_POW);
195 }
196
197 for (idx, &count) in counts.iter().enumerate() {
198 if count > 0 {
199 let c = (count as f64).powf(NEG_POW);
200 let n_samples = (c * neg_table_size as f64 / z) as usize;
201 for _ in 0..n_samples {
202 negative_table.push(idx);
203 }
204 }
205 }
206
207 let mut rng = StdRng::seed_from_u64(seed as u64);
209 negative_table.shuffle(&mut rng);
210
211 Arc::new(negative_table)
212}
213
214pub fn train_node2vec_cpu(
232 mut walks: Vec<Vec<u32>>,
233 vocab_size: usize,
234 args: CpuTrainArgs,
235 neg_table: Arc<Vec<usize>>,
236 seed: usize,
237) -> (Matrix, Matrix) {
238 let mut input_mat = Matrix::new(vocab_size, args.dim);
239 let mut output_mat = Matrix::new(vocab_size, args.dim);
240
241 input_mat.uniform(1.0 / args.dim as f32, seed);
242 output_mat.zero();
243
244 let input = Arc::new(input_mat.make_send());
245 let output = Arc::new(output_mat.make_send());
246
247 let total_tokens: usize = walks.iter().map(|w| w.len()).sum();
248 let total_tokens_all_epochs = total_tokens * args.epochs;
249
250 let mut counts = vec![0usize; vocab_size];
251 for walk in &walks {
252 for &node in walk {
253 if (node as usize) < vocab_size {
254 counts[node as usize] += 1;
255 }
256 }
257 }
258
259 let mut keep_probs = vec![1.0f32; vocab_size];
260 if args.sample > 0.0 {
261 let total_f64 = total_tokens as f64;
262 for (node, &count) in counts.iter().enumerate() {
263 if count > 0 {
264 let freq = count as f64 / total_f64;
265 let keep_prob =
267 ((freq / args.sample as f64).sqrt() + 1.0) * (args.sample as f64 / freq);
268 keep_probs[node] = keep_prob.min(1.0) as f32;
269 }
270 }
271 }
272 let keep_probs = Arc::new(keep_probs);
273
274 let processed_tokens = Arc::new(AtomicUsize::new(0));
275
276 if args.verbose {
277 println!(
278 "Training on {} random walks ({} tokens per epoch, {} total)",
279 walks.len().separate_with_underscores(),
280 total_tokens.separate_with_underscores(),
281 total_tokens_all_epochs.separate_with_underscores()
282 );
283 }
284
285 let progress = if args.verbose && std::io::stdout().is_terminal() {
286 let pb = ProgressBar::new(total_tokens_all_epochs as u64);
287 pb.set_style(
288 ProgressStyle::default_bar()
289 .template("[Training: {elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} ({percent}%) | lr: {msg}")
290 .unwrap()
291 .progress_chars("#>-"),
292 );
293 Some(pb)
294 } else {
295 None
296 };
297
298 let start_time = Instant::now();
299
300 for epoch in 0..args.epochs {
301 if args.verbose {
302 if progress.is_some() {
303 println!("\nEpoch {}/{}", epoch + 1, args.epochs);
304 } else {
305 println!(" Epoch {}/{}", epoch + 1, args.epochs);
306 }
307 }
308
309 let mut epoch_rng = StdRng::seed_from_u64(seed as u64 + epoch as u64);
310 walks.shuffle(&mut epoch_rng);
311
312 let walks_per_thread = walks.len().div_ceil(args.n_threads);
313 let walk_chunks: Vec<&[Vec<u32>]> = walks.chunks(walks_per_thread).collect();
314
315 walk_chunks
316 .par_iter()
317 .enumerate()
318 .for_each(|(thread_id, thread_walks)| {
319 train_thread(
320 thread_walks,
321 &input,
322 &output,
323 &args,
324 &neg_table,
325 &keep_probs, &processed_tokens,
327 total_tokens_all_epochs,
328 thread_id,
329 seed.wrapping_add(epoch * args.n_threads + thread_id),
330 progress.as_ref(),
331 );
332 });
333 }
334
335 if let Some(pb) = progress {
336 pb.finish_with_message(format!("lr: {:.6}", 0.0));
337 }
338
339 if args.verbose {
340 let elapsed = start_time.elapsed();
341 let tokens_per_sec = total_tokens_all_epochs as f64 / elapsed.as_secs_f64();
342 println!(
343 "\nTraining complete in {:.2}s ({:.0} tokens/sec)",
344 elapsed.as_secs_f64(),
345 tokens_per_sec
346 );
347 }
348
349 let input = Arc::try_unwrap(input)
350 .expect("Failed to unwrap input matrix")
351 .inner
352 .into_inner();
353 let output = Arc::try_unwrap(output)
354 .expect("Failed to unwrap output matrix")
355 .inner
356 .into_inner();
357
358 (input, output)
359}