Skip to main content

node2vec_rs/
graph.rs

1use indicatif::{ProgressBar, ProgressStyle};
2use rand::rngs::StdRng;
3use rand::{Rng, SeedableRng};
4use rayon::prelude::*;
5use rustc_hash::{FxHashMap, FxHashSet};
6
7/////////////
8// Helpers //
9/////////////
10
11/// Compute the transition probabilities
12///
13/// ### Params
14///
15/// * `adjacency` - The adjacency of the graph stored as a HashMap
16/// * `p` - p parameter in node2vec that controls probability to return to
17///   origin node.
18/// * `q` - q parameter in node2vec that controls probability to reach out
19///   futher in the graph.
20///
21/// ### Returns
22///
23/// The transition probabilities as an `FxHashMap`.
24pub fn compute_transition_prob(
25    adjacency: &FxHashMap<u32, Vec<(u32, f32)>>,
26    p: f32,
27    q: f32,
28) -> FxHashMap<(u32, u32), Vec<(u32, f32)>> {
29    let neighbours_set: FxHashMap<u32, FxHashSet<u32>> = adjacency
30        .iter()
31        .map(|(node, edges)| (*node, edges.iter().map(|(n, _)| *n).collect()))
32        .collect();
33
34    adjacency
35        .par_iter()
36        .flat_map(|(curr_node, curr_edges)| {
37            curr_edges.par_iter().filter_map(|(prev_node, _)| {
38                adjacency.get(curr_node).map(|next_edges| {
39                    let mut probs = Vec::new();
40                    let mut total = 0_f32;
41
42                    for (next_node, weight) in next_edges.iter() {
43                        let unnorm_prob = if next_node == prev_node {
44                            weight / p
45                        } else if neighbours_set
46                            .get(prev_node)
47                            .map(|s| s.contains(next_node))
48                            .unwrap_or(false)
49                        {
50                            *weight
51                        } else {
52                            weight / q
53                        };
54                        total += unnorm_prob;
55                        probs.push((*next_node, unnorm_prob));
56                    }
57
58                    let mut cumulative = 0.0;
59                    let normalised: Vec<(u32, f32)> = probs
60                        .into_iter()
61                        .map(|(node, prob)| {
62                            cumulative += prob / total;
63                            (node, cumulative)
64                        })
65                        .collect();
66
67                    ((*prev_node, *curr_node), normalised)
68                })
69            })
70        })
71        .collect()
72}
73
74/////////////////////
75// Graph structure //
76/////////////////////
77
78/// Structure to store the Node2Vec graph
79///
80/// ### Fields
81///
82/// * `adjacency` - The adjacency stored as an FxHashMap.
83/// * `transition_probs` - The transition probabilities stored in a FxHashMap.
84#[derive(Debug, Clone)]
85pub struct Node2VecGraph {
86    pub adjacency: FxHashMap<u32, Vec<(u32, f32)>>,
87    pub transition_probs: FxHashMap<(u32, u32), Vec<(u32, f32)>>,
88}
89
90impl Node2VecGraph {
91    /// Generates random walks from the graph
92    ///
93    /// ### Params
94    /// * `walks_per_node` - Number of walks to generate starting from each node
95    /// * `walk_length` - Length of each walk
96    /// * `seed` - Random seed for reproducibility
97    ///
98    /// ### Returns
99    ///
100    /// Vector of walks, where each walk is a sequence of node IDs
101    pub fn generate_walks(
102        &self,
103        walks_per_node: usize,
104        walk_length: usize,
105        seed: usize,
106    ) -> Vec<Vec<u32>> {
107        let total_walks = self.adjacency.len() * walks_per_node;
108        let progress = ProgressBar::new(total_walks as u64);
109
110        progress.set_style(
111            ProgressStyle::default_bar()
112                .template("[Random walk gen: {elapsed_precise}] {bar:40.cyan/blue} {pos}/{len}")
113                .unwrap()
114                .progress_chars("#>-"),
115        );
116
117        let walks = self
118            .adjacency
119            .par_iter()
120            .flat_map(|(start_node, _)| {
121                let progress = progress.clone();
122                (0..walks_per_node).into_par_iter().map(move |walk_idx| {
123                    // Properly hash the combination to avoid RNG synchronization
124                    let mut walk_seed = seed as u64;
125                    walk_seed ^= (*start_node as u64).wrapping_mul(0x517cc1b727220a95);
126                    walk_seed ^= (walk_idx as u64).wrapping_mul(0xcbf29ce484222325);
127                    walk_seed = walk_seed.wrapping_mul(0x4f1bbcdcbfa54c43); // Final mix
128
129                    let mut rng = StdRng::seed_from_u64(walk_seed);
130                    let walk = self.single_walk(*start_node, walk_length, &mut rng);
131                    progress.inc(1);
132                    walk
133                })
134            })
135            .collect();
136
137        progress.finish();
138        walks
139    }
140
141    /// Performs a single biased random walk
142    ///
143    /// ### Params
144    ///
145    /// * `start_node` - Node to start the walk from
146    /// * `walk_length` - Maximum length of the walk
147    /// * `rng` - Random number generator
148    ///
149    /// ### Returns
150    ///
151    /// The vector of node IDs for this walk
152    fn single_walk(&self, start_node: u32, walk_length: usize, rng: &mut StdRng) -> Vec<u32> {
153        let mut walk: Vec<u32> = Vec::with_capacity(walk_length);
154        walk.push(start_node);
155
156        if walk_length == 1 {
157            return walk;
158        }
159
160        let mut curr = if let Some(neighbours) = self.adjacency.get(&start_node) {
161            self.sample_neighbor(neighbours, rng)
162        } else {
163            return walk;
164        };
165
166        walk.push(curr);
167
168        for _ in 2..walk_length {
169            let prev = walk[walk.len() - 2];
170
171            if let Some(probs) = self.transition_probs.get(&(prev, curr)) {
172                curr = self.sample_from_cumulative(probs, rng);
173                walk.push(curr);
174            } else {
175                break;
176            }
177        }
178
179        walk
180    }
181
182    /// Samples a neighbour based on edge weights
183    ///
184    /// ### Params
185    ///
186    /// * `neighbours` - Slice of (node, weight) tuples
187    /// * `rng` - Random number generator
188    ///
189    /// ### Returns
190    ///
191    /// Node ID based on the neighbours
192    fn sample_neighbor(&self, neighbours: &[(u32, f32)], rng: &mut impl Rng) -> u32 {
193        let total: f32 = neighbours.iter().map(|(_, w)| w).sum();
194        let mut rand_val = rng.random::<f32>() * total;
195
196        for (node, weight) in neighbours {
197            rand_val -= weight;
198            if rand_val <= 0.0 {
199                return *node;
200            }
201        }
202
203        neighbours[0].0
204    }
205
206    /// Samples from a cumulative probability distribution
207    ///
208    /// ### Params
209    ///
210    /// * `cumulative` - Cumulative probabilities as (node, cumulative_prob)
211    ///   pairs
212    /// * `rng` - Random number generator
213    ///
214    /// ### Returns
215    ///
216    /// The node ID
217    fn sample_from_cumulative(&self, cumulative: &[(u32, f32)], rng: &mut impl Rng) -> u32 {
218        let rand_val = rng.random::<f32>();
219
220        match cumulative.binary_search_by(|(_, cum_prob)| cum_prob.partial_cmp(&rand_val).unwrap())
221        {
222            Ok(idx) => cumulative[idx].0,
223            Err(idx) => {
224                if idx < cumulative.len() {
225                    cumulative[idx].0
226                } else {
227                    cumulative[cumulative.len() - 1].0
228                }
229            }
230        }
231    }
232}
233
234#[cfg(test)]
235mod graph_tests {
236    use crate::prelude::*;
237    use rustc_hash::FxHashMap;
238
239    #[test]
240    fn test_transition_probs_sum_to_one() {
241        let mut adjacency = FxHashMap::default();
242        adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
243        adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
244        adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
245
246        let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
247
248        // Check all cumulative probabilities end at ~1.0
249        for (_, cumulative_probs) in probs.iter() {
250            let last_prob = cumulative_probs.last().unwrap().1;
251            assert!((last_prob - 1.0).abs() < 1e-6);
252        }
253    }
254
255    #[test]
256    fn test_transition_probs_p_parameter() {
257        // Simple chain: 1 - 2 - 3
258        let mut adjacency = FxHashMap::default();
259        adjacency.insert(1, vec![(2, 1.0)]);
260        adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
261        adjacency.insert(3, vec![(2, 1.0)]);
262
263        // With p=0.5, returning to origin should be more likely
264        let probs_low_p = compute_transition_prob(&adjacency, 0.5, 1.0);
265        // With p=2.0, returning to origin should be less likely
266        let probs_high_p = compute_transition_prob(&adjacency, 2.0, 1.0);
267
268        // When at node 2 coming from node 1, check probability of returning to 1
269        let low_p_return = &probs_low_p[&(1, 2)];
270        let high_p_return = &probs_high_p[&(1, 2)];
271
272        // Find the cumulative probability for node 1 (return)
273        let low_p_val = low_p_return.iter().find(|(n, _)| *n == 1).unwrap().1;
274        let high_p_val = high_p_return.iter().find(|(n, _)| *n == 1).unwrap().1;
275
276        // Lower p should give higher probability of return
277        assert!(low_p_val > high_p_val);
278    }
279
280    #[test]
281    fn test_transition_probs_q_parameter() {
282        // Triangle: 1 - 2 - 3, with 1 - 3 edge
283        let mut adjacency = FxHashMap::default();
284        adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
285        adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
286        adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
287
288        // With q=0.5, exploring further should be more likely
289        let probs_low_q = compute_transition_prob(&adjacency, 1.0, 0.5);
290        // With q=2.0, exploring further should be less likely
291        let probs_high_q = compute_transition_prob(&adjacency, 1.0, 2.0);
292
293        // Both should normalise to 1.0
294        for probs in [&probs_low_q, &probs_high_q] {
295            for cumulative_probs in probs.values() {
296                let last = cumulative_probs.last().unwrap().1;
297                assert!((last - 1.0).abs() < 1e-6);
298            }
299        }
300    }
301
302    #[test]
303    fn test_transition_probs_uniform_with_p_q_one() {
304        let mut adjacency = FxHashMap::default();
305        adjacency.insert(1, vec![(2, 1.0), (3, 1.0)]);
306        adjacency.insert(2, vec![(1, 1.0), (3, 1.0)]);
307        adjacency.insert(3, vec![(1, 1.0), (2, 1.0)]);
308
309        let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
310
311        // With p=q=1 and equal weights, all transitions should be uniform
312        for cumulative_probs in probs.values() {
313            if cumulative_probs.len() == 2 {
314                // Each neighbour should have ~0.5 probability
315                let first_prob = cumulative_probs[0].1;
316                assert!((first_prob - 0.5).abs() < 1e-6);
317            }
318        }
319    }
320
321    #[test]
322    fn test_isolated_node() {
323        let mut adjacency = FxHashMap::default();
324        adjacency.insert(1, vec![]);
325
326        let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
327
328        // Isolated node should have no transition probabilities
329        assert!(probs.is_empty());
330    }
331
332    #[test]
333    fn test_weighted_transitions() {
334        let mut adjacency = FxHashMap::default();
335        adjacency.insert(1, vec![(2, 3.0)]);
336        adjacency.insert(2, vec![(1, 3.0), (3, 1.0)]);
337        adjacency.insert(3, vec![(2, 1.0)]);
338
339        let probs = compute_transition_prob(&adjacency, 1.0, 1.0);
340
341        // At node 2 coming from 1, going back to 1 should be more likely (weight 3.0 vs 1.0)
342        let probs_at_2 = &probs[&(1, 2)];
343        let prob_to_1 = probs_at_2.iter().find(|(n, _)| *n == 1).unwrap().1;
344        let prob_to_3 = probs_at_2.iter().find(|(n, _)| *n == 3).unwrap().1;
345
346        assert!(prob_to_1 < prob_to_3); // cumulative, so first is less than second
347        assert!((prob_to_3 - 1.0).abs() < 1e-6); // last should sum to 1
348    }
349}