1use super::core::EmbeddingModel;
11use super::negative_sampling::NegativeSampler;
12use super::random_walk::RandomWalkGenerator;
13use super::types::{Node2VecConfig, RandomWalk};
14use crate::base::{DiGraph, EdgeWeight, Graph, Node};
15use crate::error::Result;
16use scirs2_core::random::seq::SliceRandom;
17
18pub struct Node2Vec<N: Node> {
23 config: Node2VecConfig,
24 model: EmbeddingModel<N>,
25 walk_generator: RandomWalkGenerator<N>,
26}
27
28impl<N: Node> Node2Vec<N> {
29 pub fn new(config: Node2VecConfig) -> Self {
31 Node2Vec {
32 model: EmbeddingModel::new(config.dimensions),
33 config,
34 walk_generator: RandomWalkGenerator::new(),
35 }
36 }
37
38 pub fn generate_walks<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<Vec<RandomWalk<N>>>
40 where
41 N: Clone + std::fmt::Debug,
42 E: EdgeWeight + Into<f64>,
43 Ix: petgraph::graph::IndexType,
44 {
45 let mut all_walks = Vec::new();
46
47 for node in graph.nodes() {
48 for _ in 0..self.config.num_walks {
49 let walk = self.walk_generator.node2vec_walk(
50 graph,
51 node,
52 self.config.walk_length,
53 self.config.p,
54 self.config.q,
55 )?;
56 all_walks.push(walk);
57 }
58 }
59
60 Ok(all_walks)
61 }
62
63 pub fn generate_walks_digraph<E, Ix>(
65 &mut self,
66 graph: &DiGraph<N, E, Ix>,
67 ) -> Result<Vec<RandomWalk<N>>>
68 where
69 N: Clone + std::fmt::Debug,
70 E: EdgeWeight + Into<f64>,
71 Ix: petgraph::graph::IndexType,
72 {
73 let mut all_walks = Vec::new();
74
75 for node in graph.nodes() {
76 for _ in 0..self.config.num_walks {
77 let walk = self.walk_generator.node2vec_walk_digraph(
78 graph,
79 node,
80 self.config.walk_length,
81 self.config.p,
82 self.config.q,
83 )?;
84 all_walks.push(walk);
85 }
86 }
87
88 Ok(all_walks)
89 }
90
91 pub fn train<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<()>
93 where
94 N: Clone + std::fmt::Debug,
95 E: EdgeWeight + Into<f64>,
96 Ix: petgraph::graph::IndexType,
97 {
98 let mut rng = scirs2_core::random::rng();
100 self.model.initialize_random(graph, &mut rng);
101
102 let negative_sampler = NegativeSampler::new(graph);
104
105 for epoch in 0..self.config.epochs {
107 let walks = self.generate_walks(graph)?;
109
110 let context_pairs =
112 EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
113
114 let mut shuffled_pairs = context_pairs;
116 shuffled_pairs.shuffle(&mut rng);
117
118 let current_lr = self.config.learning_rate
121 * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
122
123 self.model.train_skip_gram(
124 &shuffled_pairs,
125 &negative_sampler,
126 current_lr,
127 self.config.negative_samples,
128 &mut rng,
129 )?;
130 }
131
132 Ok(())
133 }
134
135 pub fn train_digraph<E, Ix>(&mut self, graph: &DiGraph<N, E, Ix>) -> Result<()>
137 where
138 N: Clone + std::fmt::Debug,
139 E: EdgeWeight + Into<f64>,
140 Ix: petgraph::graph::IndexType,
141 {
142 let mut rng = scirs2_core::random::rng();
144 self.model.initialize_random_digraph(graph, &mut rng);
145
146 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
149 let node_degrees: Vec<f64> = nodes.iter().map(|n| graph.degree(n) as f64).collect();
150
151 let total_degree: f64 = node_degrees.iter().sum();
153 let frequencies: Vec<f64> = node_degrees
154 .iter()
155 .map(|d| (d / total_degree.max(1.0)).powf(0.75))
156 .collect();
157 let total_freq: f64 = frequencies.iter().sum();
158 let normalized: Vec<f64> = frequencies
159 .iter()
160 .map(|f| f / total_freq.max(1e-10))
161 .collect();
162
163 let mut cumulative = vec![0.0; normalized.len()];
164 if !cumulative.is_empty() {
165 cumulative[0] = normalized[0];
166 for i in 1..normalized.len() {
167 cumulative[i] = cumulative[i - 1] + normalized[i];
168 }
169 }
170
171 for epoch in 0..self.config.epochs {
173 let walks = self.generate_walks_digraph(graph)?;
174 let context_pairs =
175 EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
176
177 let mut shuffled_pairs = context_pairs;
178 shuffled_pairs.shuffle(&mut rng);
179
180 let current_lr = self.config.learning_rate
181 * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
182
183 for pair in &shuffled_pairs {
186 self.train_pair_digraph(
187 pair,
188 &nodes,
189 &cumulative,
190 current_lr,
191 self.config.negative_samples,
192 &mut rng,
193 );
194 }
195 }
196
197 Ok(())
198 }
199
200 fn train_pair_digraph(
202 &mut self,
203 pair: &super::types::ContextPair<N>,
204 nodes: &[N],
205 cumulative: &[f64],
206 learning_rate: f64,
207 num_negative: usize,
208 rng: &mut impl scirs2_core::random::Rng,
209 ) where
210 N: Clone,
211 {
212 let dim = self.config.dimensions;
213
214 let target_emb = match self.model.embeddings.get(&pair.target) {
216 Some(e) => e.clone(),
217 None => return,
218 };
219
220 let context_emb = match self.model.context_embeddings.get(&pair.context) {
222 Some(e) => e.clone(),
223 None => return,
224 };
225
226 let dot: f64 = target_emb
228 .vector
229 .iter()
230 .zip(context_emb.vector.iter())
231 .map(|(a, b)| a * b)
232 .sum();
233 let sig = 1.0 / (1.0 + (-dot).exp());
234 let g = learning_rate * (1.0 - sig);
235
236 let mut target_grad = vec![0.0; dim];
237 for d in 0..dim {
238 target_grad[d] += g * context_emb.vector[d];
239 }
240
241 if let Some(ctx) = self.model.context_embeddings.get_mut(&pair.context) {
243 for d in 0..dim {
244 ctx.vector[d] += g * target_emb.vector[d];
245 }
246 }
247
248 for _ in 0..num_negative {
250 let r = rng.random::<f64>();
251 let neg_idx = cumulative
252 .iter()
253 .position(|&c| r <= c)
254 .unwrap_or(cumulative.len().saturating_sub(1));
255
256 if neg_idx >= nodes.len() {
257 continue;
258 }
259
260 let neg_node = &nodes[neg_idx];
261 if neg_node == &pair.target || neg_node == &pair.context {
262 continue;
263 }
264
265 if let Some(neg_emb) = self.model.context_embeddings.get(neg_node) {
266 let neg_dot: f64 = target_emb
267 .vector
268 .iter()
269 .zip(neg_emb.vector.iter())
270 .map(|(a, b)| a * b)
271 .sum();
272 let neg_sig = 1.0 / (1.0 + (-neg_dot).exp());
273 let neg_g = learning_rate * (-neg_sig);
274
275 for d in 0..dim {
276 target_grad[d] += neg_g * neg_emb.vector[d];
277 }
278
279 if let Some(neg_ctx) = self.model.context_embeddings.get_mut(neg_node) {
281 for d in 0..dim {
282 neg_ctx.vector[d] += neg_g * target_emb.vector[d];
283 }
284 }
285 }
286 }
287
288 if let Some(target) = self.model.embeddings.get_mut(&pair.target) {
290 for d in 0..dim {
291 target.vector[d] += target_grad[d];
292 }
293 }
294 }
295
296 pub fn model(&self) -> &EmbeddingModel<N> {
298 &self.model
299 }
300
301 pub fn model_mut(&mut self) -> &mut EmbeddingModel<N> {
303 &mut self.model
304 }
305
306 pub fn config(&self) -> &Node2VecConfig {
308 &self.config
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 fn make_triangle() -> Graph<i32, f64> {
317 let mut g = Graph::new();
318 for i in 0..3 {
319 g.add_node(i);
320 }
321 let _ = g.add_edge(0, 1, 1.0);
322 let _ = g.add_edge(1, 2, 1.0);
323 let _ = g.add_edge(0, 2, 1.0);
324 g
325 }
326
327 fn make_star_graph() -> Graph<i32, f64> {
328 let mut g = Graph::new();
329 for i in 0..5 {
330 g.add_node(i);
331 }
332 for i in 1..5 {
334 let _ = g.add_edge(0, i, 1.0);
335 }
336 g
337 }
338
339 fn make_directed_chain() -> DiGraph<i32, f64> {
340 let mut g = DiGraph::new();
341 for i in 0..5 {
342 g.add_node(i);
343 }
344 let _ = g.add_edge(0, 1, 1.0);
345 let _ = g.add_edge(1, 2, 1.0);
346 let _ = g.add_edge(2, 3, 1.0);
347 let _ = g.add_edge(3, 4, 1.0);
348 g
349 }
350
351 #[test]
352 fn test_node2vec_train_basic() {
353 let g = make_triangle();
354 let config = Node2VecConfig {
355 dimensions: 8,
356 walk_length: 5,
357 num_walks: 3,
358 window_size: 2,
359 p: 1.0,
360 q: 1.0,
361 epochs: 2,
362 learning_rate: 0.025,
363 negative_samples: 2,
364 };
365
366 let mut n2v = Node2Vec::new(config);
367 let result = n2v.train(&g);
368 assert!(result.is_ok(), "Node2Vec training should succeed");
369
370 for node in [0, 1, 2] {
372 assert!(
373 n2v.model().get_embedding(&node).is_some(),
374 "Node {node} should have an embedding"
375 );
376 }
377 }
378
379 #[test]
380 fn test_node2vec_walk_generation() {
381 let g = make_triangle();
382 let config = Node2VecConfig {
383 dimensions: 8,
384 walk_length: 10,
385 num_walks: 2,
386 p: 1.0,
387 q: 1.0,
388 ..Default::default()
389 };
390
391 let mut n2v = Node2Vec::new(config);
392 let walks = n2v.generate_walks(&g);
393 assert!(walks.is_ok());
394
395 let walks = walks.expect("walks should be valid");
396 assert_eq!(walks.len(), 6);
398
399 for walk in &walks {
401 assert!(walk.nodes.len() <= 10);
402 assert!(!walk.nodes.is_empty());
403 }
404 }
405
406 #[test]
407 fn test_node2vec_biased_walks() {
408 let g = make_star_graph();
411 let config = Node2VecConfig {
412 dimensions: 8,
413 walk_length: 20,
414 num_walks: 5,
415 p: 0.5,
416 q: 2.0,
417 ..Default::default()
418 };
419
420 let mut n2v = Node2Vec::new(config);
421 let walks = n2v.generate_walks(&g);
422 assert!(walks.is_ok());
423
424 let walks = walks.expect("walks should be valid");
425 assert!(!walks.is_empty());
426
427 for walk in &walks {
429 for node in &walk.nodes {
430 assert!(
431 (0..5).contains(node),
432 "Walk should only contain valid nodes, got {node}"
433 );
434 }
435 }
436 }
437
438 #[test]
439 fn test_node2vec_embedding_similarity() {
440 let g = make_triangle();
441 let config = Node2VecConfig {
442 dimensions: 16,
443 walk_length: 10,
444 num_walks: 10,
445 window_size: 3,
446 p: 1.0,
447 q: 1.0,
448 epochs: 5,
449 learning_rate: 0.05,
450 negative_samples: 3,
451 };
452
453 let mut n2v = Node2Vec::new(config);
454 let _ = n2v.train(&g);
455
456 let model = n2v.model();
459 let sim_01 = model.most_similar(&0, 2);
460 assert!(sim_01.is_ok());
461
462 let sim_01 = sim_01.expect("similarity should be valid");
463 assert_eq!(sim_01.len(), 2, "Should find 2 most similar nodes");
464
465 for (node, score) in &sim_01 {
466 assert!(
467 score.is_finite(),
468 "Similarity for node {node} should be finite"
469 );
470 }
471 }
472
473 #[test]
474 fn test_node2vec_digraph_train() {
475 let g = make_directed_chain();
476 let config = Node2VecConfig {
477 dimensions: 8,
478 walk_length: 4,
479 num_walks: 3,
480 window_size: 2,
481 p: 1.0,
482 q: 1.0,
483 epochs: 2,
484 learning_rate: 0.025,
485 negative_samples: 2,
486 };
487
488 let mut n2v = Node2Vec::new(config);
489 let result = n2v.train_digraph(&g);
490 assert!(result.is_ok(), "DiGraph Node2Vec training should succeed");
491
492 for node in 0..5 {
494 assert!(
495 n2v.model().get_embedding(&node).is_some(),
496 "Node {node} should have an embedding in directed graph"
497 );
498 }
499 }
500
501 #[test]
502 fn test_node2vec_config() {
503 let config = Node2VecConfig::default();
504 assert_eq!(config.dimensions, 128);
505 assert_eq!(config.walk_length, 80);
506 assert_eq!(config.p, 1.0);
507 assert_eq!(config.q, 1.0);
508
509 let n2v: Node2Vec<i32> = Node2Vec::new(config);
510 assert_eq!(n2v.config().dimensions, 128);
511 }
512}