Skip to main content

codemem_graph/
algorithms.rs

1use crate::GraphEngine;
2use petgraph::graph::NodeIndex;
3use petgraph::Direction;
4use std::collections::{HashMap, HashSet, VecDeque};
5
6impl GraphEngine {
7    /// Compute PageRank scores for all nodes using power iteration.
8    ///
9    /// - `damping`: probability of following an edge (default 0.85)
10    /// - `iterations`: max number of power iterations (default 100)
11    /// - `tolerance`: convergence threshold (default 1e-6)
12    ///
13    /// Returns a map from node ID to PageRank score.
14    pub fn pagerank(
15        &self,
16        damping: f64,
17        iterations: usize,
18        tolerance: f64,
19    ) -> HashMap<String, f64> {
20        let n = self.graph.node_count();
21        if n == 0 {
22            return HashMap::new();
23        }
24
25        let nf = n as f64;
26        let initial = 1.0 / nf;
27
28        // Collect all node indices in a stable order
29        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
30        let idx_pos: HashMap<NodeIndex, usize> = indices
31            .iter()
32            .enumerate()
33            .map(|(i, &idx)| (idx, i))
34            .collect();
35
36        let mut scores = vec![initial; n];
37
38        // Precompute out-degrees
39        let out_degree: Vec<usize> = indices
40            .iter()
41            .map(|&idx| {
42                self.graph
43                    .neighbors_directed(idx, Direction::Outgoing)
44                    .count()
45            })
46            .collect();
47
48        for _ in 0..iterations {
49            let mut new_scores = vec![(1.0 - damping) / nf; n];
50
51            // Distribute rank from each node to its out-neighbors
52            for (i, &idx) in indices.iter().enumerate() {
53                let deg = out_degree[i];
54                if deg == 0 {
55                    // Dangling node: distribute evenly to all nodes
56                    let share = damping * scores[i] / nf;
57                    for ns in new_scores.iter_mut() {
58                        *ns += share;
59                    }
60                } else {
61                    let share = damping * scores[i] / deg as f64;
62                    for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
63                        if let Some(&pos) = idx_pos.get(&neighbor) {
64                            new_scores[pos] += share;
65                        }
66                    }
67                }
68            }
69
70            // Check convergence
71            let diff: f64 = scores
72                .iter()
73                .zip(new_scores.iter())
74                .map(|(a, b)| (a - b).abs())
75                .sum();
76
77            scores = new_scores;
78
79            if diff < tolerance {
80                break;
81            }
82        }
83
84        // Map back to node IDs
85        indices
86            .iter()
87            .enumerate()
88            .filter_map(|(i, &idx)| {
89                self.graph
90                    .node_weight(idx)
91                    .map(|id| (id.clone(), scores[i]))
92            })
93            .collect()
94    }
95
96    /// Compute Personalized PageRank with custom teleport weights.
97    ///
98    /// `seed_weights` maps node IDs to teleport probabilities (will be normalized).
99    /// Nodes not in seed_weights get zero teleport probability.
100    ///
101    /// Used for blast-radius analysis and HippoRAG-2-style retrieval.
102    pub fn personalized_pagerank(
103        &self,
104        seed_weights: &HashMap<String, f64>,
105        damping: f64,
106        iterations: usize,
107        tolerance: f64,
108    ) -> HashMap<String, f64> {
109        let n = self.graph.node_count();
110        if n == 0 {
111            return HashMap::new();
112        }
113
114        let nf = n as f64;
115
116        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
117        let idx_pos: HashMap<NodeIndex, usize> = indices
118            .iter()
119            .enumerate()
120            .map(|(i, &idx)| (idx, i))
121            .collect();
122
123        // Build and normalize the teleport vector
124        let mut teleport = vec![0.0f64; n];
125        let mut teleport_sum = 0.0;
126        for (i, &idx) in indices.iter().enumerate() {
127            if let Some(node_id) = self.graph.node_weight(idx) {
128                if let Some(&w) = seed_weights.get(node_id) {
129                    teleport[i] = w;
130                    teleport_sum += w;
131                }
132            }
133        }
134        // Normalize; if no seeds provided, fall back to uniform
135        if teleport_sum > 0.0 {
136            for t in teleport.iter_mut() {
137                *t /= teleport_sum;
138            }
139        } else {
140            for t in teleport.iter_mut() {
141                *t = 1.0 / nf;
142            }
143        }
144
145        let initial = 1.0 / nf;
146        let mut scores = vec![initial; n];
147
148        let out_degree: Vec<usize> = indices
149            .iter()
150            .map(|&idx| {
151                self.graph
152                    .neighbors_directed(idx, Direction::Outgoing)
153                    .count()
154            })
155            .collect();
156
157        for _ in 0..iterations {
158            let mut new_scores: Vec<f64> = teleport.iter().map(|&t| (1.0 - damping) * t).collect();
159
160            for (i, &idx) in indices.iter().enumerate() {
161                let deg = out_degree[i];
162                if deg == 0 {
163                    // Dangling node: distribute to teleport targets
164                    let share = damping * scores[i];
165                    for (j, t) in teleport.iter().enumerate() {
166                        new_scores[j] += share * t;
167                    }
168                } else {
169                    let share = damping * scores[i] / deg as f64;
170                    for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
171                        if let Some(&pos) = idx_pos.get(&neighbor) {
172                            new_scores[pos] += share;
173                        }
174                    }
175                }
176            }
177
178            let diff: f64 = scores
179                .iter()
180                .zip(new_scores.iter())
181                .map(|(a, b)| (a - b).abs())
182                .sum();
183
184            scores = new_scores;
185
186            if diff < tolerance {
187                break;
188            }
189        }
190
191        indices
192            .iter()
193            .enumerate()
194            .filter_map(|(i, &idx)| {
195                self.graph
196                    .node_weight(idx)
197                    .map(|id| (id.clone(), scores[i]))
198            })
199            .collect()
200    }
201
202    /// Detect communities using the Louvain algorithm.
203    ///
204    /// Treats the directed graph as undirected for modularity computation.
205    /// `resolution` controls community granularity (1.0 = standard modularity).
206    /// Returns groups of node IDs, one group per community.
207    pub fn louvain_communities(&self, resolution: f64) -> Vec<Vec<String>> {
208        let n = self.graph.node_count();
209        if n == 0 {
210            return Vec::new();
211        }
212
213        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
214        let idx_pos: HashMap<NodeIndex, usize> = indices
215            .iter()
216            .enumerate()
217            .map(|(i, &idx)| (idx, i))
218            .collect();
219
220        // Build undirected adjacency with weights.
221        // adj[i] contains (j, weight) for each undirected neighbor.
222        let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
223        let mut total_weight = 0.0;
224
225        for edge_ref in self.graph.edge_indices() {
226            if let Some((src_idx, dst_idx)) = self.graph.edge_endpoints(edge_ref) {
227                let w = self.graph[edge_ref];
228                if let (Some(&si), Some(&di)) = (idx_pos.get(&src_idx), idx_pos.get(&dst_idx)) {
229                    adj[si].push((di, w));
230                    adj[di].push((si, w));
231                    total_weight += w; // Each undirected edge contributes w (counted once)
232                }
233            }
234        }
235
236        if total_weight == 0.0 {
237            // No edges: each node is its own community
238            return indices
239                .iter()
240                .filter_map(|&idx| self.graph.node_weight(idx).map(|id| vec![id.clone()]))
241                .collect();
242        }
243
244        // m = total edge weight (for undirected: sum of all edge weights)
245        let m = total_weight;
246        let m2 = 2.0 * m;
247
248        // Weighted degree of each node (sum of incident edge weights, undirected)
249        let k: Vec<f64> = (0..n)
250            .map(|i| adj[i].iter().map(|&(_, w)| w).sum())
251            .collect();
252
253        // Initial assignment: each node in its own community
254        let mut community: Vec<usize> = (0..n).collect();
255
256        // Iteratively move nodes to improve modularity
257        let mut improved = true;
258        let max_passes = 100;
259        let mut pass = 0;
260
261        while improved && pass < max_passes {
262            improved = false;
263            pass += 1;
264
265            for i in 0..n {
266                let current_comm = community[i];
267
268                // Compute weights to each neighboring community
269                let mut comm_weights: HashMap<usize, f64> = HashMap::new();
270                for &(j, w) in &adj[i] {
271                    *comm_weights.entry(community[j]).or_insert(0.0) += w;
272                }
273
274                // Sum of degrees in each community (excluding node i for its own community)
275                let mut comm_degree_sum: HashMap<usize, f64> = HashMap::new();
276                for j in 0..n {
277                    *comm_degree_sum.entry(community[j]).or_insert(0.0) += k[j];
278                }
279
280                let ki = k[i];
281
282                // Modularity gain for removing i from its current community
283                let w_in_current = comm_weights.get(&current_comm).copied().unwrap_or(0.0);
284                let sigma_current = comm_degree_sum.get(&current_comm).copied().unwrap_or(0.0);
285                let remove_cost = w_in_current - resolution * ki * (sigma_current - ki) / m2;
286
287                // Find best community to move to
288                let mut best_comm = current_comm;
289                let mut best_gain = 0.0;
290
291                for (&comm, &w_in_comm) in &comm_weights {
292                    if comm == current_comm {
293                        continue;
294                    }
295                    let sigma_comm = comm_degree_sum.get(&comm).copied().unwrap_or(0.0);
296                    let gain = w_in_comm - resolution * ki * sigma_comm / m2 - remove_cost;
297                    if gain > best_gain {
298                        best_gain = gain;
299                        best_comm = comm;
300                    }
301                }
302
303                if best_comm != current_comm {
304                    community[i] = best_comm;
305                    improved = true;
306                }
307            }
308        }
309
310        // Group nodes by community
311        let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
312        for (i, &idx) in indices.iter().enumerate() {
313            if let Some(node_id) = self.graph.node_weight(idx) {
314                groups
315                    .entry(community[i])
316                    .or_default()
317                    .push(node_id.clone());
318            }
319        }
320
321        let mut result: Vec<Vec<String>> = groups.into_values().collect();
322        for group in result.iter_mut() {
323            group.sort();
324        }
325        result.sort();
326        result
327    }
328
329    /// Compute betweenness centrality for all nodes using Brandes' algorithm.
330    ///
331    /// For graphs with more than 1000 nodes, samples sqrt(n) source nodes
332    /// for approximate computation.
333    ///
334    /// Returns a map from node ID to betweenness centrality score (normalized by
335    /// 1/((n-1)(n-2)) for directed graphs).
336    pub fn betweenness_centrality(&self) -> HashMap<String, f64> {
337        let n = self.graph.node_count();
338        if n <= 2 {
339            return self
340                .graph
341                .node_indices()
342                .filter_map(|idx| self.graph.node_weight(idx).map(|id| (id.clone(), 0.0)))
343                .collect();
344        }
345
346        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
347        let idx_pos: HashMap<NodeIndex, usize> = indices
348            .iter()
349            .enumerate()
350            .map(|(i, &idx)| (idx, i))
351            .collect();
352
353        let mut centrality = vec![0.0f64; n];
354
355        // Determine source nodes (sample for large graphs)
356        let sources: Vec<usize> = if n > 1000 {
357            let sample_size = (n as f64).sqrt() as usize;
358            // Deterministic sampling: evenly spaced
359            let step = n / sample_size;
360            (0..sample_size).map(|i| i * step).collect()
361        } else {
362            (0..n).collect()
363        };
364
365        let scale = if n > 1000 {
366            n as f64 / sources.len() as f64
367        } else {
368            1.0
369        };
370
371        for &s in &sources {
372            // Brandes' algorithm from source s
373            let mut stack: Vec<usize> = Vec::new();
374            let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
375            let mut sigma = vec![0.0f64; n]; // number of shortest paths
376            sigma[s] = 1.0;
377            let mut dist: Vec<i64> = vec![-1; n];
378            dist[s] = 0;
379
380            let mut queue: VecDeque<usize> = VecDeque::new();
381            queue.push_back(s);
382
383            while let Some(v) = queue.pop_front() {
384                stack.push(v);
385                let v_idx = indices[v];
386                for neighbor in self.graph.neighbors_directed(v_idx, Direction::Outgoing) {
387                    if let Some(&w) = idx_pos.get(&neighbor) {
388                        if dist[w] < 0 {
389                            dist[w] = dist[v] + 1;
390                            queue.push_back(w);
391                        }
392                        if dist[w] == dist[v] + 1 {
393                            sigma[w] += sigma[v];
394                            predecessors[w].push(v);
395                        }
396                    }
397                }
398            }
399
400            let mut delta = vec![0.0f64; n];
401            while let Some(w) = stack.pop() {
402                for &v in &predecessors[w] {
403                    delta[v] += (sigma[v] / sigma[w]) * (1.0 + delta[w]);
404                }
405                if w != s {
406                    centrality[w] += delta[w];
407                }
408            }
409        }
410
411        // Apply sampling scale and normalize
412        let norm = ((n - 1) * (n - 2)) as f64;
413        indices
414            .iter()
415            .enumerate()
416            .filter_map(|(i, &idx)| {
417                self.graph
418                    .node_weight(idx)
419                    .map(|id| (id.clone(), centrality[i] * scale / norm))
420            })
421            .collect()
422    }
423
424    /// Find all strongly connected components using Tarjan's algorithm.
425    ///
426    /// Returns groups of node IDs. Each group is a strongly connected component
427    /// where every node can reach every other node via directed edges.
428    pub fn strongly_connected_components(&self) -> Vec<Vec<String>> {
429        let sccs = petgraph::algo::tarjan_scc(&self.graph);
430
431        let mut result: Vec<Vec<String>> = sccs
432            .into_iter()
433            .map(|component| {
434                let mut ids: Vec<String> = component
435                    .into_iter()
436                    .filter_map(|idx| self.graph.node_weight(idx).cloned())
437                    .collect();
438                ids.sort();
439                ids
440            })
441            .collect();
442
443        result.sort();
444        result
445    }
446
447    /// Compute topological layers using Kahn's algorithm.
448    ///
449    /// Returns layers where all nodes in layer i have no dependencies on nodes
450    /// in layer i or later. For cyclic graphs, SCCs are condensed into single
451    /// super-nodes first, then the resulting DAG is topologically sorted.
452    ///
453    /// Each inner Vec contains the node IDs at that layer.
454    pub fn topological_layers(&self) -> Vec<Vec<String>> {
455        let n = self.graph.node_count();
456        if n == 0 {
457            return Vec::new();
458        }
459
460        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
461        let idx_pos: HashMap<NodeIndex, usize> = indices
462            .iter()
463            .enumerate()
464            .map(|(i, &idx)| (idx, i))
465            .collect();
466
467        // Step 1: Find SCCs
468        let sccs = petgraph::algo::tarjan_scc(&self.graph);
469
470        // Map each node position to its SCC index
471        let mut node_to_scc = vec![0usize; n];
472        for (scc_idx, scc) in sccs.iter().enumerate() {
473            for &node_idx in scc {
474                if let Some(&pos) = idx_pos.get(&node_idx) {
475                    node_to_scc[pos] = scc_idx;
476                }
477            }
478        }
479
480        let num_sccs = sccs.len();
481
482        // Step 2: Build condensed DAG (SCC graph)
483        let mut condensed_adj: Vec<HashSet<usize>> = vec![HashSet::new(); num_sccs];
484        let mut condensed_in_degree = vec![0usize; num_sccs];
485
486        for &idx in &indices {
487            if let Some(&src_pos) = idx_pos.get(&idx) {
488                let src_scc = node_to_scc[src_pos];
489                for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
490                    if let Some(&dst_pos) = idx_pos.get(&neighbor) {
491                        let dst_scc = node_to_scc[dst_pos];
492                        if src_scc != dst_scc && condensed_adj[src_scc].insert(dst_scc) {
493                            condensed_in_degree[dst_scc] += 1;
494                        }
495                    }
496                }
497            }
498        }
499
500        // Step 3: Kahn's algorithm on the condensed DAG
501        let mut queue: VecDeque<usize> = VecDeque::new();
502        for (i, &deg) in condensed_in_degree.iter().enumerate().take(num_sccs) {
503            if deg == 0 {
504                queue.push_back(i);
505            }
506        }
507
508        let mut scc_layers: Vec<Vec<usize>> = Vec::new();
509        while !queue.is_empty() {
510            let mut layer = Vec::new();
511            let mut next_queue = VecDeque::new();
512
513            while let Some(scc_idx) = queue.pop_front() {
514                layer.push(scc_idx);
515                for &neighbor_scc in &condensed_adj[scc_idx] {
516                    condensed_in_degree[neighbor_scc] -= 1;
517                    if condensed_in_degree[neighbor_scc] == 0 {
518                        next_queue.push_back(neighbor_scc);
519                    }
520                }
521            }
522
523            scc_layers.push(layer);
524            queue = next_queue;
525        }
526
527        // Step 4: Expand SCC layers back to node IDs
528        let mut result: Vec<Vec<String>> = Vec::new();
529        for scc_layer in scc_layers {
530            let mut layer_nodes: Vec<String> = Vec::new();
531            for scc_idx in scc_layer {
532                for &node_idx in &sccs[scc_idx] {
533                    if let Some(id) = self.graph.node_weight(node_idx) {
534                        layer_nodes.push(id.clone());
535                    }
536                }
537            }
538            layer_nodes.sort();
539            result.push(layer_nodes);
540        }
541
542        result
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use crate::GraphEngine;
549    use codemem_core::{Edge, GraphBackend, GraphNode, NodeKind, RelationshipType};
550    use std::collections::{HashMap, HashSet};
551
552    fn file_node(id: &str, label: &str) -> GraphNode {
553        GraphNode {
554            id: id.to_string(),
555            kind: NodeKind::File,
556            label: label.to_string(),
557            payload: HashMap::new(),
558            centrality: 0.0,
559            memory_id: None,
560            namespace: None,
561        }
562    }
563
564    fn test_edge(src: &str, dst: &str) -> Edge {
565        Edge {
566            id: format!("{src}->{dst}"),
567            src: src.to_string(),
568            dst: dst.to_string(),
569            relationship: RelationshipType::Contains,
570            weight: 1.0,
571            properties: HashMap::new(),
572            created_at: chrono::Utc::now(),
573            valid_from: None,
574            valid_to: None,
575        }
576    }
577
578    // ── PageRank Tests ──────────────────────────────────────────────────────
579
580    #[test]
581    fn pagerank_chain() {
582        // a -> b -> c
583        // c is a sink (dangling node) that redistributes rank uniformly.
584        // Rank flows a -> b -> c, with c accumulating the most. Order: c > b > a.
585        let mut graph = GraphEngine::new();
586        graph.add_node(file_node("a", "a.rs")).unwrap();
587        graph.add_node(file_node("b", "b.rs")).unwrap();
588        graph.add_node(file_node("c", "c.rs")).unwrap();
589        graph.add_edge(test_edge("a", "b")).unwrap();
590        graph.add_edge(test_edge("b", "c")).unwrap();
591
592        let ranks = graph.pagerank(0.85, 100, 1e-6);
593        assert_eq!(ranks.len(), 3);
594        assert!(
595            ranks["c"] > ranks["b"],
596            "c ({}) should rank higher than b ({})",
597            ranks["c"],
598            ranks["b"]
599        );
600        assert!(
601            ranks["b"] > ranks["a"],
602            "b ({}) should rank higher than a ({})",
603            ranks["b"],
604            ranks["a"]
605        );
606    }
607
608    #[test]
609    fn pagerank_star() {
610        // a -> b, a -> c, a -> d
611        // b, c, d are dangling nodes that redistribute rank uniformly.
612        // They each receive direct rank from a, plus redistribution.
613        // a only receives redistributed rank from the dangling nodes.
614        // So each leaf should rank higher than the hub.
615        let mut graph = GraphEngine::new();
616        graph.add_node(file_node("a", "a.rs")).unwrap();
617        graph.add_node(file_node("b", "b.rs")).unwrap();
618        graph.add_node(file_node("c", "c.rs")).unwrap();
619        graph.add_node(file_node("d", "d.rs")).unwrap();
620        graph.add_edge(test_edge("a", "b")).unwrap();
621        graph.add_edge(test_edge("a", "c")).unwrap();
622        graph.add_edge(test_edge("a", "d")).unwrap();
623
624        let ranks = graph.pagerank(0.85, 100, 1e-6);
625        assert_eq!(ranks.len(), 4);
626        // Leaves get direct rank from a AND redistribute back uniformly.
627        // b, c, d should be approximately equal and each higher than a.
628        assert!(
629            ranks["b"] > ranks["a"],
630            "b ({}) should rank higher than a ({})",
631            ranks["b"],
632            ranks["a"]
633        );
634        // b, c, d should be approximately equal
635        assert!(
636            (ranks["b"] - ranks["c"]).abs() < 0.01,
637            "b ({}) and c ({}) should be approximately equal",
638            ranks["b"],
639            ranks["c"]
640        );
641    }
642
643    #[test]
644    fn pagerank_empty_graph() {
645        let graph = GraphEngine::new();
646        let ranks = graph.pagerank(0.85, 100, 1e-6);
647        assert!(ranks.is_empty());
648    }
649
650    #[test]
651    fn pagerank_single_node() {
652        let mut graph = GraphEngine::new();
653        graph.add_node(file_node("a", "a.rs")).unwrap();
654
655        let ranks = graph.pagerank(0.85, 100, 1e-6);
656        assert_eq!(ranks.len(), 1);
657        assert!((ranks["a"] - 1.0).abs() < 0.01);
658    }
659
660    // ── Personalized PageRank Tests ─────────────────────────────────────────
661
662    #[test]
663    fn personalized_pagerank_cycle_seed_c() {
664        // a -> b -> c -> a (cycle)
665        // Seed on c: c and its neighbors should rank highest
666        let mut graph = GraphEngine::new();
667        graph.add_node(file_node("a", "a.rs")).unwrap();
668        graph.add_node(file_node("b", "b.rs")).unwrap();
669        graph.add_node(file_node("c", "c.rs")).unwrap();
670        graph.add_edge(test_edge("a", "b")).unwrap();
671        graph.add_edge(test_edge("b", "c")).unwrap();
672        graph.add_edge(test_edge("c", "a")).unwrap();
673
674        let mut seeds = HashMap::new();
675        seeds.insert("c".to_string(), 1.0);
676
677        let ranks = graph.personalized_pagerank(&seeds, 0.85, 100, 1e-6);
678        assert_eq!(ranks.len(), 3);
679        // c should have highest rank (it's the seed and receives teleport)
680        // a is c's out-neighbor so it should be next
681        assert!(
682            ranks["c"] > ranks["b"],
683            "c ({}) should rank higher than b ({})",
684            ranks["c"],
685            ranks["b"]
686        );
687        assert!(
688            ranks["a"] > ranks["b"],
689            "a ({}) should rank higher than b ({}) since c->a",
690            ranks["a"],
691            ranks["b"]
692        );
693    }
694
695    #[test]
696    fn personalized_pagerank_empty_seeds() {
697        // With no seeds, should fall back to uniform (same as regular pagerank)
698        let mut graph = GraphEngine::new();
699        graph.add_node(file_node("a", "a.rs")).unwrap();
700        graph.add_node(file_node("b", "b.rs")).unwrap();
701        graph.add_edge(test_edge("a", "b")).unwrap();
702
703        let seeds = HashMap::new();
704        let ppr = graph.personalized_pagerank(&seeds, 0.85, 100, 1e-6);
705        let pr = graph.pagerank(0.85, 100, 1e-6);
706
707        // Should be approximately equal
708        assert!((ppr["a"] - pr["a"]).abs() < 0.01);
709        assert!((ppr["b"] - pr["b"]).abs() < 0.01);
710    }
711
712    // ── Louvain Community Detection Tests ───────────────────────────────────
713
714    #[test]
715    fn louvain_two_disconnected_cliques() {
716        // Clique 1: a <-> b <-> c <-> a
717        // Clique 2: d <-> e <-> f <-> d
718        let mut graph = GraphEngine::new();
719        for id in &["a", "b", "c", "d", "e", "f"] {
720            graph.add_node(file_node(id, &format!("{id}.rs"))).unwrap();
721        }
722        // Clique 1
723        graph.add_edge(test_edge("a", "b")).unwrap();
724        graph.add_edge(test_edge("b", "a")).unwrap();
725        graph.add_edge(test_edge("b", "c")).unwrap();
726        graph.add_edge(test_edge("c", "b")).unwrap();
727        graph.add_edge(test_edge("a", "c")).unwrap();
728        graph.add_edge(test_edge("c", "a")).unwrap();
729        // Clique 2
730        graph.add_edge(test_edge("d", "e")).unwrap();
731        graph.add_edge(test_edge("e", "d")).unwrap();
732        graph.add_edge(test_edge("e", "f")).unwrap();
733        graph.add_edge(test_edge("f", "e")).unwrap();
734        graph.add_edge(test_edge("d", "f")).unwrap();
735        graph.add_edge(test_edge("f", "d")).unwrap();
736
737        let communities = graph.louvain_communities(1.0);
738        assert_eq!(
739            communities.len(),
740            2,
741            "Expected 2 communities, got {}: {:?}",
742            communities.len(),
743            communities
744        );
745        // Each community should have 3 nodes
746        assert_eq!(communities[0].len(), 3);
747        assert_eq!(communities[1].len(), 3);
748        // Check that each clique is in a separate community
749        let comm0_set: HashSet<&str> = communities[0].iter().map(|s| s.as_str()).collect();
750        let has_abc = comm0_set.contains("a") && comm0_set.contains("b") && comm0_set.contains("c");
751        let has_def = comm0_set.contains("d") && comm0_set.contains("e") && comm0_set.contains("f");
752        assert!(
753            has_abc || has_def,
754            "First community should be one of the cliques: {:?}",
755            communities[0]
756        );
757    }
758
759    #[test]
760    fn louvain_empty_graph() {
761        let graph = GraphEngine::new();
762        let communities = graph.louvain_communities(1.0);
763        assert!(communities.is_empty());
764    }
765
766    #[test]
767    fn louvain_single_node() {
768        let mut graph = GraphEngine::new();
769        graph.add_node(file_node("a", "a.rs")).unwrap();
770        let communities = graph.louvain_communities(1.0);
771        assert_eq!(communities.len(), 1);
772        assert_eq!(communities[0], vec!["a"]);
773    }
774
775    // ── Betweenness Centrality Tests ────────────────────────────────────────
776
777    #[test]
778    fn betweenness_chain_middle_highest() {
779        // a -> b -> c
780        // b is on the shortest path from a to c, so it should have highest betweenness
781        let mut graph = GraphEngine::new();
782        graph.add_node(file_node("a", "a.rs")).unwrap();
783        graph.add_node(file_node("b", "b.rs")).unwrap();
784        graph.add_node(file_node("c", "c.rs")).unwrap();
785        graph.add_edge(test_edge("a", "b")).unwrap();
786        graph.add_edge(test_edge("b", "c")).unwrap();
787
788        let bc = graph.betweenness_centrality();
789        assert_eq!(bc.len(), 3);
790        assert!(
791            bc["b"] > bc["a"],
792            "b ({}) should have higher betweenness than a ({})",
793            bc["b"],
794            bc["a"]
795        );
796        assert!(
797            bc["b"] > bc["c"],
798            "b ({}) should have higher betweenness than c ({})",
799            bc["b"],
800            bc["c"]
801        );
802        // a and c should have 0 betweenness (they are endpoints)
803        assert!(
804            bc["a"].abs() < f64::EPSILON,
805            "a should have 0 betweenness, got {}",
806            bc["a"]
807        );
808        assert!(
809            bc["c"].abs() < f64::EPSILON,
810            "c should have 0 betweenness, got {}",
811            bc["c"]
812        );
813    }
814
815    #[test]
816    fn betweenness_empty_graph() {
817        let graph = GraphEngine::new();
818        let bc = graph.betweenness_centrality();
819        assert!(bc.is_empty());
820    }
821
822    #[test]
823    fn betweenness_two_nodes() {
824        let mut graph = GraphEngine::new();
825        graph.add_node(file_node("a", "a.rs")).unwrap();
826        graph.add_node(file_node("b", "b.rs")).unwrap();
827        graph.add_edge(test_edge("a", "b")).unwrap();
828
829        let bc = graph.betweenness_centrality();
830        assert_eq!(bc.len(), 2);
831        assert!((bc["a"]).abs() < f64::EPSILON);
832        assert!((bc["b"]).abs() < f64::EPSILON);
833    }
834
835    // ── Strongly Connected Components Tests ─────────────────────────────────
836
837    #[test]
838    fn scc_cycle_all_in_one() {
839        // a -> b -> c -> a: all three should be in one SCC
840        let mut graph = GraphEngine::new();
841        graph.add_node(file_node("a", "a.rs")).unwrap();
842        graph.add_node(file_node("b", "b.rs")).unwrap();
843        graph.add_node(file_node("c", "c.rs")).unwrap();
844        graph.add_edge(test_edge("a", "b")).unwrap();
845        graph.add_edge(test_edge("b", "c")).unwrap();
846        graph.add_edge(test_edge("c", "a")).unwrap();
847
848        let sccs = graph.strongly_connected_components();
849        assert_eq!(
850            sccs.len(),
851            1,
852            "Expected 1 SCC, got {}: {:?}",
853            sccs.len(),
854            sccs
855        );
856        assert_eq!(sccs[0], vec!["a", "b", "c"]);
857    }
858
859    #[test]
860    fn scc_chain_each_separate() {
861        // a -> b -> c: no cycles, each node is its own SCC
862        let mut graph = GraphEngine::new();
863        graph.add_node(file_node("a", "a.rs")).unwrap();
864        graph.add_node(file_node("b", "b.rs")).unwrap();
865        graph.add_node(file_node("c", "c.rs")).unwrap();
866        graph.add_edge(test_edge("a", "b")).unwrap();
867        graph.add_edge(test_edge("b", "c")).unwrap();
868
869        let sccs = graph.strongly_connected_components();
870        assert_eq!(
871            sccs.len(),
872            3,
873            "Expected 3 SCCs, got {}: {:?}",
874            sccs.len(),
875            sccs
876        );
877    }
878
879    #[test]
880    fn scc_empty_graph() {
881        let graph = GraphEngine::new();
882        let sccs = graph.strongly_connected_components();
883        assert!(sccs.is_empty());
884    }
885
886    // ── Topological Sort Tests ──────────────────────────────────────────────
887
888    #[test]
889    fn topological_layers_dag() {
890        // a -> b, a -> c, b -> d, c -> d
891        // Layer 0: [a], Layer 1: [b, c], Layer 2: [d]
892        let mut graph = GraphEngine::new();
893        graph.add_node(file_node("a", "a.rs")).unwrap();
894        graph.add_node(file_node("b", "b.rs")).unwrap();
895        graph.add_node(file_node("c", "c.rs")).unwrap();
896        graph.add_node(file_node("d", "d.rs")).unwrap();
897        graph.add_edge(test_edge("a", "b")).unwrap();
898        graph.add_edge(test_edge("a", "c")).unwrap();
899        graph.add_edge(test_edge("b", "d")).unwrap();
900        graph.add_edge(test_edge("c", "d")).unwrap();
901
902        let layers = graph.topological_layers();
903        assert_eq!(
904            layers.len(),
905            3,
906            "Expected 3 layers, got {}: {:?}",
907            layers.len(),
908            layers
909        );
910        assert_eq!(layers[0], vec!["a"]);
911        assert_eq!(layers[1], vec!["b", "c"]); // sorted within layer
912        assert_eq!(layers[2], vec!["d"]);
913    }
914
915    #[test]
916    fn topological_layers_with_cycle() {
917        // a -> b -> c -> b (cycle between b and c), a -> d
918        // SCCs: {a}, {b, c}, {d}
919        // After condensation: {a} -> {b,c} and {a} -> {d}
920        // Layer 0: [a], Layer 1: [b, c, d] (b and c condensed, d also depends on a)
921        let mut graph = GraphEngine::new();
922        graph.add_node(file_node("a", "a.rs")).unwrap();
923        graph.add_node(file_node("b", "b.rs")).unwrap();
924        graph.add_node(file_node("c", "c.rs")).unwrap();
925        graph.add_node(file_node("d", "d.rs")).unwrap();
926        graph.add_edge(test_edge("a", "b")).unwrap();
927        graph.add_edge(test_edge("b", "c")).unwrap();
928        graph.add_edge(test_edge("c", "b")).unwrap();
929        graph.add_edge(test_edge("a", "d")).unwrap();
930
931        let layers = graph.topological_layers();
932        assert_eq!(
933            layers.len(),
934            2,
935            "Expected 2 layers, got {}: {:?}",
936            layers.len(),
937            layers
938        );
939        assert_eq!(layers[0], vec!["a"]);
940        // Layer 1 should contain b, c (from the cycle SCC) and d
941        assert!(layers[1].contains(&"b".to_string()));
942        assert!(layers[1].contains(&"c".to_string()));
943        assert!(layers[1].contains(&"d".to_string()));
944    }
945
946    #[test]
947    fn topological_layers_empty_graph() {
948        let graph = GraphEngine::new();
949        let layers = graph.topological_layers();
950        assert!(layers.is_empty());
951    }
952
953    #[test]
954    fn topological_layers_single_node() {
955        let mut graph = GraphEngine::new();
956        graph.add_node(file_node("a", "a.rs")).unwrap();
957        let layers = graph.topological_layers();
958        assert_eq!(layers.len(), 1);
959        assert_eq!(layers[0], vec!["a"]);
960    }
961}