Skip to main content

node2vec_rs/cpu/
train.rs

1use indicatif::{ProgressBar, ProgressStyle};
2use rand::prelude::SliceRandom;
3use rand::rngs::StdRng;
4use rand::SeedableRng;
5use rand_distr::{Distribution, Uniform};
6use rayon::prelude::*;
7use std::io::IsTerminal;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::Instant;
11use thousands::*;
12
13use crate::cpu::matrix::{Matrix, MatrixWrapper};
14use crate::cpu::word2vec_model::Word2Vec;
15
16/////////////
17// Helpers //
18/////////////
19
20/// Train arguments for CPU implementation.
21///
22/// ### Fields
23///
24/// * `dim` - Dimension of the embedding vectors.
25/// * `lr` - Learning rate.
26/// * `epochs` - Number of epochs to train.
27/// * `neg` - Number of negative samples to use.
28/// * `window` - Window size for context words.
29/// * `lr_update_rate` - Learning rate update rate.
30/// * `n_threads` - Number of threads to use.
31/// * `verbose` - Whether to print progress.
32#[derive(Clone, Debug)]
33pub struct CpuTrainArgs {
34    pub dim: usize,
35    pub lr: f32,
36    pub epochs: usize,
37    pub neg: usize,
38    pub window: usize,
39    pub lr_update_rate: usize,
40    pub n_threads: usize,
41    pub verbose: bool,
42}
43
44/// Skipgram training on a single walk
45///
46/// ### Params
47///
48/// * `model` - The Word2Vec model
49/// * `walk` - The walk (sequence of node IDs)
50/// * `rng` - Random number generator
51/// * `window_dist` - Uniform distribution for sampling window size
52fn skipgram(model: &mut Word2Vec, walk: &[u32], rng: &mut StdRng, window_dist: &Uniform<usize>) {
53    let length = walk.len();
54    for w in 0..length {
55        let bound = window_dist.sample(rng);
56        let start = w.saturating_sub(bound);
57        let end = (w + bound + 1).min(length);
58
59        for c in start..end {
60            if c != w {
61                model.update(walk[w] as usize, walk[c] as usize);
62            }
63        }
64    }
65}
66
67/// Train the model on a single thread
68///
69/// ### Params
70///
71/// * `walks` - The walks (sequences of node IDs)
72/// * `input` - The input matrix
73/// * `output` - The output matrix
74/// * `args` - The training arguments
75/// * `neg_table` - The negative sampling table
76/// * `processed_tokens` - The number of processed tokens
77/// * `total_tokens` - The total number of tokens
78/// * `thread_id` - The thread ID
79/// * `seed` - The random seed
80/// * `progress` - The progress bar
81fn train_thread(
82    walks: &[Vec<u32>],
83    input: &Arc<MatrixWrapper>,
84    output: &Arc<MatrixWrapper>,
85    args: &CpuTrainArgs,
86    neg_table: &Arc<Vec<usize>>,
87    processed_tokens: &Arc<AtomicUsize>,
88    total_tokens: usize,
89    thread_id: usize,
90    seed: usize,
91    progress: Option<&ProgressBar>,
92) {
93    let mut rng = StdRng::seed_from_u64(seed as u64);
94    let window_dist = Uniform::new(1, args.window + 1).unwrap();
95
96    // get mutable access to matrices
97    let input_ptr = input.inner.get();
98    let output_ptr = output.inner.get();
99
100    // seed
101    let neg_start = seed.wrapping_add(thread_id) as usize;
102
103    let mut model = unsafe {
104        Word2Vec::new(
105            &mut *input_ptr,
106            &mut *output_ptr,
107            args.dim,
108            args.lr,
109            args.neg,
110            neg_table.clone(),
111            neg_start,
112        )
113    };
114
115    let mut local_token_count = 0;
116
117    for walk in walks {
118        skipgram(&mut model, walk, &mut rng, &window_dist);
119        local_token_count += walk.len();
120
121        // update learning rate periodically
122        if local_token_count >= args.lr_update_rate {
123            let global_tokens = processed_tokens.fetch_add(local_token_count, Ordering::SeqCst);
124            let progress_ratio = global_tokens as f32 / total_tokens as f32;
125            let new_lr = args.lr * (1.0 - progress_ratio).max(0.0001);
126            model.set_lr(new_lr);
127
128            if thread_id == 0 {
129                if let Some(pb) = progress {
130                    pb.set_position(global_tokens as u64);
131                    pb.set_message(format!("{:.6}", new_lr));
132                }
133            }
134
135            local_token_count = 0;
136        }
137    }
138
139    // final update
140    processed_tokens.fetch_add(local_token_count, Ordering::SeqCst);
141}
142
143/// Create negative sampling table
144///
145/// ### Params
146///
147/// * `vocab_size` - Number of unique nodes
148/// * `walks` - The walks to compute node frequencies from
149/// * `neg_table_size` - Size of negative sampling table
150///
151/// ### Returns
152///
153/// Negative sampling table as Arc<Vec<usize>>
154pub fn create_negative_table(
155    vocab_size: usize,
156    walks: &[Vec<u32>],
157    neg_table_size: usize,
158    seed: usize,
159) -> Arc<Vec<usize>> {
160    const NEG_POW: f64 = 0.75;
161
162    // count node frequencies
163    let mut counts = vec![0u32; vocab_size];
164    for walk in walks {
165        for &node in walk {
166            counts[node as usize] += 1;
167        }
168    }
169
170    // build negative sampling table
171    let mut negative_table = Vec::new();
172    let mut z = 0.0;
173    for &count in &counts {
174        z += (count as f64).powf(NEG_POW);
175    }
176
177    for (idx, &count) in counts.iter().enumerate() {
178        if count > 0 {
179            let c = (count as f64).powf(NEG_POW);
180            let n_samples = (c * neg_table_size as f64 / z) as usize;
181            for _ in 0..n_samples {
182                negative_table.push(idx);
183            }
184        }
185    }
186
187    // shuffle
188    let mut rng = StdRng::seed_from_u64(seed as u64);
189    negative_table.shuffle(&mut rng);
190
191    Arc::new(negative_table)
192}
193
194//////////
195// Main //
196//////////
197
198/// Train node2vec model on generated walks
199///
200/// ### Params
201///
202/// * `walks` - The generated random walks
203/// * `vocab_size` - Size of vocabulary (number of unique nodes)
204/// * `args` - Training arguments
205/// * `neg_table` - Negative sampling table
206/// * `seed` - Random seed for reproducibility
207///
208/// ### Returns
209///
210/// Tuple of (input_matrix, output_matrix) containing learned embeddings
211pub fn train_node2vec_cpu(
212    mut walks: Vec<Vec<u32>>,
213    vocab_size: usize,
214    args: CpuTrainArgs,
215    neg_table: Arc<Vec<usize>>,
216    seed: usize,
217) -> (Matrix, Matrix) {
218    // Initialise matrices
219    let mut input_mat = Matrix::new(vocab_size, args.dim);
220    let mut output_mat = Matrix::new(vocab_size, args.dim);
221
222    input_mat.uniform(1.0 / args.dim as f32, seed);
223    output_mat.zero();
224
225    let input = Arc::new(input_mat.make_send());
226    let output = Arc::new(output_mat.make_send());
227
228    // calculate total tokens for progress tracking
229    let total_tokens: usize = walks.iter().map(|w| w.len()).sum();
230    let total_tokens_all_epochs = total_tokens * args.epochs;
231
232    let processed_tokens = Arc::new(AtomicUsize::new(0));
233
234    if args.verbose {
235        println!(
236            "Training on {} random walks ({} tokens per epoch, {} total)",
237            walks.len().separate_with_underscores(),
238            total_tokens.separate_with_underscores(),
239            total_tokens_all_epochs.separate_with_underscores()
240        );
241    }
242
243    let progress = if args.verbose && std::io::stdout().is_terminal() {
244        let pb = ProgressBar::new(total_tokens_all_epochs as u64);
245        pb.set_style(
246                ProgressStyle::default_bar()
247                    .template("[Training: {elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} ({percent}%) | lr: {msg}")
248                    .unwrap()
249                    .progress_chars("#>-"),
250            );
251        Some(pb)
252    } else {
253        None
254    };
255
256    let start_time = Instant::now();
257
258    // split walks across threads
259    let walks_per_thread = (walks.len() + args.n_threads - 1) / args.n_threads;
260    let walk_chunks: Vec<Vec<Vec<u32>>> = walks
261        .chunks(walks_per_thread)
262        .map(|chunk| chunk.to_vec())
263        .collect();
264
265    // train for multiple epochs
266    for epoch in 0..args.epochs {
267        if args.verbose {
268            if progress.is_some() {
269                println!("\nEpoch {}/{}", epoch + 1, args.epochs);
270            } else {
271                println!("  Epoch {}/{}", epoch + 1, args.epochs);
272            }
273        }
274
275        // Shuffle walks at the start of each epoch
276        let mut epoch_rng = StdRng::seed_from_u64(seed as u64 + epoch as u64);
277        walks.shuffle(&mut epoch_rng);
278
279        // Parallel training across threads
280        walk_chunks
281            .par_iter()
282            .enumerate()
283            .for_each(|(thread_id, thread_walks)| {
284                train_thread(
285                    thread_walks,
286                    &input,
287                    &output,
288                    &args,
289                    &neg_table,
290                    &processed_tokens,
291                    total_tokens_all_epochs,
292                    thread_id,
293                    seed.wrapping_add(epoch * args.n_threads + thread_id),
294                    progress.as_ref(),
295                );
296            });
297    }
298
299    if let Some(pb) = progress {
300        pb.finish_with_message(format!("lr: {:.6}", 0.0));
301    }
302
303    if args.verbose {
304        let elapsed = start_time.elapsed();
305        let tokens_per_sec = total_tokens_all_epochs as f64 / elapsed.as_secs_f64();
306        println!(
307            "\nTraining complete in {:.2}s ({:.0} tokens/sec)",
308            elapsed.as_secs_f64(),
309            tokens_per_sec
310        );
311    }
312
313    // Unwrap matrices
314    let input = Arc::try_unwrap(input)
315        .expect("Failed to unwrap input matrix")
316        .inner
317        .into_inner();
318    let output = Arc::try_unwrap(output)
319        .expect("Failed to unwrap output matrix")
320        .inner
321        .into_inner();
322
323    (input, output)
324}