Skip to main content

reddb_server/storage/engine/algorithms/
centrality.rs

1//! Centrality Algorithms
2//!
3//! Centrality measures for identifying important nodes:
4//! - Betweenness Centrality: Nodes on many shortest paths (chokepoints)
5//! - Closeness Centrality: Nodes close to all others (good attack starting points)
6//! - Degree Centrality: Nodes with many connections
7//! - Eigenvector Centrality: Nodes connected to other important nodes
8
9use std::collections::{HashMap, HashSet, VecDeque};
10
11use super::super::graph_store::GraphStore;
12
13// ============================================================================
14// Betweenness Centrality (Brandes' Algorithm)
15// ============================================================================
16
17/// Betweenness centrality computation using Brandes' algorithm
18///
19/// Betweenness centrality measures how often a node lies on shortest paths.
20/// High betweenness nodes are chokepoints - critical for network flow.
21pub struct BetweennessCentrality;
22
23/// Result of betweenness centrality computation
24#[derive(Debug, Clone)]
25pub struct BetweennessResult {
26    /// Node ID → betweenness centrality score
27    pub scores: HashMap<String, f64>,
28    /// Whether scores are normalized (divided by (n-1)(n-2))
29    pub normalized: bool,
30}
31
32impl BetweennessResult {
33    /// Get top N nodes by betweenness centrality
34    pub fn top(&self, n: usize) -> Vec<(String, f64)> {
35        let mut sorted: Vec<_> = self.scores.iter().map(|(k, v)| (k.clone(), *v)).collect();
36        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
37        sorted.truncate(n);
38        sorted
39    }
40
41    /// Get score for a specific node
42    pub fn score(&self, node_id: &str) -> Option<f64> {
43        self.scores.get(node_id).copied()
44    }
45}
46
47impl BetweennessCentrality {
48    /// Compute betweenness centrality for all nodes
49    ///
50    /// Uses Brandes' algorithm: O(V*E) time, O(V) space
51    pub fn compute(graph: &GraphStore, normalize: bool) -> BetweennessResult {
52        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
53        let n = nodes.len();
54
55        if n < 2 {
56            return BetweennessResult {
57                scores: nodes.into_iter().map(|id| (id, 0.0)).collect(),
58                normalized: normalize,
59            };
60        }
61
62        let mut centrality: HashMap<String, f64> =
63            nodes.iter().map(|id| (id.clone(), 0.0)).collect();
64
65        // Brandes' algorithm
66        for source in &nodes {
67            // Single-source shortest paths
68            let mut stack: Vec<String> = Vec::new();
69            let mut predecessors: HashMap<String, Vec<String>> = HashMap::new();
70            let mut sigma: HashMap<String, f64> =
71                nodes.iter().map(|id| (id.clone(), 0.0)).collect();
72            let mut dist: HashMap<String, i64> = nodes.iter().map(|id| (id.clone(), -1)).collect();
73
74            sigma.insert(source.clone(), 1.0);
75            dist.insert(source.clone(), 0);
76
77            let mut queue: VecDeque<String> = VecDeque::new();
78            queue.push_back(source.clone());
79
80            // BFS
81            while let Some(v) = queue.pop_front() {
82                stack.push(v.clone());
83                let v_dist = *dist.get(&v).unwrap_or(&0);
84
85                for (_, w, _) in graph.outgoing_edges(&v) {
86                    // w found for first time?
87                    if *dist.get(&w).unwrap_or(&-1) < 0 {
88                        queue.push_back(w.clone());
89                        dist.insert(w.clone(), v_dist + 1);
90                    }
91
92                    // Shortest path to w via v?
93                    if *dist.get(&w).unwrap_or(&0) == v_dist + 1 {
94                        let sigma_v = *sigma.get(&v).unwrap_or(&0.0);
95                        let sigma_w = sigma.entry(w.clone()).or_insert(0.0);
96                        *sigma_w += sigma_v;
97                        predecessors.entry(w.clone()).or_default().push(v.clone());
98                    }
99                }
100            }
101
102            // Accumulation
103            let mut delta: HashMap<String, f64> =
104                nodes.iter().map(|id| (id.clone(), 0.0)).collect();
105
106            while let Some(w) = stack.pop() {
107                if let Some(preds) = predecessors.get(&w) {
108                    let sigma_w = *sigma.get(&w).unwrap_or(&1.0);
109                    let delta_w = *delta.get(&w).unwrap_or(&0.0);
110
111                    for v in preds {
112                        let sigma_v = *sigma.get(v).unwrap_or(&1.0);
113                        let d = (sigma_v / sigma_w) * (1.0 + delta_w);
114                        *delta.entry(v.clone()).or_insert(0.0) += d;
115                    }
116                }
117
118                if w != *source {
119                    let c = centrality.entry(w.clone()).or_insert(0.0);
120                    *c += *delta.get(&w).unwrap_or(&0.0);
121                }
122            }
123        }
124
125        // Normalize if requested
126        if normalize && n > 2 {
127            let norm_factor = 1.0 / ((n - 1) * (n - 2)) as f64;
128            for score in centrality.values_mut() {
129                *score *= norm_factor;
130            }
131        }
132
133        BetweennessResult {
134            scores: centrality,
135            normalized: normalize,
136        }
137    }
138}
139
140// ============================================================================
141// Closeness Centrality
142// ============================================================================
143
144/// Closeness centrality measures how close a node is to all other nodes
145///
146/// High closeness = can reach all nodes quickly = good attack starting point
147/// Low closeness = isolated, harder to reach
148pub struct ClosenessCentrality;
149
150/// Result of closeness centrality computation
151#[derive(Debug, Clone)]
152pub struct ClosenessResult {
153    /// Node ID → closeness centrality (0 to 1, higher = more central)
154    pub scores: HashMap<String, f64>,
155}
156
157impl ClosenessResult {
158    /// Get top N nodes by closeness centrality
159    pub fn top(&self, n: usize) -> Vec<(String, f64)> {
160        let mut sorted: Vec<_> = self.scores.iter().map(|(k, v)| (k.clone(), *v)).collect();
161        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
162        sorted.truncate(n);
163        sorted
164    }
165}
166
167impl ClosenessCentrality {
168    /// Compute closeness centrality for all nodes
169    ///
170    /// Closeness = (n-1) / sum(shortest_path_distances)
171    /// For disconnected graphs, uses harmonic closeness variant
172    pub fn compute(graph: &GraphStore) -> ClosenessResult {
173        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
174        let n = nodes.len();
175
176        if n <= 1 {
177            return ClosenessResult {
178                scores: nodes.into_iter().map(|id| (id, 1.0)).collect(),
179            };
180        }
181
182        let mut scores: HashMap<String, f64> = HashMap::new();
183
184        for source in &nodes {
185            // BFS to find shortest paths from this node
186            let mut distances: HashMap<String, usize> = HashMap::new();
187            let mut queue: VecDeque<(String, usize)> = VecDeque::new();
188
189            queue.push_back((source.clone(), 0));
190            distances.insert(source.clone(), 0);
191
192            while let Some((current, dist)) = queue.pop_front() {
193                for (_, neighbor, _) in graph.outgoing_edges(&current) {
194                    if !distances.contains_key(&neighbor) {
195                        distances.insert(neighbor.clone(), dist + 1);
196                        queue.push_back((neighbor, dist + 1));
197                    }
198                }
199            }
200
201            // Calculate closeness (harmonic variant for disconnected graphs)
202            let sum_reciprocal: f64 = distances
203                .iter()
204                .filter(|(k, _)| *k != source)
205                .map(|(_, d)| 1.0 / (*d as f64))
206                .sum();
207
208            let closeness = sum_reciprocal / (n - 1) as f64;
209            scores.insert(source.clone(), closeness);
210        }
211
212        ClosenessResult { scores }
213    }
214}
215
216// ============================================================================
217// Degree Centrality
218// ============================================================================
219
220/// Degree centrality measures node importance by connection count
221///
222/// In security analysis:
223/// - High in-degree: Popular target (many paths lead here)
224/// - High out-degree: Key pivot point (can reach many targets)
225pub struct DegreeCentrality;
226
227/// Result of degree centrality computation
228#[derive(Debug, Clone)]
229pub struct DegreeCentralityResult {
230    /// Node ID → in-degree
231    pub in_degree: HashMap<String, usize>,
232    /// Node ID → out-degree
233    pub out_degree: HashMap<String, usize>,
234    /// Node ID → total degree (in + out)
235    pub total_degree: HashMap<String, usize>,
236}
237
238impl DegreeCentralityResult {
239    /// Get nodes sorted by total degree
240    pub fn top_by_total(&self, n: usize) -> Vec<(String, usize)> {
241        let mut sorted: Vec<_> = self
242            .total_degree
243            .iter()
244            .map(|(k, v)| (k.clone(), *v))
245            .collect();
246        sorted.sort_by_key(|b| std::cmp::Reverse(b.1));
247        sorted.truncate(n);
248        sorted
249    }
250
251    /// Get nodes sorted by in-degree
252    pub fn top_by_in_degree(&self, n: usize) -> Vec<(String, usize)> {
253        let mut sorted: Vec<_> = self
254            .in_degree
255            .iter()
256            .map(|(k, v)| (k.clone(), *v))
257            .collect();
258        sorted.sort_by_key(|b| std::cmp::Reverse(b.1));
259        sorted.truncate(n);
260        sorted
261    }
262
263    /// Get nodes sorted by out-degree
264    pub fn top_by_out_degree(&self, n: usize) -> Vec<(String, usize)> {
265        let mut sorted: Vec<_> = self
266            .out_degree
267            .iter()
268            .map(|(k, v)| (k.clone(), *v))
269            .collect();
270        sorted.sort_by_key(|b| std::cmp::Reverse(b.1));
271        sorted.truncate(n);
272        sorted
273    }
274}
275
276impl DegreeCentrality {
277    /// Compute degree centrality for all nodes
278    pub fn compute(graph: &GraphStore) -> DegreeCentralityResult {
279        let mut in_degree: HashMap<String, usize> = HashMap::new();
280        let mut out_degree: HashMap<String, usize> = HashMap::new();
281
282        // Initialize all nodes with 0 degree
283        for node in graph.iter_nodes() {
284            in_degree.insert(node.id.clone(), 0);
285            out_degree.insert(node.id.clone(), 0);
286        }
287
288        // Count degrees
289        for node in graph.iter_nodes() {
290            let out_count = graph.outgoing_edges(&node.id).len();
291            out_degree.insert(node.id.clone(), out_count);
292
293            // Count incoming edges by iterating targets
294            for (_, target, _) in graph.outgoing_edges(&node.id) {
295                *in_degree.entry(target).or_insert(0) += 1;
296            }
297        }
298
299        // Calculate total degree
300        let total_degree: HashMap<String, usize> = in_degree
301            .keys()
302            .map(|k| {
303                let total = in_degree.get(k).unwrap_or(&0) + out_degree.get(k).unwrap_or(&0);
304                (k.clone(), total)
305            })
306            .collect();
307
308        DegreeCentralityResult {
309            in_degree,
310            out_degree,
311            total_degree,
312        }
313    }
314}
315
316// ============================================================================
317// Eigenvector Centrality (Power Iteration)
318// ============================================================================
319
320/// Eigenvector centrality: importance based on neighbor importance
321///
322/// Like PageRank but without damping. A node is important if connected
323/// to other important nodes.
324pub struct EigenvectorCentrality {
325    /// Convergence threshold
326    pub epsilon: f64,
327    /// Maximum iterations
328    pub max_iterations: usize,
329}
330
331impl Default for EigenvectorCentrality {
332    fn default() -> Self {
333        Self {
334            epsilon: 1e-6,
335            max_iterations: 100,
336        }
337    }
338}
339
340/// Result of eigenvector centrality computation
341#[derive(Debug, Clone)]
342pub struct EigenvectorResult {
343    /// Node ID → eigenvector centrality score
344    pub scores: HashMap<String, f64>,
345    /// Number of iterations
346    pub iterations: usize,
347    /// Whether converged
348    pub converged: bool,
349}
350
351impl EigenvectorResult {
352    /// Get top N nodes by eigenvector centrality
353    pub fn top(&self, n: usize) -> Vec<(String, f64)> {
354        let mut sorted: Vec<_> = self.scores.iter().map(|(k, v)| (k.clone(), *v)).collect();
355        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
356        sorted.truncate(n);
357        sorted
358    }
359}
360
361impl EigenvectorCentrality {
362    pub fn new() -> Self {
363        Self::default()
364    }
365
366    /// Compute eigenvector centrality using power iteration
367    pub fn compute(&self, graph: &GraphStore) -> EigenvectorResult {
368        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
369        let n = nodes.len();
370
371        if n == 0 {
372            return EigenvectorResult {
373                scores: HashMap::new(),
374                iterations: 0,
375                converged: true,
376            };
377        }
378
379        // Build adjacency (treat as undirected for eigenvector centrality)
380        let mut neighbors: HashMap<String, Vec<String>> = HashMap::new();
381        for node in &nodes {
382            let mut node_neighbors: HashSet<String> = HashSet::new();
383            for (_, target, _) in graph.outgoing_edges(node) {
384                node_neighbors.insert(target);
385            }
386            for (_, source, _) in graph.incoming_edges(node) {
387                node_neighbors.insert(source);
388            }
389            neighbors.insert(node.clone(), node_neighbors.into_iter().collect());
390        }
391
392        // Initialize scores uniformly
393        let init_score = 1.0 / (n as f64).sqrt();
394        let mut scores: HashMap<String, f64> =
395            nodes.iter().map(|id| (id.clone(), init_score)).collect();
396
397        let mut converged = false;
398        let mut iterations = 0;
399
400        for iter in 0..self.max_iterations {
401            iterations = iter + 1;
402            let mut new_scores: HashMap<String, f64> = HashMap::new();
403
404            // Calculate new scores (sum of neighbor scores)
405            for node in &nodes {
406                let sum: f64 = neighbors
407                    .get(node)
408                    .map(|nbrs| {
409                        nbrs.iter()
410                            .map(|n| scores.get(n).copied().unwrap_or(0.0))
411                            .sum()
412                    })
413                    .unwrap_or(0.0);
414                new_scores.insert(node.clone(), sum);
415            }
416
417            // Normalize (L2 norm)
418            let norm: f64 = new_scores.values().map(|v| v * v).sum::<f64>().sqrt();
419            if norm > 0.0 {
420                for score in new_scores.values_mut() {
421                    *score /= norm;
422                }
423            }
424
425            // Check convergence
426            let diff: f64 = nodes
427                .iter()
428                .map(|id| {
429                    let old = scores.get(id).copied().unwrap_or(0.0);
430                    let new = new_scores.get(id).copied().unwrap_or(0.0);
431                    (old - new).abs()
432                })
433                .sum();
434
435            scores = new_scores;
436
437            if diff < self.epsilon {
438                converged = true;
439                break;
440            }
441        }
442
443        EigenvectorResult {
444            scores,
445            iterations,
446            converged,
447        }
448    }
449}