Skip to main content

scirs2_graph/embeddings/
node2vec.rs

1//! Node2Vec graph embedding algorithm
2//!
3//! Implements the Node2Vec algorithm from Grover & Leskovec (2016) for learning
4//! continuous feature representations for nodes in networks. Uses biased random
5//! walks with return parameter p and in-out parameter q to explore neighborhoods.
6//!
7//! # References
8//! - Grover, A. & Leskovec, J. (2016). node2vec: Scalable Feature Learning for Networks. KDD 2016.
9
10use super::core::EmbeddingModel;
11use super::negative_sampling::NegativeSampler;
12use super::random_walk::RandomWalkGenerator;
13use super::types::{Node2VecConfig, RandomWalk};
14use crate::base::{DiGraph, EdgeWeight, Graph, Node};
15use crate::error::Result;
16use scirs2_core::random::seq::SliceRandom;
17
18/// Node2Vec embedding algorithm
19///
20/// Learns node embeddings using biased second-order random walks followed
21/// by skip-gram optimization with negative sampling.
22pub struct Node2Vec<N: Node> {
23    config: Node2VecConfig,
24    model: EmbeddingModel<N>,
25    walk_generator: RandomWalkGenerator<N>,
26}
27
28impl<N: Node> Node2Vec<N> {
29    /// Create a new Node2Vec instance
30    pub fn new(config: Node2VecConfig) -> Self {
31        Node2Vec {
32            model: EmbeddingModel::new(config.dimensions),
33            config,
34            walk_generator: RandomWalkGenerator::new(),
35        }
36    }
37
38    /// Generate training data (biased random walks) for Node2Vec on undirected graph
39    pub fn generate_walks<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<Vec<RandomWalk<N>>>
40    where
41        N: Clone + std::fmt::Debug,
42        E: EdgeWeight + Into<f64>,
43        Ix: petgraph::graph::IndexType,
44    {
45        let mut all_walks = Vec::new();
46
47        for node in graph.nodes() {
48            for _ in 0..self.config.num_walks {
49                let walk = self.walk_generator.node2vec_walk(
50                    graph,
51                    node,
52                    self.config.walk_length,
53                    self.config.p,
54                    self.config.q,
55                )?;
56                all_walks.push(walk);
57            }
58        }
59
60        Ok(all_walks)
61    }
62
63    /// Generate training data (biased random walks) for Node2Vec on directed graph
64    pub fn generate_walks_digraph<E, Ix>(
65        &mut self,
66        graph: &DiGraph<N, E, Ix>,
67    ) -> Result<Vec<RandomWalk<N>>>
68    where
69        N: Clone + std::fmt::Debug,
70        E: EdgeWeight + Into<f64>,
71        Ix: petgraph::graph::IndexType,
72    {
73        let mut all_walks = Vec::new();
74
75        for node in graph.nodes() {
76            for _ in 0..self.config.num_walks {
77                let walk = self.walk_generator.node2vec_walk_digraph(
78                    graph,
79                    node,
80                    self.config.walk_length,
81                    self.config.p,
82                    self.config.q,
83                )?;
84                all_walks.push(walk);
85            }
86        }
87
88        Ok(all_walks)
89    }
90
91    /// Train the Node2Vec model on an undirected graph
92    pub fn train<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<()>
93    where
94        N: Clone + std::fmt::Debug,
95        E: EdgeWeight + Into<f64>,
96        Ix: petgraph::graph::IndexType,
97    {
98        // Initialize random embeddings
99        let mut rng = scirs2_core::random::rng();
100        self.model.initialize_random(graph, &mut rng);
101
102        // Create negative sampler
103        let negative_sampler = NegativeSampler::new(graph);
104
105        // Training loop over epochs
106        for epoch in 0..self.config.epochs {
107            // Generate walks for this epoch
108            let walks = self.generate_walks(graph)?;
109
110            // Generate context pairs from walks
111            let context_pairs =
112                EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
113
114            // Shuffle pairs for better training
115            let mut shuffled_pairs = context_pairs;
116            shuffled_pairs.shuffle(&mut rng);
117
118            // Train skip-gram model with negative sampling
119            // Linear learning rate decay
120            let current_lr = self.config.learning_rate
121                * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
122
123            self.model.train_skip_gram(
124                &shuffled_pairs,
125                &negative_sampler,
126                current_lr,
127                self.config.negative_samples,
128                &mut rng,
129            )?;
130        }
131
132        Ok(())
133    }
134
135    /// Train the Node2Vec model on a directed graph
136    pub fn train_digraph<E, Ix>(&mut self, graph: &DiGraph<N, E, Ix>) -> Result<()>
137    where
138        N: Clone + std::fmt::Debug,
139        E: EdgeWeight + Into<f64>,
140        Ix: petgraph::graph::IndexType,
141    {
142        // Initialize random embeddings for directed graph
143        let mut rng = scirs2_core::random::rng();
144        self.model.initialize_random_digraph(graph, &mut rng);
145
146        // Create negative sampler from the undirected view
147        // For DiGraph, we build a temporary sampler from node degrees
148        let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
149        let node_degrees: Vec<f64> = nodes.iter().map(|n| graph.degree(n) as f64).collect();
150
151        // Build cumulative distribution for negative sampling
152        let total_degree: f64 = node_degrees.iter().sum();
153        let frequencies: Vec<f64> = node_degrees
154            .iter()
155            .map(|d| (d / total_degree.max(1.0)).powf(0.75))
156            .collect();
157        let total_freq: f64 = frequencies.iter().sum();
158        let normalized: Vec<f64> = frequencies
159            .iter()
160            .map(|f| f / total_freq.max(1e-10))
161            .collect();
162
163        let mut cumulative = vec![0.0; normalized.len()];
164        if !cumulative.is_empty() {
165            cumulative[0] = normalized[0];
166            for i in 1..normalized.len() {
167                cumulative[i] = cumulative[i - 1] + normalized[i];
168            }
169        }
170
171        // Training loop
172        for epoch in 0..self.config.epochs {
173            let walks = self.generate_walks_digraph(graph)?;
174            let context_pairs =
175                EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
176
177            let mut shuffled_pairs = context_pairs;
178            shuffled_pairs.shuffle(&mut rng);
179
180            let current_lr = self.config.learning_rate
181                * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
182
183            // Manual skip-gram training for directed graphs
184            // (since NegativeSampler is built for Graph, not DiGraph)
185            for pair in &shuffled_pairs {
186                self.train_pair_digraph(
187                    pair,
188                    &nodes,
189                    &cumulative,
190                    current_lr,
191                    self.config.negative_samples,
192                    &mut rng,
193                );
194            }
195        }
196
197        Ok(())
198    }
199
200    /// Train on a single context pair for directed graphs
201    fn train_pair_digraph(
202        &mut self,
203        pair: &super::types::ContextPair<N>,
204        nodes: &[N],
205        cumulative: &[f64],
206        learning_rate: f64,
207        num_negative: usize,
208        rng: &mut impl scirs2_core::random::Rng,
209    ) where
210        N: Clone,
211    {
212        let dim = self.config.dimensions;
213
214        // Get target embedding
215        let target_emb = match self.model.embeddings.get(&pair.target) {
216            Some(e) => e.clone(),
217            None => return,
218        };
219
220        // Get context embedding
221        let context_emb = match self.model.context_embeddings.get(&pair.context) {
222            Some(e) => e.clone(),
223            None => return,
224        };
225
226        // Positive sample gradient
227        let dot: f64 = target_emb
228            .vector
229            .iter()
230            .zip(context_emb.vector.iter())
231            .map(|(a, b)| a * b)
232            .sum();
233        let sig = 1.0 / (1.0 + (-dot).exp());
234        let g = learning_rate * (1.0 - sig);
235
236        let mut target_grad = vec![0.0; dim];
237        for d in 0..dim {
238            target_grad[d] += g * context_emb.vector[d];
239        }
240
241        // Update context embedding
242        if let Some(ctx) = self.model.context_embeddings.get_mut(&pair.context) {
243            for d in 0..dim {
244                ctx.vector[d] += g * target_emb.vector[d];
245            }
246        }
247
248        // Negative samples
249        for _ in 0..num_negative {
250            let r = rng.random::<f64>();
251            let neg_idx = cumulative
252                .iter()
253                .position(|&c| r <= c)
254                .unwrap_or(cumulative.len().saturating_sub(1));
255
256            if neg_idx >= nodes.len() {
257                continue;
258            }
259
260            let neg_node = &nodes[neg_idx];
261            if neg_node == &pair.target || neg_node == &pair.context {
262                continue;
263            }
264
265            if let Some(neg_emb) = self.model.context_embeddings.get(neg_node) {
266                let neg_dot: f64 = target_emb
267                    .vector
268                    .iter()
269                    .zip(neg_emb.vector.iter())
270                    .map(|(a, b)| a * b)
271                    .sum();
272                let neg_sig = 1.0 / (1.0 + (-neg_dot).exp());
273                let neg_g = learning_rate * (-neg_sig);
274
275                for d in 0..dim {
276                    target_grad[d] += neg_g * neg_emb.vector[d];
277                }
278
279                // Update negative context
280                if let Some(neg_ctx) = self.model.context_embeddings.get_mut(neg_node) {
281                    for d in 0..dim {
282                        neg_ctx.vector[d] += neg_g * target_emb.vector[d];
283                    }
284                }
285            }
286        }
287
288        // Apply accumulated gradient to target
289        if let Some(target) = self.model.embeddings.get_mut(&pair.target) {
290            for d in 0..dim {
291                target.vector[d] += target_grad[d];
292            }
293        }
294    }
295
296    /// Get the trained model
297    pub fn model(&self) -> &EmbeddingModel<N> {
298        &self.model
299    }
300
301    /// Get mutable reference to the model
302    pub fn model_mut(&mut self) -> &mut EmbeddingModel<N> {
303        &mut self.model
304    }
305
306    /// Get the configuration
307    pub fn config(&self) -> &Node2VecConfig {
308        &self.config
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    fn make_triangle() -> Graph<i32, f64> {
317        let mut g = Graph::new();
318        for i in 0..3 {
319            g.add_node(i);
320        }
321        let _ = g.add_edge(0, 1, 1.0);
322        let _ = g.add_edge(1, 2, 1.0);
323        let _ = g.add_edge(0, 2, 1.0);
324        g
325    }
326
327    fn make_star_graph() -> Graph<i32, f64> {
328        let mut g = Graph::new();
329        for i in 0..5 {
330            g.add_node(i);
331        }
332        // Node 0 is the center
333        for i in 1..5 {
334            let _ = g.add_edge(0, i, 1.0);
335        }
336        g
337    }
338
339    fn make_directed_chain() -> DiGraph<i32, f64> {
340        let mut g = DiGraph::new();
341        for i in 0..5 {
342            g.add_node(i);
343        }
344        let _ = g.add_edge(0, 1, 1.0);
345        let _ = g.add_edge(1, 2, 1.0);
346        let _ = g.add_edge(2, 3, 1.0);
347        let _ = g.add_edge(3, 4, 1.0);
348        g
349    }
350
351    #[test]
352    fn test_node2vec_train_basic() {
353        let g = make_triangle();
354        let config = Node2VecConfig {
355            dimensions: 8,
356            walk_length: 5,
357            num_walks: 3,
358            window_size: 2,
359            p: 1.0,
360            q: 1.0,
361            epochs: 2,
362            learning_rate: 0.025,
363            negative_samples: 2,
364        };
365
366        let mut n2v = Node2Vec::new(config);
367        let result = n2v.train(&g);
368        assert!(result.is_ok(), "Node2Vec training should succeed");
369
370        // All nodes should have embeddings
371        for node in [0, 1, 2] {
372            assert!(
373                n2v.model().get_embedding(&node).is_some(),
374                "Node {node} should have an embedding"
375            );
376        }
377    }
378
379    #[test]
380    fn test_node2vec_walk_generation() {
381        let g = make_triangle();
382        let config = Node2VecConfig {
383            dimensions: 8,
384            walk_length: 10,
385            num_walks: 2,
386            p: 1.0,
387            q: 1.0,
388            ..Default::default()
389        };
390
391        let mut n2v = Node2Vec::new(config);
392        let walks = n2v.generate_walks(&g);
393        assert!(walks.is_ok());
394
395        let walks = walks.expect("walks should be valid");
396        // 3 nodes * 2 walks per node = 6 walks total
397        assert_eq!(walks.len(), 6);
398
399        // Each walk should have at most walk_length nodes
400        for walk in &walks {
401            assert!(walk.nodes.len() <= 10);
402            assert!(!walk.nodes.is_empty());
403        }
404    }
405
406    #[test]
407    fn test_node2vec_biased_walks() {
408        // With p=0.5 (low), walks should favor returning to previous nodes
409        // With q=2.0 (high), walks should favor local (BFS-like) exploration
410        let g = make_star_graph();
411        let config = Node2VecConfig {
412            dimensions: 8,
413            walk_length: 20,
414            num_walks: 5,
415            p: 0.5,
416            q: 2.0,
417            ..Default::default()
418        };
419
420        let mut n2v = Node2Vec::new(config);
421        let walks = n2v.generate_walks(&g);
422        assert!(walks.is_ok());
423
424        let walks = walks.expect("walks should be valid");
425        assert!(!walks.is_empty());
426
427        // Verify walks contain valid nodes
428        for walk in &walks {
429            for node in &walk.nodes {
430                assert!(
431                    (0..5).contains(node),
432                    "Walk should only contain valid nodes, got {node}"
433                );
434            }
435        }
436    }
437
438    #[test]
439    fn test_node2vec_embedding_similarity() {
440        let g = make_triangle();
441        let config = Node2VecConfig {
442            dimensions: 16,
443            walk_length: 10,
444            num_walks: 10,
445            window_size: 3,
446            p: 1.0,
447            q: 1.0,
448            epochs: 5,
449            learning_rate: 0.05,
450            negative_samples: 3,
451        };
452
453        let mut n2v = Node2Vec::new(config);
454        let _ = n2v.train(&g);
455
456        // In a triangle, all nodes are structurally equivalent
457        // so similarities should be computable (not NaN)
458        let model = n2v.model();
459        let sim_01 = model.most_similar(&0, 2);
460        assert!(sim_01.is_ok());
461
462        let sim_01 = sim_01.expect("similarity should be valid");
463        assert_eq!(sim_01.len(), 2, "Should find 2 most similar nodes");
464
465        for (node, score) in &sim_01 {
466            assert!(
467                score.is_finite(),
468                "Similarity for node {node} should be finite"
469            );
470        }
471    }
472
473    #[test]
474    fn test_node2vec_digraph_train() {
475        let g = make_directed_chain();
476        let config = Node2VecConfig {
477            dimensions: 8,
478            walk_length: 4,
479            num_walks: 3,
480            window_size: 2,
481            p: 1.0,
482            q: 1.0,
483            epochs: 2,
484            learning_rate: 0.025,
485            negative_samples: 2,
486        };
487
488        let mut n2v = Node2Vec::new(config);
489        let result = n2v.train_digraph(&g);
490        assert!(result.is_ok(), "DiGraph Node2Vec training should succeed");
491
492        // All nodes should have embeddings
493        for node in 0..5 {
494            assert!(
495                n2v.model().get_embedding(&node).is_some(),
496                "Node {node} should have an embedding in directed graph"
497            );
498        }
499    }
500
501    #[test]
502    fn test_node2vec_config() {
503        let config = Node2VecConfig::default();
504        assert_eq!(config.dimensions, 128);
505        assert_eq!(config.walk_length, 80);
506        assert_eq!(config.p, 1.0);
507        assert_eq!(config.q, 1.0);
508
509        let n2v: Node2Vec<i32> = Node2Vec::new(config);
510        assert_eq!(n2v.config().dimensions, 128);
511    }
512}