Skip to main content

oxirs_graphrag/
graph_embedder.rs

1//! # Graph Embedder
2//!
3//! Node2Vec-inspired random walks and structural graph embeddings for knowledge graphs.
4//!
5//! Provides:
6//! - Biased random walks with return parameter p and in-out parameter q
7//! - Simple structural embeddings based on neighborhood aggregation
8//! - Cosine similarity between node embedding vectors
9//! - Graph connectivity checks (BFS) and adjacency matrix construction
10
11use scirs2_core::random::Random;
12
13// ─── Public types ─────────────────────────────────────────────────────────────
14
15/// A directed/undirected weighted edge in the graph
16#[derive(Debug, Clone)]
17pub struct GraphEdge {
18    pub from: usize,
19    pub to: usize,
20    pub weight: f32,
21}
22
23/// Simple sparse graph with node count and edge list
24#[derive(Debug, Clone)]
25pub struct Graph {
26    pub node_count: usize,
27    pub edges: Vec<GraphEdge>,
28}
29
30impl Graph {
31    /// Create a new empty graph with `node_count` nodes and no edges
32    pub fn new(node_count: usize) -> Self {
33        Self {
34            node_count,
35            edges: Vec::new(),
36        }
37    }
38
39    /// Add a directed edge from → to with the given weight
40    pub fn add_edge(&mut self, from: usize, to: usize, weight: f32) {
41        self.edges.push(GraphEdge { from, to, weight });
42    }
43
44    /// Total number of edges
45    pub fn edge_count(&self) -> usize {
46        self.edges.len()
47    }
48
49    /// BFS from node 0 to check if all nodes are reachable (i.e. graph is connected).
50    ///
51    /// Treats edges as undirected for the connectivity check.
52    pub fn is_connected(&self) -> bool {
53        if self.node_count == 0 {
54            return true;
55        }
56        let mut visited = vec![false; self.node_count];
57        let mut queue = std::collections::VecDeque::new();
58        queue.push_back(0usize);
59        visited[0] = true;
60        let mut count = 1usize;
61
62        while let Some(node) = queue.pop_front() {
63            for edge in &self.edges {
64                let neighbor = if edge.from == node {
65                    Some(edge.to)
66                } else if edge.to == node {
67                    Some(edge.from)
68                } else {
69                    None
70                };
71                if let Some(nb) = neighbor {
72                    if nb < self.node_count && !visited[nb] {
73                        visited[nb] = true;
74                        count += 1;
75                        queue.push_back(nb);
76                    }
77                }
78            }
79        }
80        count == self.node_count
81    }
82
83    /// Build a dense adjacency matrix (node_count × node_count).
84    ///
85    /// For undirected usage the matrix is symmetric (both from→to and to→from are set).
86    pub fn adjacency_matrix(&self) -> Vec<Vec<f32>> {
87        let n = self.node_count;
88        let mut mat = vec![vec![0.0f32; n]; n];
89        for edge in &self.edges {
90            if edge.from < n && edge.to < n {
91                mat[edge.from][edge.to] = edge.weight;
92                mat[edge.to][edge.from] = edge.weight; // symmetric
93            }
94        }
95        mat
96    }
97}
98
99/// Configuration for random walk generation
100#[derive(Debug, Clone)]
101pub struct WalkConfig {
102    /// Length of each random walk (number of nodes visited)
103    pub walk_length: usize,
104    /// Number of walks starting from each node
105    pub walks_per_node: usize,
106    /// Return parameter p (controls likelihood of revisiting a node)
107    pub return_param_p: f32,
108    /// In-out parameter q (controls exploration vs. exploitation)
109    pub in_out_param_q: f32,
110}
111
112impl Default for WalkConfig {
113    fn default() -> Self {
114        Self {
115            walk_length: 10,
116            walks_per_node: 5,
117            return_param_p: 1.0,
118            in_out_param_q: 1.0,
119        }
120    }
121}
122
123/// Node embedding vector
124#[derive(Debug, Clone)]
125pub struct NodeEmbedding {
126    pub node_id: usize,
127    pub vector: Vec<f32>,
128}
129
130/// Result from graph embedding
131#[derive(Debug, Clone)]
132pub struct EmbeddingResult {
133    pub embeddings: Vec<NodeEmbedding>,
134    pub walk_count: usize,
135}
136
137// ─── Graph Embedder ───────────────────────────────────────────────────────────
138
139/// Graph embedding engine: Node2Vec-inspired random walks and structural embeddings
140pub struct GraphEmbedder;
141
142impl GraphEmbedder {
143    /// Generate biased random walks for all nodes.
144    ///
145    /// Uses Node2Vec-style second-order biased walks governed by p and q.
146    /// When p=1 and q=1 the walk degenerates to a uniform random walk (DeepWalk).
147    ///
148    /// Returns a list of walks, each being a sequence of node IDs.
149    pub fn random_walks(graph: &Graph, config: &WalkConfig) -> Vec<Vec<usize>> {
150        let mut rng = Random::default();
151        let mut walks = Vec::with_capacity(graph.node_count * config.walks_per_node);
152
153        // Build adjacency list for fast neighbor lookup
154        let adj = Self::build_adjacency(graph);
155
156        for _ in 0..config.walks_per_node {
157            for start in 0..graph.node_count {
158                let walk = Self::single_walk(
159                    &adj,
160                    graph.node_count,
161                    start,
162                    config.walk_length,
163                    config.return_param_p,
164                    config.in_out_param_q,
165                    &mut rng,
166                );
167                walks.push(walk);
168            }
169        }
170        walks
171    }
172
173    /// Generate random-walk-based embeddings.
174    ///
175    /// Each node's embedding is derived from its walk co-occurrence profile
176    /// (simplified: averaged node-ID features from walk context).
177    pub fn embed(graph: &Graph, config: &WalkConfig, dim: usize) -> EmbeddingResult {
178        let walks = Self::random_walks(graph, config);
179        let walk_count = walks.len();
180        let n = graph.node_count;
181
182        // Co-occurrence accumulation: for each node, accumulate context node IDs
183        let mut accum = vec![vec![0.0f64; n]; n];
184        let window = 2usize; // context window half-size
185
186        for walk in &walks {
187            for (idx, &center) in walk.iter().enumerate() {
188                let lo = idx.saturating_sub(window);
189                let hi = (idx + window + 1).min(walk.len());
190                for &ctx in &walk[lo..hi] {
191                    if ctx != center {
192                        accum[center][ctx] += 1.0;
193                    }
194                }
195            }
196        }
197
198        // Project co-occurrence row into `dim`-dimensional space via hash embedding
199        let embeddings: Vec<NodeEmbedding> = (0..n)
200            .map(|node_id| {
201                let row = &accum[node_id];
202                let vector = Self::project_row(row, dim, node_id);
203                NodeEmbedding { node_id, vector }
204            })
205            .collect();
206
207        EmbeddingResult {
208            embeddings,
209            walk_count,
210        }
211    }
212
213    /// Compute structural embeddings based purely on local neighborhood topology.
214    ///
215    /// Each node's embedding aggregates neighbour degree statistics projected into `dim`.
216    pub fn structural_embedding(graph: &Graph, dim: usize) -> Vec<NodeEmbedding> {
217        let n = graph.node_count;
218        (0..n)
219            .map(|node_id| {
220                let neighbors = Self::neighbors(graph, node_id);
221                // Feature: [degree, sum(neighbor_degree), sum(neighbor_weight), ...]
222                let deg = neighbors.len() as f64;
223                let sum_nb_deg: f64 = neighbors
224                    .iter()
225                    .map(|&nb| Self::degree(graph, nb) as f64)
226                    .sum();
227                let sum_weight: f64 = graph
228                    .edges
229                    .iter()
230                    .filter(|e| e.from == node_id || e.to == node_id)
231                    .map(|e| e.weight as f64)
232                    .sum();
233
234                let raw = vec![deg, sum_nb_deg, sum_weight, node_id as f64];
235                let vector = Self::project_row(&raw, dim, node_id);
236                NodeEmbedding { node_id, vector }
237            })
238            .collect()
239    }
240
241    /// Cosine similarity between two node embeddings, in range [-1, 1].
242    pub fn node_similarity(a: &NodeEmbedding, b: &NodeEmbedding) -> f32 {
243        let len = a.vector.len().min(b.vector.len());
244        if len == 0 {
245            return 0.0;
246        }
247        let dot: f32 = a.vector[..len]
248            .iter()
249            .zip(b.vector[..len].iter())
250            .map(|(x, y)| x * y)
251            .sum();
252        let norm_a: f32 = a.vector[..len].iter().map(|x| x * x).sum::<f32>().sqrt();
253        let norm_b: f32 = b.vector[..len].iter().map(|x| x * x).sum::<f32>().sqrt();
254        if norm_a == 0.0 || norm_b == 0.0 {
255            return 0.0;
256        }
257        (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
258    }
259
260    /// Return all neighbors of `node` (both directions for undirected usage).
261    pub fn neighbors(graph: &Graph, node: usize) -> Vec<usize> {
262        let mut nbs: Vec<usize> = graph
263            .edges
264            .iter()
265            .filter_map(|e| {
266                if e.from == node {
267                    Some(e.to)
268                } else if e.to == node {
269                    Some(e.from)
270                } else {
271                    None
272                }
273            })
274            .collect();
275        nbs.sort_unstable();
276        nbs.dedup();
277        nbs
278    }
279
280    /// Degree of a node (number of unique neighbors).
281    pub fn degree(graph: &Graph, node: usize) -> usize {
282        Self::neighbors(graph, node).len()
283    }
284
285    // ─── private helpers ──────────────────────────────────────────────────────
286
287    /// Build adjacency list: Vec<Vec<(neighbor, weight)>>
288    fn build_adjacency(graph: &Graph) -> Vec<Vec<(usize, f32)>> {
289        let n = graph.node_count;
290        let mut adj: Vec<Vec<(usize, f32)>> = vec![Vec::new(); n];
291        for edge in &graph.edges {
292            if edge.from < n && edge.to < n {
293                adj[edge.from].push((edge.to, edge.weight));
294                adj[edge.to].push((edge.from, edge.weight)); // undirected
295            }
296        }
297        adj
298    }
299
300    /// Perform a single Node2Vec walk starting from `start`.
301    fn single_walk(
302        adj: &[Vec<(usize, f32)>],
303        _node_count: usize,
304        start: usize,
305        walk_length: usize,
306        p: f32,
307        q: f32,
308        rng: &mut Random,
309    ) -> Vec<usize> {
310        let mut walk = Vec::with_capacity(walk_length);
311        walk.push(start);
312
313        if adj[start].is_empty() || walk_length <= 1 {
314            // Isolated node — repeat self
315            while walk.len() < walk_length {
316                walk.push(start);
317            }
318            return walk;
319        }
320
321        // First step: uniform random neighbor
322        let first_idx = (rng.random_range(0.0..1.0) * adj[start].len() as f64) as usize;
323        walk.push(adj[start][first_idx].0);
324
325        while walk.len() < walk_length {
326            let cur = *walk.last().expect("walk is non-empty");
327            let prev = walk[walk.len() - 2];
328
329            if adj[cur].is_empty() {
330                walk.push(cur); // stuck — stay
331                continue;
332            }
333
334            // Compute unnormalised transition probabilities (Node2Vec bias)
335            let weights: Vec<f32> = adj[cur]
336                .iter()
337                .map(|&(nb, w)| {
338                    let bias = if nb == prev {
339                        1.0 / p // return
340                    } else if adj[prev].iter().any(|&(x, _)| x == nb) {
341                        1.0 // common neighbor
342                    } else {
343                        1.0 / q // explore away
344                    };
345                    w * bias
346                })
347                .collect();
348
349            let total: f32 = weights.iter().sum();
350            let sample = (rng.random_range(0.0..1.0) as f32) * total;
351            let mut cumulative = 0.0f32;
352            let mut chosen = adj[cur][0].0;
353            for (i, &wt) in weights.iter().enumerate() {
354                cumulative += wt;
355                if sample <= cumulative {
356                    chosen = adj[cur][i].0;
357                    break;
358                }
359            }
360            walk.push(chosen);
361        }
362        walk
363    }
364
365    /// Project a raw feature slice into a `dim`-dimensional f32 vector
366    /// using a simple deterministic hashing / sinusoidal expansion.
367    fn project_row(row: &[f64], dim: usize, node_id: usize) -> Vec<f32> {
368        use std::f64::consts::PI;
369        if dim == 0 {
370            return vec![];
371        }
372
373        // Compute a scalar summary of the row
374        let norm: f64 = row.iter().map(|x| x * x).sum::<f64>().sqrt();
375        let sum: f64 = row.iter().sum();
376
377        let mut vec = Vec::with_capacity(dim);
378        for d in 0..dim {
379            // Deterministic sinusoidal projection
380            let angle =
381                (node_id as f64 * 0.1 + d as f64 * 1.3 + sum * 0.01) * PI / (dim as f64 + 1.0);
382            let val = (angle.sin() * (norm + 1.0).ln()) as f32;
383            vec.push(val);
384        }
385        vec
386    }
387}
388
389// ─── Tests ────────────────────────────────────────────────────────────────────
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    // ─── helpers ─────────────────────────────────────────────────────────────
396
397    /// Triangle graph: 0-1-2-0
398    fn triangle() -> Graph {
399        let mut g = Graph::new(3);
400        g.add_edge(0, 1, 1.0);
401        g.add_edge(1, 2, 1.0);
402        g.add_edge(2, 0, 1.0);
403        g
404    }
405
406    /// Path graph: 0-1-2-3
407    fn path4() -> Graph {
408        let mut g = Graph::new(4);
409        g.add_edge(0, 1, 1.0);
410        g.add_edge(1, 2, 1.0);
411        g.add_edge(2, 3, 1.0);
412        g
413    }
414
415    /// Disconnected graph: {0-1} and {2-3}
416    fn disconnected() -> Graph {
417        let mut g = Graph::new(4);
418        g.add_edge(0, 1, 1.0);
419        g.add_edge(2, 3, 1.0);
420        g
421    }
422
423    fn default_config() -> WalkConfig {
424        WalkConfig {
425            walk_length: 5,
426            walks_per_node: 3,
427            return_param_p: 1.0,
428            in_out_param_q: 1.0,
429        }
430    }
431
432    // ─── Graph construction ───────────────────────────────────────────────────
433
434    #[test]
435    fn test_graph_new_no_edges() {
436        let g = Graph::new(5);
437        assert_eq!(g.node_count, 5);
438        assert_eq!(g.edge_count(), 0);
439    }
440
441    #[test]
442    fn test_add_edge_increments_count() {
443        let mut g = Graph::new(3);
444        g.add_edge(0, 1, 2.0);
445        assert_eq!(g.edge_count(), 1);
446        g.add_edge(1, 2, 1.5);
447        assert_eq!(g.edge_count(), 2);
448    }
449
450    #[test]
451    fn test_edge_stored_correctly() {
452        let mut g = Graph::new(3);
453        g.add_edge(0, 2, 0.7);
454        let e = &g.edges[0];
455        assert_eq!(e.from, 0);
456        assert_eq!(e.to, 2);
457        assert!((e.weight - 0.7).abs() < 1e-6);
458    }
459
460    // ─── is_connected ─────────────────────────────────────────────────────────
461
462    #[test]
463    fn test_is_connected_triangle() {
464        assert!(triangle().is_connected());
465    }
466
467    #[test]
468    fn test_is_connected_path() {
469        assert!(path4().is_connected());
470    }
471
472    #[test]
473    fn test_is_connected_disconnected() {
474        assert!(!disconnected().is_connected());
475    }
476
477    #[test]
478    fn test_is_connected_single_node() {
479        let g = Graph::new(1);
480        assert!(g.is_connected());
481    }
482
483    #[test]
484    fn test_is_connected_empty_graph() {
485        let g = Graph::new(0);
486        assert!(g.is_connected()); // vacuously true
487    }
488
489    // ─── neighbors and degree ─────────────────────────────────────────────────
490
491    #[test]
492    fn test_neighbors_triangle() {
493        let g = triangle();
494        let nb0 = GraphEmbedder::neighbors(&g, 0);
495        assert!(nb0.contains(&1), "0 should be neighbor of 1");
496        assert!(nb0.contains(&2), "2 should be neighbor of 0");
497        assert_eq!(nb0.len(), 2);
498    }
499
500    #[test]
501    fn test_neighbors_path_endpoint() {
502        let g = path4();
503        let nb0 = GraphEmbedder::neighbors(&g, 0);
504        assert_eq!(nb0, vec![1]);
505    }
506
507    #[test]
508    fn test_neighbors_isolated_node() {
509        let g = Graph::new(3); // no edges
510        let nb = GraphEmbedder::neighbors(&g, 1);
511        assert!(nb.is_empty());
512    }
513
514    #[test]
515    fn test_degree_triangle() {
516        let g = triangle();
517        assert_eq!(GraphEmbedder::degree(&g, 0), 2);
518        assert_eq!(GraphEmbedder::degree(&g, 1), 2);
519        assert_eq!(GraphEmbedder::degree(&g, 2), 2);
520    }
521
522    #[test]
523    fn test_degree_path_middle() {
524        let g = path4();
525        assert_eq!(GraphEmbedder::degree(&g, 1), 2);
526    }
527
528    #[test]
529    fn test_degree_isolated() {
530        let g = Graph::new(3);
531        assert_eq!(GraphEmbedder::degree(&g, 0), 0);
532    }
533
534    // ─── adjacency_matrix ─────────────────────────────────────────────────────
535
536    #[test]
537    fn test_adjacency_matrix_size() {
538        let g = triangle();
539        let mat = g.adjacency_matrix();
540        assert_eq!(mat.len(), 3);
541        assert_eq!(mat[0].len(), 3);
542    }
543
544    #[test]
545    #[allow(clippy::needless_range_loop)]
546    fn test_adjacency_matrix_symmetric() {
547        let g = path4();
548        let mat = g.adjacency_matrix();
549        for i in 0..4 {
550            for j in 0..4 {
551                assert!(
552                    (mat[i][j] - mat[j][i]).abs() < 1e-6,
553                    "adjacency matrix must be symmetric"
554                );
555            }
556        }
557    }
558
559    #[test]
560    fn test_adjacency_matrix_zero_diagonal() {
561        let g = triangle();
562        let mat = g.adjacency_matrix();
563        for (i, row) in mat.iter().enumerate() {
564            assert_eq!(row[i], 0.0, "diagonal must be zero (no self-loops)");
565        }
566    }
567
568    #[test]
569    fn test_adjacency_matrix_edge_weight() {
570        let mut g = Graph::new(3);
571        g.add_edge(0, 1, 3.5);
572        let mat = g.adjacency_matrix();
573        assert!((mat[0][1] - 3.5).abs() < 1e-6);
574        assert!((mat[1][0] - 3.5).abs() < 1e-6);
575    }
576
577    // ─── random_walks ─────────────────────────────────────────────────────────
578
579    #[test]
580    fn test_random_walks_count() {
581        let g = triangle();
582        let config = default_config();
583        let walks = GraphEmbedder::random_walks(&g, &config);
584        // walks_per_node * node_count = 3 * 3 = 9
585        assert_eq!(walks.len(), 9, "expected 9 walks");
586    }
587
588    #[test]
589    fn test_random_walks_length() {
590        let g = triangle();
591        let config = default_config();
592        let walks = GraphEmbedder::random_walks(&g, &config);
593        for w in &walks {
594            assert_eq!(
595                w.len(),
596                config.walk_length,
597                "each walk must have walk_length nodes"
598            );
599        }
600    }
601
602    #[test]
603    fn test_random_walks_node_ids_valid() {
604        let g = path4();
605        let config = default_config();
606        let walks = GraphEmbedder::random_walks(&g, &config);
607        for w in &walks {
608            for &node in w {
609                assert!(node < g.node_count, "node id must be < node_count");
610            }
611        }
612    }
613
614    #[test]
615    fn test_random_walks_isolated_nodes() {
616        let g = Graph::new(3); // all isolated
617        let config = WalkConfig {
618            walk_length: 4,
619            walks_per_node: 2,
620            ..Default::default()
621        };
622        let walks = GraphEmbedder::random_walks(&g, &config);
623        assert_eq!(walks.len(), 6);
624        for w in &walks {
625            assert_eq!(w.len(), 4);
626        }
627    }
628
629    // ─── embed ────────────────────────────────────────────────────────────────
630
631    #[test]
632    fn test_embed_returns_node_count_embeddings() {
633        let g = triangle();
634        let config = default_config();
635        let result = GraphEmbedder::embed(&g, &config, 8);
636        assert_eq!(result.embeddings.len(), g.node_count);
637    }
638
639    #[test]
640    fn test_embed_correct_walk_count() {
641        let g = triangle();
642        let config = default_config();
643        let result = GraphEmbedder::embed(&g, &config, 8);
644        assert_eq!(result.walk_count, config.walks_per_node * g.node_count);
645    }
646
647    #[test]
648    fn test_embed_dimension() {
649        let g = triangle();
650        let config = default_config();
651        let result = GraphEmbedder::embed(&g, &config, 16);
652        for emb in &result.embeddings {
653            assert_eq!(emb.vector.len(), 16, "embedding dimension must match dim");
654        }
655    }
656
657    #[test]
658    fn test_embed_node_ids_assigned() {
659        let g = path4();
660        let config = default_config();
661        let result = GraphEmbedder::embed(&g, &config, 4);
662        for (i, emb) in result.embeddings.iter().enumerate() {
663            assert_eq!(emb.node_id, i);
664        }
665    }
666
667    // ─── structural_embedding ─────────────────────────────────────────────────
668
669    #[test]
670    fn test_structural_embedding_count() {
671        let g = triangle();
672        let embeddings = GraphEmbedder::structural_embedding(&g, 8);
673        assert_eq!(embeddings.len(), g.node_count);
674    }
675
676    #[test]
677    fn test_structural_embedding_dimension() {
678        let g = path4();
679        let dim = 12;
680        let embeddings = GraphEmbedder::structural_embedding(&g, dim);
681        for emb in &embeddings {
682            assert_eq!(emb.vector.len(), dim);
683        }
684    }
685
686    #[test]
687    fn test_structural_embedding_node_ids() {
688        let g = triangle();
689        let embeddings = GraphEmbedder::structural_embedding(&g, 4);
690        for (i, emb) in embeddings.iter().enumerate() {
691            assert_eq!(emb.node_id, i);
692        }
693    }
694
695    // ─── node_similarity ──────────────────────────────────────────────────────
696
697    #[test]
698    fn test_node_similarity_self_is_one() {
699        let emb = NodeEmbedding {
700            node_id: 0,
701            vector: vec![1.0, 0.0, 0.0],
702        };
703        let sim = GraphEmbedder::node_similarity(&emb, &emb);
704        assert!((sim - 1.0).abs() < 1e-6, "self similarity should be 1.0");
705    }
706
707    #[test]
708    fn test_node_similarity_orthogonal_is_zero() {
709        let a = NodeEmbedding {
710            node_id: 0,
711            vector: vec![1.0, 0.0],
712        };
713        let b = NodeEmbedding {
714            node_id: 1,
715            vector: vec![0.0, 1.0],
716        };
717        let sim = GraphEmbedder::node_similarity(&a, &b);
718        assert!(
719            sim.abs() < 1e-6,
720            "orthogonal vectors should have similarity 0"
721        );
722    }
723
724    #[test]
725    fn test_node_similarity_range() {
726        let g = path4();
727        let embeddings = GraphEmbedder::structural_embedding(&g, 8);
728        for a in &embeddings {
729            for b in &embeddings {
730                let sim = GraphEmbedder::node_similarity(a, b);
731                assert!(
732                    (-1.0..=1.0).contains(&sim),
733                    "similarity {sim} must be in [-1, 1]"
734                );
735            }
736        }
737    }
738
739    #[test]
740    fn test_node_similarity_empty_vectors_is_zero() {
741        let a = NodeEmbedding {
742            node_id: 0,
743            vector: vec![],
744        };
745        let b = NodeEmbedding {
746            node_id: 1,
747            vector: vec![],
748        };
749        assert_eq!(GraphEmbedder::node_similarity(&a, &b), 0.0);
750    }
751
752    #[test]
753    fn test_node_similarity_opposite_vectors() {
754        let a = NodeEmbedding {
755            node_id: 0,
756            vector: vec![1.0, 0.0],
757        };
758        let b = NodeEmbedding {
759            node_id: 1,
760            vector: vec![-1.0, 0.0],
761        };
762        let sim = GraphEmbedder::node_similarity(&a, &b);
763        assert!(
764            (sim + 1.0).abs() < 1e-6,
765            "opposite vectors: similarity = -1"
766        );
767    }
768
769    // ─── edge cases ───────────────────────────────────────────────────────────
770
771    #[test]
772    fn test_embed_single_node() {
773        let g = Graph::new(1);
774        let config = WalkConfig {
775            walk_length: 3,
776            walks_per_node: 2,
777            ..Default::default()
778        };
779        let result = GraphEmbedder::embed(&g, &config, 4);
780        assert_eq!(result.embeddings.len(), 1);
781        assert_eq!(result.walk_count, 2);
782    }
783
784    #[test]
785    fn test_structural_embedding_zero_dim() {
786        let g = triangle();
787        let embeddings = GraphEmbedder::structural_embedding(&g, 0);
788        for emb in &embeddings {
789            assert!(emb.vector.is_empty());
790        }
791    }
792
793    #[test]
794    fn test_walk_config_default() {
795        let c = WalkConfig::default();
796        assert_eq!(c.walk_length, 10);
797        assert_eq!(c.walks_per_node, 5);
798    }
799
800    #[test]
801    fn test_walks_total_count_formula() {
802        let g = path4(); // 4 nodes
803        let config = WalkConfig {
804            walk_length: 6,
805            walks_per_node: 4,
806            ..Default::default()
807        };
808        let walks = GraphEmbedder::random_walks(&g, &config);
809        assert_eq!(walks.len(), 4 * 4, "4 nodes * 4 walks = 16");
810    }
811
812    // ─── Additional tests (round 11 extra coverage) ───────────────────────────
813
814    #[test]
815    fn test_adjacency_matrix_path4_size() {
816        let g = path4(); // 4 nodes
817        let mat = g.adjacency_matrix();
818        assert_eq!(mat.len(), 4);
819        for row in &mat {
820            assert_eq!(row.len(), 4);
821        }
822    }
823
824    #[test]
825    #[allow(clippy::needless_range_loop)]
826    fn test_adjacency_matrix_path4_symmetric() {
827        let g = path4();
828        let mat = g.adjacency_matrix();
829        for i in 0..4 {
830            for j in 0..4 {
831                assert!(
832                    (mat[i][j] - mat[j][i]).abs() < 1e-6,
833                    "adjacency matrix should be symmetric"
834                );
835            }
836        }
837    }
838
839    #[test]
840    fn test_adjacency_matrix_no_self_loops_for_path() {
841        let g = path4();
842        let mat = g.adjacency_matrix();
843        for (i, row) in mat.iter().enumerate() {
844            assert_eq!(row[i], 0.0);
845        }
846    }
847
848    #[test]
849    fn test_degree_path_endpoint_is_one() {
850        // In path4: 0-1-2-3, node 0 and node 3 have degree 1
851        let g = path4();
852        assert_eq!(GraphEmbedder::degree(&g, 0), 1);
853        assert_eq!(GraphEmbedder::degree(&g, 3), 1);
854    }
855
856    #[test]
857    fn test_degree_path_middle_is_two() {
858        // In path4: 0-1-2-3, node 1 and node 2 have degree 2
859        let g = path4();
860        assert_eq!(GraphEmbedder::degree(&g, 1), 2);
861        assert_eq!(GraphEmbedder::degree(&g, 2), 2);
862    }
863
864    #[test]
865    fn test_embed_walk_count_equals_nodes_times_walks() {
866        let g = path4();
867        let config = WalkConfig {
868            walk_length: 5,
869            walks_per_node: 3,
870            ..Default::default()
871        };
872        let result = GraphEmbedder::embed(&g, &config, 4);
873        assert_eq!(
874            result.walk_count,
875            4 * 3,
876            "walk_count = nodes * walks_per_node"
877        );
878    }
879
880    #[test]
881    fn test_structural_embedding_node_ids_sequential() {
882        let g = path4();
883        let embeddings = GraphEmbedder::structural_embedding(&g, 6);
884        let ids: Vec<usize> = embeddings.iter().map(|e| e.node_id).collect();
885        let expected: Vec<usize> = (0..4).collect();
886        assert_eq!(ids, expected, "node_ids must be sequential from 0");
887    }
888}