Skip to main content

codemem_graph/
algorithms.rs

1use crate::GraphEngine;
2use petgraph::graph::NodeIndex;
3use petgraph::Direction;
4use std::collections::{HashMap, HashSet, VecDeque};
5
6impl GraphEngine {
7    /// Compute PageRank scores for all nodes using power iteration.
8    ///
9    /// - `damping`: probability of following an edge (default 0.85)
10    /// - `iterations`: max number of power iterations (default 100)
11    /// - `tolerance`: convergence threshold (default 1e-6)
12    ///
13    /// Returns a map from node ID to PageRank score.
14    pub fn pagerank(
15        &self,
16        damping: f64,
17        iterations: usize,
18        tolerance: f64,
19    ) -> HashMap<String, f64> {
20        let n = self.graph.node_count();
21        if n == 0 {
22            return HashMap::new();
23        }
24
25        let nf = n as f64;
26        let initial = 1.0 / nf;
27
28        // Collect all node indices in a stable order
29        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
30        let idx_pos: HashMap<NodeIndex, usize> = indices
31            .iter()
32            .enumerate()
33            .map(|(i, &idx)| (idx, i))
34            .collect();
35
36        let mut scores = vec![initial; n];
37
38        // Precompute out-degrees
39        let out_degree: Vec<usize> = indices
40            .iter()
41            .map(|&idx| {
42                self.graph
43                    .neighbors_directed(idx, Direction::Outgoing)
44                    .count()
45            })
46            .collect();
47
48        for _ in 0..iterations {
49            let mut new_scores = vec![(1.0 - damping) / nf; n];
50
51            // Distribute rank from each node to its out-neighbors
52            for (i, &idx) in indices.iter().enumerate() {
53                let deg = out_degree[i];
54                if deg == 0 {
55                    // Dangling node: distribute evenly to all nodes
56                    let share = damping * scores[i] / nf;
57                    for ns in new_scores.iter_mut() {
58                        *ns += share;
59                    }
60                } else {
61                    let share = damping * scores[i] / deg as f64;
62                    for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
63                        if let Some(&pos) = idx_pos.get(&neighbor) {
64                            new_scores[pos] += share;
65                        }
66                    }
67                }
68            }
69
70            // Check convergence
71            let diff: f64 = scores
72                .iter()
73                .zip(new_scores.iter())
74                .map(|(a, b)| (a - b).abs())
75                .sum();
76
77            scores = new_scores;
78
79            if diff < tolerance {
80                break;
81            }
82        }
83
84        // Map back to node IDs
85        indices
86            .iter()
87            .enumerate()
88            .filter_map(|(i, &idx)| {
89                self.graph
90                    .node_weight(idx)
91                    .map(|id| (id.clone(), scores[i]))
92            })
93            .collect()
94    }
95
96    /// Compute Personalized PageRank with custom teleport weights.
97    ///
98    /// `seed_weights` maps node IDs to teleport probabilities (will be normalized).
99    /// Nodes not in seed_weights get zero teleport probability.
100    ///
101    /// Used for blast-radius analysis and HippoRAG-2-style retrieval.
102    pub fn personalized_pagerank(
103        &self,
104        seed_weights: &HashMap<String, f64>,
105        damping: f64,
106        iterations: usize,
107        tolerance: f64,
108    ) -> HashMap<String, f64> {
109        let n = self.graph.node_count();
110        if n == 0 {
111            return HashMap::new();
112        }
113
114        let nf = n as f64;
115
116        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
117        let idx_pos: HashMap<NodeIndex, usize> = indices
118            .iter()
119            .enumerate()
120            .map(|(i, &idx)| (idx, i))
121            .collect();
122
123        // Build and normalize the teleport vector
124        let mut teleport = vec![0.0f64; n];
125        let mut teleport_sum = 0.0;
126        for (i, &idx) in indices.iter().enumerate() {
127            if let Some(node_id) = self.graph.node_weight(idx) {
128                if let Some(&w) = seed_weights.get(node_id) {
129                    teleport[i] = w;
130                    teleport_sum += w;
131                }
132            }
133        }
134        // Normalize; if no seeds provided, fall back to uniform
135        if teleport_sum > 0.0 {
136            for t in teleport.iter_mut() {
137                *t /= teleport_sum;
138            }
139        } else {
140            for t in teleport.iter_mut() {
141                *t = 1.0 / nf;
142            }
143        }
144
145        let initial = 1.0 / nf;
146        let mut scores = vec![initial; n];
147
148        let out_degree: Vec<usize> = indices
149            .iter()
150            .map(|&idx| {
151                self.graph
152                    .neighbors_directed(idx, Direction::Outgoing)
153                    .count()
154            })
155            .collect();
156
157        for _ in 0..iterations {
158            let mut new_scores: Vec<f64> = teleport.iter().map(|&t| (1.0 - damping) * t).collect();
159
160            for (i, &idx) in indices.iter().enumerate() {
161                let deg = out_degree[i];
162                if deg == 0 {
163                    // Dangling node: distribute to teleport targets
164                    let share = damping * scores[i];
165                    for (j, t) in teleport.iter().enumerate() {
166                        new_scores[j] += share * t;
167                    }
168                } else {
169                    let share = damping * scores[i] / deg as f64;
170                    for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
171                        if let Some(&pos) = idx_pos.get(&neighbor) {
172                            new_scores[pos] += share;
173                        }
174                    }
175                }
176            }
177
178            let diff: f64 = scores
179                .iter()
180                .zip(new_scores.iter())
181                .map(|(a, b)| (a - b).abs())
182                .sum();
183
184            scores = new_scores;
185
186            if diff < tolerance {
187                break;
188            }
189        }
190
191        indices
192            .iter()
193            .enumerate()
194            .filter_map(|(i, &idx)| {
195                self.graph
196                    .node_weight(idx)
197                    .map(|id| (id.clone(), scores[i]))
198            })
199            .collect()
200    }
201
202    /// Detect communities using the Louvain algorithm.
203    ///
204    /// Treats the directed graph as undirected for modularity computation.
205    /// `resolution` controls community granularity (1.0 = standard modularity).
206    /// Returns groups of node IDs, one group per community.
207    pub fn louvain_communities(&self, resolution: f64) -> Vec<Vec<String>> {
208        let n = self.graph.node_count();
209        if n == 0 {
210            return Vec::new();
211        }
212
213        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
214        let idx_pos: HashMap<NodeIndex, usize> = indices
215            .iter()
216            .enumerate()
217            .map(|(i, &idx)| (idx, i))
218            .collect();
219
220        // Build undirected adjacency with weights.
221        // adj[i] contains (j, weight) for each undirected neighbor.
222        let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
223        let mut total_weight = 0.0;
224
225        for edge_ref in self.graph.edge_indices() {
226            if let Some((src_idx, dst_idx)) = self.graph.edge_endpoints(edge_ref) {
227                let w = self.graph[edge_ref];
228                if let (Some(&si), Some(&di)) = (idx_pos.get(&src_idx), idx_pos.get(&dst_idx)) {
229                    adj[si].push((di, w));
230                    adj[di].push((si, w));
231                    total_weight += w; // Each undirected edge contributes w (counted once)
232                }
233            }
234        }
235
236        if total_weight == 0.0 {
237            // No edges: each node is its own community
238            return indices
239                .iter()
240                .filter_map(|&idx| self.graph.node_weight(idx).map(|id| vec![id.clone()]))
241                .collect();
242        }
243
244        // m = total edge weight (for undirected: sum of all edge weights)
245        let m = total_weight;
246        let m2 = 2.0 * m;
247
248        // Weighted degree of each node (sum of incident edge weights, undirected)
249        let k: Vec<f64> = (0..n)
250            .map(|i| adj[i].iter().map(|&(_, w)| w).sum())
251            .collect();
252
253        // Initial assignment: each node in its own community
254        let mut community: Vec<usize> = (0..n).collect();
255
256        // Iteratively move nodes to improve modularity
257        let mut improved = true;
258        let max_passes = 100;
259        let mut pass = 0;
260
261        while improved && pass < max_passes {
262            improved = false;
263            pass += 1;
264
265            for i in 0..n {
266                let current_comm = community[i];
267
268                // Compute weights to each neighboring community
269                let mut comm_weights: HashMap<usize, f64> = HashMap::new();
270                for &(j, w) in &adj[i] {
271                    *comm_weights.entry(community[j]).or_insert(0.0) += w;
272                }
273
274                // Sum of degrees in each community (excluding node i for its own community)
275                let mut comm_degree_sum: HashMap<usize, f64> = HashMap::new();
276                for j in 0..n {
277                    *comm_degree_sum.entry(community[j]).or_insert(0.0) += k[j];
278                }
279
280                let ki = k[i];
281
282                // Modularity gain for removing i from its current community
283                let w_in_current = comm_weights.get(&current_comm).copied().unwrap_or(0.0);
284                let sigma_current = comm_degree_sum.get(&current_comm).copied().unwrap_or(0.0);
285                let remove_cost = w_in_current - resolution * ki * (sigma_current - ki) / m2;
286
287                // Find best community to move to
288                let mut best_comm = current_comm;
289                let mut best_gain = 0.0;
290
291                for (&comm, &w_in_comm) in &comm_weights {
292                    if comm == current_comm {
293                        continue;
294                    }
295                    let sigma_comm = comm_degree_sum.get(&comm).copied().unwrap_or(0.0);
296                    let gain = w_in_comm - resolution * ki * sigma_comm / m2 - remove_cost;
297                    if gain > best_gain {
298                        best_gain = gain;
299                        best_comm = comm;
300                    }
301                }
302
303                if best_comm != current_comm {
304                    community[i] = best_comm;
305                    improved = true;
306                }
307            }
308        }
309
310        // Group nodes by community
311        let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
312        for (i, &idx) in indices.iter().enumerate() {
313            if let Some(node_id) = self.graph.node_weight(idx) {
314                groups
315                    .entry(community[i])
316                    .or_default()
317                    .push(node_id.clone());
318            }
319        }
320
321        let mut result: Vec<Vec<String>> = groups.into_values().collect();
322        for group in result.iter_mut() {
323            group.sort();
324        }
325        result.sort();
326        result
327    }
328
329    /// Compute betweenness centrality for all nodes using Brandes' algorithm.
330    ///
331    /// For graphs with more than 1000 nodes, samples sqrt(n) source nodes
332    /// for approximate computation.
333    ///
334    /// Returns a map from node ID to betweenness centrality score (normalized by
335    /// 1/((n-1)(n-2)) for directed graphs).
336    pub fn betweenness_centrality(&self) -> HashMap<String, f64> {
337        let n = self.graph.node_count();
338        if n <= 2 {
339            return self
340                .graph
341                .node_indices()
342                .filter_map(|idx| self.graph.node_weight(idx).map(|id| (id.clone(), 0.0)))
343                .collect();
344        }
345
346        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
347        let idx_pos: HashMap<NodeIndex, usize> = indices
348            .iter()
349            .enumerate()
350            .map(|(i, &idx)| (idx, i))
351            .collect();
352
353        let mut centrality = vec![0.0f64; n];
354
355        // Determine source nodes (sample for large graphs)
356        let sources: Vec<usize> = if n > 1000 {
357            let sample_size = (n as f64).sqrt() as usize;
358            // Deterministic sampling: evenly spaced
359            let step = n / sample_size;
360            (0..sample_size).map(|i| i * step).collect()
361        } else {
362            (0..n).collect()
363        };
364
365        let scale = if n > 1000 {
366            n as f64 / sources.len() as f64
367        } else {
368            1.0
369        };
370
371        for &s in &sources {
372            // Brandes' algorithm from source s
373            let mut stack: Vec<usize> = Vec::new();
374            let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
375            let mut sigma = vec![0.0f64; n]; // number of shortest paths
376            sigma[s] = 1.0;
377            let mut dist: Vec<i64> = vec![-1; n];
378            dist[s] = 0;
379
380            let mut queue: VecDeque<usize> = VecDeque::new();
381            queue.push_back(s);
382
383            while let Some(v) = queue.pop_front() {
384                stack.push(v);
385                let v_idx = indices[v];
386                for neighbor in self.graph.neighbors_directed(v_idx, Direction::Outgoing) {
387                    if let Some(&w) = idx_pos.get(&neighbor) {
388                        if dist[w] < 0 {
389                            dist[w] = dist[v] + 1;
390                            queue.push_back(w);
391                        }
392                        if dist[w] == dist[v] + 1 {
393                            sigma[w] += sigma[v];
394                            predecessors[w].push(v);
395                        }
396                    }
397                }
398            }
399
400            let mut delta = vec![0.0f64; n];
401            while let Some(w) = stack.pop() {
402                for &v in &predecessors[w] {
403                    delta[v] += (sigma[v] / sigma[w]) * (1.0 + delta[w]);
404                }
405                if w != s {
406                    centrality[w] += delta[w];
407                }
408            }
409        }
410
411        // Apply sampling scale and normalize
412        let norm = ((n - 1) * (n - 2)) as f64;
413        indices
414            .iter()
415            .enumerate()
416            .filter_map(|(i, &idx)| {
417                self.graph
418                    .node_weight(idx)
419                    .map(|id| (id.clone(), centrality[i] * scale / norm))
420            })
421            .collect()
422    }
423
424    /// Find all strongly connected components using Tarjan's algorithm.
425    ///
426    /// Returns groups of node IDs. Each group is a strongly connected component
427    /// where every node can reach every other node via directed edges.
428    pub fn strongly_connected_components(&self) -> Vec<Vec<String>> {
429        let sccs = petgraph::algo::tarjan_scc(&self.graph);
430
431        let mut result: Vec<Vec<String>> = sccs
432            .into_iter()
433            .map(|component| {
434                let mut ids: Vec<String> = component
435                    .into_iter()
436                    .filter_map(|idx| self.graph.node_weight(idx).cloned())
437                    .collect();
438                ids.sort();
439                ids
440            })
441            .collect();
442
443        result.sort();
444        result
445    }
446
447    /// Compute topological layers using Kahn's algorithm.
448    ///
449    /// Returns layers where all nodes in layer i have no dependencies on nodes
450    /// in layer i or later. For cyclic graphs, SCCs are condensed into single
451    /// super-nodes first, then the resulting DAG is topologically sorted.
452    ///
453    /// Each inner Vec contains the node IDs at that layer.
454    pub fn topological_layers(&self) -> Vec<Vec<String>> {
455        let n = self.graph.node_count();
456        if n == 0 {
457            return Vec::new();
458        }
459
460        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
461        let idx_pos: HashMap<NodeIndex, usize> = indices
462            .iter()
463            .enumerate()
464            .map(|(i, &idx)| (idx, i))
465            .collect();
466
467        // Step 1: Find SCCs
468        let sccs = petgraph::algo::tarjan_scc(&self.graph);
469
470        // Map each node position to its SCC index
471        let mut node_to_scc = vec![0usize; n];
472        for (scc_idx, scc) in sccs.iter().enumerate() {
473            for &node_idx in scc {
474                if let Some(&pos) = idx_pos.get(&node_idx) {
475                    node_to_scc[pos] = scc_idx;
476                }
477            }
478        }
479
480        let num_sccs = sccs.len();
481
482        // Step 2: Build condensed DAG (SCC graph)
483        let mut condensed_adj: Vec<HashSet<usize>> = vec![HashSet::new(); num_sccs];
484        let mut condensed_in_degree = vec![0usize; num_sccs];
485
486        for &idx in &indices {
487            if let Some(&src_pos) = idx_pos.get(&idx) {
488                let src_scc = node_to_scc[src_pos];
489                for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
490                    if let Some(&dst_pos) = idx_pos.get(&neighbor) {
491                        let dst_scc = node_to_scc[dst_pos];
492                        if src_scc != dst_scc && condensed_adj[src_scc].insert(dst_scc) {
493                            condensed_in_degree[dst_scc] += 1;
494                        }
495                    }
496                }
497            }
498        }
499
500        // Step 3: Kahn's algorithm on the condensed DAG
501        let mut queue: VecDeque<usize> = VecDeque::new();
502        for (i, &deg) in condensed_in_degree.iter().enumerate().take(num_sccs) {
503            if deg == 0 {
504                queue.push_back(i);
505            }
506        }
507
508        let mut scc_layers: Vec<Vec<usize>> = Vec::new();
509        while !queue.is_empty() {
510            let mut layer = Vec::new();
511            let mut next_queue = VecDeque::new();
512
513            while let Some(scc_idx) = queue.pop_front() {
514                layer.push(scc_idx);
515                for &neighbor_scc in &condensed_adj[scc_idx] {
516                    condensed_in_degree[neighbor_scc] -= 1;
517                    if condensed_in_degree[neighbor_scc] == 0 {
518                        next_queue.push_back(neighbor_scc);
519                    }
520                }
521            }
522
523            scc_layers.push(layer);
524            queue = next_queue;
525        }
526
527        // Step 4: Expand SCC layers back to node IDs
528        let mut result: Vec<Vec<String>> = Vec::new();
529        for scc_layer in scc_layers {
530            let mut layer_nodes: Vec<String> = Vec::new();
531            for scc_idx in scc_layer {
532                for &node_idx in &sccs[scc_idx] {
533                    if let Some(id) = self.graph.node_weight(node_idx) {
534                        layer_nodes.push(id.clone());
535                    }
536                }
537            }
538            layer_nodes.sort();
539            result.push(layer_nodes);
540        }
541
542        result
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use crate::GraphEngine;
549    use codemem_core::{Edge, GraphBackend, GraphNode, NodeKind, RelationshipType};
550    use std::collections::{HashMap, HashSet};
551
552    fn file_node(id: &str, label: &str) -> GraphNode {
553        GraphNode {
554            id: id.to_string(),
555            kind: NodeKind::File,
556            label: label.to_string(),
557            payload: HashMap::new(),
558            centrality: 0.0,
559            memory_id: None,
560            namespace: None,
561        }
562    }
563
564    fn test_edge(src: &str, dst: &str) -> Edge {
565        Edge {
566            id: format!("{src}->{dst}"),
567            src: src.to_string(),
568            dst: dst.to_string(),
569            relationship: RelationshipType::Contains,
570            weight: 1.0,
571            properties: HashMap::new(),
572            created_at: chrono::Utc::now(),
573        }
574    }
575
576    // ── PageRank Tests ──────────────────────────────────────────────────────
577
578    #[test]
579    fn pagerank_chain() {
580        // a -> b -> c
581        // c is a sink (dangling node) that redistributes rank uniformly.
582        // Rank flows a -> b -> c, with c accumulating the most. Order: c > b > a.
583        let mut graph = GraphEngine::new();
584        graph.add_node(file_node("a", "a.rs")).unwrap();
585        graph.add_node(file_node("b", "b.rs")).unwrap();
586        graph.add_node(file_node("c", "c.rs")).unwrap();
587        graph.add_edge(test_edge("a", "b")).unwrap();
588        graph.add_edge(test_edge("b", "c")).unwrap();
589
590        let ranks = graph.pagerank(0.85, 100, 1e-6);
591        assert_eq!(ranks.len(), 3);
592        assert!(
593            ranks["c"] > ranks["b"],
594            "c ({}) should rank higher than b ({})",
595            ranks["c"],
596            ranks["b"]
597        );
598        assert!(
599            ranks["b"] > ranks["a"],
600            "b ({}) should rank higher than a ({})",
601            ranks["b"],
602            ranks["a"]
603        );
604    }
605
606    #[test]
607    fn pagerank_star() {
608        // a -> b, a -> c, a -> d
609        // b, c, d are dangling nodes that redistribute rank uniformly.
610        // They each receive direct rank from a, plus redistribution.
611        // a only receives redistributed rank from the dangling nodes.
612        // So each leaf should rank higher than the hub.
613        let mut graph = GraphEngine::new();
614        graph.add_node(file_node("a", "a.rs")).unwrap();
615        graph.add_node(file_node("b", "b.rs")).unwrap();
616        graph.add_node(file_node("c", "c.rs")).unwrap();
617        graph.add_node(file_node("d", "d.rs")).unwrap();
618        graph.add_edge(test_edge("a", "b")).unwrap();
619        graph.add_edge(test_edge("a", "c")).unwrap();
620        graph.add_edge(test_edge("a", "d")).unwrap();
621
622        let ranks = graph.pagerank(0.85, 100, 1e-6);
623        assert_eq!(ranks.len(), 4);
624        // Leaves get direct rank from a AND redistribute back uniformly.
625        // b, c, d should be approximately equal and each higher than a.
626        assert!(
627            ranks["b"] > ranks["a"],
628            "b ({}) should rank higher than a ({})",
629            ranks["b"],
630            ranks["a"]
631        );
632        // b, c, d should be approximately equal
633        assert!(
634            (ranks["b"] - ranks["c"]).abs() < 0.01,
635            "b ({}) and c ({}) should be approximately equal",
636            ranks["b"],
637            ranks["c"]
638        );
639    }
640
641    #[test]
642    fn pagerank_empty_graph() {
643        let graph = GraphEngine::new();
644        let ranks = graph.pagerank(0.85, 100, 1e-6);
645        assert!(ranks.is_empty());
646    }
647
648    #[test]
649    fn pagerank_single_node() {
650        let mut graph = GraphEngine::new();
651        graph.add_node(file_node("a", "a.rs")).unwrap();
652
653        let ranks = graph.pagerank(0.85, 100, 1e-6);
654        assert_eq!(ranks.len(), 1);
655        assert!((ranks["a"] - 1.0).abs() < 0.01);
656    }
657
658    // ── Personalized PageRank Tests ─────────────────────────────────────────
659
660    #[test]
661    fn personalized_pagerank_cycle_seed_c() {
662        // a -> b -> c -> a (cycle)
663        // Seed on c: c and its neighbors should rank highest
664        let mut graph = GraphEngine::new();
665        graph.add_node(file_node("a", "a.rs")).unwrap();
666        graph.add_node(file_node("b", "b.rs")).unwrap();
667        graph.add_node(file_node("c", "c.rs")).unwrap();
668        graph.add_edge(test_edge("a", "b")).unwrap();
669        graph.add_edge(test_edge("b", "c")).unwrap();
670        graph.add_edge(test_edge("c", "a")).unwrap();
671
672        let mut seeds = HashMap::new();
673        seeds.insert("c".to_string(), 1.0);
674
675        let ranks = graph.personalized_pagerank(&seeds, 0.85, 100, 1e-6);
676        assert_eq!(ranks.len(), 3);
677        // c should have highest rank (it's the seed and receives teleport)
678        // a is c's out-neighbor so it should be next
679        assert!(
680            ranks["c"] > ranks["b"],
681            "c ({}) should rank higher than b ({})",
682            ranks["c"],
683            ranks["b"]
684        );
685        assert!(
686            ranks["a"] > ranks["b"],
687            "a ({}) should rank higher than b ({}) since c->a",
688            ranks["a"],
689            ranks["b"]
690        );
691    }
692
693    #[test]
694    fn personalized_pagerank_empty_seeds() {
695        // With no seeds, should fall back to uniform (same as regular pagerank)
696        let mut graph = GraphEngine::new();
697        graph.add_node(file_node("a", "a.rs")).unwrap();
698        graph.add_node(file_node("b", "b.rs")).unwrap();
699        graph.add_edge(test_edge("a", "b")).unwrap();
700
701        let seeds = HashMap::new();
702        let ppr = graph.personalized_pagerank(&seeds, 0.85, 100, 1e-6);
703        let pr = graph.pagerank(0.85, 100, 1e-6);
704
705        // Should be approximately equal
706        assert!((ppr["a"] - pr["a"]).abs() < 0.01);
707        assert!((ppr["b"] - pr["b"]).abs() < 0.01);
708    }
709
710    // ── Louvain Community Detection Tests ───────────────────────────────────
711
712    #[test]
713    fn louvain_two_disconnected_cliques() {
714        // Clique 1: a <-> b <-> c <-> a
715        // Clique 2: d <-> e <-> f <-> d
716        let mut graph = GraphEngine::new();
717        for id in &["a", "b", "c", "d", "e", "f"] {
718            graph.add_node(file_node(id, &format!("{id}.rs"))).unwrap();
719        }
720        // Clique 1
721        graph.add_edge(test_edge("a", "b")).unwrap();
722        graph.add_edge(test_edge("b", "a")).unwrap();
723        graph.add_edge(test_edge("b", "c")).unwrap();
724        graph.add_edge(test_edge("c", "b")).unwrap();
725        graph.add_edge(test_edge("a", "c")).unwrap();
726        graph.add_edge(test_edge("c", "a")).unwrap();
727        // Clique 2
728        graph.add_edge(test_edge("d", "e")).unwrap();
729        graph.add_edge(test_edge("e", "d")).unwrap();
730        graph.add_edge(test_edge("e", "f")).unwrap();
731        graph.add_edge(test_edge("f", "e")).unwrap();
732        graph.add_edge(test_edge("d", "f")).unwrap();
733        graph.add_edge(test_edge("f", "d")).unwrap();
734
735        let communities = graph.louvain_communities(1.0);
736        assert_eq!(
737            communities.len(),
738            2,
739            "Expected 2 communities, got {}: {:?}",
740            communities.len(),
741            communities
742        );
743        // Each community should have 3 nodes
744        assert_eq!(communities[0].len(), 3);
745        assert_eq!(communities[1].len(), 3);
746        // Check that each clique is in a separate community
747        let comm0_set: HashSet<&str> = communities[0].iter().map(|s| s.as_str()).collect();
748        let has_abc = comm0_set.contains("a") && comm0_set.contains("b") && comm0_set.contains("c");
749        let has_def = comm0_set.contains("d") && comm0_set.contains("e") && comm0_set.contains("f");
750        assert!(
751            has_abc || has_def,
752            "First community should be one of the cliques: {:?}",
753            communities[0]
754        );
755    }
756
757    #[test]
758    fn louvain_empty_graph() {
759        let graph = GraphEngine::new();
760        let communities = graph.louvain_communities(1.0);
761        assert!(communities.is_empty());
762    }
763
764    #[test]
765    fn louvain_single_node() {
766        let mut graph = GraphEngine::new();
767        graph.add_node(file_node("a", "a.rs")).unwrap();
768        let communities = graph.louvain_communities(1.0);
769        assert_eq!(communities.len(), 1);
770        assert_eq!(communities[0], vec!["a"]);
771    }
772
773    // ── Betweenness Centrality Tests ────────────────────────────────────────
774
775    #[test]
776    fn betweenness_chain_middle_highest() {
777        // a -> b -> c
778        // b is on the shortest path from a to c, so it should have highest betweenness
779        let mut graph = GraphEngine::new();
780        graph.add_node(file_node("a", "a.rs")).unwrap();
781        graph.add_node(file_node("b", "b.rs")).unwrap();
782        graph.add_node(file_node("c", "c.rs")).unwrap();
783        graph.add_edge(test_edge("a", "b")).unwrap();
784        graph.add_edge(test_edge("b", "c")).unwrap();
785
786        let bc = graph.betweenness_centrality();
787        assert_eq!(bc.len(), 3);
788        assert!(
789            bc["b"] > bc["a"],
790            "b ({}) should have higher betweenness than a ({})",
791            bc["b"],
792            bc["a"]
793        );
794        assert!(
795            bc["b"] > bc["c"],
796            "b ({}) should have higher betweenness than c ({})",
797            bc["b"],
798            bc["c"]
799        );
800        // a and c should have 0 betweenness (they are endpoints)
801        assert!(
802            bc["a"].abs() < f64::EPSILON,
803            "a should have 0 betweenness, got {}",
804            bc["a"]
805        );
806        assert!(
807            bc["c"].abs() < f64::EPSILON,
808            "c should have 0 betweenness, got {}",
809            bc["c"]
810        );
811    }
812
813    #[test]
814    fn betweenness_empty_graph() {
815        let graph = GraphEngine::new();
816        let bc = graph.betweenness_centrality();
817        assert!(bc.is_empty());
818    }
819
820    #[test]
821    fn betweenness_two_nodes() {
822        let mut graph = GraphEngine::new();
823        graph.add_node(file_node("a", "a.rs")).unwrap();
824        graph.add_node(file_node("b", "b.rs")).unwrap();
825        graph.add_edge(test_edge("a", "b")).unwrap();
826
827        let bc = graph.betweenness_centrality();
828        assert_eq!(bc.len(), 2);
829        assert!((bc["a"]).abs() < f64::EPSILON);
830        assert!((bc["b"]).abs() < f64::EPSILON);
831    }
832
833    // ── Strongly Connected Components Tests ─────────────────────────────────
834
835    #[test]
836    fn scc_cycle_all_in_one() {
837        // a -> b -> c -> a: all three should be in one SCC
838        let mut graph = GraphEngine::new();
839        graph.add_node(file_node("a", "a.rs")).unwrap();
840        graph.add_node(file_node("b", "b.rs")).unwrap();
841        graph.add_node(file_node("c", "c.rs")).unwrap();
842        graph.add_edge(test_edge("a", "b")).unwrap();
843        graph.add_edge(test_edge("b", "c")).unwrap();
844        graph.add_edge(test_edge("c", "a")).unwrap();
845
846        let sccs = graph.strongly_connected_components();
847        assert_eq!(
848            sccs.len(),
849            1,
850            "Expected 1 SCC, got {}: {:?}",
851            sccs.len(),
852            sccs
853        );
854        assert_eq!(sccs[0], vec!["a", "b", "c"]);
855    }
856
857    #[test]
858    fn scc_chain_each_separate() {
859        // a -> b -> c: no cycles, each node is its own SCC
860        let mut graph = GraphEngine::new();
861        graph.add_node(file_node("a", "a.rs")).unwrap();
862        graph.add_node(file_node("b", "b.rs")).unwrap();
863        graph.add_node(file_node("c", "c.rs")).unwrap();
864        graph.add_edge(test_edge("a", "b")).unwrap();
865        graph.add_edge(test_edge("b", "c")).unwrap();
866
867        let sccs = graph.strongly_connected_components();
868        assert_eq!(
869            sccs.len(),
870            3,
871            "Expected 3 SCCs, got {}: {:?}",
872            sccs.len(),
873            sccs
874        );
875    }
876
877    #[test]
878    fn scc_empty_graph() {
879        let graph = GraphEngine::new();
880        let sccs = graph.strongly_connected_components();
881        assert!(sccs.is_empty());
882    }
883
884    // ── Topological Sort Tests ──────────────────────────────────────────────
885
886    #[test]
887    fn topological_layers_dag() {
888        // a -> b, a -> c, b -> d, c -> d
889        // Layer 0: [a], Layer 1: [b, c], Layer 2: [d]
890        let mut graph = GraphEngine::new();
891        graph.add_node(file_node("a", "a.rs")).unwrap();
892        graph.add_node(file_node("b", "b.rs")).unwrap();
893        graph.add_node(file_node("c", "c.rs")).unwrap();
894        graph.add_node(file_node("d", "d.rs")).unwrap();
895        graph.add_edge(test_edge("a", "b")).unwrap();
896        graph.add_edge(test_edge("a", "c")).unwrap();
897        graph.add_edge(test_edge("b", "d")).unwrap();
898        graph.add_edge(test_edge("c", "d")).unwrap();
899
900        let layers = graph.topological_layers();
901        assert_eq!(
902            layers.len(),
903            3,
904            "Expected 3 layers, got {}: {:?}",
905            layers.len(),
906            layers
907        );
908        assert_eq!(layers[0], vec!["a"]);
909        assert_eq!(layers[1], vec!["b", "c"]); // sorted within layer
910        assert_eq!(layers[2], vec!["d"]);
911    }
912
913    #[test]
914    fn topological_layers_with_cycle() {
915        // a -> b -> c -> b (cycle between b and c), a -> d
916        // SCCs: {a}, {b, c}, {d}
917        // After condensation: {a} -> {b,c} and {a} -> {d}
918        // Layer 0: [a], Layer 1: [b, c, d] (b and c condensed, d also depends on a)
919        let mut graph = GraphEngine::new();
920        graph.add_node(file_node("a", "a.rs")).unwrap();
921        graph.add_node(file_node("b", "b.rs")).unwrap();
922        graph.add_node(file_node("c", "c.rs")).unwrap();
923        graph.add_node(file_node("d", "d.rs")).unwrap();
924        graph.add_edge(test_edge("a", "b")).unwrap();
925        graph.add_edge(test_edge("b", "c")).unwrap();
926        graph.add_edge(test_edge("c", "b")).unwrap();
927        graph.add_edge(test_edge("a", "d")).unwrap();
928
929        let layers = graph.topological_layers();
930        assert_eq!(
931            layers.len(),
932            2,
933            "Expected 2 layers, got {}: {:?}",
934            layers.len(),
935            layers
936        );
937        assert_eq!(layers[0], vec!["a"]);
938        // Layer 1 should contain b, c (from the cycle SCC) and d
939        assert!(layers[1].contains(&"b".to_string()));
940        assert!(layers[1].contains(&"c".to_string()));
941        assert!(layers[1].contains(&"d".to_string()));
942    }
943
944    #[test]
945    fn topological_layers_empty_graph() {
946        let graph = GraphEngine::new();
947        let layers = graph.topological_layers();
948        assert!(layers.is_empty());
949    }
950
951    #[test]
952    fn topological_layers_single_node() {
953        let mut graph = GraphEngine::new();
954        graph.add_node(file_node("a", "a.rs")).unwrap();
955        let layers = graph.topological_layers();
956        assert_eq!(layers.len(), 1);
957        assert_eq!(layers[0], vec!["a"]);
958    }
959}