1use indicatif::{ProgressBar, ProgressStyle};
2use rand::prelude::SliceRandom;
3use rand::rngs::StdRng;
4use rand::SeedableRng;
5use rand_distr::{Distribution, Uniform};
6use rayon::prelude::*;
7use std::io::IsTerminal;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::Instant;
11use thousands::*;
12
13use crate::cpu::matrix::{Matrix, MatrixWrapper};
14use crate::cpu::word2vec_model::Word2Vec;
15
16#[derive(Clone, Debug)]
33pub struct CpuTrainArgs {
34 pub dim: usize,
35 pub lr: f32,
36 pub epochs: usize,
37 pub neg: usize,
38 pub window: usize,
39 pub lr_update_rate: usize,
40 pub n_threads: usize,
41 pub verbose: bool,
42}
43
44fn skipgram(model: &mut Word2Vec, walk: &[u32], rng: &mut StdRng, window_dist: &Uniform<usize>) {
53 let length = walk.len();
54 for w in 0..length {
55 let bound = window_dist.sample(rng);
56 let start = w.saturating_sub(bound);
57 let end = (w + bound + 1).min(length);
58
59 for c in start..end {
60 if c != w {
61 model.update(walk[w] as usize, walk[c] as usize);
62 }
63 }
64 }
65}
66
67fn train_thread(
82 walks: &[Vec<u32>],
83 input: &Arc<MatrixWrapper>,
84 output: &Arc<MatrixWrapper>,
85 args: &CpuTrainArgs,
86 neg_table: &Arc<Vec<usize>>,
87 processed_tokens: &Arc<AtomicUsize>,
88 total_tokens: usize,
89 thread_id: usize,
90 seed: usize,
91 progress: Option<&ProgressBar>,
92) {
93 let mut rng = StdRng::seed_from_u64(seed as u64);
94 let window_dist = Uniform::new(1, args.window + 1).unwrap();
95
96 let input_ptr = input.inner.get();
98 let output_ptr = output.inner.get();
99
100 let neg_start = seed.wrapping_add(thread_id) as usize;
102
103 let mut model = unsafe {
104 Word2Vec::new(
105 &mut *input_ptr,
106 &mut *output_ptr,
107 args.dim,
108 args.lr,
109 args.neg,
110 neg_table.clone(),
111 neg_start,
112 )
113 };
114
115 let mut local_token_count = 0;
116
117 for walk in walks {
118 skipgram(&mut model, walk, &mut rng, &window_dist);
119 local_token_count += walk.len();
120
121 if local_token_count >= args.lr_update_rate {
123 let global_tokens = processed_tokens.fetch_add(local_token_count, Ordering::SeqCst);
124 let progress_ratio = global_tokens as f32 / total_tokens as f32;
125 let new_lr = args.lr * (1.0 - progress_ratio).max(0.0001);
126 model.set_lr(new_lr);
127
128 if thread_id == 0 {
129 if let Some(pb) = progress {
130 pb.set_position(global_tokens as u64);
131 pb.set_message(format!("{:.6}", new_lr));
132 }
133 }
134
135 local_token_count = 0;
136 }
137 }
138
139 processed_tokens.fetch_add(local_token_count, Ordering::SeqCst);
141}
142
143pub fn create_negative_table(
155 vocab_size: usize,
156 walks: &[Vec<u32>],
157 neg_table_size: usize,
158 seed: usize,
159) -> Arc<Vec<usize>> {
160 const NEG_POW: f64 = 0.75;
161
162 let mut counts = vec![0u32; vocab_size];
164 for walk in walks {
165 for &node in walk {
166 counts[node as usize] += 1;
167 }
168 }
169
170 let mut negative_table = Vec::new();
172 let mut z = 0.0;
173 for &count in &counts {
174 z += (count as f64).powf(NEG_POW);
175 }
176
177 for (idx, &count) in counts.iter().enumerate() {
178 if count > 0 {
179 let c = (count as f64).powf(NEG_POW);
180 let n_samples = (c * neg_table_size as f64 / z) as usize;
181 for _ in 0..n_samples {
182 negative_table.push(idx);
183 }
184 }
185 }
186
187 let mut rng = StdRng::seed_from_u64(seed as u64);
189 negative_table.shuffle(&mut rng);
190
191 Arc::new(negative_table)
192}
193
194pub fn train_node2vec_cpu(
212 mut walks: Vec<Vec<u32>>,
213 vocab_size: usize,
214 args: CpuTrainArgs,
215 neg_table: Arc<Vec<usize>>,
216 seed: usize,
217) -> (Matrix, Matrix) {
218 let mut input_mat = Matrix::new(vocab_size, args.dim);
220 let mut output_mat = Matrix::new(vocab_size, args.dim);
221
222 input_mat.uniform(1.0 / args.dim as f32, seed);
223 output_mat.zero();
224
225 let input = Arc::new(input_mat.make_send());
226 let output = Arc::new(output_mat.make_send());
227
228 let total_tokens: usize = walks.iter().map(|w| w.len()).sum();
230 let total_tokens_all_epochs = total_tokens * args.epochs;
231
232 let processed_tokens = Arc::new(AtomicUsize::new(0));
233
234 if args.verbose {
235 println!(
236 "Training on {} random walks ({} tokens per epoch, {} total)",
237 walks.len().separate_with_underscores(),
238 total_tokens.separate_with_underscores(),
239 total_tokens_all_epochs.separate_with_underscores()
240 );
241 }
242
243 let progress = if args.verbose && std::io::stdout().is_terminal() {
244 let pb = ProgressBar::new(total_tokens_all_epochs as u64);
245 pb.set_style(
246 ProgressStyle::default_bar()
247 .template("[Training: {elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} ({percent}%) | lr: {msg}")
248 .unwrap()
249 .progress_chars("#>-"),
250 );
251 Some(pb)
252 } else {
253 None
254 };
255
256 let start_time = Instant::now();
257
258 let walks_per_thread = (walks.len() + args.n_threads - 1) / args.n_threads;
260 let walk_chunks: Vec<Vec<Vec<u32>>> = walks
261 .chunks(walks_per_thread)
262 .map(|chunk| chunk.to_vec())
263 .collect();
264
265 for epoch in 0..args.epochs {
267 if args.verbose {
268 if progress.is_some() {
269 println!("\nEpoch {}/{}", epoch + 1, args.epochs);
270 } else {
271 println!(" Epoch {}/{}", epoch + 1, args.epochs);
272 }
273 }
274
275 let mut epoch_rng = StdRng::seed_from_u64(seed as u64 + epoch as u64);
277 walks.shuffle(&mut epoch_rng);
278
279 walk_chunks
281 .par_iter()
282 .enumerate()
283 .for_each(|(thread_id, thread_walks)| {
284 train_thread(
285 thread_walks,
286 &input,
287 &output,
288 &args,
289 &neg_table,
290 &processed_tokens,
291 total_tokens_all_epochs,
292 thread_id,
293 seed.wrapping_add(epoch * args.n_threads + thread_id),
294 progress.as_ref(),
295 );
296 });
297 }
298
299 if let Some(pb) = progress {
300 pb.finish_with_message(format!("lr: {:.6}", 0.0));
301 }
302
303 if args.verbose {
304 let elapsed = start_time.elapsed();
305 let tokens_per_sec = total_tokens_all_epochs as f64 / elapsed.as_secs_f64();
306 println!(
307 "\nTraining complete in {:.2}s ({:.0} tokens/sec)",
308 elapsed.as_secs_f64(),
309 tokens_per_sec
310 );
311 }
312
313 let input = Arc::try_unwrap(input)
315 .expect("Failed to unwrap input matrix")
316 .inner
317 .into_inner();
318 let output = Arc::try_unwrap(output)
319 .expect("Failed to unwrap output matrix")
320 .inner
321 .into_inner();
322
323 (input, output)
324}