1use indicatif::{ProgressBar, ProgressStyle};
2use rand::rngs::StdRng;
3use rand::{Rng, SeedableRng};
4use rayon::prelude::*;
5use rustc_hash::{FxHashMap, FxHashSet};
6
7pub fn compute_transition_prob(
25 adjacency: &FxHashMap<u32, Vec<(u32, f32)>>,
26 p: f32,
27 q: f32,
28) -> FxHashMap<(u32, u32), Vec<(u32, f32)>> {
29 let neighbours_set: FxHashMap<u32, FxHashSet<u32>> = adjacency
30 .iter()
31 .map(|(node, edges)| (*node, edges.iter().map(|(n, _)| *n).collect()))
32 .collect();
33
34 adjacency
35 .par_iter()
36 .flat_map(|(curr_node, curr_edges)| {
37 curr_edges.par_iter().filter_map(|(prev_node, _)| {
38 adjacency.get(curr_node).map(|next_edges| {
39 let mut probs = Vec::new();
40 let mut total = 0_f32;
41
42 for (next_node, weight) in next_edges.iter() {
43 let unnorm_prob = if next_node == prev_node {
44 weight / p
45 } else if neighbours_set
46 .get(prev_node)
47 .map(|s| s.contains(next_node))
48 .unwrap_or(false)
49 {
50 *weight
51 } else {
52 weight / q
53 };
54 total += unnorm_prob;
55 probs.push((*next_node, unnorm_prob));
56 }
57
58 let mut cumulative = 0.0;
59 let normalised: Vec<(u32, f32)> = probs
60 .into_iter()
61 .map(|(node, prob)| {
62 cumulative += prob / total;
63 (node, cumulative)
64 })
65 .collect();
66
67 ((*prev_node, *curr_node), normalised)
68 })
69 })
70 })
71 .collect()
72}
73
74#[derive(Debug, Clone)]
85pub struct Node2VecGraph {
86 pub adjacency: FxHashMap<u32, Vec<(u32, f32)>>,
87 pub transition_probs: FxHashMap<(u32, u32), Vec<(u32, f32)>>,
88}
89
90impl Node2VecGraph {
91 pub fn generate_walks(
102 &self,
103 walks_per_node: usize,
104 walk_length: usize,
105 seed: usize,
106 ) -> Vec<Vec<u32>> {
107 let total_walks = self.adjacency.len() * walks_per_node;
108 let progress = ProgressBar::new(total_walks as u64);
109
110 progress.set_style(
111 ProgressStyle::default_bar()
112 .template("[Random walk gen: {elapsed_precise}] {bar:40.cyan/blue} {pos}/{len}")
113 .unwrap()
114 .progress_chars("#>-"),
115 );
116
117 let walks = self
118 .adjacency
119 .par_iter()
120 .flat_map(|(start_node, _)| {
121 let progress = progress.clone();
122 (0..walks_per_node).into_par_iter().map(move |walk_idx| {
123 let walk_seed = seed
124 .wrapping_mul(*start_node as usize)
125 .wrapping_add(walk_idx);
126 let mut rng = StdRng::seed_from_u64(walk_seed as u64);
127 let walk = self.single_walk(*start_node, walk_length, &mut rng);
128 progress.inc(1);
129 walk
130 })
131 })
132 .collect();
133
134 progress.finish();
135 walks
136 }
137
138 fn single_walk(&self, start_node: u32, walk_length: usize, rng: &mut StdRng) -> Vec<u32> {
150 let mut walk: Vec<u32> = Vec::with_capacity(walk_length);
151 walk.push(start_node);
152
153 if walk_length == 1 {
154 return walk;
155 }
156
157 let mut curr = if let Some(neighbours) = self.adjacency.get(&start_node) {
158 self.sample_neighbor(neighbours, rng)
159 } else {
160 return walk;
161 };
162
163 walk.push(curr);
164
165 for _ in 2..walk_length {
166 let prev = walk[walk.len() - 2];
167
168 if let Some(probs) = self.transition_probs.get(&(prev, curr)) {
169 curr = self.sample_from_cumulative(probs, rng);
170 walk.push(curr);
171 } else {
172 break;
173 }
174 }
175
176 walk
177 }
178
179 fn sample_neighbor(&self, neighbours: &[(u32, f32)], rng: &mut impl Rng) -> u32 {
190 let total: f32 = neighbours.iter().map(|(_, w)| w).sum();
191 let mut rand_val = rng.random::<f32>() * total;
192
193 for (node, weight) in neighbours {
194 rand_val -= weight;
195 if rand_val <= 0.0 {
196 return *node;
197 }
198 }
199
200 neighbours[0].0
201 }
202
203 fn sample_from_cumulative(&self, cumulative: &[(u32, f32)], rng: &mut impl Rng) -> u32 {
215 let rand_val = rng.random::<f32>();
216
217 match cumulative.binary_search_by(|(_, cum_prob)| cum_prob.partial_cmp(&rand_val).unwrap())
218 {
219 Ok(idx) => cumulative[idx].0,
220 Err(idx) => {
221 if idx < cumulative.len() {
222 cumulative[idx].0
223 } else {
224 cumulative[cumulative.len() - 1].0
225 }
226 }
227 }
228 }
229}
230
231#[cfg(test)]
232mod graph_tests {
233 use crate::prelude::*;
234 use rustc_hash::FxHashMap;
235
236 #[test]
237 fn test_transition_probs_sum_to_one() {
238 let mut adjacency = FxHashMap::default();
239 adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
240 adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
241 adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
242
243 let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
244
245 for (_, cumulative_probs) in probs.iter() {
247 let last_prob = cumulative_probs.last().unwrap().1;
248 assert!((last_prob - 1.0).abs() < 1e-6);
249 }
250 }
251
252 #[test]
253 fn test_transition_probs_p_parameter() {
254 let mut adjacency = FxHashMap::default();
256 adjacency.insert(1, vec![(2, 1.0)]);
257 adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
258 adjacency.insert(3, vec![(2, 1.0)]);
259
260 let probs_low_p = compute_transition_prob(&adjacency, 0.5, 1.0);
262 let probs_high_p = compute_transition_prob(&adjacency, 2.0, 1.0);
264
265 let low_p_return = &probs_low_p[&(1, 2)];
267 let high_p_return = &probs_high_p[&(1, 2)];
268
269 let low_p_val = low_p_return.iter().find(|(n, _)| *n == 1).unwrap().1;
271 let high_p_val = high_p_return.iter().find(|(n, _)| *n == 1).unwrap().1;
272
273 assert!(low_p_val > high_p_val);
275 }
276
277 #[test]
278 fn test_transition_probs_q_parameter() {
279 let mut adjacency = FxHashMap::default();
281 adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
282 adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
283 adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
284
285 let probs_low_q = compute_transition_prob(&adjacency, 1.0, 0.5);
287 let probs_high_q = compute_transition_prob(&adjacency, 1.0, 2.0);
289
290 for probs in [&probs_low_q, &probs_high_q] {
292 for cumulative_probs in probs.values() {
293 let last = cumulative_probs.last().unwrap().1;
294 assert!((last - 1.0).abs() < 1e-6);
295 }
296 }
297 }
298
299 #[test]
300 fn test_transition_probs_uniform_with_p_q_one() {
301 let mut adjacency = FxHashMap::default();
302 adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
303 adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
304 adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
305
306 let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
307
308 for cumulative_probs in probs.values() {
310 if cumulative_probs.len() == 2 {
311 let first_prob = cumulative_probs[0].1;
313 assert!((first_prob - 0.5).abs() < 1e-6);
314 }
315 }
316 }
317
318 #[test]
319 fn test_isolated_node() {
320 let mut adjacency = FxHashMap::default();
321 adjacency.insert(1, vec![]);
322
323 let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
324
325 assert!(probs.is_empty());
327 }
328
329 #[test]
330 fn test_weighted_transitions() {
331 let mut adjacency = FxHashMap::default();
332 adjacency.insert(1, vec![(2, 3.0)]);
333 adjacency.insert(2, vec![(1, 3.0), (3, 1.0)]);
334 adjacency.insert(3, vec![(2, 1.0)]);
335
336 let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
337
338 let probs_at_2 = &probs[&(1, 2)];
340 let prob_to_1 = probs_at_2.iter().find(|(n, _)| *n == 1).unwrap().1;
341 let prob_to_3 = probs_at_2.iter().find(|(n, _)| *n == 3).unwrap().1;
342
343 assert!(prob_to_1 < prob_to_3); assert!((prob_to_3 - 1.0).abs() < 1e-6); }
346}