1use petgraph::graph::{NodeIndex, UnGraph};
29use scirs2_core::random::rand_prelude::StdRng;
30use scirs2_core::random::{seeded_rng, CoreRandom, Random};
31use std::collections::HashMap;
32
33use crate::{GraphRAGError, GraphRAGResult, Triple};
34
35#[derive(Debug, Clone)]
39pub struct Node2VecWalkConfig {
40 pub num_walks: usize,
42 pub walk_length: usize,
44 pub p: f64,
46 pub q: f64,
48 pub random_seed: u64,
50}
51
52impl Default for Node2VecWalkConfig {
53 fn default() -> Self {
54 Self {
55 num_walks: 10,
56 walk_length: 80,
57 p: 1.0,
58 q: 1.0,
59 random_seed: 42,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct Node2VecConfig {
67 pub walk: Node2VecWalkConfig,
69 pub embedding_dim: usize,
71 pub window_size: usize,
73 pub num_epochs: usize,
75 pub learning_rate: f64,
77 pub normalize: bool,
79}
80
81impl Default for Node2VecConfig {
82 fn default() -> Self {
83 Self {
84 walk: Node2VecWalkConfig::default(),
85 embedding_dim: 128,
86 window_size: 5,
87 num_epochs: 5,
88 learning_rate: 0.025,
89 normalize: true,
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
98pub struct Node2VecEmbeddings {
99 pub embeddings: HashMap<String, Vec<f64>>,
101 pub dim: usize,
103 pub total_walk_steps: usize,
105}
106
107impl Node2VecEmbeddings {
108 pub fn get(&self, node: &str) -> Option<&[f64]> {
110 self.embeddings.get(node).map(|v| v.as_slice())
111 }
112
113 pub fn cosine_similarity(&self, a: &str, b: &str) -> Option<f64> {
117 let ea = self.embeddings.get(a)?;
118 let eb = self.embeddings.get(b)?;
119
120 let dot: f64 = ea.iter().zip(eb.iter()).map(|(x, y)| x * y).sum();
121 let norm_a: f64 = ea.iter().map(|x| x * x).sum::<f64>().sqrt();
122 let norm_b: f64 = eb.iter().map(|x| x * x).sum::<f64>().sqrt();
123
124 if norm_a < 1e-12 || norm_b < 1e-12 {
125 None
126 } else {
127 Some(dot / (norm_a * norm_b))
128 }
129 }
130
131 pub fn top_k_similar(&self, query: &str, k: usize) -> Vec<(String, f64)> {
133 let Some(eq) = self.embeddings.get(query) else {
134 return vec![];
135 };
136
137 let norm_q: f64 = eq.iter().map(|x| x * x).sum::<f64>().sqrt();
138 if norm_q < 1e-12 {
139 return vec![];
140 }
141
142 let mut scored: Vec<(String, f64)> = self
143 .embeddings
144 .iter()
145 .filter(|(node, _)| node.as_str() != query)
146 .map(|(node, emb)| {
147 let dot: f64 = emb.iter().zip(eq.iter()).map(|(x, y)| x * y).sum();
148 let norm_e: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
149 let sim = if norm_e < 1e-12 {
150 0.0
151 } else {
152 dot / (norm_q * norm_e)
153 };
154 (node.clone(), sim)
155 })
156 .collect();
157
158 scored.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
159 scored.truncate(k);
160 scored
161 }
162}
163
164struct AliasTable {
170 prob: Vec<f64>,
171 alias: Vec<usize>,
172}
173
174impl AliasTable {
175 fn build(weights: &[f64]) -> Option<Self> {
177 let n = weights.len();
178 if n == 0 {
179 return None;
180 }
181
182 let sum: f64 = weights.iter().sum();
183 if sum <= 0.0 {
184 return None;
185 }
186
187 let prob_norm: Vec<f64> = weights.iter().map(|w| w * n as f64 / sum).collect();
189
190 let mut small: Vec<usize> = Vec::with_capacity(n);
191 let mut large: Vec<usize> = Vec::with_capacity(n);
192 let mut prob = prob_norm.clone();
193 let mut alias = vec![0usize; n];
194
195 for (i, &p) in prob_norm.iter().enumerate() {
196 if p < 1.0 {
197 small.push(i);
198 } else {
199 large.push(i);
200 }
201 }
202
203 while !small.is_empty() && !large.is_empty() {
204 let l = small.pop().expect("checked non-empty");
205 let g = large.last().copied().expect("checked non-empty");
206
207 alias[l] = g;
208 prob[g] -= 1.0 - prob[l];
209
210 if prob[g] < 1.0 {
211 large.pop();
212 small.push(g);
213 }
214 }
215
216 Some(Self { prob, alias })
217 }
218
219 fn sample(&self, rng: &mut CoreRandom<StdRng>) -> usize {
221 let n = self.prob.len();
222 let i = (rng.random_range(0.0..1.0) * n as f64) as usize;
224 let i = i.min(n - 1);
225 if rng.random_range(0.0..1.0) < self.prob[i] {
227 i
228 } else {
229 self.alias[i]
230 }
231 }
232}
233
234type EdgeAlias = HashMap<(NodeIndex, NodeIndex), (Vec<NodeIndex>, AliasTable)>;
238type NodeAlias = HashMap<NodeIndex, (Vec<NodeIndex>, AliasTable)>;
240
241pub struct Node2VecEmbedder {
259 config: Node2VecConfig,
260}
261
262impl Node2VecEmbedder {
263 pub fn new(config: Node2VecConfig) -> Self {
265 Self { config }
266 }
267
268 pub fn with_defaults() -> Self {
270 Self::new(Node2VecConfig::default())
271 }
272
273 pub fn embed(&self, triples: &[Triple]) -> GraphRAGResult<Node2VecEmbeddings> {
275 let (graph, node_map) = self.build_graph(triples);
277
278 if graph.node_count() == 0 {
279 return Ok(Node2VecEmbeddings {
280 embeddings: HashMap::new(),
281 dim: self.config.embedding_dim,
282 total_walk_steps: 0,
283 });
284 }
285
286 let mut rng = seeded_rng(self.config.walk.random_seed);
287
288 let (node_alias, edge_alias) = self.build_alias_tables(&graph)?;
290
291 let (walks, total_steps) =
293 self.simulate_walks(&graph, &node_map, &node_alias, &edge_alias, &mut rng);
294
295 let embeddings = self.train_skip_gram(&walks, &node_map, &mut rng)?;
297
298 Ok(Node2VecEmbeddings {
299 embeddings,
300 dim: self.config.embedding_dim,
301 total_walk_steps: total_steps,
302 })
303 }
304
305 fn build_graph(&self, triples: &[Triple]) -> (UnGraph<String, ()>, HashMap<String, NodeIndex>) {
308 let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
309 let mut node_map: HashMap<String, NodeIndex> = HashMap::new();
310
311 for triple in triples {
312 let s = *node_map
313 .entry(triple.subject.clone())
314 .or_insert_with(|| graph.add_node(triple.subject.clone()));
315 let o = *node_map
316 .entry(triple.object.clone())
317 .or_insert_with(|| graph.add_node(triple.object.clone()));
318 if s != o && graph.find_edge(s, o).is_none() {
319 graph.add_edge(s, o, ());
320 }
321 }
322
323 (graph, node_map)
324 }
325
326 fn build_alias_tables(
332 &self,
333 graph: &UnGraph<String, ()>,
334 ) -> GraphRAGResult<(NodeAlias, EdgeAlias)> {
335 let p = self.config.walk.p;
336 let q = self.config.walk.q;
337
338 let mut node_alias: NodeAlias = HashMap::new();
340 for node in graph.node_indices() {
341 let neighbors: Vec<NodeIndex> = graph.neighbors(node).collect();
342 if neighbors.is_empty() {
343 continue;
344 }
345 let weights: Vec<f64> = vec![1.0; neighbors.len()];
346 if let Some(table) = AliasTable::build(&weights) {
347 node_alias.insert(node, (neighbors, table));
348 }
349 }
350
351 let mut edge_alias: EdgeAlias = HashMap::new();
353 for edge in graph.edge_indices() {
354 let (u, v) = graph
355 .edge_endpoints(edge)
356 .ok_or_else(|| GraphRAGError::InternalError("bad edge".to_string()))?;
357
358 for (prev, cur) in [(u, v), (v, u)] {
360 let neighbors: Vec<NodeIndex> = graph.neighbors(cur).collect();
361 if neighbors.is_empty() {
362 continue;
363 }
364 let weights: Vec<f64> = neighbors
365 .iter()
366 .map(|&next| {
367 if next == prev {
368 1.0 / p
370 } else if graph.find_edge(prev, next).is_some() {
371 1.0
373 } else {
374 1.0 / q
376 }
377 })
378 .collect();
379
380 if let Some(table) = AliasTable::build(&weights) {
381 edge_alias.insert((prev, cur), (neighbors, table));
382 }
383 }
384 }
385
386 Ok((node_alias, edge_alias))
387 }
388
389 fn simulate_walks(
392 &self,
393 graph: &UnGraph<String, ()>,
394 node_map: &HashMap<String, NodeIndex>,
395 node_alias: &NodeAlias,
396 edge_alias: &EdgeAlias,
397 rng: &mut CoreRandom<StdRng>,
398 ) -> (Vec<Vec<String>>, usize) {
399 let walk_length = self.config.walk.walk_length;
400 let num_walks = self.config.walk.num_walks;
401
402 let node_indices: Vec<NodeIndex> = node_map.values().copied().collect();
403 let mut walks: Vec<Vec<String>> = Vec::with_capacity(num_walks * node_indices.len());
404 let mut total_steps = 0usize;
405
406 for _ in 0..num_walks {
407 let mut order = node_indices.clone();
409 for i in (1..order.len()).rev() {
410 let j = (rng.random_range(0.0..1.0) * (i + 1) as f64) as usize;
411 order.swap(i, j.min(i));
412 }
413
414 for &start in &order {
415 let walk = self.single_walk(graph, start, walk_length, node_alias, edge_alias, rng);
416 total_steps += walk.len();
417 walks.push(walk);
418 }
419 }
420
421 (walks, total_steps)
422 }
423
424 fn single_walk(
425 &self,
426 graph: &UnGraph<String, ()>,
427 start: NodeIndex,
428 walk_length: usize,
429 node_alias: &NodeAlias,
430 edge_alias: &EdgeAlias,
431 rng: &mut CoreRandom<StdRng>,
432 ) -> Vec<String> {
433 let mut walk: Vec<String> = Vec::with_capacity(walk_length);
434
435 if let Some(label) = graph.node_weight(start) {
437 walk.push(label.clone());
438 } else {
439 return walk;
440 }
441
442 let mut current = start;
443 let mut prev: Option<NodeIndex> = None;
444
445 for _ in 1..walk_length {
446 let next = if let Some(p) = prev {
447 if let Some((neighbors, table)) = edge_alias.get(&(p, current)) {
449 let idx = table.sample(rng);
450 neighbors.get(idx).copied()
451 } else {
452 None
453 }
454 } else {
455 if let Some((neighbors, table)) = node_alias.get(¤t) {
457 let idx = table.sample(rng);
458 neighbors.get(idx).copied()
459 } else {
460 None
461 }
462 };
463
464 match next {
465 Some(n) => {
466 if let Some(label) = graph.node_weight(n) {
467 walk.push(label.clone());
468 }
469 prev = Some(current);
470 current = n;
471 }
472 None => break, }
474 }
475
476 walk
477 }
478
479 fn train_skip_gram(
487 &self,
488 walks: &[Vec<String>],
489 node_map: &HashMap<String, NodeIndex>,
490 rng: &mut CoreRandom<StdRng>,
491 ) -> GraphRAGResult<HashMap<String, Vec<f64>>> {
492 let dim = self.config.embedding_dim;
493 let window = self.config.window_size;
494 let lr_init = self.config.learning_rate;
495
496 let mut embeddings: HashMap<String, Vec<f64>> = HashMap::new();
498 for node_label in node_map.keys() {
499 let emb: Vec<f64> = (0..dim)
500 .map(|_| (rng.random_range(0.0..1.0) - 0.5) / dim as f64)
501 .collect();
502 embeddings.insert(node_label.clone(), emb);
503 }
504
505 let mut ctx_embeddings: HashMap<String, Vec<f64>> = HashMap::new();
507 for node_label in node_map.keys() {
508 ctx_embeddings.insert(node_label.clone(), vec![0.0f64; dim]);
509 }
510
511 let total_epochs = self.config.num_epochs;
512 let total_pairs: usize = walks
513 .iter()
514 .map(|w| w.len() * (2 * window).min(if w.len() > 1 { w.len() - 1 } else { 0 }))
515 .sum();
516
517 let mut pair_count = 0usize;
518
519 for epoch in 0..total_epochs {
520 let lr = lr_init * (1.0 - epoch as f64 / total_epochs as f64).max(0.001);
522
523 for walk in walks {
524 for (i, target) in walk.iter().enumerate() {
525 let start = i.saturating_sub(window);
526 let end = (i + window + 1).min(walk.len());
527
528 for (j, context) in walk[start..end].iter().enumerate() {
529 let abs_j = start + j;
530 if abs_j == i || context == target {
531 continue;
532 }
533
534 let local_lr =
536 lr * (1.0 - pair_count as f64 / (total_pairs + 1) as f64).max(0.001);
537
538 self.sgd_update(
539 target,
540 context,
541 &mut embeddings,
542 &mut ctx_embeddings,
543 local_lr,
544 dim,
545 );
546
547 pair_count += 1;
548 }
549 }
550 }
551 }
552
553 for (node, emb) in &mut embeddings {
555 if let Some(ctx) = ctx_embeddings.get(node) {
556 for (e, c) in emb.iter_mut().zip(ctx.iter()) {
557 *e = (*e + c) / 2.0;
558 }
559 }
560 }
561
562 if self.config.normalize {
564 for emb in embeddings.values_mut() {
565 let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
566 if norm > 1e-12 {
567 for v in emb.iter_mut() {
568 *v /= norm;
569 }
570 }
571 }
572 }
573
574 Ok(embeddings)
575 }
576
577 fn sgd_update(
584 &self,
585 target: &str,
586 context: &str,
587 embeddings: &mut HashMap<String, Vec<f64>>,
588 ctx_embeddings: &mut HashMap<String, Vec<f64>>,
589 lr: f64,
590 dim: usize,
591 ) {
592 let score = {
594 let Some(te) = embeddings.get(target) else {
595 return;
596 };
597 let Some(ce) = ctx_embeddings.get(context) else {
598 return;
599 };
600 te.iter().zip(ce.iter()).map(|(a, b)| a * b).sum::<f64>()
601 };
602
603 let sigma = 1.0 / (1.0 + (-score).exp());
605 let grad = (1.0 - sigma) * lr;
606
607 let te_snap: Vec<f64> = match embeddings.get(target) {
609 Some(v) => v.clone(),
610 None => return,
611 };
612 let ce_snap: Vec<f64> = match ctx_embeddings.get(context) {
613 Some(v) => v.clone(),
614 None => return,
615 };
616
617 if let Some(te) = embeddings.get_mut(target) {
619 for k in 0..dim {
620 te[k] += grad * ce_snap[k];
621 }
622 }
623
624 if let Some(ce) = ctx_embeddings.get_mut(context) {
626 for k in 0..dim {
627 ce[k] += grad * te_snap[k];
628 }
629 }
630 }
631}
632
633#[cfg(test)]
636mod tests {
637 use super::*;
638 use crate::Triple;
639
640 fn ring_triples(n: usize) -> Vec<Triple> {
641 (0..n)
642 .map(|i| {
643 Triple::new(
644 format!("node_{}", i),
645 "connects",
646 format!("node_{}", (i + 1) % n),
647 )
648 })
649 .collect()
650 }
651
652 fn complete_triples(n: usize) -> Vec<Triple> {
653 let mut ts = Vec::new();
654 for i in 0..n {
655 for j in i + 1..n {
656 ts.push(Triple::new(format!("n{}", i), "edge", format!("n{}", j)));
657 }
658 }
659 ts
660 }
661
662 fn small_config() -> Node2VecConfig {
663 Node2VecConfig {
664 walk: Node2VecWalkConfig {
665 num_walks: 5,
666 walk_length: 10,
667 p: 1.0,
668 q: 1.0,
669 random_seed: 99,
670 },
671 embedding_dim: 16,
672 window_size: 2,
673 num_epochs: 3,
674 learning_rate: 0.05,
675 normalize: true,
676 }
677 }
678
679 #[test]
680 fn test_embed_produces_correct_number_of_embeddings() {
681 let triples = ring_triples(6);
682 let embedder = Node2VecEmbedder::new(small_config());
683 let result = embedder.embed(&triples).expect("embed failed");
684 assert_eq!(result.embeddings.len(), 6);
686 assert_eq!(result.dim, 16);
687 }
688
689 #[test]
690 fn test_embed_correct_dimension() {
691 let triples = complete_triples(4);
692 let embedder = Node2VecEmbedder::new(small_config());
693 let result = embedder.embed(&triples).expect("embed failed");
694 for emb in result.embeddings.values() {
695 assert_eq!(emb.len(), 16);
696 }
697 }
698
699 #[test]
700 fn test_normalized_embeddings_have_unit_norm() {
701 let triples = ring_triples(5);
702 let config = Node2VecConfig {
703 normalize: true,
704 ..small_config()
705 };
706 let embedder = Node2VecEmbedder::new(config);
707 let result = embedder.embed(&triples).expect("embed failed");
708 for (node, emb) in &result.embeddings {
709 let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
710 assert!(
711 (norm - 1.0).abs() < 1e-6,
712 "node {} has non-unit norm {:.6}",
713 node,
714 norm
715 );
716 }
717 }
718
719 #[test]
720 fn test_cosine_similarity_in_range() {
721 let triples = complete_triples(5);
722 let embedder = Node2VecEmbedder::new(small_config());
723 let result = embedder.embed(&triples).expect("embed failed");
724
725 let nodes: Vec<String> = result.embeddings.keys().cloned().collect();
726 if nodes.len() >= 2 {
727 if let Some(sim) = result.cosine_similarity(&nodes[0], &nodes[1]) {
728 assert!(
729 (-1.0 - 1e-9..=1.0 + 1e-9).contains(&sim),
730 "cosine similarity out of range: {}",
731 sim
732 );
733 }
734 }
735 }
736
737 #[test]
738 fn test_top_k_similar_returns_at_most_k() {
739 let triples = ring_triples(8);
740 let embedder = Node2VecEmbedder::new(small_config());
741 let result = embedder.embed(&triples).expect("embed failed");
742
743 let similar = result.top_k_similar("node_0", 3);
744 assert!(similar.len() <= 3);
745 }
746
747 #[test]
748 fn test_empty_triples_returns_empty_embeddings() {
749 let embedder = Node2VecEmbedder::new(small_config());
750 let result = embedder.embed(&[]).expect("embed failed");
751 assert!(result.embeddings.is_empty());
752 assert_eq!(result.total_walk_steps, 0);
753 }
754
755 #[test]
756 fn test_single_node_isolated() {
757 let triples = vec![Triple::new("a", "r", "b")];
760 let embedder = Node2VecEmbedder::new(small_config());
761 let result = embedder.embed(&triples).expect("embed failed");
762 assert_eq!(result.embeddings.len(), 2);
763 }
764
765 #[test]
766 fn test_walk_bias_dfs_vs_bfs() {
767 let triples = ring_triples(10);
769
770 let dfs_config = Node2VecConfig {
771 walk: Node2VecWalkConfig {
772 num_walks: 3,
773 walk_length: 20,
774 p: 0.25,
775 q: 0.25,
776 random_seed: 1,
777 },
778 ..small_config()
779 };
780 let bfs_config = Node2VecConfig {
781 walk: Node2VecWalkConfig {
782 num_walks: 3,
783 walk_length: 20,
784 p: 4.0,
785 q: 4.0,
786 random_seed: 1,
787 },
788 ..small_config()
789 };
790
791 let embedder_dfs = Node2VecEmbedder::new(dfs_config);
792 let embedder_bfs = Node2VecEmbedder::new(bfs_config);
793
794 let res_dfs = embedder_dfs.embed(&triples).expect("dfs embed failed");
795 let res_bfs = embedder_bfs.embed(&triples).expect("bfs embed failed");
796
797 assert_eq!(res_dfs.embeddings.len(), 10);
799 assert_eq!(res_bfs.embeddings.len(), 10);
800 }
801
802 #[test]
803 fn test_total_walk_steps_is_plausible() {
804 let n = 5usize;
805 let triples = ring_triples(n);
806 let config = Node2VecConfig {
807 walk: Node2VecWalkConfig {
808 num_walks: 2,
809 walk_length: 10,
810 ..Default::default()
811 },
812 ..small_config()
813 };
814 let embedder = Node2VecEmbedder::new(config);
815 let result = embedder.embed(&triples).expect("embed failed");
816
817 assert!(
820 result.total_walk_steps >= n * 2,
821 "expected ≥{} steps, got {}",
822 n * 2,
823 result.total_walk_steps
824 );
825 assert!(
826 result.total_walk_steps <= n * 2 * 10 + n * 2,
827 "unexpectedly many steps: {}",
828 result.total_walk_steps
829 );
830 }
831
832 #[test]
833 fn test_alias_table_samples_valid_index() {
834 let weights = vec![1.0, 2.0, 3.0, 4.0];
835 let table = AliasTable::build(&weights).expect("alias build failed");
836 let mut rng = seeded_rng(777);
837 for _ in 0..100 {
838 let idx = table.sample(&mut rng);
839 assert!(idx < weights.len());
840 }
841 }
842
843 #[test]
844 fn test_alias_table_uniform_weights() {
845 let weights = vec![1.0; 4];
847 let table = AliasTable::build(&weights).expect("alias build failed");
848 let mut rng = seeded_rng(42);
849 let mut counts = [0usize; 4];
850 for _ in 0..4000 {
851 counts[table.sample(&mut rng)] += 1;
852 }
853 for c in counts {
855 assert!(c > 800 && c < 1200, "bucket count out of range: {}", c);
856 }
857 }
858}