Skip to main content

codemem_storage/graph/
algorithms.rs

1use super::GraphEngine;
2use codemem_core::{Edge, GraphNode, NodeKind};
3use petgraph::graph::NodeIndex;
4use petgraph::Direction;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7impl GraphEngine {
8    /// Compute PageRank scores for all nodes using power iteration.
9    ///
10    /// - `damping`: probability of following an edge (default 0.85)
11    /// - `iterations`: max number of power iterations (default 100)
12    /// - `tolerance`: convergence threshold (default 1e-6)
13    ///
14    /// Returns a map from node ID to PageRank score.
15    pub fn pagerank(
16        &self,
17        damping: f64,
18        iterations: usize,
19        tolerance: f64,
20    ) -> HashMap<String, f64> {
21        let n = self.graph.node_count();
22        if n == 0 {
23            return HashMap::new();
24        }
25
26        let nf = n as f64;
27        let initial = 1.0 / nf;
28
29        // Collect all node indices in a stable order
30        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
31        let idx_pos: HashMap<NodeIndex, usize> = indices
32            .iter()
33            .enumerate()
34            .map(|(i, &idx)| (idx, i))
35            .collect();
36
37        let mut scores = vec![initial; n];
38
39        // Precompute out-degrees
40        let out_degree: Vec<usize> = indices
41            .iter()
42            .map(|&idx| {
43                self.graph
44                    .neighbors_directed(idx, Direction::Outgoing)
45                    .count()
46            })
47            .collect();
48
49        for _ in 0..iterations {
50            let mut new_scores = vec![(1.0 - damping) / nf; n];
51
52            // Distribute rank from each node to its out-neighbors
53            for (i, &idx) in indices.iter().enumerate() {
54                let deg = out_degree[i];
55                if deg == 0 {
56                    // Dangling node: distribute evenly to all nodes
57                    let share = damping * scores[i] / nf;
58                    for ns in new_scores.iter_mut() {
59                        *ns += share;
60                    }
61                } else {
62                    let share = damping * scores[i] / deg as f64;
63                    for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
64                        if let Some(&pos) = idx_pos.get(&neighbor) {
65                            new_scores[pos] += share;
66                        }
67                    }
68                }
69            }
70
71            // Check convergence
72            let diff: f64 = scores
73                .iter()
74                .zip(new_scores.iter())
75                .map(|(a, b)| (a - b).abs())
76                .sum();
77
78            scores = new_scores;
79
80            if diff < tolerance {
81                break;
82            }
83        }
84
85        // Map back to node IDs
86        indices
87            .iter()
88            .enumerate()
89            .filter_map(|(i, &idx)| {
90                self.graph
91                    .node_weight(idx)
92                    .map(|id| (id.clone(), scores[i]))
93            })
94            .collect()
95    }
96
97    /// Compute Personalized PageRank with custom teleport weights.
98    ///
99    /// `seed_weights` maps node IDs to teleport probabilities (will be normalized).
100    /// Nodes not in seed_weights get zero teleport probability.
101    ///
102    /// Used for blast-radius analysis and HippoRAG-2-style retrieval.
103    pub fn personalized_pagerank(
104        &self,
105        seed_weights: &HashMap<String, f64>,
106        damping: f64,
107        iterations: usize,
108        tolerance: f64,
109    ) -> HashMap<String, f64> {
110        let n = self.graph.node_count();
111        if n == 0 {
112            return HashMap::new();
113        }
114
115        let nf = n as f64;
116
117        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
118        let idx_pos: HashMap<NodeIndex, usize> = indices
119            .iter()
120            .enumerate()
121            .map(|(i, &idx)| (idx, i))
122            .collect();
123
124        // Build and normalize the teleport vector
125        let mut teleport = vec![0.0f64; n];
126        let mut teleport_sum = 0.0;
127        for (i, &idx) in indices.iter().enumerate() {
128            if let Some(node_id) = self.graph.node_weight(idx) {
129                if let Some(&w) = seed_weights.get(node_id) {
130                    teleport[i] = w;
131                    teleport_sum += w;
132                }
133            }
134        }
135        // Normalize; if no seeds provided, fall back to uniform
136        if teleport_sum > 0.0 {
137            for t in teleport.iter_mut() {
138                *t /= teleport_sum;
139            }
140        } else {
141            for t in teleport.iter_mut() {
142                *t = 1.0 / nf;
143            }
144        }
145
146        let initial = 1.0 / nf;
147        let mut scores = vec![initial; n];
148
149        let out_degree: Vec<usize> = indices
150            .iter()
151            .map(|&idx| {
152                self.graph
153                    .neighbors_directed(idx, Direction::Outgoing)
154                    .count()
155            })
156            .collect();
157
158        for _ in 0..iterations {
159            let mut new_scores: Vec<f64> = teleport.iter().map(|&t| (1.0 - damping) * t).collect();
160
161            for (i, &idx) in indices.iter().enumerate() {
162                let deg = out_degree[i];
163                if deg == 0 {
164                    // Dangling node: distribute to teleport targets
165                    let share = damping * scores[i];
166                    for (j, t) in teleport.iter().enumerate() {
167                        new_scores[j] += share * t;
168                    }
169                } else {
170                    let share = damping * scores[i] / deg as f64;
171                    for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
172                        if let Some(&pos) = idx_pos.get(&neighbor) {
173                            new_scores[pos] += share;
174                        }
175                    }
176                }
177            }
178
179            let diff: f64 = scores
180                .iter()
181                .zip(new_scores.iter())
182                .map(|(a, b)| (a - b).abs())
183                .sum();
184
185            scores = new_scores;
186
187            if diff < tolerance {
188                break;
189            }
190        }
191
192        indices
193            .iter()
194            .enumerate()
195            .filter_map(|(i, &idx)| {
196                self.graph
197                    .node_weight(idx)
198                    .map(|id| (id.clone(), scores[i]))
199            })
200            .collect()
201    }
202
203    /// Detect communities using the Louvain algorithm.
204    ///
205    /// Treats the directed graph as undirected for modularity computation.
206    /// `resolution` controls community granularity (1.0 = standard modularity).
207    /// Returns groups of node IDs, one group per community.
208    pub fn louvain_communities(&self, resolution: f64) -> Vec<Vec<String>> {
209        let n = self.graph.node_count();
210        if n == 0 {
211            return Vec::new();
212        }
213
214        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
215        let idx_pos: HashMap<NodeIndex, usize> = indices
216            .iter()
217            .enumerate()
218            .map(|(i, &idx)| (idx, i))
219            .collect();
220
221        // Build undirected adjacency with weights.
222        // Deduplicate bidirectional edges: for A->B and B->A, merge into one
223        // undirected edge with combined weight.
224        let mut undirected_weights: HashMap<(usize, usize), f64> = HashMap::new();
225        for edge_ref in self.graph.edge_indices() {
226            if let Some((src_idx, dst_idx)) = self.graph.edge_endpoints(edge_ref) {
227                let w = self.graph[edge_ref];
228                if let (Some(&si), Some(&di)) = (idx_pos.get(&src_idx), idx_pos.get(&dst_idx)) {
229                    let key = if si <= di { (si, di) } else { (di, si) };
230                    *undirected_weights.entry(key).or_insert(0.0) += w;
231                }
232            }
233        }
234
235        let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
236        let mut total_weight = 0.0;
237
238        for (&(si, di), &w) in &undirected_weights {
239            adj[si].push((di, w));
240            if si != di {
241                adj[di].push((si, w));
242            }
243            total_weight += w;
244        }
245
246        if total_weight == 0.0 {
247            // No edges: each node is its own community
248            return indices
249                .iter()
250                .filter_map(|&idx| self.graph.node_weight(idx).map(|id| vec![id.clone()]))
251                .collect();
252        }
253
254        // m = total undirected edge weight
255        let m = total_weight;
256        let m2 = 2.0 * m;
257
258        // Weighted degree of each node (sum of incident undirected edge weights)
259        let k: Vec<f64> = (0..n)
260            .map(|i| adj[i].iter().map(|&(_, w)| w).sum())
261            .collect();
262
263        // Initial assignment: each node in its own community
264        let mut community: Vec<usize> = (0..n).collect();
265
266        // sigma_tot[c] = sum of degrees of nodes in community c.
267        // Maintained incrementally to avoid O(n^2) per pass.
268        let mut sigma_tot: Vec<f64> = k.clone();
269
270        // Iteratively move nodes to improve modularity
271        let mut improved = true;
272        let max_passes = 100;
273        let mut pass = 0;
274
275        while improved && pass < max_passes {
276            improved = false;
277            pass += 1;
278
279            for i in 0..n {
280                let current_comm = community[i];
281                let ki = k[i];
282
283                // Compute weights to each neighboring community
284                let mut comm_weights: HashMap<usize, f64> = HashMap::new();
285                for &(j, w) in &adj[i] {
286                    *comm_weights.entry(community[j]).or_insert(0.0) += w;
287                }
288
289                // Standard Louvain delta-Q formula:
290                // delta_Q = [w_in_new/m - resolution * ki * sigma_new / m2]
291                //         - [w_in_current/m - resolution * ki * (sigma_current - ki) / m2]
292                let w_in_current = comm_weights.get(&current_comm).copied().unwrap_or(0.0);
293                let sigma_current = sigma_tot[current_comm];
294                let remove_cost =
295                    w_in_current / m - resolution * ki * (sigma_current - ki) / (m2 * m);
296
297                // Find best community to move to
298                let mut best_comm = current_comm;
299                let mut best_gain = 0.0;
300
301                for (&comm, &w_in_comm) in &comm_weights {
302                    if comm == current_comm {
303                        continue;
304                    }
305                    let sigma_comm = sigma_tot[comm];
306                    let gain =
307                        w_in_comm / m - resolution * ki * sigma_comm / (m2 * m) - remove_cost;
308                    if gain > best_gain {
309                        best_gain = gain;
310                        best_comm = comm;
311                    }
312                }
313
314                if best_comm != current_comm {
315                    // Update sigma_tot incrementally
316                    sigma_tot[current_comm] -= ki;
317                    sigma_tot[best_comm] += ki;
318                    community[i] = best_comm;
319                    improved = true;
320                }
321            }
322        }
323
324        // Group nodes by community
325        let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
326        for (i, &idx) in indices.iter().enumerate() {
327            if let Some(node_id) = self.graph.node_weight(idx) {
328                groups
329                    .entry(community[i])
330                    .or_default()
331                    .push(node_id.clone());
332            }
333        }
334
335        let mut result: Vec<Vec<String>> = groups.into_values().collect();
336        for group in result.iter_mut() {
337            group.sort();
338        }
339        result.sort();
340        result
341    }
342
343    /// Compute betweenness centrality for all nodes using Brandes' algorithm.
344    ///
345    /// For graphs with more than 1000 nodes, samples sqrt(n) source nodes
346    /// for approximate computation.
347    ///
348    /// Returns a map from node ID to betweenness centrality score (normalized by
349    /// 1/((n-1)(n-2)) for directed graphs).
350    pub fn betweenness_centrality(&self) -> HashMap<String, f64> {
351        let n = self.graph.node_count();
352        if n <= 2 {
353            return self
354                .graph
355                .node_indices()
356                .filter_map(|idx| self.graph.node_weight(idx).map(|id| (id.clone(), 0.0)))
357                .collect();
358        }
359
360        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
361        let idx_pos: HashMap<NodeIndex, usize> = indices
362            .iter()
363            .enumerate()
364            .map(|(i, &idx)| (idx, i))
365            .collect();
366
367        let mut centrality = vec![0.0f64; n];
368
369        // Determine source nodes (sample for large graphs)
370        let sources: Vec<usize> = if n > 1000 {
371            let sample_size = (n as f64).sqrt() as usize;
372            // Deterministic sampling: evenly spaced
373            let step = n / sample_size;
374            (0..sample_size).map(|i| i * step).collect()
375        } else {
376            (0..n).collect()
377        };
378
379        let scale = if n > 1000 {
380            n as f64 / sources.len() as f64
381        } else {
382            1.0
383        };
384
385        for &s in &sources {
386            // Brandes' algorithm from source s
387            let mut stack: Vec<usize> = Vec::new();
388            let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
389            let mut sigma = vec![0.0f64; n]; // number of shortest paths
390            sigma[s] = 1.0;
391            let mut dist: Vec<i64> = vec![-1; n];
392            dist[s] = 0;
393
394            let mut queue: VecDeque<usize> = VecDeque::new();
395            queue.push_back(s);
396
397            while let Some(v) = queue.pop_front() {
398                stack.push(v);
399                let v_idx = indices[v];
400                for neighbor in self.graph.neighbors_directed(v_idx, Direction::Outgoing) {
401                    if let Some(&w) = idx_pos.get(&neighbor) {
402                        if dist[w] < 0 {
403                            dist[w] = dist[v] + 1;
404                            queue.push_back(w);
405                        }
406                        if dist[w] == dist[v] + 1 {
407                            sigma[w] += sigma[v];
408                            predecessors[w].push(v);
409                        }
410                    }
411                }
412            }
413
414            let mut delta = vec![0.0f64; n];
415            while let Some(w) = stack.pop() {
416                for &v in &predecessors[w] {
417                    delta[v] += (sigma[v] / sigma[w]) * (1.0 + delta[w]);
418                }
419                if w != s {
420                    centrality[w] += delta[w];
421                }
422            }
423        }
424
425        // Apply sampling scale and normalize
426        let norm = ((n - 1) * (n - 2)) as f64;
427        indices
428            .iter()
429            .enumerate()
430            .filter_map(|(i, &idx)| {
431                self.graph
432                    .node_weight(idx)
433                    .map(|id| (id.clone(), centrality[i] * scale / norm))
434            })
435            .collect()
436    }
437
438    /// Find all strongly connected components using Tarjan's algorithm.
439    ///
440    /// Returns groups of node IDs. Each group is a strongly connected component
441    /// where every node can reach every other node via directed edges.
442    pub fn strongly_connected_components(&self) -> Vec<Vec<String>> {
443        let sccs = petgraph::algo::tarjan_scc(&self.graph);
444
445        let mut result: Vec<Vec<String>> = sccs
446            .into_iter()
447            .map(|component| {
448                let mut ids: Vec<String> = component
449                    .into_iter()
450                    .filter_map(|idx| self.graph.node_weight(idx).cloned())
451                    .collect();
452                ids.sort();
453                ids
454            })
455            .collect();
456
457        result.sort();
458        result
459    }
460
461    /// Compute topological layers using Kahn's algorithm.
462    ///
463    /// Returns layers where all nodes in layer i have no dependencies on nodes
464    /// in layer i or later. For cyclic graphs, SCCs are condensed into single
465    /// super-nodes first, then the resulting DAG is topologically sorted.
466    ///
467    /// Each inner Vec contains the node IDs at that layer.
468    pub fn topological_layers(&self) -> Vec<Vec<String>> {
469        let n = self.graph.node_count();
470        if n == 0 {
471            return Vec::new();
472        }
473
474        let indices: Vec<NodeIndex> = self.graph.node_indices().collect();
475        let idx_pos: HashMap<NodeIndex, usize> = indices
476            .iter()
477            .enumerate()
478            .map(|(i, &idx)| (idx, i))
479            .collect();
480
481        // Step 1: Find SCCs
482        let sccs = petgraph::algo::tarjan_scc(&self.graph);
483
484        // Map each node position to its SCC index
485        let mut node_to_scc = vec![0usize; n];
486        for (scc_idx, scc) in sccs.iter().enumerate() {
487            for &node_idx in scc {
488                if let Some(&pos) = idx_pos.get(&node_idx) {
489                    node_to_scc[pos] = scc_idx;
490                }
491            }
492        }
493
494        let num_sccs = sccs.len();
495
496        // Step 2: Build condensed DAG (SCC graph)
497        let mut condensed_adj: Vec<HashSet<usize>> = vec![HashSet::new(); num_sccs];
498        let mut condensed_in_degree = vec![0usize; num_sccs];
499
500        for &idx in &indices {
501            if let Some(&src_pos) = idx_pos.get(&idx) {
502                let src_scc = node_to_scc[src_pos];
503                for neighbor in self.graph.neighbors_directed(idx, Direction::Outgoing) {
504                    if let Some(&dst_pos) = idx_pos.get(&neighbor) {
505                        let dst_scc = node_to_scc[dst_pos];
506                        if src_scc != dst_scc && condensed_adj[src_scc].insert(dst_scc) {
507                            condensed_in_degree[dst_scc] += 1;
508                        }
509                    }
510                }
511            }
512        }
513
514        // Step 3: Kahn's algorithm on the condensed DAG
515        let mut queue: VecDeque<usize> = VecDeque::new();
516        for (i, &deg) in condensed_in_degree.iter().enumerate().take(num_sccs) {
517            if deg == 0 {
518                queue.push_back(i);
519            }
520        }
521
522        let mut scc_layers: Vec<Vec<usize>> = Vec::new();
523        while !queue.is_empty() {
524            let mut layer = Vec::new();
525            let mut next_queue = VecDeque::new();
526
527            while let Some(scc_idx) = queue.pop_front() {
528                layer.push(scc_idx);
529                for &neighbor_scc in &condensed_adj[scc_idx] {
530                    condensed_in_degree[neighbor_scc] -= 1;
531                    if condensed_in_degree[neighbor_scc] == 0 {
532                        next_queue.push_back(neighbor_scc);
533                    }
534                }
535            }
536
537            scc_layers.push(layer);
538            queue = next_queue;
539        }
540
541        // Step 4: Expand SCC layers back to node IDs
542        let mut result: Vec<Vec<String>> = Vec::new();
543        for scc_layer in scc_layers {
544            let mut layer_nodes: Vec<String> = Vec::new();
545            for scc_idx in scc_layer {
546                for &node_idx in &sccs[scc_idx] {
547                    if let Some(id) = self.graph.node_weight(node_idx) {
548                        layer_nodes.push(id.clone());
549                    }
550                }
551            }
552            layer_nodes.sort();
553            result.push(layer_nodes);
554        }
555
556        result
557    }
558
559    /// Return top-N nodes by centrality and edges between them.
560    /// Optionally filter by namespace and/or node kinds.
561    pub fn subgraph_top_n(
562        &self,
563        n: usize,
564        namespace: Option<&str>,
565        kinds: Option<&[NodeKind]>,
566    ) -> (Vec<GraphNode>, Vec<Edge>) {
567        let mut candidates: Vec<&GraphNode> = self
568            .nodes
569            .values()
570            .filter(|node| {
571                if let Some(ns) = namespace {
572                    match &node.namespace {
573                        Some(node_ns) => node_ns == ns,
574                        None => false,
575                    }
576                } else {
577                    true
578                }
579            })
580            .filter(|node| {
581                if let Some(k) = kinds {
582                    k.contains(&node.kind)
583                } else {
584                    true
585                }
586            })
587            .collect();
588
589        // Sort by centrality descending
590        candidates.sort_by(|a, b| {
591            b.centrality
592                .partial_cmp(&a.centrality)
593                .unwrap_or(std::cmp::Ordering::Equal)
594        });
595
596        // Take top N
597        candidates.truncate(n);
598
599        let top_ids: HashSet<&str> = candidates.iter().map(|node| node.id.as_str()).collect();
600        let nodes_vec: Vec<GraphNode> = candidates.into_iter().cloned().collect();
601
602        // Collect edges where both src and dst are in the top-N set
603        let edges_vec: Vec<Edge> = self
604            .edges
605            .values()
606            .filter(|edge| {
607                top_ids.contains(edge.src.as_str()) && top_ids.contains(edge.dst.as_str())
608            })
609            .cloned()
610            .collect();
611
612        (nodes_vec, edges_vec)
613    }
614
615    /// Return node-to-community-ID mapping for Louvain.
616    pub fn louvain_with_assignment(&self, resolution: f64) -> HashMap<String, usize> {
617        let communities = self.louvain_communities(resolution);
618        let mut assignment = HashMap::new();
619        for (idx, community) in communities.into_iter().enumerate() {
620            for node_id in community {
621                assignment.insert(node_id, idx);
622            }
623        }
624        assignment
625    }
626}
627
628#[cfg(test)]
629#[path = "../tests/graph_algorithms_tests.rs"]
630mod tests;