Skip to main content

node2vec_rs/cpu/
train.rs

1use indicatif::{ProgressBar, ProgressStyle};
2use rand::prelude::SliceRandom;
3use rand::rngs::StdRng;
4use rand::{Rng, SeedableRng};
5use rand_distr::{Distribution, Uniform};
6use rayon::prelude::*;
7use std::io::IsTerminal;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::Instant;
11use thousands::*;
12
13use crate::cpu::matrix::{Matrix, MatrixWrapper};
14use crate::cpu::word2vec_model::Word2Vec;
15
16/////////////
17// Helpers //
18/////////////
19
20/// Train arguments for CPU implementation.
21///
22/// ### Fields
23///
24/// * `dim` - Dimension of the embedding vectors.
25/// * `lr` - Learning rate.
26/// * `epochs` - Number of epochs to train.
27/// * `neg` - Number of negative samples to use.
28/// * `window` - Window size for context words.
29/// * `lr_update_rate` - Learning rate update rate.
30/// * `n_threads` - Number of threads to use.
31/// * `verbose` - Whether to print progress.
32/// * `sample` - Subsampling threshold; nodes with frequency above this are
33///   randomly dropped during training.
34#[derive(Clone, Debug)]
35pub struct CpuTrainArgs {
36    pub dim: usize,
37    pub lr: f32,
38    pub epochs: usize,
39    pub neg: usize,
40    pub window: usize,
41    pub lr_update_rate: usize,
42    pub n_threads: usize,
43    pub verbose: bool,
44    pub sample: f32,
45}
46
47/// Skipgram training on a single walk
48///
49/// ### Params
50///
51/// * `model` - The Word2Vec model
52/// * `walk` - The walk (sequence of node IDs)
53/// * `rng` - Random number generator
54/// * `window_dist` - Uniform distribution for sampling window size
55fn skipgram(
56    model: &mut Word2Vec,
57    walk: &[u32],
58    rng: &mut StdRng,
59    window_dist: &Uniform<usize>,
60    keep_probs: &[f32],
61) {
62    // subsample the walk: drop highly frequent nodes to bring rare nodes closer
63    let mut active_walk = Vec::with_capacity(walk.len());
64    for &node in walk {
65        let prob = keep_probs[node as usize];
66        if prob >= 1.0 || rng.random::<f32>() < prob {
67            active_walk.push(node);
68        }
69    }
70
71    // train on the active walk
72    let length = active_walk.len();
73    for w in 0..length {
74        let bound = window_dist.sample(rng);
75        let start = w.saturating_sub(bound);
76        let end = (w + bound + 1).min(length);
77
78        for c in start..end {
79            if c != w {
80                model.update(active_walk[w] as usize, active_walk[c] as usize);
81            }
82        }
83    }
84}
85
86/// Train the model on a single thread
87///
88/// ### Params
89///
90/// * `walks` - The walks (sequences of node IDs)
91/// * `input` - The input matrix
92/// * `output` - The output matrix
93/// * `args` - The training arguments
94/// * `neg_table` - The negative sampling table
95/// * `keep_probs` - Per-node probability of being kept during subsampling.
96/// * `processed_tokens` - The number of processed tokens
97/// * `total_tokens` - The total number of tokens
98/// * `thread_id` - The thread ID
99/// * `seed` - The random seed
100/// * `progress` - The progress bar
101#[allow(clippy::too_many_arguments)]
102fn train_thread(
103    walks: &[Vec<u32>],
104    input: &Arc<MatrixWrapper>,
105    output: &Arc<MatrixWrapper>,
106    args: &CpuTrainArgs,
107    neg_table: &Arc<Vec<usize>>,
108    keep_probs: &Arc<Vec<f32>>,
109    processed_tokens: &Arc<AtomicUsize>,
110    total_tokens: usize,
111    thread_id: usize,
112    seed: usize,
113    progress: Option<&ProgressBar>,
114) {
115    let mut rng = StdRng::seed_from_u64(seed as u64);
116    let window_dist = Uniform::new(1, args.window + 1).unwrap();
117
118    let input_ptr = input.inner.get();
119    let output_ptr = output.inner.get();
120    let neg_start = seed.wrapping_add(thread_id);
121
122    let mut model = unsafe {
123        Word2Vec::new(
124            &mut *input_ptr,
125            &mut *output_ptr,
126            args.dim,
127            args.lr,
128            args.neg,
129            neg_table.clone(),
130            neg_start,
131        )
132    };
133
134    let mut local_token_count = 0;
135
136    for walk in walks {
137        skipgram(&mut model, walk, &mut rng, &window_dist, keep_probs);
138
139        // Use original walk length to pace the learning rate decay consistently
140        local_token_count += walk.len();
141
142        if local_token_count >= args.lr_update_rate {
143            let global_tokens = processed_tokens.fetch_add(local_token_count, Ordering::SeqCst);
144            let progress_ratio = global_tokens as f32 / total_tokens as f32;
145            let new_lr = args.lr * (1.0 - progress_ratio).max(0.0001);
146            model.set_lr(new_lr);
147
148            if thread_id == 0 {
149                if let Some(pb) = progress {
150                    pb.set_position(global_tokens as u64);
151                    pb.set_message(format!("{:.6}", new_lr));
152                }
153            }
154
155            local_token_count = 0;
156        }
157    }
158
159    processed_tokens.fetch_add(local_token_count, Ordering::SeqCst);
160}
161
162/// Create negative sampling table
163///
164/// ### Params
165///
166/// * `vocab_size` - Number of unique nodes
167/// * `walks` - The walks to compute node frequencies from
168/// * `neg_table_size` - Size of negative sampling table
169/// * `seed` - Random seed for shuffling the negative sampling table.
170///
171/// ### Returns
172///
173/// Negative sampling table as Arc<Vec<usize>>
174pub fn create_negative_table(
175    vocab_size: usize,
176    walks: &[Vec<u32>],
177    neg_table_size: usize,
178    seed: usize,
179) -> Arc<Vec<usize>> {
180    const NEG_POW: f64 = 0.75;
181
182    // count node frequencies
183    let mut counts = vec![0u32; vocab_size];
184    for walk in walks {
185        for &node in walk {
186            counts[node as usize] += 1;
187        }
188    }
189
190    // build negative sampling table
191    let mut negative_table = Vec::new();
192    let mut z = 0.0;
193    for &count in &counts {
194        z += (count as f64).powf(NEG_POW);
195    }
196
197    for (idx, &count) in counts.iter().enumerate() {
198        if count > 0 {
199            let c = (count as f64).powf(NEG_POW);
200            let n_samples = (c * neg_table_size as f64 / z) as usize;
201            for _ in 0..n_samples {
202                negative_table.push(idx);
203            }
204        }
205    }
206
207    // shuffle
208    let mut rng = StdRng::seed_from_u64(seed as u64);
209    negative_table.shuffle(&mut rng);
210
211    Arc::new(negative_table)
212}
213
214//////////
215// Main //
216//////////
217
218/// Train node2vec model on generated walks
219///
220/// ### Params
221///
222/// * `walks` - The generated random walks
223/// * `vocab_size` - Size of vocabulary (number of unique nodes)
224/// * `args` - Training arguments
225/// * `neg_table` - Negative sampling table
226/// * `seed` - Random seed for reproducibility
227///
228/// ### Returns
229///
230/// Tuple of (input_matrix, output_matrix) containing learned embeddings
231pub fn train_node2vec_cpu(
232    mut walks: Vec<Vec<u32>>,
233    vocab_size: usize,
234    args: CpuTrainArgs,
235    neg_table: Arc<Vec<usize>>,
236    seed: usize,
237) -> (Matrix, Matrix) {
238    let mut input_mat = Matrix::new(vocab_size, args.dim);
239    let mut output_mat = Matrix::new(vocab_size, args.dim);
240
241    input_mat.uniform(1.0 / args.dim as f32, seed);
242    output_mat.zero();
243
244    let input = Arc::new(input_mat.make_send());
245    let output = Arc::new(output_mat.make_send());
246
247    let total_tokens: usize = walks.iter().map(|w| w.len()).sum();
248    let total_tokens_all_epochs = total_tokens * args.epochs;
249
250    let mut counts = vec![0usize; vocab_size];
251    for walk in &walks {
252        for &node in walk {
253            if (node as usize) < vocab_size {
254                counts[node as usize] += 1;
255            }
256        }
257    }
258
259    let mut keep_probs = vec![1.0f32; vocab_size];
260    if args.sample > 0.0 {
261        let total_f64 = total_tokens as f64;
262        for (node, &count) in counts.iter().enumerate() {
263            if count > 0 {
264                let freq = count as f64 / total_f64;
265                // Mikolov's subsampling formula
266                let keep_prob =
267                    ((freq / args.sample as f64).sqrt() + 1.0) * (args.sample as f64 / freq);
268                keep_probs[node] = keep_prob.min(1.0) as f32;
269            }
270        }
271    }
272    let keep_probs = Arc::new(keep_probs);
273
274    let processed_tokens = Arc::new(AtomicUsize::new(0));
275
276    if args.verbose {
277        println!(
278            "Training on {} random walks ({} tokens per epoch, {} total)",
279            walks.len().separate_with_underscores(),
280            total_tokens.separate_with_underscores(),
281            total_tokens_all_epochs.separate_with_underscores()
282        );
283    }
284
285    let progress = if args.verbose && std::io::stdout().is_terminal() {
286        let pb = ProgressBar::new(total_tokens_all_epochs as u64);
287        pb.set_style(
288                ProgressStyle::default_bar()
289                    .template("[Training: {elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} ({percent}%) | lr: {msg}")
290                    .unwrap()
291                    .progress_chars("#>-"),
292            );
293        Some(pb)
294    } else {
295        None
296    };
297
298    let start_time = Instant::now();
299
300    for epoch in 0..args.epochs {
301        if args.verbose {
302            if progress.is_some() {
303                println!("\nEpoch {}/{}", epoch + 1, args.epochs);
304            } else {
305                println!("  Epoch {}/{}", epoch + 1, args.epochs);
306            }
307        }
308
309        let mut epoch_rng = StdRng::seed_from_u64(seed as u64 + epoch as u64);
310        walks.shuffle(&mut epoch_rng);
311
312        let walks_per_thread = walks.len().div_ceil(args.n_threads);
313        let walk_chunks: Vec<&[Vec<u32>]> = walks.chunks(walks_per_thread).collect();
314
315        walk_chunks
316            .par_iter()
317            .enumerate()
318            .for_each(|(thread_id, thread_walks)| {
319                train_thread(
320                    thread_walks,
321                    &input,
322                    &output,
323                    &args,
324                    &neg_table,
325                    &keep_probs, // <-- Pass it here
326                    &processed_tokens,
327                    total_tokens_all_epochs,
328                    thread_id,
329                    seed.wrapping_add(epoch * args.n_threads + thread_id),
330                    progress.as_ref(),
331                );
332            });
333    }
334
335    if let Some(pb) = progress {
336        pb.finish_with_message(format!("lr: {:.6}", 0.0));
337    }
338
339    if args.verbose {
340        let elapsed = start_time.elapsed();
341        let tokens_per_sec = total_tokens_all_epochs as f64 / elapsed.as_secs_f64();
342        println!(
343            "\nTraining complete in {:.2}s ({:.0} tokens/sec)",
344            elapsed.as_secs_f64(),
345            tokens_per_sec
346        );
347    }
348
349    let input = Arc::try_unwrap(input)
350        .expect("Failed to unwrap input matrix")
351        .inner
352        .into_inner();
353    let output = Arc::try_unwrap(output)
354        .expect("Failed to unwrap output matrix")
355        .inner
356        .into_inner();
357
358    (input, output)
359}