Skip to main content

datasynth_graph/models/
graph.rs

1//! Graph container model.
2
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5
6use super::edges::{EdgeId, EdgeType, GraphEdge};
7use super::nodes::{GraphNode, NodeId, NodeType};
8
9/// A graph containing nodes and edges.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Graph {
12    /// Graph name.
13    pub name: String,
14    /// Graph type.
15    pub graph_type: GraphType,
16    /// Nodes indexed by ID.
17    pub nodes: HashMap<NodeId, GraphNode>,
18    /// Edges indexed by ID.
19    pub edges: HashMap<EdgeId, GraphEdge>,
20    /// Adjacency list (node -> outgoing edges).
21    pub adjacency: HashMap<NodeId, Vec<EdgeId>>,
22    /// Reverse adjacency (node -> incoming edges).
23    pub reverse_adjacency: HashMap<NodeId, Vec<EdgeId>>,
24    /// Node type index.
25    pub nodes_by_type: HashMap<NodeType, Vec<NodeId>>,
26    /// Edge type index.
27    pub edges_by_type: HashMap<EdgeType, Vec<EdgeId>>,
28    /// Metadata.
29    pub metadata: GraphMetadata,
30    /// Next node ID.
31    next_node_id: NodeId,
32    /// Next edge ID.
33    next_edge_id: EdgeId,
34}
35
36impl Graph {
37    /// Creates a new graph.
38    pub fn new(name: &str, graph_type: GraphType) -> Self {
39        Self {
40            name: name.to_string(),
41            graph_type,
42            nodes: HashMap::new(),
43            edges: HashMap::new(),
44            adjacency: HashMap::new(),
45            reverse_adjacency: HashMap::new(),
46            nodes_by_type: HashMap::new(),
47            edges_by_type: HashMap::new(),
48            metadata: GraphMetadata::default(),
49            next_node_id: 1,
50            next_edge_id: 1,
51        }
52    }
53
54    /// Adds a node to the graph, returning its ID.
55    pub fn add_node(&mut self, mut node: GraphNode) -> NodeId {
56        let id = self.next_node_id;
57        self.next_node_id += 1;
58        node.id = id;
59
60        // Update type index
61        self.nodes_by_type
62            .entry(node.node_type.clone())
63            .or_default()
64            .push(id);
65
66        // Initialize adjacency
67        self.adjacency.insert(id, Vec::new());
68        self.reverse_adjacency.insert(id, Vec::new());
69
70        self.nodes.insert(id, node);
71        id
72    }
73
74    /// Adds an edge to the graph, returning its ID.
75    pub fn add_edge(&mut self, mut edge: GraphEdge) -> EdgeId {
76        let id = self.next_edge_id;
77        self.next_edge_id += 1;
78        edge.id = id;
79
80        // Update adjacency
81        self.adjacency.entry(edge.source).or_default().push(id);
82        self.reverse_adjacency
83            .entry(edge.target)
84            .or_default()
85            .push(id);
86
87        // Update type index
88        self.edges_by_type
89            .entry(edge.edge_type.clone())
90            .or_default()
91            .push(id);
92
93        self.edges.insert(id, edge);
94        id
95    }
96
97    /// Gets a node by ID.
98    pub fn get_node(&self, id: NodeId) -> Option<&GraphNode> {
99        self.nodes.get(&id)
100    }
101
102    /// Gets a mutable node by ID.
103    pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut GraphNode> {
104        self.nodes.get_mut(&id)
105    }
106
107    /// Gets an edge by ID.
108    pub fn get_edge(&self, id: EdgeId) -> Option<&GraphEdge> {
109        self.edges.get(&id)
110    }
111
112    /// Gets a mutable edge by ID.
113    pub fn get_edge_mut(&mut self, id: EdgeId) -> Option<&mut GraphEdge> {
114        self.edges.get_mut(&id)
115    }
116
117    /// Returns all nodes of a given type.
118    pub fn nodes_of_type(&self, node_type: &NodeType) -> Vec<&GraphNode> {
119        self.nodes_by_type
120            .get(node_type)
121            .map(|ids| ids.iter().filter_map(|id| self.nodes.get(id)).collect())
122            .unwrap_or_default()
123    }
124
125    /// Returns all edges of a given type.
126    pub fn edges_of_type(&self, edge_type: &EdgeType) -> Vec<&GraphEdge> {
127        self.edges_by_type
128            .get(edge_type)
129            .map(|ids| ids.iter().filter_map(|id| self.edges.get(id)).collect())
130            .unwrap_or_default()
131    }
132
133    /// Returns outgoing edges from a node.
134    pub fn outgoing_edges(&self, node_id: NodeId) -> Vec<&GraphEdge> {
135        self.adjacency
136            .get(&node_id)
137            .map(|ids| ids.iter().filter_map(|id| self.edges.get(id)).collect())
138            .unwrap_or_default()
139    }
140
141    /// Returns incoming edges to a node.
142    pub fn incoming_edges(&self, node_id: NodeId) -> Vec<&GraphEdge> {
143        self.reverse_adjacency
144            .get(&node_id)
145            .map(|ids| ids.iter().filter_map(|id| self.edges.get(id)).collect())
146            .unwrap_or_default()
147    }
148
149    /// Returns neighbors of a node.
150    pub fn neighbors(&self, node_id: NodeId) -> Vec<NodeId> {
151        let mut neighbors = HashSet::new();
152
153        // Outgoing
154        if let Some(edges) = self.adjacency.get(&node_id) {
155            for edge_id in edges {
156                if let Some(edge) = self.edges.get(edge_id) {
157                    neighbors.insert(edge.target);
158                }
159            }
160        }
161
162        // Incoming
163        if let Some(edges) = self.reverse_adjacency.get(&node_id) {
164            for edge_id in edges {
165                if let Some(edge) = self.edges.get(edge_id) {
166                    neighbors.insert(edge.source);
167                }
168            }
169        }
170
171        neighbors.into_iter().collect()
172    }
173
174    /// Returns the number of nodes.
175    pub fn node_count(&self) -> usize {
176        self.nodes.len()
177    }
178
179    /// Returns the number of edges.
180    pub fn edge_count(&self) -> usize {
181        self.edges.len()
182    }
183
184    /// Returns the out-degree of a node.
185    pub fn out_degree(&self, node_id: NodeId) -> usize {
186        self.adjacency.get(&node_id).map(|e| e.len()).unwrap_or(0)
187    }
188
189    /// Returns the in-degree of a node.
190    pub fn in_degree(&self, node_id: NodeId) -> usize {
191        self.reverse_adjacency
192            .get(&node_id)
193            .map(|e| e.len())
194            .unwrap_or(0)
195    }
196
197    /// Returns the total degree of a node.
198    pub fn degree(&self, node_id: NodeId) -> usize {
199        self.out_degree(node_id) + self.in_degree(node_id)
200    }
201
202    /// Returns anomalous nodes.
203    pub fn anomalous_nodes(&self) -> Vec<&GraphNode> {
204        self.nodes.values().filter(|n| n.is_anomaly).collect()
205    }
206
207    /// Returns anomalous edges.
208    pub fn anomalous_edges(&self) -> Vec<&GraphEdge> {
209        self.edges.values().filter(|e| e.is_anomaly).collect()
210    }
211
212    /// Computes graph statistics.
213    pub fn compute_statistics(&mut self) {
214        self.metadata.node_count = self.nodes.len();
215        self.metadata.edge_count = self.edges.len();
216
217        // Count by type
218        self.metadata.node_type_counts = self
219            .nodes_by_type
220            .iter()
221            .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
222            .collect();
223
224        self.metadata.edge_type_counts = self
225            .edges_by_type
226            .iter()
227            .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
228            .collect();
229
230        // Anomaly counts
231        self.metadata.anomalous_node_count = self.anomalous_nodes().len();
232        self.metadata.anomalous_edge_count = self.anomalous_edges().len();
233
234        // Density
235        if self.metadata.node_count > 1 {
236            let max_edges = self.metadata.node_count * (self.metadata.node_count - 1);
237            self.metadata.density = self.metadata.edge_count as f64 / max_edges as f64;
238        }
239
240        // Feature dimensions
241        if let Some(node) = self.nodes.values().next() {
242            self.metadata.node_feature_dim = node.features.len();
243        }
244        if let Some(edge) = self.edges.values().next() {
245            self.metadata.edge_feature_dim = edge.features.len();
246        }
247    }
248
249    /// Returns the edge index as a pair of vectors (source_ids, target_ids).
250    pub fn edge_index(&self) -> (Vec<NodeId>, Vec<NodeId>) {
251        let mut sources = Vec::with_capacity(self.edges.len());
252        let mut targets = Vec::with_capacity(self.edges.len());
253
254        for edge in self.edges.values() {
255            sources.push(edge.source);
256            targets.push(edge.target);
257        }
258
259        (sources, targets)
260    }
261
262    /// Returns the node feature matrix (nodes x features).
263    pub fn node_features(&self) -> Vec<Vec<f64>> {
264        let mut node_ids: Vec<_> = self.nodes.keys().copied().collect();
265        node_ids.sort();
266
267        node_ids
268            .iter()
269            .filter_map(|id| self.nodes.get(id))
270            .map(|n| n.features.clone())
271            .collect()
272    }
273
274    /// Returns the edge feature matrix (edges x features).
275    pub fn edge_features(&self) -> Vec<Vec<f64>> {
276        let mut edge_ids: Vec<_> = self.edges.keys().copied().collect();
277        edge_ids.sort();
278
279        edge_ids
280            .iter()
281            .filter_map(|id| self.edges.get(id))
282            .map(|e| e.features.clone())
283            .collect()
284    }
285
286    /// Returns node labels.
287    pub fn node_labels(&self) -> Vec<Vec<String>> {
288        let mut node_ids: Vec<_> = self.nodes.keys().copied().collect();
289        node_ids.sort();
290
291        node_ids
292            .iter()
293            .filter_map(|id| self.nodes.get(id))
294            .map(|n| n.labels.clone())
295            .collect()
296    }
297
298    /// Returns edge labels.
299    pub fn edge_labels(&self) -> Vec<Vec<String>> {
300        let mut edge_ids: Vec<_> = self.edges.keys().copied().collect();
301        edge_ids.sort();
302
303        edge_ids
304            .iter()
305            .filter_map(|id| self.edges.get(id))
306            .map(|e| e.labels.clone())
307            .collect()
308    }
309
310    /// Returns node anomaly flags.
311    pub fn node_anomaly_mask(&self) -> Vec<bool> {
312        let mut node_ids: Vec<_> = self.nodes.keys().copied().collect();
313        node_ids.sort();
314
315        node_ids
316            .iter()
317            .filter_map(|id| self.nodes.get(id))
318            .map(|n| n.is_anomaly)
319            .collect()
320    }
321
322    /// Returns edge anomaly flags.
323    pub fn edge_anomaly_mask(&self) -> Vec<bool> {
324        let mut edge_ids: Vec<_> = self.edges.keys().copied().collect();
325        edge_ids.sort();
326
327        edge_ids
328            .iter()
329            .filter_map(|id| self.edges.get(id))
330            .map(|e| e.is_anomaly)
331            .collect()
332    }
333}
334
335/// Type of graph.
336#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
337pub enum GraphType {
338    /// Transaction network (accounts as nodes, transactions as edges).
339    Transaction,
340    /// Approval network (users as nodes, approvals as edges).
341    Approval,
342    /// Entity relationship (companies as nodes, ownership as edges).
343    EntityRelationship,
344    /// Heterogeneous graph (multiple node and edge types).
345    Heterogeneous,
346    /// Custom graph type.
347    Custom(String),
348}
349
350/// Metadata about a graph.
351#[derive(Debug, Clone, Default, Serialize, Deserialize)]
352pub struct GraphMetadata {
353    /// Number of nodes.
354    pub node_count: usize,
355    /// Number of edges.
356    pub edge_count: usize,
357    /// Node counts by type.
358    pub node_type_counts: HashMap<String, usize>,
359    /// Edge counts by type.
360    pub edge_type_counts: HashMap<String, usize>,
361    /// Number of anomalous nodes.
362    pub anomalous_node_count: usize,
363    /// Number of anomalous edges.
364    pub anomalous_edge_count: usize,
365    /// Graph density.
366    pub density: f64,
367    /// Node feature dimension.
368    pub node_feature_dim: usize,
369    /// Edge feature dimension.
370    pub edge_feature_dim: usize,
371    /// Additional properties.
372    pub properties: HashMap<String, String>,
373}
374
375/// A heterogeneous graph with multiple node and edge types.
376#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct HeterogeneousGraph {
378    /// Graph name.
379    pub name: String,
380    /// Subgraphs by relation type (source_type, edge_type, target_type).
381    pub relations: HashMap<(String, String, String), Graph>,
382    /// All node IDs by type.
383    pub all_nodes: HashMap<String, Vec<NodeId>>,
384    /// Metadata.
385    pub metadata: GraphMetadata,
386}
387
388impl HeterogeneousGraph {
389    /// Creates a new heterogeneous graph.
390    pub fn new(name: &str) -> Self {
391        Self {
392            name: name.to_string(),
393            relations: HashMap::new(),
394            all_nodes: HashMap::new(),
395            metadata: GraphMetadata::default(),
396        }
397    }
398
399    /// Adds a relation (edge type between node types).
400    pub fn add_relation(
401        &mut self,
402        source_type: &str,
403        edge_type: &str,
404        target_type: &str,
405        graph: Graph,
406    ) {
407        let key = (
408            source_type.to_string(),
409            edge_type.to_string(),
410            target_type.to_string(),
411        );
412        self.relations.insert(key, graph);
413    }
414
415    /// Gets a relation graph.
416    pub fn get_relation(
417        &self,
418        source_type: &str,
419        edge_type: &str,
420        target_type: &str,
421    ) -> Option<&Graph> {
422        let key = (
423            source_type.to_string(),
424            edge_type.to_string(),
425            target_type.to_string(),
426        );
427        self.relations.get(&key)
428    }
429
430    /// Returns all relation keys.
431    pub fn relation_types(&self) -> Vec<(String, String, String)> {
432        self.relations.keys().cloned().collect()
433    }
434
435    /// Computes statistics for the heterogeneous graph.
436    pub fn compute_statistics(&mut self) {
437        let mut total_nodes = 0;
438        let mut total_edges = 0;
439
440        for graph in self.relations.values() {
441            total_nodes += graph.node_count();
442            total_edges += graph.edge_count();
443        }
444
445        self.metadata.node_count = total_nodes;
446        self.metadata.edge_count = total_edges;
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn test_graph_creation() {
456        let mut graph = Graph::new("test", GraphType::Transaction);
457
458        let node1 = GraphNode::new(0, NodeType::Account, "1000".to_string(), "Cash".to_string());
459        let node2 = GraphNode::new(0, NodeType::Account, "2000".to_string(), "AP".to_string());
460
461        let id1 = graph.add_node(node1);
462        let id2 = graph.add_node(node2);
463
464        let edge = GraphEdge::new(0, id1, id2, EdgeType::Transaction);
465        graph.add_edge(edge);
466
467        assert_eq!(graph.node_count(), 2);
468        assert_eq!(graph.edge_count(), 1);
469    }
470
471    #[test]
472    fn test_adjacency() {
473        let mut graph = Graph::new("test", GraphType::Transaction);
474
475        let n1 = graph.add_node(GraphNode::new(
476            0,
477            NodeType::Account,
478            "1".to_string(),
479            "A".to_string(),
480        ));
481        let n2 = graph.add_node(GraphNode::new(
482            0,
483            NodeType::Account,
484            "2".to_string(),
485            "B".to_string(),
486        ));
487        let n3 = graph.add_node(GraphNode::new(
488            0,
489            NodeType::Account,
490            "3".to_string(),
491            "C".to_string(),
492        ));
493
494        graph.add_edge(GraphEdge::new(0, n1, n2, EdgeType::Transaction));
495        graph.add_edge(GraphEdge::new(0, n1, n3, EdgeType::Transaction));
496        graph.add_edge(GraphEdge::new(0, n2, n3, EdgeType::Transaction));
497
498        assert_eq!(graph.out_degree(n1), 2);
499        assert_eq!(graph.in_degree(n3), 2);
500        assert_eq!(graph.neighbors(n1).len(), 2);
501    }
502
503    #[test]
504    fn test_edge_index() {
505        let mut graph = Graph::new("test", GraphType::Transaction);
506
507        let n1 = graph.add_node(GraphNode::new(
508            0,
509            NodeType::Account,
510            "1".to_string(),
511            "A".to_string(),
512        ));
513        let n2 = graph.add_node(GraphNode::new(
514            0,
515            NodeType::Account,
516            "2".to_string(),
517            "B".to_string(),
518        ));
519
520        graph.add_edge(GraphEdge::new(0, n1, n2, EdgeType::Transaction));
521
522        let (sources, targets) = graph.edge_index();
523        assert_eq!(sources.len(), 1);
524        assert_eq!(targets.len(), 1);
525        assert_eq!(sources[0], n1);
526        assert_eq!(targets[0], n2);
527    }
528}