Skip to main content

reddb_server/storage/engine/algorithms/
pagerank.rs

1//! PageRank Algorithms
2//!
3//! PageRank for identifying critical nodes in the graph.
4//! High PageRank nodes are important because many other nodes link to them.
5//!
6//! Includes:
7//! - PageRank: Standard PageRank algorithm
8//! - PersonalizedPageRank: Teleport only to specified seed nodes
9
10use std::collections::HashMap;
11
12use super::super::graph_store::GraphStore;
13
14// ============================================================================
15// PageRank
16// ============================================================================
17
18/// PageRank algorithm for identifying critical nodes
19///
20/// Nodes with high PageRank are "important" because many other nodes link to them.
21/// In attack path analysis, high PageRank nodes are critical chokepoints.
22pub struct PageRank {
23    /// Damping factor (probability of following a link vs teleporting)
24    pub alpha: f64,
25    /// Convergence threshold
26    pub epsilon: f64,
27    /// Maximum iterations
28    pub max_iterations: usize,
29}
30
31impl Default for PageRank {
32    fn default() -> Self {
33        Self {
34            alpha: 0.85,
35            epsilon: 1e-6,
36            max_iterations: 100,
37        }
38    }
39}
40
41/// Result of PageRank computation
42#[derive(Debug, Clone)]
43pub struct PageRankResult {
44    /// Node ID → PageRank score
45    pub scores: HashMap<String, f64>,
46    /// Number of iterations until convergence
47    pub iterations: usize,
48    /// Whether the algorithm converged
49    pub converged: bool,
50}
51
52impl PageRankResult {
53    /// Get top N nodes by PageRank score
54    pub fn top(&self, n: usize) -> Vec<(String, f64)> {
55        let mut sorted: Vec<_> = self.scores.iter().map(|(k, v)| (k.clone(), *v)).collect();
56        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
57        sorted.truncate(n);
58        sorted
59    }
60
61    /// Get score for a specific node
62    pub fn score(&self, node_id: &str) -> Option<f64> {
63        self.scores.get(node_id).copied()
64    }
65}
66
67impl PageRank {
68    /// Create PageRank with default parameters
69    pub fn new() -> Self {
70        Self::default()
71    }
72
73    /// Set damping factor (default: 0.85)
74    pub fn alpha(mut self, alpha: f64) -> Self {
75        self.alpha = alpha.clamp(0.0, 1.0);
76        self
77    }
78
79    /// Set convergence threshold (default: 1e-6)
80    pub fn epsilon(mut self, epsilon: f64) -> Self {
81        self.epsilon = epsilon;
82        self
83    }
84
85    /// Set maximum iterations (default: 100)
86    pub fn max_iterations(mut self, max: usize) -> Self {
87        self.max_iterations = max;
88        self
89    }
90
91    /// Run PageRank on the graph
92    pub fn run(&self, graph: &GraphStore) -> PageRankResult {
93        // Collect all nodes
94        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
95        let n = nodes.len();
96
97        if n == 0 {
98            return PageRankResult {
99                scores: HashMap::new(),
100                iterations: 0,
101                converged: true,
102            };
103        }
104
105        // Build adjacency: node_id → outgoing targets
106        let mut outgoing: HashMap<String, Vec<String>> = HashMap::new();
107        for node_id in &nodes {
108            let edges = graph.outgoing_edges(node_id);
109            let targets: Vec<String> = edges.into_iter().map(|(_, target, _)| target).collect();
110            outgoing.insert(node_id.clone(), targets);
111        }
112
113        // Initialize scores uniformly
114        let initial_score = 1.0 / n as f64;
115        let mut scores: HashMap<String, f64> =
116            nodes.iter().map(|id| (id.clone(), initial_score)).collect();
117
118        let teleport = (1.0 - self.alpha) / n as f64;
119        let mut converged = false;
120        let mut iterations = 0;
121
122        for iter in 0..self.max_iterations {
123            iterations = iter + 1;
124            let mut new_scores: HashMap<String, f64> = HashMap::new();
125
126            // Calculate new scores
127            for node_id in &nodes {
128                let mut score = teleport;
129
130                // Sum contributions from incoming edges
131                for (source_id, targets) in &outgoing {
132                    if targets.contains(node_id) {
133                        let source_score = scores.get(source_id).copied().unwrap_or(0.0);
134                        let out_degree = targets.len() as f64;
135                        if out_degree > 0.0 {
136                            score += self.alpha * source_score / out_degree;
137                        }
138                    }
139                }
140
141                new_scores.insert(node_id.clone(), score);
142            }
143
144            // Handle dangling nodes (no outgoing edges) - distribute their score
145            let dangling_sum: f64 = nodes
146                .iter()
147                .filter(|id| outgoing.get(*id).map(|v| v.is_empty()).unwrap_or(true))
148                .map(|id| scores.get(id).copied().unwrap_or(0.0))
149                .sum();
150
151            let dangling_contrib = self.alpha * dangling_sum / n as f64;
152            for score in new_scores.values_mut() {
153                *score += dangling_contrib;
154            }
155
156            // Check convergence
157            let diff: f64 = nodes
158                .iter()
159                .map(|id| {
160                    let old = scores.get(id).copied().unwrap_or(0.0);
161                    let new = new_scores.get(id).copied().unwrap_or(0.0);
162                    (old - new).abs()
163                })
164                .sum();
165
166            scores = new_scores;
167
168            if diff < self.epsilon {
169                converged = true;
170                break;
171            }
172        }
173
174        PageRankResult {
175            scores,
176            iterations,
177            converged,
178        }
179    }
180}
181
182// ============================================================================
183// Personalized PageRank
184// ============================================================================
185
186/// Personalized PageRank - teleport only to specified seed nodes
187///
188/// Useful for finding nodes that are "important" relative to a specific
189/// entry point (e.g., "what hosts are most reachable from this compromised server?")
190pub struct PersonalizedPageRank {
191    /// Damping factor
192    pub alpha: f64,
193    /// Convergence threshold
194    pub epsilon: f64,
195    /// Maximum iterations
196    pub max_iterations: usize,
197    /// Seed nodes (teleportation targets)
198    seeds: Vec<String>,
199    /// Seed weights (must sum to 1.0)
200    weights: Vec<f64>,
201}
202
203impl PersonalizedPageRank {
204    /// Create personalized PageRank with uniform seed weights
205    pub fn new(seeds: Vec<String>) -> Self {
206        let n = seeds.len().max(1) as f64;
207        let weights = vec![1.0 / n; seeds.len()];
208        Self {
209            alpha: 0.85,
210            epsilon: 1e-6,
211            max_iterations: 100,
212            seeds,
213            weights,
214        }
215    }
216
217    /// Create with custom seed weights (must sum to 1.0)
218    pub fn with_weights(seeds: Vec<String>, weights: Vec<f64>) -> Self {
219        // Normalize weights to sum to 1.0
220        let sum: f64 = weights.iter().sum();
221        let normalized = if sum > 0.0 {
222            weights.iter().map(|w| w / sum).collect()
223        } else {
224            vec![1.0 / seeds.len().max(1) as f64; seeds.len()]
225        };
226
227        Self {
228            alpha: 0.85,
229            epsilon: 1e-6,
230            max_iterations: 100,
231            seeds,
232            weights: normalized,
233        }
234    }
235
236    /// Set damping factor
237    pub fn alpha(mut self, alpha: f64) -> Self {
238        self.alpha = alpha.clamp(0.0, 1.0);
239        self
240    }
241
242    /// Set convergence threshold
243    pub fn epsilon(mut self, epsilon: f64) -> Self {
244        self.epsilon = epsilon;
245        self
246    }
247
248    /// Set maximum iterations
249    pub fn max_iterations(mut self, max: usize) -> Self {
250        self.max_iterations = max;
251        self
252    }
253
254    /// Run personalized PageRank
255    pub fn run(&self, graph: &GraphStore) -> PageRankResult {
256        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
257        let n = nodes.len();
258
259        if n == 0 || self.seeds.is_empty() {
260            return PageRankResult {
261                scores: HashMap::new(),
262                iterations: 0,
263                converged: true,
264            };
265        }
266
267        // Build seed weight lookup
268        let seed_weights: HashMap<String, f64> = self
269            .seeds
270            .iter()
271            .zip(self.weights.iter())
272            .map(|(s, w)| (s.clone(), *w))
273            .collect();
274
275        // Build adjacency
276        let mut outgoing: HashMap<String, Vec<String>> = HashMap::new();
277        for node_id in &nodes {
278            let edges = graph.outgoing_edges(node_id);
279            let targets: Vec<String> = edges.into_iter().map(|(_, target, _)| target).collect();
280            outgoing.insert(node_id.clone(), targets);
281        }
282
283        // Initialize scores - start concentrated on seeds
284        let mut scores: HashMap<String, f64> = HashMap::new();
285        for node_id in &nodes {
286            let initial = seed_weights.get(node_id).copied().unwrap_or(0.0);
287            scores.insert(node_id.clone(), initial);
288        }
289
290        let mut converged = false;
291        let mut iterations = 0;
292
293        for iter in 0..self.max_iterations {
294            iterations = iter + 1;
295            let mut new_scores: HashMap<String, f64> = HashMap::new();
296
297            for node_id in &nodes {
298                // Teleport: go to seeds with their weights
299                let teleport =
300                    (1.0 - self.alpha) * seed_weights.get(node_id).copied().unwrap_or(0.0);
301                let mut score = teleport;
302
303                // Sum contributions from incoming edges
304                for (source_id, targets) in &outgoing {
305                    if targets.contains(node_id) {
306                        let source_score = scores.get(source_id).copied().unwrap_or(0.0);
307                        let out_degree = targets.len() as f64;
308                        if out_degree > 0.0 {
309                            score += self.alpha * source_score / out_degree;
310                        }
311                    }
312                }
313
314                new_scores.insert(node_id.clone(), score);
315            }
316
317            // Handle dangling nodes - distribute to seeds
318            let dangling_sum: f64 = nodes
319                .iter()
320                .filter(|id| outgoing.get(*id).map(|v| v.is_empty()).unwrap_or(true))
321                .map(|id| scores.get(id).copied().unwrap_or(0.0))
322                .sum();
323
324            for (seed, weight) in &seed_weights {
325                if let Some(score) = new_scores.get_mut(seed) {
326                    *score += self.alpha * dangling_sum * weight;
327                }
328            }
329
330            // Check convergence
331            let diff: f64 = nodes
332                .iter()
333                .map(|id| {
334                    let old = scores.get(id).copied().unwrap_or(0.0);
335                    let new = new_scores.get(id).copied().unwrap_or(0.0);
336                    (old - new).abs()
337                })
338                .sum();
339
340            scores = new_scores;
341
342            if diff < self.epsilon {
343                converged = true;
344                break;
345            }
346        }
347
348        PageRankResult {
349            scores,
350            iterations,
351            converged,
352        }
353    }
354}