Skip to main content

entrenar/monitor/inference/provenance/
reconstructor.rs

1//! Incident reconstructor for forensic analysis.
2
3use super::attack::{Anomaly, AttackPath};
4use super::graph::ProvenanceGraph;
5use super::node::{NodeId, ProvenanceNode};
6
7/// Incident reconstructor using provenance graph
8pub struct IncidentReconstructor<'a> {
9    graph: &'a ProvenanceGraph,
10}
11
12impl<'a> IncidentReconstructor<'a> {
13    /// Create a new reconstructor
14    pub fn new(graph: &'a ProvenanceGraph) -> Self {
15        Self { graph }
16    }
17
18    /// Trace backwards from incident node to root causes
19    pub fn reconstruct_path(&self, incident_node: NodeId, max_depth: usize) -> AttackPath {
20        let mut nodes = Vec::new();
21        let mut edges = Vec::new();
22        let mut visited = std::collections::HashSet::new();
23        let mut queue = std::collections::VecDeque::new();
24
25        queue.push_back((incident_node, 0usize));
26        visited.insert(incident_node);
27
28        while let Some((node_id, depth)) = queue.pop_front() {
29            if depth > max_depth {
30                continue;
31            }
32
33            if let Some(node) = self.graph.get_node(node_id) {
34                nodes.push((node_id, node.clone()));
35            }
36
37            for edge in self.graph.incoming_edges(node_id) {
38                edges.push(edge.clone());
39
40                if !visited.contains(&edge.from) {
41                    visited.insert(edge.from);
42                    queue.push_back((edge.from, depth + 1));
43                }
44            }
45        }
46
47        // Reverse to get causal order (root → incident)
48        nodes.reverse();
49
50        // Calculate duration
51        let duration_ns = self.calculate_duration(&nodes);
52
53        AttackPath { nodes, edges, duration_ns, anomaly_indices: Vec::new() }
54    }
55
56    /// Calculate duration from timestamps
57    fn calculate_duration(&self, nodes: &[(NodeId, ProvenanceNode)]) -> u64 {
58        let timestamps: Vec<u64> = nodes.iter().filter_map(|(_, n)| n.timestamp_ns()).collect();
59
60        if timestamps.len() < 2 {
61            return 0;
62        }
63
64        let min = *timestamps.iter().min().unwrap_or(&0);
65        let max = *timestamps.iter().max().unwrap_or(&0);
66        max - min
67    }
68
69    /// Identify anomalies in a path
70    pub fn identify_anomalies(&self, path: &AttackPath, confidence_threshold: f32) -> Vec<Anomaly> {
71        let mut anomalies = Vec::new();
72
73        for (idx, (node_id, node)) in path.nodes.iter().enumerate() {
74            // Check for low confidence inferences
75            if let ProvenanceNode::Inference { confidence, .. } = node {
76                if *confidence < confidence_threshold {
77                    anomalies.push(Anomaly {
78                        node_id: *node_id,
79                        description: format!(
80                            "Low confidence inference: {:.1}% (threshold: {:.1}%)",
81                            confidence * 100.0,
82                            confidence_threshold * 100.0
83                        ),
84                        severity: 1.0 - *confidence,
85                    });
86                }
87            }
88
89            // Check for suspicious fusion with many inputs
90            if let ProvenanceNode::Fusion { input_refs, .. } = node {
91                if input_refs.len() > 10 {
92                    anomalies.push(Anomaly {
93                        node_id: *node_id,
94                        description: format!(
95                            "Unusually many fusion inputs: {} (expected <10)",
96                            input_refs.len()
97                        ),
98                        severity: 0.3,
99                    });
100                }
101            }
102
103            // Flag nodes with no predecessors (except inputs)
104            if !matches!(node, ProvenanceNode::Input { .. }) {
105                let preds = self.graph.predecessors(*node_id);
106                if preds.is_empty() {
107                    anomalies.push(Anomaly {
108                        node_id: *node_id,
109                        description: format!("{} node has no predecessors", node.type_name()),
110                        severity: 0.5,
111                    });
112                }
113            }
114
115            let _ = idx; // Used in future enhancements
116        }
117
118        anomalies
119    }
120}