1use indicatif::{ProgressBar, ProgressStyle};
2use rand::rngs::StdRng;
3use rand::{Rng, SeedableRng};
4use rayon::prelude::*;
5use rustc_hash::{FxHashMap, FxHashSet};
6
7pub fn compute_transition_prob(
25 adjacency: &FxHashMap<u32, Vec<(u32, f32)>>,
26 p: f32,
27 q: f32,
28) -> FxHashMap<(u32, u32), Vec<(u32, f32)>> {
29 let neighbours_set: FxHashMap<u32, FxHashSet<u32>> = adjacency
30 .iter()
31 .map(|(node, edges)| (*node, edges.iter().map(|(n, _)| *n).collect()))
32 .collect();
33
34 adjacency
35 .par_iter()
36 .flat_map(|(curr_node, curr_edges)| {
37 curr_edges.par_iter().filter_map(|(prev_node, _)| {
38 adjacency.get(curr_node).map(|next_edges| {
39 let mut probs = Vec::new();
40 let mut total = 0_f32;
41
42 for (next_node, weight) in next_edges.iter() {
43 let unnorm_prob = if next_node == prev_node {
44 weight / p
45 } else if neighbours_set
46 .get(prev_node)
47 .map(|s| s.contains(next_node))
48 .unwrap_or(false)
49 {
50 *weight
51 } else {
52 weight / q
53 };
54 total += unnorm_prob;
55 probs.push((*next_node, unnorm_prob));
56 }
57
58 let mut cumulative = 0.0;
59 let normalised: Vec<(u32, f32)> = probs
60 .into_iter()
61 .map(|(node, prob)| {
62 cumulative += prob / total;
63 (node, cumulative)
64 })
65 .collect();
66
67 ((*prev_node, *curr_node), normalised)
68 })
69 })
70 })
71 .collect()
72}
73
74#[derive(Debug, Clone)]
85pub struct Node2VecGraph {
86 pub adjacency: FxHashMap<u32, Vec<(u32, f32)>>,
87 pub transition_probs: FxHashMap<(u32, u32), Vec<(u32, f32)>>,
88}
89
90impl Node2VecGraph {
91 pub fn generate_walks(
102 &self,
103 walks_per_node: usize,
104 walk_length: usize,
105 seed: usize,
106 ) -> Vec<Vec<u32>> {
107 let total_walks = self.adjacency.len() * walks_per_node;
108 let progress = ProgressBar::new(total_walks as u64);
109
110 progress.set_style(
111 ProgressStyle::default_bar()
112 .template("[Random walk gen: {elapsed_precise}] {bar:40.cyan/blue} {pos}/{len}")
113 .unwrap()
114 .progress_chars("#>-"),
115 );
116
117 let walks = self
118 .adjacency
119 .par_iter()
120 .flat_map(|(start_node, _)| {
121 let progress = progress.clone();
122 (0..walks_per_node).into_par_iter().map(move |walk_idx| {
123 let mut walk_seed = seed as u64;
125 walk_seed ^= (*start_node as u64).wrapping_mul(0x517cc1b727220a95);
126 walk_seed ^= (walk_idx as u64).wrapping_mul(0xcbf29ce484222325);
127 walk_seed = walk_seed.wrapping_mul(0x4f1bbcdcbfa54c43); let mut rng = StdRng::seed_from_u64(walk_seed);
130 let walk = self.single_walk(*start_node, walk_length, &mut rng);
131 progress.inc(1);
132 walk
133 })
134 })
135 .collect();
136
137 progress.finish();
138 walks
139 }
140
141 fn single_walk(&self, start_node: u32, walk_length: usize, rng: &mut StdRng) -> Vec<u32> {
153 let mut walk: Vec<u32> = Vec::with_capacity(walk_length);
154 walk.push(start_node);
155
156 if walk_length == 1 {
157 return walk;
158 }
159
160 let mut curr = if let Some(neighbours) = self.adjacency.get(&start_node) {
161 self.sample_neighbor(neighbours, rng)
162 } else {
163 return walk;
164 };
165
166 walk.push(curr);
167
168 for _ in 2..walk_length {
169 let prev = walk[walk.len() - 2];
170
171 if let Some(probs) = self.transition_probs.get(&(prev, curr)) {
172 curr = self.sample_from_cumulative(probs, rng);
173 walk.push(curr);
174 } else {
175 break;
176 }
177 }
178
179 walk
180 }
181
182 fn sample_neighbor(&self, neighbours: &[(u32, f32)], rng: &mut impl Rng) -> u32 {
193 let total: f32 = neighbours.iter().map(|(_, w)| w).sum();
194 let mut rand_val = rng.random::<f32>() * total;
195
196 for (node, weight) in neighbours {
197 rand_val -= weight;
198 if rand_val <= 0.0 {
199 return *node;
200 }
201 }
202
203 neighbours[0].0
204 }
205
206 fn sample_from_cumulative(&self, cumulative: &[(u32, f32)], rng: &mut impl Rng) -> u32 {
218 let rand_val = rng.random::<f32>();
219
220 match cumulative.binary_search_by(|(_, cum_prob)| cum_prob.partial_cmp(&rand_val).unwrap())
221 {
222 Ok(idx) => cumulative[idx].0,
223 Err(idx) => {
224 if idx < cumulative.len() {
225 cumulative[idx].0
226 } else {
227 cumulative[cumulative.len() - 1].0
228 }
229 }
230 }
231 }
232}
233
234#[cfg(test)]
235mod graph_tests {
236 use crate::prelude::*;
237 use rustc_hash::FxHashMap;
238
239 #[test]
240 fn test_transition_probs_sum_to_one() {
241 let mut adjacency = FxHashMap::default();
242 adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
243 adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
244 adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
245
246 let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
247
248 for (_, cumulative_probs) in probs.iter() {
250 let last_prob = cumulative_probs.last().unwrap().1;
251 assert!((last_prob - 1.0).abs() < 1e-6);
252 }
253 }
254
255 #[test]
256 fn test_transition_probs_p_parameter() {
257 let mut adjacency = FxHashMap::default();
259 adjacency.insert(1, vec![(2, 1.0)]);
260 adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
261 adjacency.insert(3, vec![(2, 1.0)]);
262
263 let probs_low_p = compute_transition_prob(&adjacency, 0.5, 1.0);
265 let probs_high_p = compute_transition_prob(&adjacency, 2.0, 1.0);
267
268 let low_p_return = &probs_low_p[&(1, 2)];
270 let high_p_return = &probs_high_p[&(1, 2)];
271
272 let low_p_val = low_p_return.iter().find(|(n, _)| *n == 1).unwrap().1;
274 let high_p_val = high_p_return.iter().find(|(n, _)| *n == 1).unwrap().1;
275
276 assert!(low_p_val > high_p_val);
278 }
279
280 #[test]
281 fn test_transition_probs_q_parameter() {
282 let mut adjacency = FxHashMap::default();
284 adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
285 adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
286 adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
287
288 let probs_low_q = compute_transition_prob(&adjacency, 1.0, 0.5);
290 let probs_high_q = compute_transition_prob(&adjacency, 1.0, 2.0);
292
293 for probs in [&probs_low_q, &probs_high_q] {
295 for cumulative_probs in probs.values() {
296 let last = cumulative_probs.last().unwrap().1;
297 assert!((last - 1.0).abs() < 1e-6);
298 }
299 }
300 }
301
302 #[test]
303 fn test_transition_probs_uniform_with_p_q_one() {
304 let mut adjacency = FxHashMap::default();
305 adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
306 adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
307 adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
308
309 let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
310
311 for cumulative_probs in probs.values() {
313 if cumulative_probs.len() == 2 {
314 let first_prob = cumulative_probs[0].1;
316 assert!((first_prob - 0.5).abs() < 1e-6);
317 }
318 }
319 }
320
321 #[test]
322 fn test_isolated_node() {
323 let mut adjacency = FxHashMap::default();
324 adjacency.insert(1, vec![]);
325
326 let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
327
328 assert!(probs.is_empty());
330 }
331
332 #[test]
333 fn test_weighted_transitions() {
334 let mut adjacency = FxHashMap::default();
335 adjacency.insert(1, vec![(2, 3.0)]);
336 adjacency.insert(2, vec![(1, 3.0), (3, 1.0)]);
337 adjacency.insert(3, vec![(2, 1.0)]);
338
339 let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
340
341 let probs_at_2 = &probs[&(1, 2)];
343 let prob_to_1 = probs_at_2.iter().find(|(n, _)| *n == 1).unwrap().1;
344 let prob_to_3 = probs_at_2.iter().find(|(n, _)| *n == 3).unwrap().1;
345
346 assert!(prob_to_1 < prob_to_3); assert!((prob_to_3 - 1.0).abs() < 1e-6); }
349}