Skip to main content

node2vec_rs/
graph.rs

1use indicatif::{ProgressBar, ProgressStyle};
2use rand::rngs::StdRng;
3use rand::{Rng, SeedableRng};
4use rayon::prelude::*;
5use rustc_hash::{FxHashMap, FxHashSet};
6
7/////////////
8// Helpers //
9/////////////
10
11/// Compute the transition probabilities
12///
13/// ### Params
14///
15/// * `adjacency` - The adjacency of the graph stored as a HashMap
16/// * `p` - p parameter in node2vec that controls probability to return to
17///   origin node.
18/// * `q` - q parameter in node2vec that controls probability to reach out
19///   futher in the graph.
20///
21/// ### Returns
22///
23/// The transition probabilities as an `FxHashMap`.
24pub fn compute_transition_prob(
25    adjacency: &FxHashMap<u32, Vec<(u32, f32)>>,
26    p: f32,
27    q: f32,
28) -> FxHashMap<(u32, u32), Vec<(u32, f32)>> {
29    let neighbours_set: FxHashMap<u32, FxHashSet<u32>> = adjacency
30        .iter()
31        .map(|(node, edges)| (*node, edges.iter().map(|(n, _)| *n).collect()))
32        .collect();
33
34    adjacency
35        .par_iter()
36        .flat_map(|(curr_node, curr_edges)| {
37            curr_edges.par_iter().filter_map(|(prev_node, _)| {
38                adjacency.get(curr_node).map(|next_edges| {
39                    let mut probs = Vec::new();
40                    let mut total = 0_f32;
41
42                    for (next_node, weight) in next_edges.iter() {
43                        let unnorm_prob = if next_node == prev_node {
44                            weight / p
45                        } else if neighbours_set
46                            .get(prev_node)
47                            .map(|s| s.contains(next_node))
48                            .unwrap_or(false)
49                        {
50                            *weight
51                        } else {
52                            weight / q
53                        };
54                        total += unnorm_prob;
55                        probs.push((*next_node, unnorm_prob));
56                    }
57
58                    let mut cumulative = 0.0;
59                    let normalised: Vec<(u32, f32)> = probs
60                        .into_iter()
61                        .map(|(node, prob)| {
62                            cumulative += prob / total;
63                            (node, cumulative)
64                        })
65                        .collect();
66
67                    ((*prev_node, *curr_node), normalised)
68                })
69            })
70        })
71        .collect()
72}
73
74/////////////////////
75// Graph structure //
76/////////////////////
77
78/// Structure to store the Node2Vec graph
79///
80/// ### Fields
81///
82/// * `adjacency` - The adjacency stored as an FxHashMap.
83/// * `transition_probs` - The transition probabilities stored in a FxHashMap.
84#[derive(Debug, Clone)]
85pub struct Node2VecGraph {
86    pub adjacency: FxHashMap<u32, Vec<(u32, f32)>>,
87    pub transition_probs: FxHashMap<(u32, u32), Vec<(u32, f32)>>,
88}
89
90impl Node2VecGraph {
91    /// Generates random walks from the graph
92    ///
93    /// ### Params
94    /// * `walks_per_node` - Number of walks to generate starting from each node
95    /// * `walk_length` - Length of each walk
96    /// * `seed` - Random seed for reproducibility
97    ///
98    /// ### Returns
99    ///
100    /// Vector of walks, where each walk is a sequence of node IDs
101    pub fn generate_walks(
102        &self,
103        walks_per_node: usize,
104        walk_length: usize,
105        seed: usize,
106    ) -> Vec<Vec<u32>> {
107        let total_walks = self.adjacency.len() * walks_per_node;
108        let progress = ProgressBar::new(total_walks as u64);
109
110        progress.set_style(
111            ProgressStyle::default_bar()
112                .template("[Random walk gen: {elapsed_precise}] {bar:40.cyan/blue} {pos}/{len}")
113                .unwrap()
114                .progress_chars("#>-"),
115        );
116
117        let walks = self
118            .adjacency
119            .par_iter()
120            .flat_map(|(start_node, _)| {
121                let progress = progress.clone();
122                (0..walks_per_node).into_par_iter().map(move |walk_idx| {
123                    let walk_seed = seed
124                        .wrapping_mul(*start_node as usize)
125                        .wrapping_add(walk_idx);
126                    let mut rng = StdRng::seed_from_u64(walk_seed as u64);
127                    let walk = self.single_walk(*start_node, walk_length, &mut rng);
128                    progress.inc(1);
129                    walk
130                })
131            })
132            .collect();
133
134        progress.finish();
135        walks
136    }
137
138    /// Performs a single biased random walk
139    ///
140    /// ### Params
141    ///
142    /// * `start_node` - Node to start the walk from
143    /// * `walk_length` - Maximum length of the walk
144    /// * `rng` - Random number generator
145    ///
146    /// ### Returns
147    ///
148    /// The vector of node IDs for this walk
149    fn single_walk(&self, start_node: u32, walk_length: usize, rng: &mut StdRng) -> Vec<u32> {
150        let mut walk: Vec<u32> = Vec::with_capacity(walk_length);
151        walk.push(start_node);
152
153        if walk_length == 1 {
154            return walk;
155        }
156
157        let mut curr = if let Some(neighbours) = self.adjacency.get(&start_node) {
158            self.sample_neighbor(neighbours, rng)
159        } else {
160            return walk;
161        };
162
163        walk.push(curr);
164
165        for _ in 2..walk_length {
166            let prev = walk[walk.len() - 2];
167
168            if let Some(probs) = self.transition_probs.get(&(prev, curr)) {
169                curr = self.sample_from_cumulative(probs, rng);
170                walk.push(curr);
171            } else {
172                break;
173            }
174        }
175
176        walk
177    }
178
179    /// Samples a neighbour based on edge weights
180    ///
181    /// ### Params
182    ///
183    /// * `neighbours` - Slice of (node, weight) tuples
184    /// * `rng` - Random number generator
185    ///
186    /// ### Returns
187    ///
188    /// Node ID based on the neighbours
189    fn sample_neighbor(&self, neighbours: &[(u32, f32)], rng: &mut impl Rng) -> u32 {
190        let total: f32 = neighbours.iter().map(|(_, w)| w).sum();
191        let mut rand_val = rng.random::<f32>() * total;
192
193        for (node, weight) in neighbours {
194            rand_val -= weight;
195            if rand_val <= 0.0 {
196                return *node;
197            }
198        }
199
200        neighbours[0].0
201    }
202
203    /// Samples from a cumulative probability distribution
204    ///
205    /// ### Params
206    ///
207    /// * `cumulative` - Cumulative probabilities as (node, cumulative_prob)
208    ///   pairs
209    /// * `rng` - Random number generator
210    ///
211    /// ### Returns
212    ///
213    /// The node ID
214    fn sample_from_cumulative(&self, cumulative: &[(u32, f32)], rng: &mut impl Rng) -> u32 {
215        let rand_val = rng.random::<f32>();
216
217        match cumulative.binary_search_by(|(_, cum_prob)| cum_prob.partial_cmp(&rand_val).unwrap())
218        {
219            Ok(idx) => cumulative[idx].0,
220            Err(idx) => {
221                if idx < cumulative.len() {
222                    cumulative[idx].0
223                } else {
224                    cumulative[cumulative.len() - 1].0
225                }
226            }
227        }
228    }
229}
230
231#[cfg(test)]
232mod graph_tests {
233    use crate::prelude::*;
234    use rustc_hash::FxHashMap;
235
236    #[test]
237    fn test_transition_probs_sum_to_one() {
238        let mut adjacency = FxHashMap::default();
239        adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
240        adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
241        adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
242
243        let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
244
245        // Check all cumulative probabilities end at ~1.0
246        for (_, cumulative_probs) in probs.iter() {
247            let last_prob = cumulative_probs.last().unwrap().1;
248            assert!((last_prob - 1.0).abs() < 1e-6);
249        }
250    }
251
252    #[test]
253    fn test_transition_probs_p_parameter() {
254        // Simple chain: 1 - 2 - 3
255        let mut adjacency = FxHashMap::default();
256        adjacency.insert(1, vec![(2, 1.0)]);
257        adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
258        adjacency.insert(3, vec![(2, 1.0)]);
259
260        // With p=0.5, returning to origin should be more likely
261        let probs_low_p = compute_transition_prob(&adjacency, 0.5, 1.0);
262        // With p=2.0, returning to origin should be less likely
263        let probs_high_p = compute_transition_prob(&adjacency, 2.0, 1.0);
264
265        // When at node 2 coming from node 1, check probability of returning to 1
266        let low_p_return = &probs_low_p[&(1, 2)];
267        let high_p_return = &probs_high_p[&(1, 2)];
268
269        // Find the cumulative probability for node 1 (return)
270        let low_p_val = low_p_return.iter().find(|(n, _)| *n == 1).unwrap().1;
271        let high_p_val = high_p_return.iter().find(|(n, _)| *n == 1).unwrap().1;
272
273        // Lower p should give higher probability of return
274        assert!(low_p_val > high_p_val);
275    }
276
277    #[test]
278    fn test_transition_probs_q_parameter() {
279        // Triangle: 1 - 2 - 3, with 1 - 3 edge
280        let mut adjacency = FxHashMap::default();
281        adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
282        adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
283        adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
284
285        // With q=0.5, exploring further should be more likely
286        let probs_low_q = compute_transition_prob(&adjacency, 1.0, 0.5);
287        // With q=2.0, exploring further should be less likely
288        let probs_high_q = compute_transition_prob(&adjacency, 1.0, 2.0);
289
290        // Both should normalise to 1.0
291        for probs in [&probs_low_q, &probs_high_q] {
292            for cumulative_probs in probs.values() {
293                let last = cumulative_probs.last().unwrap().1;
294                assert!((last - 1.0).abs() < 1e-6);
295            }
296        }
297    }
298
299    #[test]
300    fn test_transition_probs_uniform_with_p_q_one() {
301        let mut adjacency = FxHashMap::default();
302        adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
303        adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
304        adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
305
306        let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
307
308        // With p=q=1 and equal weights, all transitions should be uniform
309        for cumulative_probs in probs.values() {
310            if cumulative_probs.len() == 2 {
311                // Each neighbour should have ~0.5 probability
312                let first_prob = cumulative_probs[0].1;
313                assert!((first_prob - 0.5).abs() < 1e-6);
314            }
315        }
316    }
317
318    #[test]
319    fn test_isolated_node() {
320        let mut adjacency = FxHashMap::default();
321        adjacency.insert(1, vec![]);
322
323        let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
324
325        // Isolated node should have no transition probabilities
326        assert!(probs.is_empty());
327    }
328
329    #[test]
330    fn test_weighted_transitions() {
331        let mut adjacency = FxHashMap::default();
332        adjacency.insert(1, vec![(2, 3.0)]);
333        adjacency.insert(2, vec![(1, 3.0), (3, 1.0)]);
334        adjacency.insert(3, vec![(2, 1.0)]);
335
336        let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
337
338        // At node 2 coming from 1, going back to 1 should be more likely (weight 3.0 vs 1.0)
339        let probs_at_2 = &probs[&(1, 2)];
340        let prob_to_1 = probs_at_2.iter().find(|(n, _)| *n == 1).unwrap().1;
341        let prob_to_3 = probs_at_2.iter().find(|(n, _)| *n == 3).unwrap().1;
342
343        assert!(prob_to_1 < prob_to_3); // cumulative, so first is less than second
344        assert!((prob_to_3 - 1.0).abs() < 1e-6); // last should sum to 1
345    }
346}