Skip to main content

reddb_server/storage/engine/algorithms/
structural.rs

1//! Structural Graph Algorithms
2//!
3//! Algorithms for analyzing graph structure:
4//! - HITS: Hubs and Authorities
5//! - Strongly Connected Components (Tarjan's)
6//! - Weakly Connected Components
7//! - Triangle Counting
8//! - Clustering Coefficient
9
10use std::collections::{HashMap, HashSet, VecDeque};
11
12use super::super::graph_store::GraphStore;
13
14// ============================================================================
15// HITS (Hubs and Authorities)
16// ============================================================================
17
18/// HITS algorithm: Identifies hubs and authorities
19///
20/// - Authorities: Nodes that are pointed to by many hubs (valuable targets)
21/// - Hubs: Nodes that point to many authorities (good pivot points)
22pub struct HITS {
23    /// Maximum iterations
24    pub max_iterations: usize,
25    /// Convergence threshold
26    pub epsilon: f64,
27}
28
29impl Default for HITS {
30    fn default() -> Self {
31        Self {
32            max_iterations: 100,
33            epsilon: 1e-6,
34        }
35    }
36}
37
38/// Result of HITS computation
39#[derive(Debug, Clone)]
40pub struct HITSResult {
41    /// Node ID → hub score
42    pub hub_scores: HashMap<String, f64>,
43    /// Node ID → authority score
44    pub authority_scores: HashMap<String, f64>,
45    /// Iterations until convergence
46    pub iterations: usize,
47    /// Whether converged
48    pub converged: bool,
49}
50
51impl HITSResult {
52    /// Get top N hubs
53    pub fn top_hubs(&self, n: usize) -> Vec<(String, f64)> {
54        let mut sorted: Vec<_> = self
55            .hub_scores
56            .iter()
57            .map(|(k, v)| (k.clone(), *v))
58            .collect();
59        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
60        sorted.truncate(n);
61        sorted
62    }
63
64    /// Get top N authorities
65    pub fn top_authorities(&self, n: usize) -> Vec<(String, f64)> {
66        let mut sorted: Vec<_> = self
67            .authority_scores
68            .iter()
69            .map(|(k, v)| (k.clone(), *v))
70            .collect();
71        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
72        sorted.truncate(n);
73        sorted
74    }
75}
76
77impl HITS {
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Compute HITS hub and authority scores
83    pub fn compute(&self, graph: &GraphStore) -> HITSResult {
84        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
85        let n = nodes.len();
86
87        if n == 0 {
88            return HITSResult {
89                hub_scores: HashMap::new(),
90                authority_scores: HashMap::new(),
91                iterations: 0,
92                converged: true,
93            };
94        }
95
96        // Build adjacency
97        let mut outgoing: HashMap<String, Vec<String>> = HashMap::new();
98        let mut incoming: HashMap<String, Vec<String>> = HashMap::new();
99
100        for node in &nodes {
101            let out: Vec<String> = graph
102                .outgoing_edges(node)
103                .into_iter()
104                .map(|(_, t, _)| t)
105                .collect();
106            outgoing.insert(node.clone(), out);
107
108            let inc: Vec<String> = graph
109                .incoming_edges(node)
110                .into_iter()
111                .map(|(_, s, _)| s)
112                .collect();
113            incoming.insert(node.clone(), inc);
114        }
115
116        // Initialize scores
117        let init = 1.0 / (n as f64).sqrt();
118        let mut hub: HashMap<String, f64> = nodes.iter().map(|id| (id.clone(), init)).collect();
119        let mut auth: HashMap<String, f64> = nodes.iter().map(|id| (id.clone(), init)).collect();
120
121        let mut converged = false;
122        let mut iterations = 0;
123
124        for iter in 0..self.max_iterations {
125            iterations = iter + 1;
126
127            // Update authority scores: auth(p) = sum of hub(q) for all q that point to p
128            let mut new_auth: HashMap<String, f64> = HashMap::new();
129            for node in &nodes {
130                let sum: f64 = incoming
131                    .get(node)
132                    .map(|inc| inc.iter().map(|s| hub.get(s).copied().unwrap_or(0.0)).sum())
133                    .unwrap_or(0.0);
134                new_auth.insert(node.clone(), sum);
135            }
136
137            // Normalize authority scores
138            let auth_norm: f64 = new_auth.values().map(|v| v * v).sum::<f64>().sqrt();
139            if auth_norm > 0.0 {
140                for v in new_auth.values_mut() {
141                    *v /= auth_norm;
142                }
143            }
144
145            // Update hub scores: hub(p) = sum of auth(q) for all q that p points to
146            let mut new_hub: HashMap<String, f64> = HashMap::new();
147            for node in &nodes {
148                let sum: f64 = outgoing
149                    .get(node)
150                    .map(|out| {
151                        out.iter()
152                            .map(|t| new_auth.get(t).copied().unwrap_or(0.0))
153                            .sum()
154                    })
155                    .unwrap_or(0.0);
156                new_hub.insert(node.clone(), sum);
157            }
158
159            // Normalize hub scores
160            let hub_norm: f64 = new_hub.values().map(|v| v * v).sum::<f64>().sqrt();
161            if hub_norm > 0.0 {
162                for v in new_hub.values_mut() {
163                    *v /= hub_norm;
164                }
165            }
166
167            // Check convergence
168            let hub_diff: f64 = nodes
169                .iter()
170                .map(|id| {
171                    (hub.get(id).copied().unwrap_or(0.0) - new_hub.get(id).copied().unwrap_or(0.0))
172                        .abs()
173                })
174                .sum();
175
176            hub = new_hub;
177            auth = new_auth;
178
179            if hub_diff < self.epsilon {
180                converged = true;
181                break;
182            }
183        }
184
185        HITSResult {
186            hub_scores: hub,
187            authority_scores: auth,
188            iterations,
189            converged,
190        }
191    }
192}
193
194// ============================================================================
195// Strongly Connected Components (Tarjan's Algorithm)
196// ============================================================================
197
198/// Strongly connected components using Tarjan's algorithm
199///
200/// In a directed graph, an SCC is a maximal set of nodes where every node
201/// is reachable from every other node.
202pub struct StronglyConnectedComponents;
203
204/// Result of SCC computation
205#[derive(Debug, Clone)]
206pub struct SCCResult {
207    /// List of strongly connected components (sets of node IDs)
208    pub components: Vec<Vec<String>>,
209    /// Number of SCCs
210    pub count: usize,
211}
212
213impl SCCResult {
214    /// Get the largest SCC
215    pub fn largest(&self) -> Option<&Vec<String>> {
216        self.components.iter().max_by_key(|c| c.len())
217    }
218
219    /// Find which SCC a node belongs to
220    pub fn component_of(&self, node_id: &str) -> Option<usize> {
221        self.components
222            .iter()
223            .position(|c| c.contains(&node_id.to_string()))
224    }
225}
226
227impl StronglyConnectedComponents {
228    /// Find all strongly connected components using Tarjan's algorithm
229    pub fn find(graph: &GraphStore) -> SCCResult {
230        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
231
232        let mut index_counter = 0;
233        let mut stack: Vec<String> = Vec::new();
234        let mut on_stack: HashSet<String> = HashSet::new();
235        let mut indices: HashMap<String, usize> = HashMap::new();
236        let mut lowlinks: HashMap<String, usize> = HashMap::new();
237        let mut components: Vec<Vec<String>> = Vec::new();
238
239        fn strongconnect(
240            graph: &GraphStore,
241            node: &str,
242            index_counter: &mut usize,
243            stack: &mut Vec<String>,
244            on_stack: &mut HashSet<String>,
245            indices: &mut HashMap<String, usize>,
246            lowlinks: &mut HashMap<String, usize>,
247            components: &mut Vec<Vec<String>>,
248        ) {
249            indices.insert(node.to_string(), *index_counter);
250            lowlinks.insert(node.to_string(), *index_counter);
251            *index_counter += 1;
252            stack.push(node.to_string());
253            on_stack.insert(node.to_string());
254
255            for (_, neighbor, _) in graph.outgoing_edges(node) {
256                if !indices.contains_key(&neighbor) {
257                    // Neighbor not visited, recurse
258                    strongconnect(
259                        graph,
260                        &neighbor,
261                        index_counter,
262                        stack,
263                        on_stack,
264                        indices,
265                        lowlinks,
266                        components,
267                    );
268                    let neighbor_ll = *lowlinks.get(&neighbor).unwrap();
269                    let node_ll = lowlinks.get_mut(node).unwrap();
270                    *node_ll = (*node_ll).min(neighbor_ll);
271                } else if on_stack.contains(&neighbor) {
272                    // Neighbor on stack, update lowlink
273                    let neighbor_idx = *indices.get(&neighbor).unwrap();
274                    let node_ll = lowlinks.get_mut(node).unwrap();
275                    *node_ll = (*node_ll).min(neighbor_idx);
276                }
277            }
278
279            // If node is root of SCC
280            if lowlinks.get(node) == indices.get(node) {
281                let mut component = Vec::new();
282                loop {
283                    let w = stack.pop().unwrap();
284                    on_stack.remove(&w);
285                    component.push(w.clone());
286                    if w == node {
287                        break;
288                    }
289                }
290                components.push(component);
291            }
292        }
293
294        for node in &nodes {
295            if !indices.contains_key(node) {
296                strongconnect(
297                    graph,
298                    node,
299                    &mut index_counter,
300                    &mut stack,
301                    &mut on_stack,
302                    &mut indices,
303                    &mut lowlinks,
304                    &mut components,
305                );
306            }
307        }
308
309        // Sort components by size descending
310        components.sort_by_key(|b| std::cmp::Reverse(b.len()));
311
312        SCCResult {
313            count: components.len(),
314            components,
315        }
316    }
317}
318
319// ============================================================================
320// Weakly Connected Components
321// ============================================================================
322
323/// Weakly connected components - treats directed graph as undirected
324///
325/// A weakly connected component is a set of nodes where there is a path
326/// between any two nodes when ignoring edge direction.
327pub struct WeaklyConnectedComponents;
328
329/// Result of weakly connected components
330#[derive(Debug, Clone)]
331pub struct WCCResult {
332    /// Each component as a list of node IDs
333    pub components: Vec<Vec<String>>,
334    /// Number of components
335    pub count: usize,
336    /// Node ID → component index
337    pub node_to_component: HashMap<String, usize>,
338}
339
340impl WCCResult {
341    /// Get the largest component
342    pub fn largest(&self) -> Option<&Vec<String>> {
343        self.components.iter().max_by_key(|c| c.len())
344    }
345
346    /// Get nodes in the same component as the given node
347    pub fn component_of(&self, node: &str) -> Option<&Vec<String>> {
348        self.node_to_component
349            .get(node)
350            .and_then(|&i| self.components.get(i))
351    }
352}
353
354impl WeaklyConnectedComponents {
355    /// Find all weakly connected components
356    pub fn find(graph: &GraphStore) -> WCCResult {
357        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
358
359        // Build undirected adjacency
360        let mut neighbors: HashMap<String, Vec<String>> = HashMap::new();
361        for node in &nodes {
362            let mut nbrs: Vec<String> = Vec::new();
363            for (_, target, _) in graph.outgoing_edges(node) {
364                if target != *node {
365                    nbrs.push(target);
366                }
367            }
368            for (_, source, _) in graph.incoming_edges(node) {
369                if source != *node && !nbrs.contains(&source) {
370                    nbrs.push(source);
371                }
372            }
373            neighbors.insert(node.clone(), nbrs);
374        }
375
376        // BFS to find components
377        let mut visited: HashSet<String> = HashSet::new();
378        let mut components: Vec<Vec<String>> = Vec::new();
379        let mut node_to_component: HashMap<String, usize> = HashMap::new();
380
381        for node in &nodes {
382            if visited.contains(node) {
383                continue;
384            }
385
386            let mut component: Vec<String> = Vec::new();
387            let mut queue: VecDeque<String> = VecDeque::new();
388            queue.push_back(node.clone());
389            visited.insert(node.clone());
390
391            while let Some(current) = queue.pop_front() {
392                component.push(current.clone());
393
394                if let Some(nbrs) = neighbors.get(&current) {
395                    for nbr in nbrs {
396                        if !visited.contains(nbr) {
397                            visited.insert(nbr.clone());
398                            queue.push_back(nbr.clone());
399                        }
400                    }
401                }
402            }
403
404            let component_idx = components.len();
405            for n in &component {
406                node_to_component.insert(n.clone(), component_idx);
407            }
408            components.push(component);
409        }
410
411        // Sort by size descending
412        let mut indexed: Vec<(usize, Vec<String>)> = components.into_iter().enumerate().collect();
413        indexed.sort_by_key(|b| std::cmp::Reverse(b.1.len()));
414
415        // Rebuild with new indices
416        let mut new_node_to_component: HashMap<String, usize> = HashMap::new();
417        let new_components: Vec<Vec<String>> = indexed
418            .into_iter()
419            .enumerate()
420            .map(|(new_idx, (_, comp))| {
421                for n in &comp {
422                    new_node_to_component.insert(n.clone(), new_idx);
423                }
424                comp
425            })
426            .collect();
427
428        WCCResult {
429            count: new_components.len(),
430            components: new_components,
431            node_to_component: new_node_to_component,
432        }
433    }
434}
435
436// ============================================================================
437// Triangle Counting
438// ============================================================================
439
440/// Count triangles in the graph
441///
442/// Triangles indicate tightly connected clusters.
443/// High triangle count in attack graphs suggests multiple redundant paths.
444pub struct TriangleCounting;
445
446/// Result of triangle counting
447#[derive(Debug, Clone)]
448pub struct TriangleResult {
449    /// Total number of triangles
450    pub count: usize,
451    /// Node ID → number of triangles containing this node
452    pub per_node: HashMap<String, usize>,
453    /// The triangles themselves (as triples of node IDs)
454    pub triangles: Vec<(String, String, String)>,
455}
456
457impl TriangleCounting {
458    /// Count all triangles in the graph (treating as undirected)
459    pub fn count(graph: &GraphStore) -> TriangleResult {
460        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
461
462        // Build undirected adjacency
463        let mut neighbors: HashMap<String, HashSet<String>> = HashMap::new();
464        for node in &nodes {
465            let mut nbrs: HashSet<String> = HashSet::new();
466            for (_, target, _) in graph.outgoing_edges(node) {
467                nbrs.insert(target);
468            }
469            for (_, source, _) in graph.incoming_edges(node) {
470                nbrs.insert(source);
471            }
472            neighbors.insert(node.clone(), nbrs);
473        }
474
475        let mut triangles: Vec<(String, String, String)> = Vec::new();
476        let mut per_node: HashMap<String, usize> = nodes.iter().map(|n| (n.clone(), 0)).collect();
477
478        // For each node, check if any two neighbors are connected
479        for node in &nodes {
480            if let Some(node_nbrs) = neighbors.get(node) {
481                let nbr_list: Vec<&String> = node_nbrs.iter().collect();
482                for i in 0..nbr_list.len() {
483                    for j in (i + 1)..nbr_list.len() {
484                        let a = nbr_list[i];
485                        let b = nbr_list[j];
486
487                        // Check if a and b are connected
488                        if neighbors.get(a).map(|s| s.contains(b)).unwrap_or(false) {
489                            // Sort to avoid duplicates
490                            let mut triple = vec![node.clone(), a.clone(), b.clone()];
491                            triple.sort();
492
493                            // Check if we've already found this triangle
494                            let is_new = !triangles.iter().any(|(x, y, z)| {
495                                let mut existing = vec![x.clone(), y.clone(), z.clone()];
496                                existing.sort();
497                                existing == triple
498                            });
499
500                            if is_new {
501                                triangles.push((
502                                    triple[0].clone(),
503                                    triple[1].clone(),
504                                    triple[2].clone(),
505                                ));
506                                *per_node.entry(triple[0].clone()).or_insert(0) += 1;
507                                *per_node.entry(triple[1].clone()).or_insert(0) += 1;
508                                *per_node.entry(triple[2].clone()).or_insert(0) += 1;
509                            }
510                        }
511                    }
512                }
513            }
514        }
515
516        TriangleResult {
517            count: triangles.len(),
518            per_node,
519            triangles,
520        }
521    }
522}
523
524// ============================================================================
525// Clustering Coefficient
526// ============================================================================
527
528/// Local and global clustering coefficient
529///
530/// Measures how much neighbors of a node are connected to each other.
531/// High clustering = tightly connected neighborhood = potential attack surface.
532pub struct ClusteringCoefficient;
533
534/// Result of clustering coefficient computation
535#[derive(Debug, Clone)]
536pub struct ClusteringResult {
537    /// Node ID → local clustering coefficient (0 to 1)
538    pub local: HashMap<String, f64>,
539    /// Global clustering coefficient (average of local)
540    pub global: f64,
541}
542
543impl ClusteringResult {
544    /// Get nodes with highest clustering
545    pub fn top(&self, n: usize) -> Vec<(String, f64)> {
546        let mut sorted: Vec<_> = self.local.iter().map(|(k, v)| (k.clone(), *v)).collect();
547        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
548        sorted.truncate(n);
549        sorted
550    }
551}
552
553impl ClusteringCoefficient {
554    /// Compute local and global clustering coefficients
555    pub fn compute(graph: &GraphStore) -> ClusteringResult {
556        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
557
558        // Build undirected adjacency
559        let mut neighbors: HashMap<String, HashSet<String>> = HashMap::new();
560        for node in &nodes {
561            let mut nbrs: HashSet<String> = HashSet::new();
562            for (_, target, _) in graph.outgoing_edges(node) {
563                if target != *node {
564                    nbrs.insert(target);
565                }
566            }
567            for (_, source, _) in graph.incoming_edges(node) {
568                if source != *node {
569                    nbrs.insert(source);
570                }
571            }
572            neighbors.insert(node.clone(), nbrs);
573        }
574
575        let mut local: HashMap<String, f64> = HashMap::new();
576
577        for node in &nodes {
578            if let Some(node_nbrs) = neighbors.get(node) {
579                let k = node_nbrs.len();
580                if k < 2 {
581                    local.insert(node.clone(), 0.0);
582                    continue;
583                }
584
585                // Count edges between neighbors
586                let mut edges_between = 0;
587                let nbr_list: Vec<&String> = node_nbrs.iter().collect();
588                for i in 0..nbr_list.len() {
589                    for j in (i + 1)..nbr_list.len() {
590                        if neighbors
591                            .get(nbr_list[i])
592                            .map(|s| s.contains(nbr_list[j]))
593                            .unwrap_or(false)
594                        {
595                            edges_between += 1;
596                        }
597                    }
598                }
599
600                // Local clustering coefficient = 2 * edges_between / (k * (k-1))
601                let max_edges = k * (k - 1) / 2;
602                let cc = if max_edges > 0 {
603                    edges_between as f64 / max_edges as f64
604                } else {
605                    0.0
606                };
607                local.insert(node.clone(), cc);
608            } else {
609                local.insert(node.clone(), 0.0);
610            }
611        }
612
613        // Global clustering coefficient (average of local)
614        let global = if !local.is_empty() {
615            local.values().sum::<f64>() / local.len() as f64
616        } else {
617            0.0
618        };
619
620        ClusteringResult { local, global }
621    }
622}