Skip to main content

codemem_graph/
lib.rs

1//! codemem-graph: Graph engine with petgraph algorithms and SQLite persistence.
2//!
3//! Provides BFS, DFS, shortest path, and connected components over
4//! a knowledge graph with 6 node types and 15 relationship types.
5
6#[cfg(test)]
7use codemem_core::NodeKind;
8use codemem_core::{CodememError, Edge, GraphBackend, GraphNode, GraphStats};
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::visit::Bfs;
11use petgraph::Direction;
12use std::collections::{HashMap, HashSet, VecDeque};
13
14/// In-memory graph backed by petgraph, synced to SQLite via codemem-storage.
15pub struct GraphEngine {
16    graph: DiGraph<String, f64>,
17    /// Map from string node IDs to petgraph NodeIndex.
18    id_to_index: HashMap<String, NodeIndex>,
19    /// Node data by ID.
20    nodes: HashMap<String, GraphNode>,
21    /// Edge data by ID.
22    edges: HashMap<String, Edge>,
23    /// Cached PageRank scores (populated by `recompute_centrality()`).
24    cached_pagerank: HashMap<String, f64>,
25    /// Cached betweenness centrality scores (populated by `recompute_centrality()`).
26    cached_betweenness: HashMap<String, f64>,
27}
28
29impl GraphEngine {
30    /// Create a new empty graph.
31    pub fn new() -> Self {
32        Self {
33            graph: DiGraph::new(),
34            id_to_index: HashMap::new(),
35            nodes: HashMap::new(),
36            edges: HashMap::new(),
37            cached_pagerank: HashMap::new(),
38            cached_betweenness: HashMap::new(),
39        }
40    }
41
42    /// Load graph from storage.
43    pub fn from_storage(storage: &codemem_storage::Storage) -> Result<Self, CodememError> {
44        let mut engine = Self::new();
45
46        // Load all nodes
47        let nodes = storage.all_graph_nodes()?;
48        for node in nodes {
49            engine.add_node(node)?;
50        }
51
52        // Load all edges
53        let edges = storage.all_graph_edges()?;
54        for edge in edges {
55            engine.add_edge(edge)?;
56        }
57
58        Ok(engine)
59    }
60
61    /// Get the number of nodes.
62    pub fn node_count(&self) -> usize {
63        self.nodes.len()
64    }
65
66    /// Get the number of edges.
67    pub fn edge_count(&self) -> usize {
68        self.edges.len()
69    }
70
71    /// Multi-hop expansion: given a set of node IDs, expand N hops to find related nodes.
72    pub fn expand(
73        &self,
74        start_ids: &[String],
75        max_hops: usize,
76    ) -> Result<Vec<GraphNode>, CodememError> {
77        let mut visited = std::collections::HashSet::new();
78        let mut result = Vec::new();
79
80        for start_id in start_ids {
81            let nodes = self.bfs(start_id, max_hops)?;
82            for node in nodes {
83                if visited.insert(node.id.clone()) {
84                    result.push(node);
85                }
86            }
87        }
88
89        Ok(result)
90    }
91
92    /// Get neighbors of a node (1-hop).
93    pub fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, CodememError> {
94        let idx = self
95            .id_to_index
96            .get(node_id)
97            .ok_or_else(|| CodememError::NotFound(format!("Node {node_id}")))?;
98
99        let mut result = Vec::new();
100        for neighbor_idx in self.graph.neighbors(*idx) {
101            if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
102                if let Some(node) = self.nodes.get(neighbor_id) {
103                    result.push(node.clone());
104                }
105            }
106        }
107
108        Ok(result)
109    }
110
111    /// Return groups of connected node IDs.
112    ///
113    /// Treats the directed graph as undirected: two nodes are in the same
114    /// component if there is a path between them in either direction.
115    /// Each inner `Vec<String>` is one connected component.
116    pub fn connected_components(&self) -> Vec<Vec<String>> {
117        let mut visited: HashSet<NodeIndex> = HashSet::new();
118        let mut components: Vec<Vec<String>> = Vec::new();
119
120        for &start_idx in self.id_to_index.values() {
121            if visited.contains(&start_idx) {
122                continue;
123            }
124
125            // BFS treating edges as undirected
126            let mut component: Vec<String> = Vec::new();
127            let mut queue: VecDeque<NodeIndex> = VecDeque::new();
128            queue.push_back(start_idx);
129            visited.insert(start_idx);
130
131            while let Some(current) = queue.pop_front() {
132                if let Some(node_id) = self.graph.node_weight(current) {
133                    component.push(node_id.clone());
134                }
135
136                // Follow outgoing edges
137                for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
138                    if visited.insert(neighbor) {
139                        queue.push_back(neighbor);
140                    }
141                }
142
143                // Follow incoming edges (treat as undirected)
144                for neighbor in self.graph.neighbors_directed(current, Direction::Incoming) {
145                    if visited.insert(neighbor) {
146                        queue.push_back(neighbor);
147                    }
148                }
149            }
150
151            component.sort();
152            components.push(component);
153        }
154
155        components.sort();
156        components
157    }
158
159    /// Compute degree centrality for every node and update their `centrality` field.
160    ///
161    /// Degree centrality for node *v* is defined as:
162    ///   `(in_degree(v) + out_degree(v)) / (N - 1)`
163    /// where *N* is the total number of nodes.  When N <= 1, centrality is 0.
164    pub fn compute_centrality(&mut self) {
165        let n = self.nodes.len();
166        if n <= 1 {
167            for node in self.nodes.values_mut() {
168                node.centrality = 0.0;
169            }
170            return;
171        }
172
173        let denominator = (n - 1) as f64;
174
175        // Pre-compute centrality values by node ID.
176        let centrality_map: HashMap<String, f64> = self
177            .id_to_index
178            .iter()
179            .map(|(id, &idx)| {
180                let in_deg = self
181                    .graph
182                    .neighbors_directed(idx, Direction::Incoming)
183                    .count();
184                let out_deg = self
185                    .graph
186                    .neighbors_directed(idx, Direction::Outgoing)
187                    .count();
188                let centrality = (in_deg + out_deg) as f64 / denominator;
189                (id.clone(), centrality)
190            })
191            .collect();
192
193        // Apply centrality values to the stored nodes.
194        for (id, centrality) in &centrality_map {
195            if let Some(node) = self.nodes.get_mut(id) {
196                node.centrality = *centrality;
197            }
198        }
199    }
200
201    /// Return all nodes currently in the graph.
202    pub fn get_all_nodes(&self) -> Vec<GraphNode> {
203        self.nodes.values().cloned().collect()
204    }
205
206    /// Recompute and cache PageRank and betweenness centrality scores.
207    ///
208    /// This should be called after loading the graph (e.g., on server start)
209    /// and periodically when the graph changes significantly.
210    pub fn recompute_centrality(&mut self) {
211        self.cached_pagerank = self.pagerank(0.85, 100, 1e-6);
212        self.cached_betweenness = self.betweenness_centrality();
213    }
214
215    /// Get the cached PageRank score for a node. Returns 0.0 if not found.
216    pub fn get_pagerank(&self, node_id: &str) -> f64 {
217        self.cached_pagerank.get(node_id).copied().unwrap_or(0.0)
218    }
219
220    /// Get the cached betweenness centrality score for a node. Returns 0.0 if not found.
221    pub fn get_betweenness(&self, node_id: &str) -> f64 {
222        self.cached_betweenness.get(node_id).copied().unwrap_or(0.0)
223    }
224
225    /// Get the maximum degree (in + out) across all nodes in the graph.
226    /// Returns 1.0 if the graph has fewer than 2 nodes to avoid division by zero.
227    pub fn max_degree(&self) -> f64 {
228        if self.nodes.len() <= 1 {
229            return 1.0;
230        }
231        self.id_to_index
232            .values()
233            .map(|&idx| {
234                let in_deg = self
235                    .graph
236                    .neighbors_directed(idx, Direction::Incoming)
237                    .count();
238                let out_deg = self
239                    .graph
240                    .neighbors_directed(idx, Direction::Outgoing)
241                    .count();
242                (in_deg + out_deg) as f64
243            })
244            .fold(1.0f64, f64::max)
245    }
246}
247
248impl Default for GraphEngine {
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254impl GraphBackend for GraphEngine {
255    fn add_node(&mut self, node: GraphNode) -> Result<(), CodememError> {
256        let id = node.id.clone();
257
258        if !self.id_to_index.contains_key(&id) {
259            let idx = self.graph.add_node(id.clone());
260            self.id_to_index.insert(id.clone(), idx);
261        }
262
263        self.nodes.insert(id, node);
264        Ok(())
265    }
266
267    fn get_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
268        Ok(self.nodes.get(id).cloned())
269    }
270
271    fn remove_node(&mut self, id: &str) -> Result<bool, CodememError> {
272        if let Some(idx) = self.id_to_index.remove(id) {
273            self.graph.remove_node(idx);
274            self.nodes.remove(id);
275
276            // Remove associated edges
277            let edge_ids: Vec<String> = self
278                .edges
279                .iter()
280                .filter(|(_, e)| e.src == id || e.dst == id)
281                .map(|(eid, _)| eid.clone())
282                .collect();
283            for eid in edge_ids {
284                self.edges.remove(&eid);
285            }
286
287            Ok(true)
288        } else {
289            Ok(false)
290        }
291    }
292
293    fn add_edge(&mut self, edge: Edge) -> Result<(), CodememError> {
294        let src_idx = self
295            .id_to_index
296            .get(&edge.src)
297            .ok_or_else(|| CodememError::NotFound(format!("Source node {}", edge.src)))?;
298        let dst_idx = self
299            .id_to_index
300            .get(&edge.dst)
301            .ok_or_else(|| CodememError::NotFound(format!("Destination node {}", edge.dst)))?;
302
303        self.graph.add_edge(*src_idx, *dst_idx, edge.weight);
304        self.edges.insert(edge.id.clone(), edge);
305        Ok(())
306    }
307
308    fn get_edges(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
309        let edges: Vec<Edge> = self
310            .edges
311            .values()
312            .filter(|e| e.src == node_id || e.dst == node_id)
313            .cloned()
314            .collect();
315        Ok(edges)
316    }
317
318    fn remove_edge(&mut self, id: &str) -> Result<bool, CodememError> {
319        if let Some(edge) = self.edges.remove(id) {
320            // Also remove from petgraph
321            if let (Some(&src_idx), Some(&dst_idx)) = (
322                self.id_to_index.get(&edge.src),
323                self.id_to_index.get(&edge.dst),
324            ) {
325                if let Some(edge_idx) = self.graph.find_edge(src_idx, dst_idx) {
326                    self.graph.remove_edge(edge_idx);
327                }
328            }
329            Ok(true)
330        } else {
331            Ok(false)
332        }
333    }
334
335    fn bfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
336        let start_idx = self
337            .id_to_index
338            .get(start_id)
339            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
340
341        let mut visited = HashSet::new();
342        let mut result = Vec::new();
343        let mut bfs = Bfs::new(&self.graph, *start_idx);
344        let mut depth_map: HashMap<NodeIndex, usize> = HashMap::new();
345        depth_map.insert(*start_idx, 0);
346
347        while let Some(node_idx) = bfs.next(&self.graph) {
348            let depth = depth_map.get(&node_idx).copied().unwrap_or(0);
349            if depth > max_depth {
350                continue;
351            }
352
353            if visited.insert(node_idx) {
354                if let Some(node_id) = self.graph.node_weight(node_idx) {
355                    if let Some(node) = self.nodes.get(node_id) {
356                        result.push(node.clone());
357                    }
358                }
359            }
360
361            // Set depth for neighbors
362            for neighbor in self.graph.neighbors(node_idx) {
363                depth_map.entry(neighbor).or_insert(depth + 1);
364            }
365        }
366
367        Ok(result)
368    }
369
370    fn dfs(&self, start_id: &str, max_depth: usize) -> Result<Vec<GraphNode>, CodememError> {
371        let start_idx = self
372            .id_to_index
373            .get(start_id)
374            .ok_or_else(|| CodememError::NotFound(format!("Node {start_id}")))?;
375
376        let mut visited = HashSet::new();
377        let mut result = Vec::new();
378        let mut stack: Vec<(NodeIndex, usize)> = vec![(*start_idx, 0)];
379
380        while let Some((node_idx, depth)) = stack.pop() {
381            if depth > max_depth || !visited.insert(node_idx) {
382                continue;
383            }
384
385            if let Some(node_id) = self.graph.node_weight(node_idx) {
386                if let Some(node) = self.nodes.get(node_id) {
387                    result.push(node.clone());
388                }
389            }
390
391            for neighbor in self.graph.neighbors(node_idx) {
392                if !visited.contains(&neighbor) {
393                    stack.push((neighbor, depth + 1));
394                }
395            }
396        }
397
398        Ok(result)
399    }
400
401    fn shortest_path(&self, from: &str, to: &str) -> Result<Vec<String>, CodememError> {
402        let from_idx = self
403            .id_to_index
404            .get(from)
405            .ok_or_else(|| CodememError::NotFound(format!("Node {from}")))?;
406        let to_idx = self
407            .id_to_index
408            .get(to)
409            .ok_or_else(|| CodememError::NotFound(format!("Node {to}")))?;
410
411        // BFS shortest path (unweighted)
412        use petgraph::algo::astar;
413        let path = astar(
414            &self.graph,
415            *from_idx,
416            |finish| finish == *to_idx,
417            |_| 1.0f64,
418            |_| 0.0f64,
419        );
420
421        match path {
422            Some((_cost, nodes)) => {
423                let ids: Vec<String> = nodes
424                    .iter()
425                    .filter_map(|idx| self.graph.node_weight(*idx).cloned())
426                    .collect();
427                Ok(ids)
428            }
429            None => Ok(vec![]),
430        }
431    }
432
433    fn stats(&self) -> GraphStats {
434        let mut node_kind_counts = HashMap::new();
435        for node in self.nodes.values() {
436            *node_kind_counts.entry(node.kind.to_string()).or_insert(0) += 1;
437        }
438
439        let mut relationship_type_counts = HashMap::new();
440        for edge in self.edges.values() {
441            *relationship_type_counts
442                .entry(edge.relationship.to_string())
443                .or_insert(0) += 1;
444        }
445
446        GraphStats {
447            node_count: self.nodes.len(),
448            edge_count: self.edges.len(),
449            node_kind_counts,
450            relationship_type_counts,
451        }
452    }
453}
454
455// ── Advanced Graph Algorithms ───────────────────────────────────────────────
456
457impl GraphEngine {
458    /// Compute PageRank scores for all nodes using power iteration.
459    ///
460    /// - `damping`: probability of following an edge (default 0.85)
461    /// - `iterations`: max number of power iterations (default 100)
462    /// - `tolerance`: convergence threshold (default 1e-6)
463    ///
464    /// Returns a map from node ID to PageRank score.
465    pub fn pagerank(
466        &self,
467        damping: f64,
468        iterations: usize,
469        tolerance: f64,
470    ) -> HashMap<String, f64> {
471        let n = self.graph.node_count();
472        if n == 0 {
473            return HashMap::new();
474        }
475
476        let nf = n as f64;
477        let initial = 1.0 / nf;
478
479        // Collect all node indices in a stable order
480        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
481        let idx_pos: HashMap<NodeIndex, usize> = indices
482            .iter()
483            .enumerate()
484            .map(|(i, &idx)| (idx, i))
485            .collect();
486
487        let mut scores = vec![initial; n];
488
489        // Precompute out-degrees
490        let out_degree: Vec<usize> = indices
491            .iter()
492            .map(|&idx| {
493                self.graph
494                    .neighbors_directed(idx, Direction::Outgoing)
495                    .count()
496            })
497            .collect();
498
499        for _ in 0..iterations {
500            let mut new_scores = vec![(1.0 - damping) / nf; n];
501
502            // Distribute rank from each node to its out-neighbors
503            for (i, &idx) in indices.iter().enumerate() {
504                let deg = out_degree[i];
505                if deg == 0 {
506                    // Dangling node: distribute evenly to all nodes
507                    let share = damping * scores[i] / nf;
508                    for ns in new_scores.iter_mut() {
509                        *ns += share;
510                    }
511                } else {
512                    let share = damping * scores[i] / deg as f64;
513                    for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
514                        if let Some(&pos) = idx_pos.get(&neighbor) {
515                            new_scores[pos] += share;
516                        }
517                    }
518                }
519            }
520
521            // Check convergence
522            let diff: f64 = scores
523                .iter()
524                .zip(new_scores.iter())
525                .map(|(a, b)| (a - b).abs())
526                .sum();
527
528            scores = new_scores;
529
530            if diff < tolerance {
531                break;
532            }
533        }
534
535        // Map back to node IDs
536        indices
537            .iter()
538            .enumerate()
539            .filter_map(|(i, &idx)| {
540                self.graph
541                    .node_weight(idx)
542                    .map(|id| (id.clone(), scores[i]))
543            })
544            .collect()
545    }
546
547    /// Compute Personalized PageRank with custom teleport weights.
548    ///
549    /// `seed_weights` maps node IDs to teleport probabilities (will be normalized).
550    /// Nodes not in seed_weights get zero teleport probability.
551    ///
552    /// Used for blast-radius analysis and HippoRAG-2-style retrieval.
553    pub fn personalized_pagerank(
554        &self,
555        seed_weights: &HashMap<String, f64>,
556        damping: f64,
557        iterations: usize,
558        tolerance: f64,
559    ) -> HashMap<String, f64> {
560        let n = self.graph.node_count();
561        if n == 0 {
562            return HashMap::new();
563        }
564
565        let nf = n as f64;
566
567        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
568        let idx_pos: HashMap<NodeIndex, usize> = indices
569            .iter()
570            .enumerate()
571            .map(|(i, &idx)| (idx, i))
572            .collect();
573
574        // Build and normalize the teleport vector
575        let mut teleport = vec![0.0f64; n];
576        let mut teleport_sum = 0.0;
577        for (i, &idx) in indices.iter().enumerate() {
578            if let Some(node_id) = self.graph.node_weight(idx) {
579                if let Some(&w) = seed_weights.get(node_id) {
580                    teleport[i] = w;
581                    teleport_sum += w;
582                }
583            }
584        }
585        // Normalize; if no seeds provided, fall back to uniform
586        if teleport_sum > 0.0 {
587            for t in teleport.iter_mut() {
588                *t /= teleport_sum;
589            }
590        } else {
591            for t in teleport.iter_mut() {
592                *t = 1.0 / nf;
593            }
594        }
595
596        let initial = 1.0 / nf;
597        let mut scores = vec![initial; n];
598
599        let out_degree: Vec<usize> = indices
600            .iter()
601            .map(|&idx| {
602                self.graph
603                    .neighbors_directed(idx, Direction::Outgoing)
604                    .count()
605            })
606            .collect();
607
608        for _ in 0..iterations {
609            let mut new_scores: Vec<f64> = teleport.iter().map(|&t| (1.0 - damping) * t).collect();
610
611            for (i, &idx) in indices.iter().enumerate() {
612                let deg = out_degree[i];
613                if deg == 0 {
614                    // Dangling node: distribute to teleport targets
615                    let share = damping * scores[i];
616                    for (j, t) in teleport.iter().enumerate() {
617                        new_scores[j] += share * t;
618                    }
619                } else {
620                    let share = damping * scores[i] / deg as f64;
621                    for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
622                        if let Some(&pos) = idx_pos.get(&neighbor) {
623                            new_scores[pos] += share;
624                        }
625                    }
626                }
627            }
628
629            let diff: f64 = scores
630                .iter()
631                .zip(new_scores.iter())
632                .map(|(a, b)| (a - b).abs())
633                .sum();
634
635            scores = new_scores;
636
637            if diff < tolerance {
638                break;
639            }
640        }
641
642        indices
643            .iter()
644            .enumerate()
645            .filter_map(|(i, &idx)| {
646                self.graph
647                    .node_weight(idx)
648                    .map(|id| (id.clone(), scores[i]))
649            })
650            .collect()
651    }
652
653    /// Detect communities using the Louvain algorithm.
654    ///
655    /// Treats the directed graph as undirected for modularity computation.
656    /// `resolution` controls community granularity (1.0 = standard modularity).
657    /// Returns groups of node IDs, one group per community.
658    pub fn louvain_communities(&self, resolution: f64) -> Vec<Vec<String>> {
659        let n = self.graph.node_count();
660        if n == 0 {
661            return Vec::new();
662        }
663
664        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
665        let idx_pos: HashMap<NodeIndex, usize> = indices
666            .iter()
667            .enumerate()
668            .map(|(i, &idx)| (idx, i))
669            .collect();
670
671        // Build undirected adjacency with weights.
672        // adj[i] contains (j, weight) for each undirected neighbor.
673        let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
674        let mut total_weight = 0.0;
675
676        for edge_ref in self.graph.edge_indices() {
677            if let Some((src_idx, dst_idx)) = self.graph.edge_endpoints(edge_ref) {
678                let w = self.graph[edge_ref];
679                if let (Some(&si), Some(&di)) = (idx_pos.get(&src_idx), idx_pos.get(&dst_idx)) {
680                    adj[si].push((di, w));
681                    adj[di].push((si, w));
682                    total_weight += w; // Each undirected edge contributes w (counted once)
683                }
684            }
685        }
686
687        if total_weight == 0.0 {
688            // No edges: each node is its own community
689            return indices
690                .iter()
691                .filter_map(|&idx| self.graph.node_weight(idx).map(|id| vec![id.clone()]))
692                .collect();
693        }
694
695        // m = total edge weight (for undirected: sum of all edge weights)
696        let m = total_weight;
697        let m2 = 2.0 * m;
698
699        // Weighted degree of each node (sum of incident edge weights, undirected)
700        let k: Vec<f64> = (0..n)
701            .map(|i| adj[i].iter().map(|&(_, w)| w).sum())
702            .collect();
703
704        // Initial assignment: each node in its own community
705        let mut community: Vec<usize> = (0..n).collect();
706
707        // Iteratively move nodes to improve modularity
708        let mut improved = true;
709        let max_passes = 100;
710        let mut pass = 0;
711
712        while improved && pass < max_passes {
713            improved = false;
714            pass += 1;
715
716            for i in 0..n {
717                let current_comm = community[i];
718
719                // Compute weights to each neighboring community
720                let mut comm_weights: HashMap<usize, f64> = HashMap::new();
721                for &(j, w) in &adj[i] {
722                    *comm_weights.entry(community[j]).or_insert(0.0) += w;
723                }
724
725                // Sum of degrees in each community (excluding node i for its own community)
726                let mut comm_degree_sum: HashMap<usize, f64> = HashMap::new();
727                for j in 0..n {
728                    *comm_degree_sum.entry(community[j]).or_insert(0.0) += k[j];
729                }
730
731                let ki = k[i];
732
733                // Modularity gain for removing i from its current community
734                let w_in_current = comm_weights.get(&current_comm).copied().unwrap_or(0.0);
735                let sigma_current = comm_degree_sum.get(&current_comm).copied().unwrap_or(0.0);
736                let remove_cost = w_in_current - resolution * ki * (sigma_current - ki) / m2;
737
738                // Find best community to move to
739                let mut best_comm = current_comm;
740                let mut best_gain = 0.0;
741
742                for (&comm, &w_in_comm) in &comm_weights {
743                    if comm == current_comm {
744                        continue;
745                    }
746                    let sigma_comm = comm_degree_sum.get(&comm).copied().unwrap_or(0.0);
747                    let gain = w_in_comm - resolution * ki * sigma_comm / m2 - remove_cost;
748                    if gain > best_gain {
749                        best_gain = gain;
750                        best_comm = comm;
751                    }
752                }
753
754                if best_comm != current_comm {
755                    community[i] = best_comm;
756                    improved = true;
757                }
758            }
759        }
760
761        // Group nodes by community
762        let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
763        for (i, &idx) in indices.iter().enumerate() {
764            if let Some(node_id) = self.graph.node_weight(idx) {
765                groups
766                    .entry(community[i])
767                    .or_default()
768                    .push(node_id.clone());
769            }
770        }
771
772        let mut result: Vec<Vec<String>> = groups.into_values().collect();
773        for group in result.iter_mut() {
774            group.sort();
775        }
776        result.sort();
777        result
778    }
779
780    /// Compute betweenness centrality for all nodes using Brandes' algorithm.
781    ///
782    /// For graphs with more than 1000 nodes, samples sqrt(n) source nodes
783    /// for approximate computation.
784    ///
785    /// Returns a map from node ID to betweenness centrality score (normalized by
786    /// 1/((n-1)(n-2)) for directed graphs).
787    pub fn betweenness_centrality(&self) -> HashMap<String, f64> {
788        let n = self.graph.node_count();
789        if n <= 2 {
790            return self
791                .graph
792                .node_indices()
793                .filter_map(|idx| self.graph.node_weight(idx).map(|id| (id.clone(), 0.0)))
794                .collect();
795        }
796
797        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
798        let idx_pos: HashMap<NodeIndex, usize> = indices
799            .iter()
800            .enumerate()
801            .map(|(i, &idx)| (idx, i))
802            .collect();
803
804        let mut centrality = vec![0.0f64; n];
805
806        // Determine source nodes (sample for large graphs)
807        let sources: Vec<usize> = if n > 1000 {
808            let sample_size = (n as f64).sqrt() as usize;
809            // Deterministic sampling: evenly spaced
810            let step = n / sample_size;
811            (0..sample_size).map(|i| i * step).collect()
812        } else {
813            (0..n).collect()
814        };
815
816        let scale = if n > 1000 {
817            n as f64 / sources.len() as f64
818        } else {
819            1.0
820        };
821
822        for &s in &sources {
823            // Brandes' algorithm from source s
824            let mut stack: Vec<usize> = Vec::new();
825            let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
826            let mut sigma = vec![0.0f64; n]; // number of shortest paths
827            sigma[s] = 1.0;
828            let mut dist: Vec<i64> = vec![-1; n];
829            dist[s] = 0;
830
831            let mut queue: VecDeque<usize> = VecDeque::new();
832            queue.push_back(s);
833
834            while let Some(v) = queue.pop_front() {
835                stack.push(v);
836                let v_idx = indices[v];
837                for neighbor in self.graph.neighbors_directed(v_idx, Direction::Outgoing) {
838                    if let Some(&w) = idx_pos.get(&neighbor) {
839                        if dist[w] < 0 {
840                            dist[w] = dist[v] + 1;
841                            queue.push_back(w);
842                        }
843                        if dist[w] == dist[v] + 1 {
844                            sigma[w] += sigma[v];
845                            predecessors[w].push(v);
846                        }
847                    }
848                }
849            }
850
851            let mut delta = vec![0.0f64; n];
852            while let Some(w) = stack.pop() {
853                for &v in &predecessors[w] {
854                    delta[v] += (sigma[v] / sigma[w]) * (1.0 + delta[w]);
855                }
856                if w != s {
857                    centrality[w] += delta[w];
858                }
859            }
860        }
861
862        // Apply sampling scale and normalize
863        let norm = ((n - 1) * (n - 2)) as f64;
864        indices
865            .iter()
866            .enumerate()
867            .filter_map(|(i, &idx)| {
868                self.graph
869                    .node_weight(idx)
870                    .map(|id| (id.clone(), centrality[i] * scale / norm))
871            })
872            .collect()
873    }
874
875    /// Find all strongly connected components using Tarjan's algorithm.
876    ///
877    /// Returns groups of node IDs. Each group is a strongly connected component
878    /// where every node can reach every other node via directed edges.
879    pub fn strongly_connected_components(&self) -> Vec<Vec<String>> {
880        let sccs = petgraph::algo::tarjan_scc(&self.graph);
881
882        let mut result: Vec<Vec<String>> = sccs
883            .into_iter()
884            .map(|component| {
885                let mut ids: Vec<String> = component
886                    .into_iter()
887                    .filter_map(|idx| self.graph.node_weight(idx).cloned())
888                    .collect();
889                ids.sort();
890                ids
891            })
892            .collect();
893
894        result.sort();
895        result
896    }
897
898    /// Compute topological layers using Kahn's algorithm.
899    ///
900    /// Returns layers where all nodes in layer i have no dependencies on nodes
901    /// in layer i or later. For cyclic graphs, SCCs are condensed into single
902    /// super-nodes first, then the resulting DAG is topologically sorted.
903    ///
904    /// Each inner Vec contains the node IDs at that layer.
905    pub fn topological_layers(&self) -> Vec<Vec<String>> {
906        let n = self.graph.node_count();
907        if n == 0 {
908            return Vec::new();
909        }
910
911        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
912        let idx_pos: HashMap<NodeIndex, usize> = indices
913            .iter()
914            .enumerate()
915            .map(|(i, &idx)| (idx, i))
916            .collect();
917
918        // Step 1: Find SCCs
919        let sccs = petgraph::algo::tarjan_scc(&self.graph);
920
921        // Map each node position to its SCC index
922        let mut node_to_scc = vec![0usize; n];
923        for (scc_idx, scc) in sccs.iter().enumerate() {
924            for &node_idx in scc {
925                if let Some(&pos) = idx_pos.get(&node_idx) {
926                    node_to_scc[pos] = scc_idx;
927                }
928            }
929        }
930
931        let num_sccs = sccs.len();
932
933        // Step 2: Build condensed DAG (SCC graph)
934        let mut condensed_adj: Vec<HashSet<usize>> = vec![HashSet::new(); num_sccs];
935        let mut condensed_in_degree = vec![0usize; num_sccs];
936
937        for &idx in &indices {
938            if let Some(&src_pos) = idx_pos.get(&idx) {
939                let src_scc = node_to_scc[src_pos];
940                for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
941                    if let Some(&dst_pos) = idx_pos.get(&neighbor) {
942                        let dst_scc = node_to_scc[dst_pos];
943                        if src_scc != dst_scc && condensed_adj[src_scc].insert(dst_scc) {
944                            condensed_in_degree[dst_scc] += 1;
945                        }
946                    }
947                }
948            }
949        }
950
951        // Step 3: Kahn's algorithm on the condensed DAG
952        let mut queue: VecDeque<usize> = VecDeque::new();
953        for (i, &deg) in condensed_in_degree.iter().enumerate().take(num_sccs) {
954            if deg == 0 {
955                queue.push_back(i);
956            }
957        }
958
959        let mut scc_layers: Vec<Vec<usize>> = Vec::new();
960        while !queue.is_empty() {
961            let mut layer = Vec::new();
962            let mut next_queue = VecDeque::new();
963
964            while let Some(scc_idx) = queue.pop_front() {
965                layer.push(scc_idx);
966                for &neighbor_scc in &condensed_adj[scc_idx] {
967                    condensed_in_degree[neighbor_scc] -= 1;
968                    if condensed_in_degree[neighbor_scc] == 0 {
969                        next_queue.push_back(neighbor_scc);
970                    }
971                }
972            }
973
974            scc_layers.push(layer);
975            queue = next_queue;
976        }
977
978        // Step 4: Expand SCC layers back to node IDs
979        let mut result: Vec<Vec<String>> = Vec::new();
980        for scc_layer in scc_layers {
981            let mut layer_nodes: Vec<String> = Vec::new();
982            for scc_idx in scc_layer {
983                for &node_idx in &sccs[scc_idx] {
984                    if let Some(id) = self.graph.node_weight(node_idx) {
985                        layer_nodes.push(id.clone());
986                    }
987                }
988            }
989            layer_nodes.sort();
990            result.push(layer_nodes);
991        }
992
993        result
994    }
995}
996
997#[cfg(test)]
998mod tests {
999    use super::*;
1000    use codemem_core::RelationshipType;
1001
1002    fn file_node(id: &str, label: &str) -> GraphNode {
1003        GraphNode {
1004            id: id.to_string(),
1005            kind: NodeKind::File,
1006            label: label.to_string(),
1007            payload: HashMap::new(),
1008            centrality: 0.0,
1009            memory_id: None,
1010            namespace: None,
1011        }
1012    }
1013
1014    fn test_edge(src: &str, dst: &str) -> Edge {
1015        Edge {
1016            id: format!("{src}->{dst}"),
1017            src: src.to_string(),
1018            dst: dst.to_string(),
1019            relationship: RelationshipType::Contains,
1020            weight: 1.0,
1021            properties: HashMap::new(),
1022            created_at: chrono::Utc::now(),
1023        }
1024    }
1025
1026    #[test]
1027    fn add_nodes_and_edges() {
1028        let mut graph = GraphEngine::new();
1029        graph.add_node(file_node("a", "a.rs")).unwrap();
1030        graph.add_node(file_node("b", "b.rs")).unwrap();
1031        graph.add_edge(test_edge("a", "b")).unwrap();
1032
1033        assert_eq!(graph.node_count(), 2);
1034        assert_eq!(graph.edge_count(), 1);
1035    }
1036
1037    #[test]
1038    fn bfs_traversal() {
1039        let mut graph = GraphEngine::new();
1040        graph.add_node(file_node("a", "a.rs")).unwrap();
1041        graph.add_node(file_node("b", "b.rs")).unwrap();
1042        graph.add_node(file_node("c", "c.rs")).unwrap();
1043        graph.add_edge(test_edge("a", "b")).unwrap();
1044        graph.add_edge(test_edge("b", "c")).unwrap();
1045
1046        let nodes = graph.bfs("a", 1).unwrap();
1047        assert_eq!(nodes.len(), 2); // a and b (c is at depth 2)
1048    }
1049
1050    #[test]
1051    fn shortest_path() {
1052        let mut graph = GraphEngine::new();
1053        graph.add_node(file_node("a", "a.rs")).unwrap();
1054        graph.add_node(file_node("b", "b.rs")).unwrap();
1055        graph.add_node(file_node("c", "c.rs")).unwrap();
1056        graph.add_edge(test_edge("a", "b")).unwrap();
1057        graph.add_edge(test_edge("b", "c")).unwrap();
1058
1059        let path = graph.shortest_path("a", "c").unwrap();
1060        assert_eq!(path, vec!["a", "b", "c"]);
1061    }
1062
1063    #[test]
1064    fn connected_components_single_component() {
1065        let mut graph = GraphEngine::new();
1066        graph.add_node(file_node("a", "a.rs")).unwrap();
1067        graph.add_node(file_node("b", "b.rs")).unwrap();
1068        graph.add_node(file_node("c", "c.rs")).unwrap();
1069        graph.add_edge(test_edge("a", "b")).unwrap();
1070        graph.add_edge(test_edge("b", "c")).unwrap();
1071
1072        let components = graph.connected_components();
1073        assert_eq!(components.len(), 1);
1074        assert_eq!(components[0], vec!["a", "b", "c"]);
1075    }
1076
1077    #[test]
1078    fn connected_components_multiple() {
1079        let mut graph = GraphEngine::new();
1080        graph.add_node(file_node("a", "a.rs")).unwrap();
1081        graph.add_node(file_node("b", "b.rs")).unwrap();
1082        graph.add_node(file_node("c", "c.rs")).unwrap();
1083        graph.add_node(file_node("d", "d.rs")).unwrap();
1084        graph.add_edge(test_edge("a", "b")).unwrap();
1085        graph.add_edge(test_edge("c", "d")).unwrap();
1086
1087        let components = graph.connected_components();
1088        assert_eq!(components.len(), 2);
1089        assert_eq!(components[0], vec!["a", "b"]);
1090        assert_eq!(components[1], vec!["c", "d"]);
1091    }
1092
1093    #[test]
1094    fn connected_components_isolated_node() {
1095        let mut graph = GraphEngine::new();
1096        graph.add_node(file_node("a", "a.rs")).unwrap();
1097        graph.add_node(file_node("b", "b.rs")).unwrap();
1098        graph.add_node(file_node("c", "c.rs")).unwrap();
1099        graph.add_edge(test_edge("a", "b")).unwrap();
1100        // "c" is isolated
1101
1102        let components = graph.connected_components();
1103        assert_eq!(components.len(), 2);
1104        // Sorted: ["a","b"] comes before ["c"]
1105        assert_eq!(components[0], vec!["a", "b"]);
1106        assert_eq!(components[1], vec!["c"]);
1107    }
1108
1109    #[test]
1110    fn connected_components_reverse_edge_connects() {
1111        // Directed edge c->a should still put a and c in the same component
1112        // when treated as undirected.
1113        let mut graph = GraphEngine::new();
1114        graph.add_node(file_node("a", "a.rs")).unwrap();
1115        graph.add_node(file_node("b", "b.rs")).unwrap();
1116        graph.add_node(file_node("c", "c.rs")).unwrap();
1117        graph.add_edge(test_edge("a", "b")).unwrap();
1118        graph.add_edge(test_edge("c", "a")).unwrap();
1119
1120        let components = graph.connected_components();
1121        assert_eq!(components.len(), 1);
1122        assert_eq!(components[0], vec!["a", "b", "c"]);
1123    }
1124
1125    #[test]
1126    fn connected_components_empty_graph() {
1127        let graph = GraphEngine::new();
1128        let components = graph.connected_components();
1129        assert!(components.is_empty());
1130    }
1131
1132    #[test]
1133    fn compute_centrality_simple() {
1134        // Graph: a -> b -> c
1135        // Node a: out=1, in=0 => centrality = 1/2 = 0.5
1136        // Node b: out=1, in=1 => centrality = 2/2 = 1.0
1137        // Node c: out=0, in=1 => centrality = 1/2 = 0.5
1138        let mut graph = GraphEngine::new();
1139        graph.add_node(file_node("a", "a.rs")).unwrap();
1140        graph.add_node(file_node("b", "b.rs")).unwrap();
1141        graph.add_node(file_node("c", "c.rs")).unwrap();
1142        graph.add_edge(test_edge("a", "b")).unwrap();
1143        graph.add_edge(test_edge("b", "c")).unwrap();
1144
1145        graph.compute_centrality();
1146
1147        let a = graph.get_node("a").unwrap().unwrap();
1148        let b = graph.get_node("b").unwrap().unwrap();
1149        let c = graph.get_node("c").unwrap().unwrap();
1150
1151        assert!((a.centrality - 0.5).abs() < f64::EPSILON);
1152        assert!((b.centrality - 1.0).abs() < f64::EPSILON);
1153        assert!((c.centrality - 0.5).abs() < f64::EPSILON);
1154    }
1155
1156    #[test]
1157    fn compute_centrality_star() {
1158        // Graph: a -> b, a -> c, a -> d (star topology)
1159        // Node a: out=3, in=0 => centrality = 3/3 = 1.0
1160        // Node b: out=0, in=1 => centrality = 1/3
1161        // Node c: out=0, in=1 => centrality = 1/3
1162        // Node d: out=0, in=1 => centrality = 1/3
1163        let mut graph = GraphEngine::new();
1164        graph.add_node(file_node("a", "a.rs")).unwrap();
1165        graph.add_node(file_node("b", "b.rs")).unwrap();
1166        graph.add_node(file_node("c", "c.rs")).unwrap();
1167        graph.add_node(file_node("d", "d.rs")).unwrap();
1168        graph.add_edge(test_edge("a", "b")).unwrap();
1169        graph.add_edge(test_edge("a", "c")).unwrap();
1170        graph.add_edge(test_edge("a", "d")).unwrap();
1171
1172        graph.compute_centrality();
1173
1174        let a = graph.get_node("a").unwrap().unwrap();
1175        let b = graph.get_node("b").unwrap().unwrap();
1176
1177        assert!((a.centrality - 1.0).abs() < f64::EPSILON);
1178        assert!((b.centrality - 1.0 / 3.0).abs() < f64::EPSILON);
1179    }
1180
1181    #[test]
1182    fn compute_centrality_single_node() {
1183        let mut graph = GraphEngine::new();
1184        graph.add_node(file_node("a", "a.rs")).unwrap();
1185
1186        graph.compute_centrality();
1187
1188        let a = graph.get_node("a").unwrap().unwrap();
1189        assert!((a.centrality - 0.0).abs() < f64::EPSILON);
1190    }
1191
1192    #[test]
1193    fn compute_centrality_no_edges() {
1194        let mut graph = GraphEngine::new();
1195        graph.add_node(file_node("a", "a.rs")).unwrap();
1196        graph.add_node(file_node("b", "b.rs")).unwrap();
1197
1198        graph.compute_centrality();
1199
1200        let a = graph.get_node("a").unwrap().unwrap();
1201        let b = graph.get_node("b").unwrap().unwrap();
1202        assert!((a.centrality - 0.0).abs() < f64::EPSILON);
1203        assert!((b.centrality - 0.0).abs() < f64::EPSILON);
1204    }
1205
1206    #[test]
1207    fn get_all_nodes_returns_all() {
1208        let mut graph = GraphEngine::new();
1209        graph.add_node(file_node("a", "a.rs")).unwrap();
1210        graph.add_node(file_node("b", "b.rs")).unwrap();
1211        graph.add_node(file_node("c", "c.rs")).unwrap();
1212
1213        let mut all = graph.get_all_nodes();
1214        all.sort_by(|x, y| x.id.cmp(&y.id));
1215        assert_eq!(all.len(), 3);
1216        assert_eq!(all[0].id, "a");
1217        assert_eq!(all[1].id, "b");
1218        assert_eq!(all[2].id, "c");
1219    }
1220    // ── PageRank Tests ──────────────────────────────────────────────────────
1221
1222    #[test]
1223    fn pagerank_chain() {
1224        // a -> b -> c
1225        // c is a sink (dangling node) that redistributes rank uniformly.
1226        // Rank flows a -> b -> c, with c accumulating the most. Order: c > b > a.
1227        let mut graph = GraphEngine::new();
1228        graph.add_node(file_node("a", "a.rs")).unwrap();
1229        graph.add_node(file_node("b", "b.rs")).unwrap();
1230        graph.add_node(file_node("c", "c.rs")).unwrap();
1231        graph.add_edge(test_edge("a", "b")).unwrap();
1232        graph.add_edge(test_edge("b", "c")).unwrap();
1233
1234        let ranks = graph.pagerank(0.85, 100, 1e-6);
1235        assert_eq!(ranks.len(), 3);
1236        assert!(
1237            ranks["c"] > ranks["b"],
1238            "c ({}) should rank higher than b ({})",
1239            ranks["c"],
1240            ranks["b"]
1241        );
1242        assert!(
1243            ranks["b"] > ranks["a"],
1244            "b ({}) should rank higher than a ({})",
1245            ranks["b"],
1246            ranks["a"]
1247        );
1248    }
1249
1250    #[test]
1251    fn pagerank_star() {
1252        // a -> b, a -> c, a -> d
1253        // b, c, d are dangling nodes that redistribute rank uniformly.
1254        // They each receive direct rank from a, plus redistribution.
1255        // a only receives redistributed rank from the dangling nodes.
1256        // So each leaf should rank higher than the hub.
1257        let mut graph = GraphEngine::new();
1258        graph.add_node(file_node("a", "a.rs")).unwrap();
1259        graph.add_node(file_node("b", "b.rs")).unwrap();
1260        graph.add_node(file_node("c", "c.rs")).unwrap();
1261        graph.add_node(file_node("d", "d.rs")).unwrap();
1262        graph.add_edge(test_edge("a", "b")).unwrap();
1263        graph.add_edge(test_edge("a", "c")).unwrap();
1264        graph.add_edge(test_edge("a", "d")).unwrap();
1265
1266        let ranks = graph.pagerank(0.85, 100, 1e-6);
1267        assert_eq!(ranks.len(), 4);
1268        // Leaves get direct rank from a AND redistribute back uniformly.
1269        // b, c, d should be approximately equal and each higher than a.
1270        assert!(
1271            ranks["b"] > ranks["a"],
1272            "b ({}) should rank higher than a ({})",
1273            ranks["b"],
1274            ranks["a"]
1275        );
1276        // b, c, d should be approximately equal
1277        assert!(
1278            (ranks["b"] - ranks["c"]).abs() < 0.01,
1279            "b ({}) and c ({}) should be approximately equal",
1280            ranks["b"],
1281            ranks["c"]
1282        );
1283    }
1284
1285    #[test]
1286    fn pagerank_empty_graph() {
1287        let graph = GraphEngine::new();
1288        let ranks = graph.pagerank(0.85, 100, 1e-6);
1289        assert!(ranks.is_empty());
1290    }
1291
1292    #[test]
1293    fn pagerank_single_node() {
1294        let mut graph = GraphEngine::new();
1295        graph.add_node(file_node("a", "a.rs")).unwrap();
1296
1297        let ranks = graph.pagerank(0.85, 100, 1e-6);
1298        assert_eq!(ranks.len(), 1);
1299        assert!((ranks["a"] - 1.0).abs() < 0.01);
1300    }
1301
1302    // ── Personalized PageRank Tests ─────────────────────────────────────────
1303
1304    #[test]
1305    fn personalized_pagerank_cycle_seed_c() {
1306        // a -> b -> c -> a (cycle)
1307        // Seed on c: c and its neighbors should rank highest
1308        let mut graph = GraphEngine::new();
1309        graph.add_node(file_node("a", "a.rs")).unwrap();
1310        graph.add_node(file_node("b", "b.rs")).unwrap();
1311        graph.add_node(file_node("c", "c.rs")).unwrap();
1312        graph.add_edge(test_edge("a", "b")).unwrap();
1313        graph.add_edge(test_edge("b", "c")).unwrap();
1314        graph.add_edge(test_edge("c", "a")).unwrap();
1315
1316        let mut seeds = HashMap::new();
1317        seeds.insert("c".to_string(), 1.0);
1318
1319        let ranks = graph.personalized_pagerank(&seeds, 0.85, 100, 1e-6);
1320        assert_eq!(ranks.len(), 3);
1321        // c should have highest rank (it's the seed and receives teleport)
1322        // a is c's out-neighbor so it should be next
1323        assert!(
1324            ranks["c"] > ranks["b"],
1325            "c ({}) should rank higher than b ({})",
1326            ranks["c"],
1327            ranks["b"]
1328        );
1329        assert!(
1330            ranks["a"] > ranks["b"],
1331            "a ({}) should rank higher than b ({}) since c->a",
1332            ranks["a"],
1333            ranks["b"]
1334        );
1335    }
1336
1337    #[test]
1338    fn personalized_pagerank_empty_seeds() {
1339        // With no seeds, should fall back to uniform (same as regular pagerank)
1340        let mut graph = GraphEngine::new();
1341        graph.add_node(file_node("a", "a.rs")).unwrap();
1342        graph.add_node(file_node("b", "b.rs")).unwrap();
1343        graph.add_edge(test_edge("a", "b")).unwrap();
1344
1345        let seeds = HashMap::new();
1346        let ppr = graph.personalized_pagerank(&seeds, 0.85, 100, 1e-6);
1347        let pr = graph.pagerank(0.85, 100, 1e-6);
1348
1349        // Should be approximately equal
1350        assert!((ppr["a"] - pr["a"]).abs() < 0.01);
1351        assert!((ppr["b"] - pr["b"]).abs() < 0.01);
1352    }
1353
1354    // ── Louvain Community Detection Tests ───────────────────────────────────
1355
1356    #[test]
1357    fn louvain_two_disconnected_cliques() {
1358        // Clique 1: a <-> b <-> c <-> a
1359        // Clique 2: d <-> e <-> f <-> d
1360        let mut graph = GraphEngine::new();
1361        for id in &["a", "b", "c", "d", "e", "f"] {
1362            graph.add_node(file_node(id, &format!("{id}.rs"))).unwrap();
1363        }
1364        // Clique 1
1365        graph.add_edge(test_edge("a", "b")).unwrap();
1366        graph.add_edge(test_edge("b", "a")).unwrap();
1367        graph.add_edge(test_edge("b", "c")).unwrap();
1368        graph.add_edge(test_edge("c", "b")).unwrap();
1369        graph.add_edge(test_edge("a", "c")).unwrap();
1370        graph.add_edge(test_edge("c", "a")).unwrap();
1371        // Clique 2
1372        graph.add_edge(test_edge("d", "e")).unwrap();
1373        graph.add_edge(test_edge("e", "d")).unwrap();
1374        graph.add_edge(test_edge("e", "f")).unwrap();
1375        graph.add_edge(test_edge("f", "e")).unwrap();
1376        graph.add_edge(test_edge("d", "f")).unwrap();
1377        graph.add_edge(test_edge("f", "d")).unwrap();
1378
1379        let communities = graph.louvain_communities(1.0);
1380        assert_eq!(
1381            communities.len(),
1382            2,
1383            "Expected 2 communities, got {}: {:?}",
1384            communities.len(),
1385            communities
1386        );
1387        // Each community should have 3 nodes
1388        assert_eq!(communities[0].len(), 3);
1389        assert_eq!(communities[1].len(), 3);
1390        // Check that each clique is in a separate community
1391        let comm0_set: HashSet<&str> = communities[0].iter().map(|s| s.as_str()).collect();
1392        let has_abc = comm0_set.contains("a") && comm0_set.contains("b") && comm0_set.contains("c");
1393        let has_def = comm0_set.contains("d") && comm0_set.contains("e") && comm0_set.contains("f");
1394        assert!(
1395            has_abc || has_def,
1396            "First community should be one of the cliques: {:?}",
1397            communities[0]
1398        );
1399    }
1400
1401    #[test]
1402    fn louvain_empty_graph() {
1403        let graph = GraphEngine::new();
1404        let communities = graph.louvain_communities(1.0);
1405        assert!(communities.is_empty());
1406    }
1407
1408    #[test]
1409    fn louvain_single_node() {
1410        let mut graph = GraphEngine::new();
1411        graph.add_node(file_node("a", "a.rs")).unwrap();
1412        let communities = graph.louvain_communities(1.0);
1413        assert_eq!(communities.len(), 1);
1414        assert_eq!(communities[0], vec!["a"]);
1415    }
1416
1417    // ── Betweenness Centrality Tests ────────────────────────────────────────
1418
1419    #[test]
1420    fn betweenness_chain_middle_highest() {
1421        // a -> b -> c
1422        // b is on the shortest path from a to c, so it should have highest betweenness
1423        let mut graph = GraphEngine::new();
1424        graph.add_node(file_node("a", "a.rs")).unwrap();
1425        graph.add_node(file_node("b", "b.rs")).unwrap();
1426        graph.add_node(file_node("c", "c.rs")).unwrap();
1427        graph.add_edge(test_edge("a", "b")).unwrap();
1428        graph.add_edge(test_edge("b", "c")).unwrap();
1429
1430        let bc = graph.betweenness_centrality();
1431        assert_eq!(bc.len(), 3);
1432        assert!(
1433            bc["b"] > bc["a"],
1434            "b ({}) should have higher betweenness than a ({})",
1435            bc["b"],
1436            bc["a"]
1437        );
1438        assert!(
1439            bc["b"] > bc["c"],
1440            "b ({}) should have higher betweenness than c ({})",
1441            bc["b"],
1442            bc["c"]
1443        );
1444        // a and c should have 0 betweenness (they are endpoints)
1445        assert!(
1446            bc["a"].abs() < f64::EPSILON,
1447            "a should have 0 betweenness, got {}",
1448            bc["a"]
1449        );
1450        assert!(
1451            bc["c"].abs() < f64::EPSILON,
1452            "c should have 0 betweenness, got {}",
1453            bc["c"]
1454        );
1455    }
1456
1457    #[test]
1458    fn betweenness_empty_graph() {
1459        let graph = GraphEngine::new();
1460        let bc = graph.betweenness_centrality();
1461        assert!(bc.is_empty());
1462    }
1463
1464    #[test]
1465    fn betweenness_two_nodes() {
1466        let mut graph = GraphEngine::new();
1467        graph.add_node(file_node("a", "a.rs")).unwrap();
1468        graph.add_node(file_node("b", "b.rs")).unwrap();
1469        graph.add_edge(test_edge("a", "b")).unwrap();
1470
1471        let bc = graph.betweenness_centrality();
1472        assert_eq!(bc.len(), 2);
1473        assert!((bc["a"]).abs() < f64::EPSILON);
1474        assert!((bc["b"]).abs() < f64::EPSILON);
1475    }
1476
1477    // ── Strongly Connected Components Tests ─────────────────────────────────
1478
1479    #[test]
1480    fn scc_cycle_all_in_one() {
1481        // a -> b -> c -> a: all three should be in one SCC
1482        let mut graph = GraphEngine::new();
1483        graph.add_node(file_node("a", "a.rs")).unwrap();
1484        graph.add_node(file_node("b", "b.rs")).unwrap();
1485        graph.add_node(file_node("c", "c.rs")).unwrap();
1486        graph.add_edge(test_edge("a", "b")).unwrap();
1487        graph.add_edge(test_edge("b", "c")).unwrap();
1488        graph.add_edge(test_edge("c", "a")).unwrap();
1489
1490        let sccs = graph.strongly_connected_components();
1491        assert_eq!(
1492            sccs.len(),
1493            1,
1494            "Expected 1 SCC, got {}: {:?}",
1495            sccs.len(),
1496            sccs
1497        );
1498        assert_eq!(sccs[0], vec!["a", "b", "c"]);
1499    }
1500
1501    #[test]
1502    fn scc_chain_each_separate() {
1503        // a -> b -> c: no cycles, each node is its own SCC
1504        let mut graph = GraphEngine::new();
1505        graph.add_node(file_node("a", "a.rs")).unwrap();
1506        graph.add_node(file_node("b", "b.rs")).unwrap();
1507        graph.add_node(file_node("c", "c.rs")).unwrap();
1508        graph.add_edge(test_edge("a", "b")).unwrap();
1509        graph.add_edge(test_edge("b", "c")).unwrap();
1510
1511        let sccs = graph.strongly_connected_components();
1512        assert_eq!(
1513            sccs.len(),
1514            3,
1515            "Expected 3 SCCs, got {}: {:?}",
1516            sccs.len(),
1517            sccs
1518        );
1519    }
1520
1521    #[test]
1522    fn scc_empty_graph() {
1523        let graph = GraphEngine::new();
1524        let sccs = graph.strongly_connected_components();
1525        assert!(sccs.is_empty());
1526    }
1527
1528    // ── Topological Sort Tests ──────────────────────────────────────────────
1529
1530    #[test]
1531    fn topological_layers_dag() {
1532        // a -> b, a -> c, b -> d, c -> d
1533        // Layer 0: [a], Layer 1: [b, c], Layer 2: [d]
1534        let mut graph = GraphEngine::new();
1535        graph.add_node(file_node("a", "a.rs")).unwrap();
1536        graph.add_node(file_node("b", "b.rs")).unwrap();
1537        graph.add_node(file_node("c", "c.rs")).unwrap();
1538        graph.add_node(file_node("d", "d.rs")).unwrap();
1539        graph.add_edge(test_edge("a", "b")).unwrap();
1540        graph.add_edge(test_edge("a", "c")).unwrap();
1541        graph.add_edge(test_edge("b", "d")).unwrap();
1542        graph.add_edge(test_edge("c", "d")).unwrap();
1543
1544        let layers = graph.topological_layers();
1545        assert_eq!(
1546            layers.len(),
1547            3,
1548            "Expected 3 layers, got {}: {:?}",
1549            layers.len(),
1550            layers
1551        );
1552        assert_eq!(layers[0], vec!["a"]);
1553        assert_eq!(layers[1], vec!["b", "c"]); // sorted within layer
1554        assert_eq!(layers[2], vec!["d"]);
1555    }
1556
1557    #[test]
1558    fn topological_layers_with_cycle() {
1559        // a -> b -> c -> b (cycle between b and c), a -> d
1560        // SCCs: {a}, {b, c}, {d}
1561        // After condensation: {a} -> {b,c} and {a} -> {d}
1562        // Layer 0: [a], Layer 1: [b, c, d] (b and c condensed, d also depends on a)
1563        let mut graph = GraphEngine::new();
1564        graph.add_node(file_node("a", "a.rs")).unwrap();
1565        graph.add_node(file_node("b", "b.rs")).unwrap();
1566        graph.add_node(file_node("c", "c.rs")).unwrap();
1567        graph.add_node(file_node("d", "d.rs")).unwrap();
1568        graph.add_edge(test_edge("a", "b")).unwrap();
1569        graph.add_edge(test_edge("b", "c")).unwrap();
1570        graph.add_edge(test_edge("c", "b")).unwrap();
1571        graph.add_edge(test_edge("a", "d")).unwrap();
1572
1573        let layers = graph.topological_layers();
1574        assert_eq!(
1575            layers.len(),
1576            2,
1577            "Expected 2 layers, got {}: {:?}",
1578            layers.len(),
1579            layers
1580        );
1581        assert_eq!(layers[0], vec!["a"]);
1582        // Layer 1 should contain b, c (from the cycle SCC) and d
1583        assert!(layers[1].contains(&"b".to_string()));
1584        assert!(layers[1].contains(&"c".to_string()));
1585        assert!(layers[1].contains(&"d".to_string()));
1586    }
1587
1588    #[test]
1589    fn topological_layers_empty_graph() {
1590        let graph = GraphEngine::new();
1591        let layers = graph.topological_layers();
1592        assert!(layers.is_empty());
1593    }
1594
1595    #[test]
1596    fn topological_layers_single_node() {
1597        let mut graph = GraphEngine::new();
1598        graph.add_node(file_node("a", "a.rs")).unwrap();
1599        let layers = graph.topological_layers();
1600        assert_eq!(layers.len(), 1);
1601        assert_eq!(layers[0], vec!["a"]);
1602    }
1603
1604    // ── Centrality Caching Tests ────────────────────────────────────────────
1605
1606    #[test]
1607    fn recompute_centrality_caches_pagerank() {
1608        let mut graph = GraphEngine::new();
1609        graph.add_node(file_node("a", "a.rs")).unwrap();
1610        graph.add_node(file_node("b", "b.rs")).unwrap();
1611        graph.add_node(file_node("c", "c.rs")).unwrap();
1612        graph.add_edge(test_edge("a", "b")).unwrap();
1613        graph.add_edge(test_edge("b", "c")).unwrap();
1614
1615        // Before recompute, cached values should be 0.0
1616        assert_eq!(graph.get_pagerank("a"), 0.0);
1617        assert_eq!(graph.get_betweenness("a"), 0.0);
1618
1619        graph.recompute_centrality();
1620
1621        // After recompute, cached PageRank values should be non-zero
1622        assert!(graph.get_pagerank("a") > 0.0);
1623        assert!(graph.get_pagerank("b") > 0.0);
1624        assert!(graph.get_pagerank("c") > 0.0);
1625
1626        // c should have highest PageRank (sink node in a -> b -> c)
1627        assert!(
1628            graph.get_pagerank("c") > graph.get_pagerank("a"),
1629            "c ({}) should have higher PageRank than a ({})",
1630            graph.get_pagerank("c"),
1631            graph.get_pagerank("a")
1632        );
1633
1634        // b should have highest betweenness (middle of chain)
1635        assert!(
1636            graph.get_betweenness("b") > graph.get_betweenness("a"),
1637            "b ({}) should have higher betweenness than a ({})",
1638            graph.get_betweenness("b"),
1639            graph.get_betweenness("a")
1640        );
1641    }
1642
1643    #[test]
1644    fn get_pagerank_returns_zero_for_unknown_node() {
1645        let graph = GraphEngine::new();
1646        assert_eq!(graph.get_pagerank("nonexistent"), 0.0);
1647        assert_eq!(graph.get_betweenness("nonexistent"), 0.0);
1648    }
1649
1650    #[test]
1651    fn max_degree_returns_correct_value() {
1652        let mut graph = GraphEngine::new();
1653        graph.add_node(file_node("a", "a.rs")).unwrap();
1654        graph.add_node(file_node("b", "b.rs")).unwrap();
1655        graph.add_node(file_node("c", "c.rs")).unwrap();
1656        graph.add_node(file_node("d", "d.rs")).unwrap();
1657        // a -> b, a -> c, a -> d (star: a has degree 3)
1658        graph.add_edge(test_edge("a", "b")).unwrap();
1659        graph.add_edge(test_edge("a", "c")).unwrap();
1660        graph.add_edge(test_edge("a", "d")).unwrap();
1661
1662        assert!((graph.max_degree() - 3.0).abs() < f64::EPSILON);
1663    }
1664
1665    #[test]
1666    fn enhanced_graph_strength_differs_from_simple_edge_count() {
1667        // Build a graph where PageRank/betweenness differ from simple edge count.
1668        // a -> b -> c -> d (chain)
1669        // b is in the middle with betweenness, c gets more PageRank flow
1670        let mut graph = GraphEngine::new();
1671        for id in &["a", "b", "c", "d"] {
1672            graph.add_node(file_node(id, &format!("{id}.rs"))).unwrap();
1673        }
1674        graph.add_edge(test_edge("a", "b")).unwrap();
1675        graph.add_edge(test_edge("b", "c")).unwrap();
1676        graph.add_edge(test_edge("c", "d")).unwrap();
1677        graph.recompute_centrality();
1678
1679        // Nodes b and c both have 2 edges (in+out), but different centrality profiles.
1680        // b has higher betweenness (on path a->c and a->d).
1681        // Simple edge count would give them equal scores.
1682        let edges_b = graph.get_edges("b").unwrap().len();
1683        let edges_c = graph.get_edges("c").unwrap().len();
1684        assert_eq!(edges_b, edges_c, "b and c should have same edge count");
1685
1686        // But their centrality profiles should differ
1687        let pr_b = graph.get_pagerank("b");
1688        let pr_c = graph.get_pagerank("c");
1689        let bt_b = graph.get_betweenness("b");
1690        let bt_c = graph.get_betweenness("c");
1691
1692        // At least one centrality metric should differ between b and c
1693        let centrality_differs = (pr_b - pr_c).abs() > 1e-6 || (bt_b - bt_c).abs() > 1e-6;
1694        assert!(
1695            centrality_differs,
1696            "Centrality should differ: b(pr={pr_b}, bt={bt_b}) vs c(pr={pr_c}, bt={bt_c})"
1697        );
1698    }
1699}