Skip to main content

ronn_core/
graph.rs

1//! Graph manipulation and validation utilities.
2//!
3//! This module provides utilities for working with model graphs, including
4//! validation, topological ordering, traversal, and subgraph extraction.
5
6use crate::types::{AttributeValue, GraphEdge, GraphNode, ModelGraph, NodeId, SubGraph};
7use anyhow::{Result, anyhow};
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl ModelGraph {
11    /// Create a new empty model graph.
12    pub fn new() -> Self {
13        Self {
14            nodes: Vec::new(),
15            edges: Vec::new(),
16            inputs: Vec::new(),
17            outputs: Vec::new(),
18            metadata: HashMap::new(),
19        }
20    }
21
22    /// Add a node to the graph.
23    ///
24    /// # Arguments
25    /// * `node` - The graph node to add
26    ///
27    /// # Returns
28    /// The node ID that was assigned
29    pub fn add_node(&mut self, mut node: GraphNode) -> NodeId {
30        let node_id = self.nodes.len();
31        node.id = node_id;
32        self.nodes.push(node);
33        node_id
34    }
35
36    /// Add an edge to the graph.
37    ///
38    /// # Arguments
39    /// * `edge` - The graph edge to add
40    pub fn add_edge(&mut self, edge: GraphEdge) -> Result<()> {
41        // Validate that the nodes exist
42        if edge.from_node >= self.nodes.len() || edge.to_node >= self.nodes.len() {
43            return Err(anyhow!("Edge references non-existent nodes"));
44        }
45        self.edges.push(edge);
46        Ok(())
47    }
48
49    /// Get a node by ID.
50    pub fn get_node(&self, node_id: NodeId) -> Option<&GraphNode> {
51        self.nodes.get(node_id)
52    }
53
54    /// Get a mutable node by ID.
55    pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut GraphNode> {
56        self.nodes.get_mut(node_id)
57    }
58
59    /// Get all nodes in the graph.
60    pub fn nodes(&self) -> &[GraphNode] {
61        &self.nodes
62    }
63
64    /// Get mutable access to all nodes.
65    pub fn nodes_mut(&mut self) -> &mut Vec<GraphNode> {
66        &mut self.nodes
67    }
68
69    /// Get the number of nodes in the graph.
70    pub fn node_count(&self) -> usize {
71        self.nodes.len()
72    }
73
74    /// Create a ModelGraph from a list of nodes.
75    pub fn from_nodes(nodes: Vec<GraphNode>) -> Self {
76        Self {
77            nodes,
78            edges: Vec::new(),
79            inputs: Vec::new(),
80            outputs: Vec::new(),
81            metadata: HashMap::new(),
82        }
83    }
84
85    /// Find nodes by operation type.
86    pub fn find_nodes_by_op(&self, op_type: &str) -> Vec<NodeId> {
87        self.nodes
88            .iter()
89            .filter_map(|node| {
90                if node.op_type == op_type {
91                    Some(node.id)
92                } else {
93                    None
94                }
95            })
96            .collect()
97    }
98
99    /// Get all edges connected to a node.
100    pub fn get_node_edges(&self, node_id: NodeId) -> (Vec<&GraphEdge>, Vec<&GraphEdge>) {
101        let incoming: Vec<&GraphEdge> = self
102            .edges
103            .iter()
104            .filter(|edge| edge.to_node == node_id)
105            .collect();
106
107        let outgoing: Vec<&GraphEdge> = self
108            .edges
109            .iter()
110            .filter(|edge| edge.from_node == node_id)
111            .collect();
112
113        (incoming, outgoing)
114    }
115
116    /// Validate the graph structure.
117    pub fn validate(&self) -> Result<()> {
118        // Check for duplicate node IDs
119        let mut seen_ids = HashSet::new();
120        for node in &self.nodes {
121            if !seen_ids.insert(node.id) {
122                return Err(anyhow!("Duplicate node ID: {}", node.id));
123            }
124        }
125
126        // Validate edges reference existing nodes
127        for edge in &self.edges {
128            if edge.from_node >= self.nodes.len() {
129                return Err(anyhow!(
130                    "Edge references non-existent from_node: {}",
131                    edge.from_node
132                ));
133            }
134            if edge.to_node >= self.nodes.len() {
135                return Err(anyhow!(
136                    "Edge references non-existent to_node: {}",
137                    edge.to_node
138                ));
139            }
140        }
141
142        // Check for cycles using DFS
143        if self.has_cycles()? {
144            return Err(anyhow!("Graph contains cycles"));
145        }
146
147        // Validate input/output tensor names are used in nodes
148        self.validate_input_output_tensors()?;
149
150        Ok(())
151    }
152
153    /// Check if the graph has cycles.
154    fn has_cycles(&self) -> Result<bool> {
155        let mut state = vec![NodeState::Unvisited; self.nodes.len()];
156
157        for node_id in 0..self.nodes.len() {
158            if state[node_id] == NodeState::Unvisited {
159                if self.has_cycles_dfs(node_id, &mut state)? {
160                    return Ok(true);
161                }
162            }
163        }
164        Ok(false)
165    }
166
167    fn has_cycles_dfs(&self, node_id: NodeId, state: &mut Vec<NodeState>) -> Result<bool> {
168        state[node_id] = NodeState::Visiting;
169
170        let (_, outgoing) = self.get_node_edges(node_id);
171        for edge in outgoing {
172            match state[edge.to_node] {
173                NodeState::Visiting => return Ok(true), // Back edge found - cycle detected
174                NodeState::Unvisited => {
175                    if self.has_cycles_dfs(edge.to_node, state)? {
176                        return Ok(true);
177                    }
178                }
179                NodeState::Visited => {} // Safe to ignore
180            }
181        }
182
183        state[node_id] = NodeState::Visited;
184        Ok(false)
185    }
186
187    /// Validate that input/output tensor names are used in the graph.
188    fn validate_input_output_tensors(&self) -> Result<()> {
189        let mut all_tensor_names: HashSet<String> = HashSet::new();
190
191        // Collect all tensor names used in nodes
192        for node in &self.nodes {
193            for input in &node.inputs {
194                all_tensor_names.insert(input.clone());
195            }
196            for output in &node.outputs {
197                all_tensor_names.insert(output.clone());
198            }
199        }
200
201        // Check that graph inputs exist in tensor names
202        for input in &self.inputs {
203            if !all_tensor_names.contains(input) {
204                return Err(anyhow!("Graph input '{}' is not used by any node", input));
205            }
206        }
207
208        // Check that graph outputs exist in tensor names
209        for output in &self.outputs {
210            if !all_tensor_names.contains(output) {
211                return Err(anyhow!(
212                    "Graph output '{}' is not produced by any node",
213                    output
214                ));
215            }
216        }
217
218        Ok(())
219    }
220
221    /// Get topological ordering of nodes.
222    pub fn topological_sort(&self) -> Result<Vec<NodeId>> {
223        let mut in_degree = vec![0; self.nodes.len()];
224
225        // Calculate in-degrees
226        for edge in &self.edges {
227            in_degree[edge.to_node] += 1;
228        }
229
230        // Queue nodes with no dependencies
231        let mut queue = VecDeque::new();
232        for (node_id, &degree) in in_degree.iter().enumerate() {
233            if degree == 0 {
234                queue.push_back(node_id);
235            }
236        }
237
238        let mut result = Vec::new();
239
240        while let Some(node_id) = queue.pop_front() {
241            result.push(node_id);
242
243            // Process outgoing edges
244            let (_, outgoing) = self.get_node_edges(node_id);
245            for edge in outgoing {
246                in_degree[edge.to_node] -= 1;
247                if in_degree[edge.to_node] == 0 {
248                    queue.push_back(edge.to_node);
249                }
250            }
251        }
252
253        if result.len() != self.nodes.len() {
254            return Err(anyhow!(
255                "Graph contains cycles - cannot perform topological sort"
256            ));
257        }
258
259        Ok(result)
260    }
261
262    /// Extract a subgraph containing specified nodes.
263    pub fn extract_subgraph(&self, node_ids: &[NodeId]) -> Result<SubGraph> {
264        let node_set: HashSet<NodeId> = node_ids.iter().cloned().collect();
265
266        // Validate all node IDs exist
267        for &node_id in node_ids {
268            if node_id >= self.nodes.len() {
269                return Err(anyhow!("Node ID {} does not exist", node_id));
270            }
271        }
272
273        // Create mapping from old node IDs to new node IDs
274        let mut id_mapping = HashMap::new();
275        let mut subgraph_nodes = Vec::new();
276
277        for (new_id, &old_id) in node_ids.iter().enumerate() {
278            id_mapping.insert(old_id, new_id);
279            let mut node = self.nodes[old_id].clone();
280            node.id = new_id;
281            subgraph_nodes.push(node);
282        }
283
284        // Extract relevant edges
285        let mut subgraph_edges = Vec::new();
286        for edge in &self.edges {
287            if node_set.contains(&edge.from_node) && node_set.contains(&edge.to_node) {
288                let mut new_edge = edge.clone();
289                new_edge.from_node = id_mapping[&edge.from_node];
290                new_edge.to_node = id_mapping[&edge.to_node];
291                subgraph_edges.push(new_edge);
292            }
293        }
294
295        // Determine subgraph inputs and outputs
296        let mut subgraph_inputs = HashSet::new();
297        let mut subgraph_outputs = HashSet::new();
298
299        for node in &subgraph_nodes {
300            // Inputs that don't come from within the subgraph are external inputs
301            for input in &node.inputs {
302                let mut is_external = true;
303                for other_node in &subgraph_nodes {
304                    if other_node.outputs.contains(input) {
305                        is_external = false;
306                        break;
307                    }
308                }
309                if is_external {
310                    subgraph_inputs.insert(input.clone());
311                }
312            }
313
314            // Outputs that don't go to nodes within the subgraph are external outputs
315            for output in &node.outputs {
316                let mut is_external = true;
317                for other_node in &subgraph_nodes {
318                    if other_node.inputs.contains(output) {
319                        is_external = false;
320                        break;
321                    }
322                }
323                if is_external {
324                    subgraph_outputs.insert(output.clone());
325                }
326            }
327        }
328
329        Ok(SubGraph {
330            nodes: subgraph_nodes,
331            edges: subgraph_edges,
332            inputs: subgraph_inputs.into_iter().collect(),
333            outputs: subgraph_outputs.into_iter().collect(),
334        })
335    }
336
337    /// Count nodes by operation type.
338    pub fn count_ops(&self) -> HashMap<String, usize> {
339        let mut counts = HashMap::new();
340        for node in &self.nodes {
341            *counts.entry(node.op_type.clone()).or_insert(0) += 1;
342        }
343        counts
344    }
345
346    /// Get graph statistics.
347    pub fn statistics(&self) -> GraphStatistics {
348        let node_count = self.nodes.len();
349        let edge_count = self.edges.len();
350        let op_counts = self.count_ops();
351        let input_count = self.inputs.len();
352        let output_count = self.outputs.len();
353
354        let depth = self.calculate_depth();
355
356        GraphStatistics {
357            node_count,
358            edge_count,
359            op_counts,
360            input_count,
361            output_count,
362            depth,
363        }
364    }
365
366    /// Calculate the maximum depth of the graph.
367    fn calculate_depth(&self) -> usize {
368        if let Ok(topo_order) = self.topological_sort() {
369            let mut depths = vec![0; self.nodes.len()];
370
371            for &node_id in &topo_order {
372                let (incoming, _) = self.get_node_edges(node_id);
373                if incoming.is_empty() {
374                    depths[node_id] = 0;
375                } else {
376                    let max_input_depth = incoming
377                        .iter()
378                        .map(|edge| depths[edge.from_node])
379                        .max()
380                        .unwrap_or(0);
381                    depths[node_id] = max_input_depth + 1;
382                }
383            }
384
385            depths.into_iter().max().unwrap_or(0)
386        } else {
387            0 // If there are cycles, depth is undefined
388        }
389    }
390}
391
392impl Default for ModelGraph {
393    fn default() -> Self {
394        Self::new()
395    }
396}
397
398/// Node state for cycle detection.
399#[derive(Debug, Clone, Copy, PartialEq, Eq)]
400enum NodeState {
401    Unvisited,
402    Visiting,
403    Visited,
404}
405
406/// Graph statistics.
407#[derive(Debug, Clone)]
408pub struct GraphStatistics {
409    /// Total number of nodes.
410    pub node_count: usize,
411    /// Total number of edges.
412    pub edge_count: usize,
413    /// Count of each operation type.
414    pub op_counts: HashMap<String, usize>,
415    /// Number of graph inputs.
416    pub input_count: usize,
417    /// Number of graph outputs.
418    pub output_count: usize,
419    /// Maximum depth of the graph.
420    pub depth: usize,
421}
422
423/// Graph builder for convenient graph construction.
424pub struct GraphBuilder {
425    graph: ModelGraph,
426}
427
428impl GraphBuilder {
429    /// Create a new graph builder.
430    pub fn new() -> Self {
431        Self {
432            graph: ModelGraph::new(),
433        }
434    }
435
436    /// Add a node with the given operation type.
437    pub fn add_op(&mut self, op_type: &str, name: Option<String>) -> NodeId {
438        let node = GraphNode {
439            id: 0, // Will be set by add_node
440            op_type: op_type.to_string(),
441            attributes: HashMap::new(),
442            inputs: Vec::new(),
443            outputs: Vec::new(),
444            name,
445        };
446        self.graph.add_node(node)
447    }
448
449    /// Add an input tensor to a node.
450    pub fn add_input(&mut self, node_id: NodeId, tensor_name: &str) -> &mut Self {
451        if let Some(node) = self.graph.get_node_mut(node_id) {
452            node.inputs.push(tensor_name.to_string());
453        }
454        self
455    }
456
457    /// Add an output tensor to a node.
458    pub fn add_output(&mut self, node_id: NodeId, tensor_name: &str) -> &mut Self {
459        if let Some(node) = self.graph.get_node_mut(node_id) {
460            node.outputs.push(tensor_name.to_string());
461        }
462        self
463    }
464
465    /// Add an attribute to a node.
466    pub fn add_attribute(
467        &mut self,
468        node_id: NodeId,
469        name: &str,
470        value: AttributeValue,
471    ) -> &mut Self {
472        if let Some(node) = self.graph.get_node_mut(node_id) {
473            node.attributes.insert(name.to_string(), value);
474        }
475        self
476    }
477
478    /// Connect two nodes with a tensor.
479    pub fn connect(
480        &mut self,
481        from_node: NodeId,
482        to_node: NodeId,
483        tensor_name: &str,
484    ) -> Result<&mut Self> {
485        let edge = GraphEdge {
486            from_node,
487            to_node,
488            tensor_name: tensor_name.to_string(),
489            tensor_shape: None,
490            tensor_dtype: crate::types::DataType::F32, // Default
491        };
492        self.graph.add_edge(edge)?;
493        Ok(self)
494    }
495
496    /// Set graph inputs.
497    pub fn set_inputs(&mut self, inputs: Vec<String>) -> &mut Self {
498        self.graph.inputs = inputs;
499        self
500    }
501
502    /// Set graph outputs.
503    pub fn set_outputs(&mut self, outputs: Vec<String>) -> &mut Self {
504        self.graph.outputs = outputs;
505        self
506    }
507
508    /// Build the final graph.
509    pub fn build(self) -> Result<ModelGraph> {
510        self.graph.validate()?;
511        Ok(self.graph)
512    }
513}
514
515impl Default for GraphBuilder {
516    fn default() -> Self {
517        Self::new()
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524    use crate::types::DataType;
525
526    #[test]
527    fn test_graph_creation() {
528        let mut graph = ModelGraph::new();
529        assert_eq!(graph.nodes.len(), 0);
530        assert_eq!(graph.edges.len(), 0);
531
532        let node = GraphNode {
533            id: 0,
534            op_type: "Conv".to_string(),
535            attributes: HashMap::new(),
536            inputs: vec!["input1".to_string()],
537            outputs: vec!["output1".to_string()],
538            name: Some("conv1".to_string()),
539        };
540
541        let node_id = graph.add_node(node);
542        assert_eq!(node_id, 0);
543        assert_eq!(graph.nodes.len(), 1);
544    }
545
546    #[test]
547    fn test_edge_addition() -> Result<()> {
548        let mut graph = ModelGraph::new();
549
550        // Add two nodes
551        let node1 = GraphNode {
552            id: 0,
553            op_type: "Input".to_string(),
554            attributes: HashMap::new(),
555            inputs: vec![],
556            outputs: vec!["tensor1".to_string()],
557            name: Some("input".to_string()),
558        };
559
560        let node2 = GraphNode {
561            id: 1,
562            op_type: "Conv".to_string(),
563            attributes: HashMap::new(),
564            inputs: vec!["tensor1".to_string()],
565            outputs: vec!["tensor2".to_string()],
566            name: Some("conv".to_string()),
567        };
568
569        let id1 = graph.add_node(node1);
570        let id2 = graph.add_node(node2);
571
572        let edge = GraphEdge {
573            from_node: id1,
574            to_node: id2,
575            tensor_name: "tensor1".to_string(),
576            tensor_shape: Some(vec![1, 3, 224, 224]),
577            tensor_dtype: DataType::F32,
578        };
579
580        graph.add_edge(edge)?;
581        assert_eq!(graph.edges.len(), 1);
582
583        Ok(())
584    }
585
586    #[test]
587    fn test_topological_sort() -> Result<()> {
588        let mut graph = ModelGraph::new();
589
590        // Create a simple linear graph: A -> B -> C
591        let node_a = GraphNode {
592            id: 0,
593            op_type: "Input".to_string(),
594            attributes: HashMap::new(),
595            inputs: vec![],
596            outputs: vec!["a_out".to_string()],
597            name: Some("A".to_string()),
598        };
599
600        let node_b = GraphNode {
601            id: 1,
602            op_type: "Conv".to_string(),
603            attributes: HashMap::new(),
604            inputs: vec!["a_out".to_string()],
605            outputs: vec!["b_out".to_string()],
606            name: Some("B".to_string()),
607        };
608
609        let node_c = GraphNode {
610            id: 2,
611            op_type: "ReLU".to_string(),
612            attributes: HashMap::new(),
613            inputs: vec!["b_out".to_string()],
614            outputs: vec!["c_out".to_string()],
615            name: Some("C".to_string()),
616        };
617
618        let id_a = graph.add_node(node_a);
619        let id_b = graph.add_node(node_b);
620        let id_c = graph.add_node(node_c);
621
622        graph.add_edge(GraphEdge {
623            from_node: id_a,
624            to_node: id_b,
625            tensor_name: "a_out".to_string(),
626            tensor_shape: None,
627            tensor_dtype: DataType::F32,
628        })?;
629
630        graph.add_edge(GraphEdge {
631            from_node: id_b,
632            to_node: id_c,
633            tensor_name: "b_out".to_string(),
634            tensor_shape: None,
635            tensor_dtype: DataType::F32,
636        })?;
637
638        let topo_order = graph.topological_sort()?;
639        assert_eq!(topo_order, vec![0, 1, 2]);
640
641        Ok(())
642    }
643
644    #[test]
645    fn test_graph_builder() -> Result<()> {
646        let mut builder = GraphBuilder::new();
647
648        let input_id = builder.add_op("Input", Some("input_layer".to_string()));
649        builder.add_output(input_id, "input_tensor");
650
651        let conv_id = builder.add_op("Conv", Some("conv_layer".to_string()));
652        builder
653            .add_input(conv_id, "input_tensor")
654            .add_output(conv_id, "conv_output")
655            .add_attribute(conv_id, "kernel_size", AttributeValue::IntArray(vec![3, 3]));
656
657        builder.connect(input_id, conv_id, "input_tensor")?;
658        builder
659            .set_inputs(vec!["input_tensor".to_string()])
660            .set_outputs(vec!["conv_output".to_string()]);
661
662        let graph = builder.build()?;
663        assert_eq!(graph.nodes.len(), 2);
664        assert_eq!(graph.edges.len(), 1);
665        assert_eq!(graph.inputs, vec!["input_tensor"]);
666        assert_eq!(graph.outputs, vec!["conv_output"]);
667
668        Ok(())
669    }
670
671    #[test]
672    fn test_cycle_detection() {
673        let mut graph = ModelGraph::new();
674
675        // Create a cycle: A -> B -> C -> A
676        let node_a = GraphNode {
677            id: 0,
678            op_type: "A".to_string(),
679            attributes: HashMap::new(),
680            inputs: vec!["c_out".to_string()],
681            outputs: vec!["a_out".to_string()],
682            name: Some("A".to_string()),
683        };
684
685        let node_b = GraphNode {
686            id: 1,
687            op_type: "B".to_string(),
688            attributes: HashMap::new(),
689            inputs: vec!["a_out".to_string()],
690            outputs: vec!["b_out".to_string()],
691            name: Some("B".to_string()),
692        };
693
694        let node_c = GraphNode {
695            id: 2,
696            op_type: "C".to_string(),
697            attributes: HashMap::new(),
698            inputs: vec!["b_out".to_string()],
699            outputs: vec!["c_out".to_string()],
700            name: Some("C".to_string()),
701        };
702
703        let id_a = graph.add_node(node_a);
704        let id_b = graph.add_node(node_b);
705        let id_c = graph.add_node(node_c);
706
707        // Add edges to form a cycle
708        graph
709            .add_edge(GraphEdge {
710                from_node: id_a,
711                to_node: id_b,
712                tensor_name: "a_out".to_string(),
713                tensor_shape: None,
714                tensor_dtype: DataType::F32,
715            })
716            .unwrap();
717
718        graph
719            .add_edge(GraphEdge {
720                from_node: id_b,
721                to_node: id_c,
722                tensor_name: "b_out".to_string(),
723                tensor_shape: None,
724                tensor_dtype: DataType::F32,
725            })
726            .unwrap();
727
728        graph
729            .add_edge(GraphEdge {
730                from_node: id_c,
731                to_node: id_a,
732                tensor_name: "c_out".to_string(),
733                tensor_shape: None,
734                tensor_dtype: DataType::F32,
735            })
736            .unwrap();
737
738        // This should fail validation due to the cycle
739        assert!(graph.validate().is_err());
740        assert!(graph.has_cycles().unwrap());
741    }
742
743    #[test]
744    fn test_subgraph_extraction() -> Result<()> {
745        let mut graph = ModelGraph::new();
746
747        // Create a graph: Input -> Conv1 -> Conv2 -> Output
748        let input_id = graph.add_node(GraphNode {
749            id: 0,
750            op_type: "Input".to_string(),
751            attributes: HashMap::new(),
752            inputs: vec![],
753            outputs: vec!["input_out".to_string()],
754            name: Some("input".to_string()),
755        });
756
757        let conv1_id = graph.add_node(GraphNode {
758            id: 1,
759            op_type: "Conv".to_string(),
760            attributes: HashMap::new(),
761            inputs: vec!["input_out".to_string()],
762            outputs: vec!["conv1_out".to_string()],
763            name: Some("conv1".to_string()),
764        });
765
766        let conv2_id = graph.add_node(GraphNode {
767            id: 2,
768            op_type: "Conv".to_string(),
769            attributes: HashMap::new(),
770            inputs: vec!["conv1_out".to_string()],
771            outputs: vec!["conv2_out".to_string()],
772            name: Some("conv2".to_string()),
773        });
774
775        // Add edges
776        graph.add_edge(GraphEdge {
777            from_node: input_id,
778            to_node: conv1_id,
779            tensor_name: "input_out".to_string(),
780            tensor_shape: None,
781            tensor_dtype: DataType::F32,
782        })?;
783
784        graph.add_edge(GraphEdge {
785            from_node: conv1_id,
786            to_node: conv2_id,
787            tensor_name: "conv1_out".to_string(),
788            tensor_shape: None,
789            tensor_dtype: DataType::F32,
790        })?;
791
792        // Extract subgraph containing just the conv layers
793        let subgraph = graph.extract_subgraph(&[conv1_id, conv2_id])?;
794
795        assert_eq!(subgraph.nodes.len(), 2);
796        assert_eq!(subgraph.edges.len(), 1);
797        assert_eq!(subgraph.inputs, vec!["input_out"]);
798        assert_eq!(subgraph.outputs, vec!["conv2_out"]);
799
800        Ok(())
801    }
802}