Skip to main content

oxigdal_security/lineage/
graph.rs

1//! Lineage graph construction and management.
2
3use crate::error::{Result, SecurityError};
4use crate::lineage::{EdgeType, LineageEdge, LineageEvent, LineageNode, NodeType};
5use dashmap::DashMap;
6use petgraph::Direction;
7use petgraph::graph::{DiGraph, NodeIndex};
8use std::sync::Arc;
9
10/// Lineage graph.
11pub struct LineageGraph {
12    /// Graph structure.
13    graph: parking_lot::RwLock<DiGraph<LineageNode, LineageEdge>>,
14    /// Node ID to graph index mapping.
15    node_index: Arc<DashMap<String, NodeIndex>>,
16    /// Entity ID to node ID mapping.
17    entity_index: Arc<DashMap<String, Vec<String>>>,
18}
19
20impl LineageGraph {
21    /// Create a new lineage graph.
22    pub fn new() -> Self {
23        Self {
24            graph: parking_lot::RwLock::new(DiGraph::new()),
25            node_index: Arc::new(DashMap::new()),
26            entity_index: Arc::new(DashMap::new()),
27        }
28    }
29
30    /// Add a node to the graph.
31    pub fn add_node(&self, node: LineageNode) -> Result<String> {
32        let node_id = node.id.clone();
33        let entity_id = node.entity_id.clone();
34
35        let mut graph = self.graph.write();
36        let idx = graph.add_node(node);
37
38        self.node_index.insert(node_id.clone(), idx);
39        self.entity_index
40            .entry(entity_id)
41            .or_default()
42            .push(node_id.clone());
43
44        Ok(node_id)
45    }
46
47    /// Add an edge to the graph.
48    pub fn add_edge(&self, edge: LineageEdge) -> Result<String> {
49        let source_idx = *self
50            .node_index
51            .get(&edge.source_id)
52            .ok_or_else(|| SecurityError::lineage_tracking("Source node not found"))?;
53
54        let target_idx = *self
55            .node_index
56            .get(&edge.target_id)
57            .ok_or_else(|| SecurityError::lineage_tracking("Target node not found"))?;
58
59        let edge_id = edge.id.clone();
60        let mut graph = self.graph.write();
61        graph.add_edge(source_idx, target_idx, edge);
62
63        Ok(edge_id)
64    }
65
66    /// Get a node by ID.
67    pub fn get_node(&self, node_id: &str) -> Option<LineageNode> {
68        let idx = self.node_index.get(node_id)?;
69        let graph = self.graph.read();
70        graph.node_weight(*idx).cloned()
71    }
72
73    /// Get nodes by entity ID.
74    pub fn get_nodes_by_entity(&self, entity_id: &str) -> Vec<LineageNode> {
75        let node_ids = match self.entity_index.get(entity_id) {
76            Some(ids) => ids.clone(),
77            None => return Vec::new(),
78        };
79
80        node_ids.iter().filter_map(|id| self.get_node(id)).collect()
81    }
82
83    /// Get upstream nodes (dependencies).
84    pub fn get_upstream(&self, node_id: &str) -> Result<Vec<LineageNode>> {
85        let idx = *self
86            .node_index
87            .get(node_id)
88            .ok_or_else(|| SecurityError::lineage_tracking("Node not found"))?;
89
90        let graph = self.graph.read();
91        let upstream_indices: Vec<_> = graph.neighbors_directed(idx, Direction::Incoming).collect();
92
93        Ok(upstream_indices
94            .iter()
95            .filter_map(|&i| graph.node_weight(i).cloned())
96            .collect())
97    }
98
99    /// Get downstream nodes (dependents).
100    pub fn get_downstream(&self, node_id: &str) -> Result<Vec<LineageNode>> {
101        let idx = *self
102            .node_index
103            .get(node_id)
104            .ok_or_else(|| SecurityError::lineage_tracking("Node not found"))?;
105
106        let graph = self.graph.read();
107        let downstream_indices: Vec<_> =
108            graph.neighbors_directed(idx, Direction::Outgoing).collect();
109
110        Ok(downstream_indices
111            .iter()
112            .filter_map(|&i| graph.node_weight(i).cloned())
113            .collect())
114    }
115
116    /// Get all ancestors (recursive upstream).
117    pub fn get_ancestors(&self, node_id: &str) -> Result<Vec<LineageNode>> {
118        let mut ancestors = Vec::new();
119        let mut visited = std::collections::HashSet::new();
120        self.collect_ancestors(node_id, &mut ancestors, &mut visited)?;
121        Ok(ancestors)
122    }
123
124    fn collect_ancestors(
125        &self,
126        node_id: &str,
127        ancestors: &mut Vec<LineageNode>,
128        visited: &mut std::collections::HashSet<String>,
129    ) -> Result<()> {
130        if visited.contains(node_id) {
131            return Ok(());
132        }
133        visited.insert(node_id.to_string());
134
135        let upstream = self.get_upstream(node_id)?;
136        for node in upstream {
137            ancestors.push(node.clone());
138            self.collect_ancestors(&node.id, ancestors, visited)?;
139        }
140
141        Ok(())
142    }
143
144    /// Get all descendants (recursive downstream).
145    pub fn get_descendants(&self, node_id: &str) -> Result<Vec<LineageNode>> {
146        let mut descendants = Vec::new();
147        let mut visited = std::collections::HashSet::new();
148        self.collect_descendants(node_id, &mut descendants, &mut visited)?;
149        Ok(descendants)
150    }
151
152    fn collect_descendants(
153        &self,
154        node_id: &str,
155        descendants: &mut Vec<LineageNode>,
156        visited: &mut std::collections::HashSet<String>,
157    ) -> Result<()> {
158        if visited.contains(node_id) {
159            return Ok(());
160        }
161        visited.insert(node_id.to_string());
162
163        let downstream = self.get_downstream(node_id)?;
164        for node in downstream {
165            descendants.push(node.clone());
166            self.collect_descendants(&node.id, descendants, visited)?;
167        }
168
169        Ok(())
170    }
171
172    /// Record a lineage event.
173    pub fn record_event(&self, event: LineageEvent) -> Result<()> {
174        // Create operation node if specified
175        if let Some(ref operation_id) = event.operation {
176            let op_node = LineageNode::new(NodeType::Operation, operation_id.clone())
177                .with_metadata("event_type".to_string(), event.event_type.clone());
178            self.add_node(op_node)?;
179        }
180
181        // Create edges from inputs to operation
182        if let Some(ref operation_id) = event.operation {
183            if let Some(op_node_id) = self
184                .entity_index
185                .get(operation_id)
186                .and_then(|ids| ids.first().cloned())
187            {
188                for input_id in &event.inputs {
189                    if let Some(input_node_id) = self
190                        .entity_index
191                        .get(input_id)
192                        .and_then(|ids| ids.last().cloned())
193                    {
194                        let edge =
195                            LineageEdge::new(input_node_id, op_node_id.clone(), EdgeType::Used);
196                        self.add_edge(edge)?;
197                    }
198                }
199
200                // Create edges from operation to outputs
201                for output_id in &event.outputs {
202                    if let Some(output_node_id) = self
203                        .entity_index
204                        .get(output_id)
205                        .and_then(|ids| ids.last().cloned())
206                    {
207                        let edge = LineageEdge::new(
208                            op_node_id.clone(),
209                            output_node_id,
210                            EdgeType::GeneratedBy,
211                        );
212                        self.add_edge(edge)?;
213                    }
214                }
215            }
216        } else {
217            // Direct edges from inputs to outputs
218            for input_id in &event.inputs {
219                if let Some(input_node_id) = self
220                    .entity_index
221                    .get(input_id)
222                    .and_then(|ids| ids.last().cloned())
223                {
224                    for output_id in &event.outputs {
225                        if let Some(output_node_id) = self
226                            .entity_index
227                            .get(output_id)
228                            .and_then(|ids| ids.last().cloned())
229                        {
230                            let edge = LineageEdge::new(
231                                input_node_id.clone(),
232                                output_node_id,
233                                EdgeType::DerivedFrom,
234                            );
235                            self.add_edge(edge)?;
236                        }
237                    }
238                }
239            }
240        }
241
242        Ok(())
243    }
244
245    /// Get graph statistics.
246    pub fn stats(&self) -> (usize, usize) {
247        let graph = self.graph.read();
248        (graph.node_count(), graph.edge_count())
249    }
250
251    /// Clear the graph.
252    pub fn clear(&self) {
253        let mut graph = self.graph.write();
254        graph.clear();
255        self.node_index.clear();
256        self.entity_index.clear();
257    }
258}
259
260impl Default for LineageGraph {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_add_node() {
272        let graph = LineageGraph::new();
273        let node = LineageNode::new(NodeType::Dataset, "dataset-1".to_string());
274        let node_id = graph.add_node(node).expect("Failed to add node");
275
276        assert!(graph.get_node(&node_id).is_some());
277    }
278
279    #[test]
280    fn test_add_edge() {
281        let graph = LineageGraph::new();
282
283        let node1 = LineageNode::new(NodeType::Dataset, "dataset-1".to_string());
284        let node1_id = graph.add_node(node1).expect("Failed to add node");
285
286        let node2 = LineageNode::new(NodeType::Dataset, "dataset-2".to_string());
287        let node2_id = graph.add_node(node2).expect("Failed to add node");
288
289        let edge = LineageEdge::new(node1_id.clone(), node2_id.clone(), EdgeType::DerivedFrom);
290        graph.add_edge(edge).expect("Failed to add edge");
291
292        let downstream = graph
293            .get_downstream(&node1_id)
294            .expect("Failed to get downstream");
295        assert_eq!(downstream.len(), 1);
296        assert_eq!(downstream[0].id, node2_id);
297    }
298
299    #[test]
300    fn test_upstream_downstream() {
301        let graph = LineageGraph::new();
302
303        let node1 = LineageNode::new(NodeType::Dataset, "dataset-1".to_string());
304        let node1_id = graph.add_node(node1).expect("Failed to add node");
305
306        let node2 = LineageNode::new(NodeType::Dataset, "dataset-2".to_string());
307        let node2_id = graph.add_node(node2).expect("Failed to add node");
308
309        let edge = LineageEdge::new(node1_id.clone(), node2_id.clone(), EdgeType::DerivedFrom);
310        graph.add_edge(edge).expect("Failed to add edge");
311
312        let upstream = graph
313            .get_upstream(&node2_id)
314            .expect("Failed to get upstream");
315        assert_eq!(upstream.len(), 1);
316        assert_eq!(upstream[0].id, node1_id);
317
318        let downstream = graph
319            .get_downstream(&node1_id)
320            .expect("Failed to get downstream");
321        assert_eq!(downstream.len(), 1);
322        assert_eq!(downstream[0].id, node2_id);
323    }
324
325    #[test]
326    fn test_record_event() {
327        let graph = LineageGraph::new();
328
329        let input_node = LineageNode::new(NodeType::Dataset, "input-1".to_string());
330        graph.add_node(input_node).expect("Failed to add node");
331
332        let output_node = LineageNode::new(NodeType::Dataset, "output-1".to_string());
333        graph.add_node(output_node).expect("Failed to add node");
334
335        let event = LineageEvent::new("transform".to_string())
336            .with_input("input-1".to_string())
337            .with_output("output-1".to_string())
338            .with_operation("op-1".to_string());
339
340        graph.record_event(event).expect("Failed to record event");
341
342        let (nodes, edges) = graph.stats();
343        assert!(nodes >= 2); // At least input and output
344        assert!(edges >= 2); // At least input->op and op->output
345    }
346}