Skip to main content

codemem_storage/graph/
algorithms.rs

1use super::GraphEngine;
2use codemem_core::{Edge, GraphNode, NodeKind};
3use petgraph::graph::NodeIndex;
4use petgraph::Direction;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7impl GraphEngine {
8    /// Compute PageRank scores for all nodes using power iteration.
9    ///
10    /// - `damping`: probability of following an edge (default 0.85)
11    /// - `iterations`: max number of power iterations (default 100)
12    /// - `tolerance`: convergence threshold (default 1e-6)
13    ///
14    /// Returns a map from node ID to PageRank score.
15    pub fn pagerank(
16        &self,
17        damping: f64,
18        iterations: usize,
19        tolerance: f64,
20    ) -> HashMap<String, f64> {
21        let n = self.graph.node_count();
22        if n == 0 {
23            return HashMap::new();
24        }
25
26        let nf = n as f64;
27        let initial = 1.0 / nf;
28
29        // Collect all node indices in a stable order
30        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
31        let idx_pos: HashMap<NodeIndex, usize> = indices
32            .iter()
33            .enumerate()
34            .map(|(i, &idx)| (idx, i))
35            .collect();
36
37        let mut scores = vec![initial; n];
38
39        // Precompute out-degrees
40        let out_degree: Vec<usize> = indices
41            .iter()
42            .map(|&idx| {
43                self.graph
44                    .neighbors_directed(idx, Direction::Outgoing)
45                    .count()
46            })
47            .collect();
48
49        for _ in 0..iterations {
50            let mut new_scores = vec![(1.0 - damping) / nf; n];
51
52            // Distribute rank from each node to its out-neighbors
53            for (i, &idx) in indices.iter().enumerate() {
54                let deg = out_degree[i];
55                if deg == 0 {
56                    // Dangling node: distribute evenly to all nodes
57                    let share = damping * scores[i] / nf;
58                    for ns in new_scores.iter_mut() {
59                        *ns += share;
60                    }
61                } else {
62                    let share = damping * scores[i] / deg as f64;
63                    for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
64                        if let Some(&pos) = idx_pos.get(&neighbor) {
65                            new_scores[pos] += share;
66                        }
67                    }
68                }
69            }
70
71            // Check convergence
72            let diff: f64 = scores
73                .iter()
74                .zip(new_scores.iter())
75                .map(|(a, b)| (a - b).abs())
76                .sum();
77
78            scores = new_scores;
79
80            if diff < tolerance {
81                break;
82            }
83        }
84
85        // Map back to node IDs
86        indices
87            .iter()
88            .enumerate()
89            .filter_map(|(i, &idx)| {
90                self.graph
91                    .node_weight(idx)
92                    .map(|id| (id.clone(), scores[i]))
93            })
94            .collect()
95    }
96
97    /// Compute Personalized PageRank with custom teleport weights.
98    ///
99    /// `seed_weights` maps node IDs to teleport probabilities (will be normalized).
100    /// Nodes not in seed_weights get zero teleport probability.
101    ///
102    /// Used for blast-radius analysis and HippoRAG-2-style retrieval.
103    #[cfg(test)]
104    pub fn personalized_pagerank(
105        &self,
106        seed_weights: &HashMap<String, f64>,
107        damping: f64,
108        iterations: usize,
109        tolerance: f64,
110    ) -> HashMap<String, f64> {
111        let n = self.graph.node_count();
112        if n == 0 {
113            return HashMap::new();
114        }
115
116        let nf = n as f64;
117
118        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
119        let idx_pos: HashMap<NodeIndex, usize> = indices
120            .iter()
121            .enumerate()
122            .map(|(i, &idx)| (idx, i))
123            .collect();
124
125        // Build and normalize the teleport vector
126        let mut teleport = vec![0.0f64; n];
127        let mut teleport_sum = 0.0;
128        for (i, &idx) in indices.iter().enumerate() {
129            if let Some(node_id) = self.graph.node_weight(idx) {
130                if let Some(&w) = seed_weights.get(node_id) {
131                    teleport[i] = w;
132                    teleport_sum += w;
133                }
134            }
135        }
136        // Normalize; if no seeds provided, fall back to uniform
137        if teleport_sum > 0.0 {
138            for t in teleport.iter_mut() {
139                *t /= teleport_sum;
140            }
141        } else {
142            for t in teleport.iter_mut() {
143                *t = 1.0 / nf;
144            }
145        }
146
147        let initial = 1.0 / nf;
148        let mut scores = vec![initial; n];
149
150        let out_degree: Vec<usize> = indices
151            .iter()
152            .map(|&idx| {
153                self.graph
154                    .neighbors_directed(idx, Direction::Outgoing)
155                    .count()
156            })
157            .collect();
158
159        for _ in 0..iterations {
160            let mut new_scores: Vec<f64> = teleport.iter().map(|&t| (1.0 - damping) * t).collect();
161
162            for (i, &idx) in indices.iter().enumerate() {
163                let deg = out_degree[i];
164                if deg == 0 {
165                    // Dangling node: distribute to teleport targets
166                    let share = damping * scores[i];
167                    for (j, t) in teleport.iter().enumerate() {
168                        new_scores[j] += share * t;
169                    }
170                } else {
171                    let share = damping * scores[i] / deg as f64;
172                    for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
173                        if let Some(&pos) = idx_pos.get(&neighbor) {
174                            new_scores[pos] += share;
175                        }
176                    }
177                }
178            }
179
180            let diff: f64 = scores
181                .iter()
182                .zip(new_scores.iter())
183                .map(|(a, b)| (a - b).abs())
184                .sum();
185
186            scores = new_scores;
187
188            if diff < tolerance {
189                break;
190            }
191        }
192
193        indices
194            .iter()
195            .enumerate()
196            .filter_map(|(i, &idx)| {
197                self.graph
198                    .node_weight(idx)
199                    .map(|id| (id.clone(), scores[i]))
200            })
201            .collect()
202    }
203
204    /// Detect communities using the Louvain algorithm.
205    ///
206    /// Treats the directed graph as undirected for modularity computation.
207    /// `resolution` controls community granularity (1.0 = standard modularity).
208    /// Returns groups of node IDs, one group per community.
209    pub fn louvain_communities(&self, resolution: f64) -> Vec<Vec<String>> {
210        let n = self.graph.node_count();
211        if n == 0 {
212            return Vec::new();
213        }
214
215        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
216        let idx_pos: HashMap<NodeIndex, usize> = indices
217            .iter()
218            .enumerate()
219            .map(|(i, &idx)| (idx, i))
220            .collect();
221
222        // Build undirected adjacency with weights.
223        // Deduplicate bidirectional edges: for A->B and B->A, merge into one
224        // undirected edge with combined weight.
225        let mut undirected_weights: HashMap<(usize, usize), f64> = HashMap::new();
226        for edge_ref in self.graph.edge_indices() {
227            if let Some((src_idx, dst_idx)) = self.graph.edge_endpoints(edge_ref) {
228                let w = self.graph[edge_ref];
229                if let (Some(&si), Some(&di)) = (idx_pos.get(&src_idx), idx_pos.get(&dst_idx)) {
230                    let key = if si <= di { (si, di) } else { (di, si) };
231                    *undirected_weights.entry(key).or_insert(0.0) += w;
232                }
233            }
234        }
235
236        let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
237        let mut total_weight = 0.0;
238
239        for (&(si, di), &w) in &undirected_weights {
240            adj[si].push((di, w));
241            if si != di {
242                adj[di].push((si, w));
243            }
244            total_weight += w;
245        }
246
247        if total_weight == 0.0 {
248            // No edges: each node is its own community
249            return indices
250                .iter()
251                .filter_map(|&idx| self.graph.node_weight(idx).map(|id| vec![id.clone()]))
252                .collect();
253        }
254
255        // m = total undirected edge weight
256        let m = total_weight;
257        let m2 = 2.0 * m;
258
259        // Weighted degree of each node (sum of incident undirected edge weights)
260        let k: Vec<f64> = (0..n)
261            .map(|i| adj[i].iter().map(|&(_, w)| w).sum())
262            .collect();
263
264        // Initial assignment: each node in its own community
265        let mut community: Vec<usize> = (0..n).collect();
266
267        // sigma_tot[c] = sum of degrees of nodes in community c.
268        // Maintained incrementally to avoid O(n^2) per pass.
269        let mut sigma_tot: Vec<f64> = k.clone();
270
271        // Iteratively move nodes to improve modularity
272        let mut improved = true;
273        let max_passes = 100;
274        let mut pass = 0;
275
276        while improved && pass < max_passes {
277            improved = false;
278            pass += 1;
279
280            for i in 0..n {
281                let current_comm = community[i];
282                let ki = k[i];
283
284                // Compute weights to each neighboring community
285                let mut comm_weights: HashMap<usize, f64> = HashMap::new();
286                for &(j, w) in &adj[i] {
287                    *comm_weights.entry(community[j]).or_insert(0.0) += w;
288                }
289
290                // Standard Louvain delta-Q formula:
291                // delta_Q = [w_in_new/m - resolution * ki * sigma_new / m2]
292                //         - [w_in_current/m - resolution * ki * (sigma_current - ki) / m2]
293                let w_in_current = comm_weights.get(&current_comm).copied().unwrap_or(0.0);
294                let sigma_current = sigma_tot[current_comm];
295                let remove_cost =
296                    w_in_current / m - resolution * ki * (sigma_current - ki) / (m2 * m);
297
298                // Find best community to move to
299                let mut best_comm = current_comm;
300                let mut best_gain = 0.0;
301
302                for (&comm, &w_in_comm) in &comm_weights {
303                    if comm == current_comm {
304                        continue;
305                    }
306                    let sigma_comm = sigma_tot[comm];
307                    let gain =
308                        w_in_comm / m - resolution * ki * sigma_comm / (m2 * m) - remove_cost;
309                    if gain > best_gain {
310                        best_gain = gain;
311                        best_comm = comm;
312                    }
313                }
314
315                if best_comm != current_comm {
316                    // Update sigma_tot incrementally
317                    sigma_tot[current_comm] -= ki;
318                    sigma_tot[best_comm] += ki;
319                    community[i] = best_comm;
320                    improved = true;
321                }
322            }
323        }
324
325        // Group nodes by community
326        let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
327        for (i, &idx) in indices.iter().enumerate() {
328            if let Some(node_id) = self.graph.node_weight(idx) {
329                groups
330                    .entry(community[i])
331                    .or_default()
332                    .push(node_id.clone());
333            }
334        }
335
336        let mut result: Vec<Vec<String>> = groups.into_values().collect();
337        for group in result.iter_mut() {
338            group.sort();
339        }
340        result.sort();
341        result
342    }
343
344    /// Compute betweenness centrality for all nodes using Brandes' algorithm.
345    ///
346    /// For graphs with more than 1000 nodes, samples sqrt(n) source nodes
347    /// for approximate computation.
348    ///
349    /// Returns a map from node ID to betweenness centrality score (normalized by
350    /// 1/((n-1)(n-2)) for directed graphs).
351    pub fn betweenness_centrality(&self) -> HashMap<String, f64> {
352        let n = self.graph.node_count();
353        if n <= 2 {
354            return self
355                .graph
356                .node_indices()
357                .filter_map(|idx| self.graph.node_weight(idx).map(|id| (id.clone(), 0.0)))
358                .collect();
359        }
360
361        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
362        let idx_pos: HashMap<NodeIndex, usize> = indices
363            .iter()
364            .enumerate()
365            .map(|(i, &idx)| (idx, i))
366            .collect();
367
368        let mut centrality = vec![0.0f64; n];
369
370        // Determine source nodes (sample for large graphs)
371        let sources: Vec<usize> = if n > 1000 {
372            let sample_size = (n as f64).sqrt() as usize;
373            // Deterministic sampling: evenly spaced
374            let step = n / sample_size;
375            (0..sample_size).map(|i| i * step).collect()
376        } else {
377            (0..n).collect()
378        };
379
380        let scale = if n > 1000 {
381            n as f64 / sources.len() as f64
382        } else {
383            1.0
384        };
385
386        for &s in &sources {
387            // Brandes' algorithm from source s
388            let mut stack: Vec<usize> = Vec::new();
389            let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
390            let mut sigma = vec![0.0f64; n]; // number of shortest paths
391            sigma[s] = 1.0;
392            let mut dist: Vec<i64> = vec![-1; n];
393            dist[s] = 0;
394
395            let mut queue: VecDeque<usize> = VecDeque::new();
396            queue.push_back(s);
397
398            while let Some(v) = queue.pop_front() {
399                stack.push(v);
400                let v_idx = indices[v];
401                for neighbor in self.graph.neighbors_directed(v_idx, Direction::Outgoing) {
402                    if let Some(&w) = idx_pos.get(&neighbor) {
403                        if dist[w] < 0 {
404                            dist[w] = dist[v] + 1;
405                            queue.push_back(w);
406                        }
407                        if dist[w] == dist[v] + 1 {
408                            sigma[w] += sigma[v];
409                            predecessors[w].push(v);
410                        }
411                    }
412                }
413            }
414
415            let mut delta = vec![0.0f64; n];
416            while let Some(w) = stack.pop() {
417                for &v in &predecessors[w] {
418                    delta[v] += (sigma[v] / sigma[w]) * (1.0 + delta[w]);
419                }
420                if w != s {
421                    centrality[w] += delta[w];
422                }
423            }
424        }
425
426        // Apply sampling scale and normalize
427        let norm = ((n - 1) * (n - 2)) as f64;
428        indices
429            .iter()
430            .enumerate()
431            .filter_map(|(i, &idx)| {
432                self.graph
433                    .node_weight(idx)
434                    .map(|id| (id.clone(), centrality[i] * scale / norm))
435            })
436            .collect()
437    }
438
439    /// Find all strongly connected components using Tarjan's algorithm.
440    ///
441    /// Returns groups of node IDs. Each group is a strongly connected component
442    /// where every node can reach every other node via directed edges.
443    #[cfg(test)]
444    pub fn strongly_connected_components(&self) -> Vec<Vec<String>> {
445        let sccs = petgraph::algo::tarjan_scc(&self.graph);
446
447        let mut result: Vec<Vec<String>> = sccs
448            .into_iter()
449            .map(|component| {
450                let mut ids: Vec<String> = component
451                    .into_iter()
452                    .filter_map(|idx| self.graph.node_weight(idx).cloned())
453                    .collect();
454                ids.sort();
455                ids
456            })
457            .collect();
458
459        result.sort();
460        result
461    }
462
463    /// Compute topological layers using Kahn's algorithm.
464    ///
465    /// Returns layers where all nodes in layer i have no dependencies on nodes
466    /// in layer i or later. For cyclic graphs, SCCs are condensed into single
467    /// super-nodes first, then the resulting DAG is topologically sorted.
468    ///
469    /// Each inner Vec contains the node IDs at that layer.
470    pub fn topological_layers(&self) -> Vec<Vec<String>> {
471        let n = self.graph.node_count();
472        if n == 0 {
473            return Vec::new();
474        }
475
476        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
477        let idx_pos: HashMap<NodeIndex, usize> = indices
478            .iter()
479            .enumerate()
480            .map(|(i, &idx)| (idx, i))
481            .collect();
482
483        // Step 1: Find SCCs
484        let sccs = petgraph::algo::tarjan_scc(&self.graph);
485
486        // Map each node position to its SCC index
487        let mut node_to_scc = vec![0usize; n];
488        for (scc_idx, scc) in sccs.iter().enumerate() {
489            for &node_idx in scc {
490                if let Some(&pos) = idx_pos.get(&node_idx) {
491                    node_to_scc[pos] = scc_idx;
492                }
493            }
494        }
495
496        let num_sccs = sccs.len();
497
498        // Step 2: Build condensed DAG (SCC graph)
499        let mut condensed_adj: Vec<HashSet<usize>> = vec![HashSet::new(); num_sccs];
500        let mut condensed_in_degree = vec![0usize; num_sccs];
501
502        for &idx in &indices {
503            if let Some(&src_pos) = idx_pos.get(&idx) {
504                let src_scc = node_to_scc[src_pos];
505                for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
506                    if let Some(&dst_pos) = idx_pos.get(&neighbor) {
507                        let dst_scc = node_to_scc[dst_pos];
508                        if src_scc != dst_scc && condensed_adj[src_scc].insert(dst_scc) {
509                            condensed_in_degree[dst_scc] += 1;
510                        }
511                    }
512                }
513            }
514        }
515
516        // Step 3: Kahn's algorithm on the condensed DAG
517        let mut queue: VecDeque<usize> = VecDeque::new();
518        for (i, &deg) in condensed_in_degree.iter().enumerate().take(num_sccs) {
519            if deg == 0 {
520                queue.push_back(i);
521            }
522        }
523
524        let mut scc_layers: Vec<Vec<usize>> = Vec::new();
525        while !queue.is_empty() {
526            let mut layer = Vec::new();
527            let mut next_queue = VecDeque::new();
528
529            while let Some(scc_idx) = queue.pop_front() {
530                layer.push(scc_idx);
531                for &neighbor_scc in &condensed_adj[scc_idx] {
532                    condensed_in_degree[neighbor_scc] -= 1;
533                    if condensed_in_degree[neighbor_scc] == 0 {
534                        next_queue.push_back(neighbor_scc);
535                    }
536                }
537            }
538
539            scc_layers.push(layer);
540            queue = next_queue;
541        }
542
543        // Step 4: Expand SCC layers back to node IDs
544        let mut result: Vec<Vec<String>> = Vec::new();
545        for scc_layer in scc_layers {
546            let mut layer_nodes: Vec<String> = Vec::new();
547            for scc_idx in scc_layer {
548                for &node_idx in &sccs[scc_idx] {
549                    if let Some(id) = self.graph.node_weight(node_idx) {
550                        layer_nodes.push(id.clone());
551                    }
552                }
553            }
554            layer_nodes.sort();
555            result.push(layer_nodes);
556        }
557
558        result
559    }
560
561    /// Return top-N nodes by centrality and edges between them.
562    /// Optionally filter by namespace and/or node kinds.
563    pub fn subgraph_top_n(
564        &self,
565        n: usize,
566        namespace: Option<&str>,
567        kinds: Option<&[NodeKind]>,
568    ) -> (Vec<GraphNode>, Vec<Edge>) {
569        let mut candidates: Vec<&GraphNode> = self
570            .nodes
571            .values()
572            .filter(|node| {
573                if let Some(ns) = namespace {
574                    match &node.namespace {
575                        Some(node_ns) => node_ns == ns,
576                        None => false,
577                    }
578                } else {
579                    true
580                }
581            })
582            .filter(|node| {
583                if let Some(k) = kinds {
584                    k.contains(&node.kind)
585                } else {
586                    true
587                }
588            })
589            .collect();
590
591        // Sort by centrality descending
592        candidates.sort_by(|a, b| {
593            b.centrality
594                .partial_cmp(&a.centrality)
595                .unwrap_or(std::cmp::Ordering::Equal)
596        });
597
598        // Take top N
599        candidates.truncate(n);
600
601        let top_ids: HashSet<&str> = candidates.iter().map(|node| node.id.as_str()).collect();
602        let nodes_vec: Vec<GraphNode> = candidates.into_iter().cloned().collect();
603
604        // Collect edges where both src and dst are in the top-N set
605        let edges_vec: Vec<Edge> = self
606            .edges
607            .values()
608            .filter(|edge| {
609                top_ids.contains(edge.src.as_str()) && top_ids.contains(edge.dst.as_str())
610            })
611            .cloned()
612            .collect();
613
614        (nodes_vec, edges_vec)
615    }
616
617    /// Return node-to-community-ID mapping for Louvain.
618    pub fn louvain_with_assignment(&self, resolution: f64) -> HashMap<String, usize> {
619        let communities = self.louvain_communities(resolution);
620        let mut assignment = HashMap::new();
621        for (idx, community) in communities.into_iter().enumerate() {
622            for node_id in community {
623                assignment.insert(node_id, idx);
624            }
625        }
626        assignment
627    }
628}
629
630#[cfg(test)]
631#[path = "../tests/graph_algorithms_tests.rs"]
632mod tests;