1use scirs2_core::random::Random;
12
13#[derive(Debug, Clone)]
17pub struct GraphEdge {
18 pub from: usize,
19 pub to: usize,
20 pub weight: f32,
21}
22
23#[derive(Debug, Clone)]
25pub struct Graph {
26 pub node_count: usize,
27 pub edges: Vec<GraphEdge>,
28}
29
30impl Graph {
31 pub fn new(node_count: usize) -> Self {
33 Self {
34 node_count,
35 edges: Vec::new(),
36 }
37 }
38
39 pub fn add_edge(&mut self, from: usize, to: usize, weight: f32) {
41 self.edges.push(GraphEdge { from, to, weight });
42 }
43
44 pub fn edge_count(&self) -> usize {
46 self.edges.len()
47 }
48
49 pub fn is_connected(&self) -> bool {
53 if self.node_count == 0 {
54 return true;
55 }
56 let mut visited = vec![false; self.node_count];
57 let mut queue = std::collections::VecDeque::new();
58 queue.push_back(0usize);
59 visited[0] = true;
60 let mut count = 1usize;
61
62 while let Some(node) = queue.pop_front() {
63 for edge in &self.edges {
64 let neighbor = if edge.from == node {
65 Some(edge.to)
66 } else if edge.to == node {
67 Some(edge.from)
68 } else {
69 None
70 };
71 if let Some(nb) = neighbor {
72 if nb < self.node_count && !visited[nb] {
73 visited[nb] = true;
74 count += 1;
75 queue.push_back(nb);
76 }
77 }
78 }
79 }
80 count == self.node_count
81 }
82
83 pub fn adjacency_matrix(&self) -> Vec<Vec<f32>> {
87 let n = self.node_count;
88 let mut mat = vec![vec![0.0f32; n]; n];
89 for edge in &self.edges {
90 if edge.from < n && edge.to < n {
91 mat[edge.from][edge.to] = edge.weight;
92 mat[edge.to][edge.from] = edge.weight; }
94 }
95 mat
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct WalkConfig {
102 pub walk_length: usize,
104 pub walks_per_node: usize,
106 pub return_param_p: f32,
108 pub in_out_param_q: f32,
110}
111
112impl Default for WalkConfig {
113 fn default() -> Self {
114 Self {
115 walk_length: 10,
116 walks_per_node: 5,
117 return_param_p: 1.0,
118 in_out_param_q: 1.0,
119 }
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct NodeEmbedding {
126 pub node_id: usize,
127 pub vector: Vec<f32>,
128}
129
130#[derive(Debug, Clone)]
132pub struct EmbeddingResult {
133 pub embeddings: Vec<NodeEmbedding>,
134 pub walk_count: usize,
135}
136
137pub struct GraphEmbedder;
141
142impl GraphEmbedder {
143 pub fn random_walks(graph: &Graph, config: &WalkConfig) -> Vec<Vec<usize>> {
150 let mut rng = Random::default();
151 let mut walks = Vec::with_capacity(graph.node_count * config.walks_per_node);
152
153 let adj = Self::build_adjacency(graph);
155
156 for _ in 0..config.walks_per_node {
157 for start in 0..graph.node_count {
158 let walk = Self::single_walk(
159 &adj,
160 graph.node_count,
161 start,
162 config.walk_length,
163 config.return_param_p,
164 config.in_out_param_q,
165 &mut rng,
166 );
167 walks.push(walk);
168 }
169 }
170 walks
171 }
172
173 pub fn embed(graph: &Graph, config: &WalkConfig, dim: usize) -> EmbeddingResult {
178 let walks = Self::random_walks(graph, config);
179 let walk_count = walks.len();
180 let n = graph.node_count;
181
182 let mut accum = vec![vec![0.0f64; n]; n];
184 let window = 2usize; for walk in &walks {
187 for (idx, ¢er) in walk.iter().enumerate() {
188 let lo = idx.saturating_sub(window);
189 let hi = (idx + window + 1).min(walk.len());
190 for &ctx in &walk[lo..hi] {
191 if ctx != center {
192 accum[center][ctx] += 1.0;
193 }
194 }
195 }
196 }
197
198 let embeddings: Vec<NodeEmbedding> = (0..n)
200 .map(|node_id| {
201 let row = &accum[node_id];
202 let vector = Self::project_row(row, dim, node_id);
203 NodeEmbedding { node_id, vector }
204 })
205 .collect();
206
207 EmbeddingResult {
208 embeddings,
209 walk_count,
210 }
211 }
212
213 pub fn structural_embedding(graph: &Graph, dim: usize) -> Vec<NodeEmbedding> {
217 let n = graph.node_count;
218 (0..n)
219 .map(|node_id| {
220 let neighbors = Self::neighbors(graph, node_id);
221 let deg = neighbors.len() as f64;
223 let sum_nb_deg: f64 = neighbors
224 .iter()
225 .map(|&nb| Self::degree(graph, nb) as f64)
226 .sum();
227 let sum_weight: f64 = graph
228 .edges
229 .iter()
230 .filter(|e| e.from == node_id || e.to == node_id)
231 .map(|e| e.weight as f64)
232 .sum();
233
234 let raw = vec![deg, sum_nb_deg, sum_weight, node_id as f64];
235 let vector = Self::project_row(&raw, dim, node_id);
236 NodeEmbedding { node_id, vector }
237 })
238 .collect()
239 }
240
241 pub fn node_similarity(a: &NodeEmbedding, b: &NodeEmbedding) -> f32 {
243 let len = a.vector.len().min(b.vector.len());
244 if len == 0 {
245 return 0.0;
246 }
247 let dot: f32 = a.vector[..len]
248 .iter()
249 .zip(b.vector[..len].iter())
250 .map(|(x, y)| x * y)
251 .sum();
252 let norm_a: f32 = a.vector[..len].iter().map(|x| x * x).sum::<f32>().sqrt();
253 let norm_b: f32 = b.vector[..len].iter().map(|x| x * x).sum::<f32>().sqrt();
254 if norm_a == 0.0 || norm_b == 0.0 {
255 return 0.0;
256 }
257 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
258 }
259
260 pub fn neighbors(graph: &Graph, node: usize) -> Vec<usize> {
262 let mut nbs: Vec<usize> = graph
263 .edges
264 .iter()
265 .filter_map(|e| {
266 if e.from == node {
267 Some(e.to)
268 } else if e.to == node {
269 Some(e.from)
270 } else {
271 None
272 }
273 })
274 .collect();
275 nbs.sort_unstable();
276 nbs.dedup();
277 nbs
278 }
279
280 pub fn degree(graph: &Graph, node: usize) -> usize {
282 Self::neighbors(graph, node).len()
283 }
284
285 fn build_adjacency(graph: &Graph) -> Vec<Vec<(usize, f32)>> {
289 let n = graph.node_count;
290 let mut adj: Vec<Vec<(usize, f32)>> = vec![Vec::new(); n];
291 for edge in &graph.edges {
292 if edge.from < n && edge.to < n {
293 adj[edge.from].push((edge.to, edge.weight));
294 adj[edge.to].push((edge.from, edge.weight)); }
296 }
297 adj
298 }
299
300 fn single_walk(
302 adj: &[Vec<(usize, f32)>],
303 _node_count: usize,
304 start: usize,
305 walk_length: usize,
306 p: f32,
307 q: f32,
308 rng: &mut Random,
309 ) -> Vec<usize> {
310 let mut walk = Vec::with_capacity(walk_length);
311 walk.push(start);
312
313 if adj[start].is_empty() || walk_length <= 1 {
314 while walk.len() < walk_length {
316 walk.push(start);
317 }
318 return walk;
319 }
320
321 let first_idx = (rng.random_range(0.0..1.0) * adj[start].len() as f64) as usize;
323 walk.push(adj[start][first_idx].0);
324
325 while walk.len() < walk_length {
326 let cur = *walk.last().expect("walk is non-empty");
327 let prev = walk[walk.len() - 2];
328
329 if adj[cur].is_empty() {
330 walk.push(cur); continue;
332 }
333
334 let weights: Vec<f32> = adj[cur]
336 .iter()
337 .map(|&(nb, w)| {
338 let bias = if nb == prev {
339 1.0 / p } else if adj[prev].iter().any(|&(x, _)| x == nb) {
341 1.0 } else {
343 1.0 / q };
345 w * bias
346 })
347 .collect();
348
349 let total: f32 = weights.iter().sum();
350 let sample = (rng.random_range(0.0..1.0) as f32) * total;
351 let mut cumulative = 0.0f32;
352 let mut chosen = adj[cur][0].0;
353 for (i, &wt) in weights.iter().enumerate() {
354 cumulative += wt;
355 if sample <= cumulative {
356 chosen = adj[cur][i].0;
357 break;
358 }
359 }
360 walk.push(chosen);
361 }
362 walk
363 }
364
365 fn project_row(row: &[f64], dim: usize, node_id: usize) -> Vec<f32> {
368 use std::f64::consts::PI;
369 if dim == 0 {
370 return vec![];
371 }
372
373 let norm: f64 = row.iter().map(|x| x * x).sum::<f64>().sqrt();
375 let sum: f64 = row.iter().sum();
376
377 let mut vec = Vec::with_capacity(dim);
378 for d in 0..dim {
379 let angle =
381 (node_id as f64 * 0.1 + d as f64 * 1.3 + sum * 0.01) * PI / (dim as f64 + 1.0);
382 let val = (angle.sin() * (norm + 1.0).ln()) as f32;
383 vec.push(val);
384 }
385 vec
386 }
387}
388
389#[cfg(test)]
392mod tests {
393 use super::*;
394
395 fn triangle() -> Graph {
399 let mut g = Graph::new(3);
400 g.add_edge(0, 1, 1.0);
401 g.add_edge(1, 2, 1.0);
402 g.add_edge(2, 0, 1.0);
403 g
404 }
405
406 fn path4() -> Graph {
408 let mut g = Graph::new(4);
409 g.add_edge(0, 1, 1.0);
410 g.add_edge(1, 2, 1.0);
411 g.add_edge(2, 3, 1.0);
412 g
413 }
414
415 fn disconnected() -> Graph {
417 let mut g = Graph::new(4);
418 g.add_edge(0, 1, 1.0);
419 g.add_edge(2, 3, 1.0);
420 g
421 }
422
423 fn default_config() -> WalkConfig {
424 WalkConfig {
425 walk_length: 5,
426 walks_per_node: 3,
427 return_param_p: 1.0,
428 in_out_param_q: 1.0,
429 }
430 }
431
432 #[test]
435 fn test_graph_new_no_edges() {
436 let g = Graph::new(5);
437 assert_eq!(g.node_count, 5);
438 assert_eq!(g.edge_count(), 0);
439 }
440
441 #[test]
442 fn test_add_edge_increments_count() {
443 let mut g = Graph::new(3);
444 g.add_edge(0, 1, 2.0);
445 assert_eq!(g.edge_count(), 1);
446 g.add_edge(1, 2, 1.5);
447 assert_eq!(g.edge_count(), 2);
448 }
449
450 #[test]
451 fn test_edge_stored_correctly() {
452 let mut g = Graph::new(3);
453 g.add_edge(0, 2, 0.7);
454 let e = &g.edges[0];
455 assert_eq!(e.from, 0);
456 assert_eq!(e.to, 2);
457 assert!((e.weight - 0.7).abs() < 1e-6);
458 }
459
460 #[test]
463 fn test_is_connected_triangle() {
464 assert!(triangle().is_connected());
465 }
466
467 #[test]
468 fn test_is_connected_path() {
469 assert!(path4().is_connected());
470 }
471
472 #[test]
473 fn test_is_connected_disconnected() {
474 assert!(!disconnected().is_connected());
475 }
476
477 #[test]
478 fn test_is_connected_single_node() {
479 let g = Graph::new(1);
480 assert!(g.is_connected());
481 }
482
483 #[test]
484 fn test_is_connected_empty_graph() {
485 let g = Graph::new(0);
486 assert!(g.is_connected()); }
488
489 #[test]
492 fn test_neighbors_triangle() {
493 let g = triangle();
494 let nb0 = GraphEmbedder::neighbors(&g, 0);
495 assert!(nb0.contains(&1), "0 should be neighbor of 1");
496 assert!(nb0.contains(&2), "2 should be neighbor of 0");
497 assert_eq!(nb0.len(), 2);
498 }
499
500 #[test]
501 fn test_neighbors_path_endpoint() {
502 let g = path4();
503 let nb0 = GraphEmbedder::neighbors(&g, 0);
504 assert_eq!(nb0, vec![1]);
505 }
506
507 #[test]
508 fn test_neighbors_isolated_node() {
509 let g = Graph::new(3); let nb = GraphEmbedder::neighbors(&g, 1);
511 assert!(nb.is_empty());
512 }
513
514 #[test]
515 fn test_degree_triangle() {
516 let g = triangle();
517 assert_eq!(GraphEmbedder::degree(&g, 0), 2);
518 assert_eq!(GraphEmbedder::degree(&g, 1), 2);
519 assert_eq!(GraphEmbedder::degree(&g, 2), 2);
520 }
521
522 #[test]
523 fn test_degree_path_middle() {
524 let g = path4();
525 assert_eq!(GraphEmbedder::degree(&g, 1), 2);
526 }
527
528 #[test]
529 fn test_degree_isolated() {
530 let g = Graph::new(3);
531 assert_eq!(GraphEmbedder::degree(&g, 0), 0);
532 }
533
534 #[test]
537 fn test_adjacency_matrix_size() {
538 let g = triangle();
539 let mat = g.adjacency_matrix();
540 assert_eq!(mat.len(), 3);
541 assert_eq!(mat[0].len(), 3);
542 }
543
544 #[test]
545 #[allow(clippy::needless_range_loop)]
546 fn test_adjacency_matrix_symmetric() {
547 let g = path4();
548 let mat = g.adjacency_matrix();
549 for i in 0..4 {
550 for j in 0..4 {
551 assert!(
552 (mat[i][j] - mat[j][i]).abs() < 1e-6,
553 "adjacency matrix must be symmetric"
554 );
555 }
556 }
557 }
558
559 #[test]
560 fn test_adjacency_matrix_zero_diagonal() {
561 let g = triangle();
562 let mat = g.adjacency_matrix();
563 for (i, row) in mat.iter().enumerate() {
564 assert_eq!(row[i], 0.0, "diagonal must be zero (no self-loops)");
565 }
566 }
567
568 #[test]
569 fn test_adjacency_matrix_edge_weight() {
570 let mut g = Graph::new(3);
571 g.add_edge(0, 1, 3.5);
572 let mat = g.adjacency_matrix();
573 assert!((mat[0][1] - 3.5).abs() < 1e-6);
574 assert!((mat[1][0] - 3.5).abs() < 1e-6);
575 }
576
577 #[test]
580 fn test_random_walks_count() {
581 let g = triangle();
582 let config = default_config();
583 let walks = GraphEmbedder::random_walks(&g, &config);
584 assert_eq!(walks.len(), 9, "expected 9 walks");
586 }
587
588 #[test]
589 fn test_random_walks_length() {
590 let g = triangle();
591 let config = default_config();
592 let walks = GraphEmbedder::random_walks(&g, &config);
593 for w in &walks {
594 assert_eq!(
595 w.len(),
596 config.walk_length,
597 "each walk must have walk_length nodes"
598 );
599 }
600 }
601
602 #[test]
603 fn test_random_walks_node_ids_valid() {
604 let g = path4();
605 let config = default_config();
606 let walks = GraphEmbedder::random_walks(&g, &config);
607 for w in &walks {
608 for &node in w {
609 assert!(node < g.node_count, "node id must be < node_count");
610 }
611 }
612 }
613
614 #[test]
615 fn test_random_walks_isolated_nodes() {
616 let g = Graph::new(3); let config = WalkConfig {
618 walk_length: 4,
619 walks_per_node: 2,
620 ..Default::default()
621 };
622 let walks = GraphEmbedder::random_walks(&g, &config);
623 assert_eq!(walks.len(), 6);
624 for w in &walks {
625 assert_eq!(w.len(), 4);
626 }
627 }
628
629 #[test]
632 fn test_embed_returns_node_count_embeddings() {
633 let g = triangle();
634 let config = default_config();
635 let result = GraphEmbedder::embed(&g, &config, 8);
636 assert_eq!(result.embeddings.len(), g.node_count);
637 }
638
639 #[test]
640 fn test_embed_correct_walk_count() {
641 let g = triangle();
642 let config = default_config();
643 let result = GraphEmbedder::embed(&g, &config, 8);
644 assert_eq!(result.walk_count, config.walks_per_node * g.node_count);
645 }
646
647 #[test]
648 fn test_embed_dimension() {
649 let g = triangle();
650 let config = default_config();
651 let result = GraphEmbedder::embed(&g, &config, 16);
652 for emb in &result.embeddings {
653 assert_eq!(emb.vector.len(), 16, "embedding dimension must match dim");
654 }
655 }
656
657 #[test]
658 fn test_embed_node_ids_assigned() {
659 let g = path4();
660 let config = default_config();
661 let result = GraphEmbedder::embed(&g, &config, 4);
662 for (i, emb) in result.embeddings.iter().enumerate() {
663 assert_eq!(emb.node_id, i);
664 }
665 }
666
667 #[test]
670 fn test_structural_embedding_count() {
671 let g = triangle();
672 let embeddings = GraphEmbedder::structural_embedding(&g, 8);
673 assert_eq!(embeddings.len(), g.node_count);
674 }
675
676 #[test]
677 fn test_structural_embedding_dimension() {
678 let g = path4();
679 let dim = 12;
680 let embeddings = GraphEmbedder::structural_embedding(&g, dim);
681 for emb in &embeddings {
682 assert_eq!(emb.vector.len(), dim);
683 }
684 }
685
686 #[test]
687 fn test_structural_embedding_node_ids() {
688 let g = triangle();
689 let embeddings = GraphEmbedder::structural_embedding(&g, 4);
690 for (i, emb) in embeddings.iter().enumerate() {
691 assert_eq!(emb.node_id, i);
692 }
693 }
694
695 #[test]
698 fn test_node_similarity_self_is_one() {
699 let emb = NodeEmbedding {
700 node_id: 0,
701 vector: vec![1.0, 0.0, 0.0],
702 };
703 let sim = GraphEmbedder::node_similarity(&emb, &emb);
704 assert!((sim - 1.0).abs() < 1e-6, "self similarity should be 1.0");
705 }
706
707 #[test]
708 fn test_node_similarity_orthogonal_is_zero() {
709 let a = NodeEmbedding {
710 node_id: 0,
711 vector: vec![1.0, 0.0],
712 };
713 let b = NodeEmbedding {
714 node_id: 1,
715 vector: vec![0.0, 1.0],
716 };
717 let sim = GraphEmbedder::node_similarity(&a, &b);
718 assert!(
719 sim.abs() < 1e-6,
720 "orthogonal vectors should have similarity 0"
721 );
722 }
723
724 #[test]
725 fn test_node_similarity_range() {
726 let g = path4();
727 let embeddings = GraphEmbedder::structural_embedding(&g, 8);
728 for a in &embeddings {
729 for b in &embeddings {
730 let sim = GraphEmbedder::node_similarity(a, b);
731 assert!(
732 (-1.0..=1.0).contains(&sim),
733 "similarity {sim} must be in [-1, 1]"
734 );
735 }
736 }
737 }
738
739 #[test]
740 fn test_node_similarity_empty_vectors_is_zero() {
741 let a = NodeEmbedding {
742 node_id: 0,
743 vector: vec![],
744 };
745 let b = NodeEmbedding {
746 node_id: 1,
747 vector: vec![],
748 };
749 assert_eq!(GraphEmbedder::node_similarity(&a, &b), 0.0);
750 }
751
752 #[test]
753 fn test_node_similarity_opposite_vectors() {
754 let a = NodeEmbedding {
755 node_id: 0,
756 vector: vec![1.0, 0.0],
757 };
758 let b = NodeEmbedding {
759 node_id: 1,
760 vector: vec![-1.0, 0.0],
761 };
762 let sim = GraphEmbedder::node_similarity(&a, &b);
763 assert!(
764 (sim + 1.0).abs() < 1e-6,
765 "opposite vectors: similarity = -1"
766 );
767 }
768
769 #[test]
772 fn test_embed_single_node() {
773 let g = Graph::new(1);
774 let config = WalkConfig {
775 walk_length: 3,
776 walks_per_node: 2,
777 ..Default::default()
778 };
779 let result = GraphEmbedder::embed(&g, &config, 4);
780 assert_eq!(result.embeddings.len(), 1);
781 assert_eq!(result.walk_count, 2);
782 }
783
784 #[test]
785 fn test_structural_embedding_zero_dim() {
786 let g = triangle();
787 let embeddings = GraphEmbedder::structural_embedding(&g, 0);
788 for emb in &embeddings {
789 assert!(emb.vector.is_empty());
790 }
791 }
792
793 #[test]
794 fn test_walk_config_default() {
795 let c = WalkConfig::default();
796 assert_eq!(c.walk_length, 10);
797 assert_eq!(c.walks_per_node, 5);
798 }
799
800 #[test]
801 fn test_walks_total_count_formula() {
802 let g = path4(); let config = WalkConfig {
804 walk_length: 6,
805 walks_per_node: 4,
806 ..Default::default()
807 };
808 let walks = GraphEmbedder::random_walks(&g, &config);
809 assert_eq!(walks.len(), 4 * 4, "4 nodes * 4 walks = 16");
810 }
811
812 #[test]
815 fn test_adjacency_matrix_path4_size() {
816 let g = path4(); let mat = g.adjacency_matrix();
818 assert_eq!(mat.len(), 4);
819 for row in &mat {
820 assert_eq!(row.len(), 4);
821 }
822 }
823
824 #[test]
825 #[allow(clippy::needless_range_loop)]
826 fn test_adjacency_matrix_path4_symmetric() {
827 let g = path4();
828 let mat = g.adjacency_matrix();
829 for i in 0..4 {
830 for j in 0..4 {
831 assert!(
832 (mat[i][j] - mat[j][i]).abs() < 1e-6,
833 "adjacency matrix should be symmetric"
834 );
835 }
836 }
837 }
838
839 #[test]
840 fn test_adjacency_matrix_no_self_loops_for_path() {
841 let g = path4();
842 let mat = g.adjacency_matrix();
843 for (i, row) in mat.iter().enumerate() {
844 assert_eq!(row[i], 0.0);
845 }
846 }
847
848 #[test]
849 fn test_degree_path_endpoint_is_one() {
850 let g = path4();
852 assert_eq!(GraphEmbedder::degree(&g, 0), 1);
853 assert_eq!(GraphEmbedder::degree(&g, 3), 1);
854 }
855
856 #[test]
857 fn test_degree_path_middle_is_two() {
858 let g = path4();
860 assert_eq!(GraphEmbedder::degree(&g, 1), 2);
861 assert_eq!(GraphEmbedder::degree(&g, 2), 2);
862 }
863
864 #[test]
865 fn test_embed_walk_count_equals_nodes_times_walks() {
866 let g = path4();
867 let config = WalkConfig {
868 walk_length: 5,
869 walks_per_node: 3,
870 ..Default::default()
871 };
872 let result = GraphEmbedder::embed(&g, &config, 4);
873 assert_eq!(
874 result.walk_count,
875 4 * 3,
876 "walk_count = nodes * walks_per_node"
877 );
878 }
879
880 #[test]
881 fn test_structural_embedding_node_ids_sequential() {
882 let g = path4();
883 let embeddings = GraphEmbedder::structural_embedding(&g, 6);
884 let ids: Vec<usize> = embeddings.iter().map(|e| e.node_id).collect();
885 let expected: Vec<usize> = (0..4).collect();
886 assert_eq!(ids, expected, "node_ids must be sequential from 0");
887 }
888}