1use crate::types::{AttributeValue, GraphEdge, GraphNode, ModelGraph, NodeId, SubGraph};
7use anyhow::{Result, anyhow};
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl ModelGraph {
11 pub fn new() -> Self {
13 Self {
14 nodes: Vec::new(),
15 edges: Vec::new(),
16 inputs: Vec::new(),
17 outputs: Vec::new(),
18 metadata: HashMap::new(),
19 }
20 }
21
22 pub fn add_node(&mut self, mut node: GraphNode) -> NodeId {
30 let node_id = self.nodes.len();
31 node.id = node_id;
32 self.nodes.push(node);
33 node_id
34 }
35
36 pub fn add_edge(&mut self, edge: GraphEdge) -> Result<()> {
41 if edge.from_node >= self.nodes.len() || edge.to_node >= self.nodes.len() {
43 return Err(anyhow!("Edge references non-existent nodes"));
44 }
45 self.edges.push(edge);
46 Ok(())
47 }
48
49 pub fn get_node(&self, node_id: NodeId) -> Option<&GraphNode> {
51 self.nodes.get(node_id)
52 }
53
54 pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut GraphNode> {
56 self.nodes.get_mut(node_id)
57 }
58
59 pub fn nodes(&self) -> &[GraphNode] {
61 &self.nodes
62 }
63
64 pub fn nodes_mut(&mut self) -> &mut Vec<GraphNode> {
66 &mut self.nodes
67 }
68
69 pub fn node_count(&self) -> usize {
71 self.nodes.len()
72 }
73
74 pub fn from_nodes(nodes: Vec<GraphNode>) -> Self {
76 Self {
77 nodes,
78 edges: Vec::new(),
79 inputs: Vec::new(),
80 outputs: Vec::new(),
81 metadata: HashMap::new(),
82 }
83 }
84
85 pub fn find_nodes_by_op(&self, op_type: &str) -> Vec<NodeId> {
87 self.nodes
88 .iter()
89 .filter_map(|node| {
90 if node.op_type == op_type {
91 Some(node.id)
92 } else {
93 None
94 }
95 })
96 .collect()
97 }
98
99 pub fn get_node_edges(&self, node_id: NodeId) -> (Vec<&GraphEdge>, Vec<&GraphEdge>) {
101 let incoming: Vec<&GraphEdge> = self
102 .edges
103 .iter()
104 .filter(|edge| edge.to_node == node_id)
105 .collect();
106
107 let outgoing: Vec<&GraphEdge> = self
108 .edges
109 .iter()
110 .filter(|edge| edge.from_node == node_id)
111 .collect();
112
113 (incoming, outgoing)
114 }
115
116 pub fn validate(&self) -> Result<()> {
118 let mut seen_ids = HashSet::new();
120 for node in &self.nodes {
121 if !seen_ids.insert(node.id) {
122 return Err(anyhow!("Duplicate node ID: {}", node.id));
123 }
124 }
125
126 for edge in &self.edges {
128 if edge.from_node >= self.nodes.len() {
129 return Err(anyhow!(
130 "Edge references non-existent from_node: {}",
131 edge.from_node
132 ));
133 }
134 if edge.to_node >= self.nodes.len() {
135 return Err(anyhow!(
136 "Edge references non-existent to_node: {}",
137 edge.to_node
138 ));
139 }
140 }
141
142 if self.has_cycles()? {
144 return Err(anyhow!("Graph contains cycles"));
145 }
146
147 self.validate_input_output_tensors()?;
149
150 Ok(())
151 }
152
153 fn has_cycles(&self) -> Result<bool> {
155 let mut state = vec![NodeState::Unvisited; self.nodes.len()];
156
157 for node_id in 0..self.nodes.len() {
158 if state[node_id] == NodeState::Unvisited {
159 if self.has_cycles_dfs(node_id, &mut state)? {
160 return Ok(true);
161 }
162 }
163 }
164 Ok(false)
165 }
166
167 fn has_cycles_dfs(&self, node_id: NodeId, state: &mut Vec<NodeState>) -> Result<bool> {
168 state[node_id] = NodeState::Visiting;
169
170 let (_, outgoing) = self.get_node_edges(node_id);
171 for edge in outgoing {
172 match state[edge.to_node] {
173 NodeState::Visiting => return Ok(true), NodeState::Unvisited => {
175 if self.has_cycles_dfs(edge.to_node, state)? {
176 return Ok(true);
177 }
178 }
179 NodeState::Visited => {} }
181 }
182
183 state[node_id] = NodeState::Visited;
184 Ok(false)
185 }
186
187 fn validate_input_output_tensors(&self) -> Result<()> {
189 let mut all_tensor_names: HashSet<String> = HashSet::new();
190
191 for node in &self.nodes {
193 for input in &node.inputs {
194 all_tensor_names.insert(input.clone());
195 }
196 for output in &node.outputs {
197 all_tensor_names.insert(output.clone());
198 }
199 }
200
201 for input in &self.inputs {
203 if !all_tensor_names.contains(input) {
204 return Err(anyhow!("Graph input '{}' is not used by any node", input));
205 }
206 }
207
208 for output in &self.outputs {
210 if !all_tensor_names.contains(output) {
211 return Err(anyhow!(
212 "Graph output '{}' is not produced by any node",
213 output
214 ));
215 }
216 }
217
218 Ok(())
219 }
220
221 pub fn topological_sort(&self) -> Result<Vec<NodeId>> {
223 let mut in_degree = vec![0; self.nodes.len()];
224
225 for edge in &self.edges {
227 in_degree[edge.to_node] += 1;
228 }
229
230 let mut queue = VecDeque::new();
232 for (node_id, °ree) in in_degree.iter().enumerate() {
233 if degree == 0 {
234 queue.push_back(node_id);
235 }
236 }
237
238 let mut result = Vec::new();
239
240 while let Some(node_id) = queue.pop_front() {
241 result.push(node_id);
242
243 let (_, outgoing) = self.get_node_edges(node_id);
245 for edge in outgoing {
246 in_degree[edge.to_node] -= 1;
247 if in_degree[edge.to_node] == 0 {
248 queue.push_back(edge.to_node);
249 }
250 }
251 }
252
253 if result.len() != self.nodes.len() {
254 return Err(anyhow!(
255 "Graph contains cycles - cannot perform topological sort"
256 ));
257 }
258
259 Ok(result)
260 }
261
262 pub fn extract_subgraph(&self, node_ids: &[NodeId]) -> Result<SubGraph> {
264 let node_set: HashSet<NodeId> = node_ids.iter().cloned().collect();
265
266 for &node_id in node_ids {
268 if node_id >= self.nodes.len() {
269 return Err(anyhow!("Node ID {} does not exist", node_id));
270 }
271 }
272
273 let mut id_mapping = HashMap::new();
275 let mut subgraph_nodes = Vec::new();
276
277 for (new_id, &old_id) in node_ids.iter().enumerate() {
278 id_mapping.insert(old_id, new_id);
279 let mut node = self.nodes[old_id].clone();
280 node.id = new_id;
281 subgraph_nodes.push(node);
282 }
283
284 let mut subgraph_edges = Vec::new();
286 for edge in &self.edges {
287 if node_set.contains(&edge.from_node) && node_set.contains(&edge.to_node) {
288 let mut new_edge = edge.clone();
289 new_edge.from_node = id_mapping[&edge.from_node];
290 new_edge.to_node = id_mapping[&edge.to_node];
291 subgraph_edges.push(new_edge);
292 }
293 }
294
295 let mut subgraph_inputs = HashSet::new();
297 let mut subgraph_outputs = HashSet::new();
298
299 for node in &subgraph_nodes {
300 for input in &node.inputs {
302 let mut is_external = true;
303 for other_node in &subgraph_nodes {
304 if other_node.outputs.contains(input) {
305 is_external = false;
306 break;
307 }
308 }
309 if is_external {
310 subgraph_inputs.insert(input.clone());
311 }
312 }
313
314 for output in &node.outputs {
316 let mut is_external = true;
317 for other_node in &subgraph_nodes {
318 if other_node.inputs.contains(output) {
319 is_external = false;
320 break;
321 }
322 }
323 if is_external {
324 subgraph_outputs.insert(output.clone());
325 }
326 }
327 }
328
329 Ok(SubGraph {
330 nodes: subgraph_nodes,
331 edges: subgraph_edges,
332 inputs: subgraph_inputs.into_iter().collect(),
333 outputs: subgraph_outputs.into_iter().collect(),
334 })
335 }
336
337 pub fn count_ops(&self) -> HashMap<String, usize> {
339 let mut counts = HashMap::new();
340 for node in &self.nodes {
341 *counts.entry(node.op_type.clone()).or_insert(0) += 1;
342 }
343 counts
344 }
345
346 pub fn statistics(&self) -> GraphStatistics {
348 let node_count = self.nodes.len();
349 let edge_count = self.edges.len();
350 let op_counts = self.count_ops();
351 let input_count = self.inputs.len();
352 let output_count = self.outputs.len();
353
354 let depth = self.calculate_depth();
355
356 GraphStatistics {
357 node_count,
358 edge_count,
359 op_counts,
360 input_count,
361 output_count,
362 depth,
363 }
364 }
365
366 fn calculate_depth(&self) -> usize {
368 if let Ok(topo_order) = self.topological_sort() {
369 let mut depths = vec![0; self.nodes.len()];
370
371 for &node_id in &topo_order {
372 let (incoming, _) = self.get_node_edges(node_id);
373 if incoming.is_empty() {
374 depths[node_id] = 0;
375 } else {
376 let max_input_depth = incoming
377 .iter()
378 .map(|edge| depths[edge.from_node])
379 .max()
380 .unwrap_or(0);
381 depths[node_id] = max_input_depth + 1;
382 }
383 }
384
385 depths.into_iter().max().unwrap_or(0)
386 } else {
387 0 }
389 }
390}
391
392impl Default for ModelGraph {
393 fn default() -> Self {
394 Self::new()
395 }
396}
397
398#[derive(Debug, Clone, Copy, PartialEq, Eq)]
400enum NodeState {
401 Unvisited,
402 Visiting,
403 Visited,
404}
405
406#[derive(Debug, Clone)]
408pub struct GraphStatistics {
409 pub node_count: usize,
411 pub edge_count: usize,
413 pub op_counts: HashMap<String, usize>,
415 pub input_count: usize,
417 pub output_count: usize,
419 pub depth: usize,
421}
422
423pub struct GraphBuilder {
425 graph: ModelGraph,
426}
427
428impl GraphBuilder {
429 pub fn new() -> Self {
431 Self {
432 graph: ModelGraph::new(),
433 }
434 }
435
436 pub fn add_op(&mut self, op_type: &str, name: Option<String>) -> NodeId {
438 let node = GraphNode {
439 id: 0, op_type: op_type.to_string(),
441 attributes: HashMap::new(),
442 inputs: Vec::new(),
443 outputs: Vec::new(),
444 name,
445 };
446 self.graph.add_node(node)
447 }
448
449 pub fn add_input(&mut self, node_id: NodeId, tensor_name: &str) -> &mut Self {
451 if let Some(node) = self.graph.get_node_mut(node_id) {
452 node.inputs.push(tensor_name.to_string());
453 }
454 self
455 }
456
457 pub fn add_output(&mut self, node_id: NodeId, tensor_name: &str) -> &mut Self {
459 if let Some(node) = self.graph.get_node_mut(node_id) {
460 node.outputs.push(tensor_name.to_string());
461 }
462 self
463 }
464
465 pub fn add_attribute(
467 &mut self,
468 node_id: NodeId,
469 name: &str,
470 value: AttributeValue,
471 ) -> &mut Self {
472 if let Some(node) = self.graph.get_node_mut(node_id) {
473 node.attributes.insert(name.to_string(), value);
474 }
475 self
476 }
477
478 pub fn connect(
480 &mut self,
481 from_node: NodeId,
482 to_node: NodeId,
483 tensor_name: &str,
484 ) -> Result<&mut Self> {
485 let edge = GraphEdge {
486 from_node,
487 to_node,
488 tensor_name: tensor_name.to_string(),
489 tensor_shape: None,
490 tensor_dtype: crate::types::DataType::F32, };
492 self.graph.add_edge(edge)?;
493 Ok(self)
494 }
495
496 pub fn set_inputs(&mut self, inputs: Vec<String>) -> &mut Self {
498 self.graph.inputs = inputs;
499 self
500 }
501
502 pub fn set_outputs(&mut self, outputs: Vec<String>) -> &mut Self {
504 self.graph.outputs = outputs;
505 self
506 }
507
508 pub fn build(self) -> Result<ModelGraph> {
510 self.graph.validate()?;
511 Ok(self.graph)
512 }
513}
514
515impl Default for GraphBuilder {
516 fn default() -> Self {
517 Self::new()
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use crate::types::DataType;
525
526 #[test]
527 fn test_graph_creation() {
528 let mut graph = ModelGraph::new();
529 assert_eq!(graph.nodes.len(), 0);
530 assert_eq!(graph.edges.len(), 0);
531
532 let node = GraphNode {
533 id: 0,
534 op_type: "Conv".to_string(),
535 attributes: HashMap::new(),
536 inputs: vec!["input1".to_string()],
537 outputs: vec!["output1".to_string()],
538 name: Some("conv1".to_string()),
539 };
540
541 let node_id = graph.add_node(node);
542 assert_eq!(node_id, 0);
543 assert_eq!(graph.nodes.len(), 1);
544 }
545
546 #[test]
547 fn test_edge_addition() -> Result<()> {
548 let mut graph = ModelGraph::new();
549
550 let node1 = GraphNode {
552 id: 0,
553 op_type: "Input".to_string(),
554 attributes: HashMap::new(),
555 inputs: vec![],
556 outputs: vec!["tensor1".to_string()],
557 name: Some("input".to_string()),
558 };
559
560 let node2 = GraphNode {
561 id: 1,
562 op_type: "Conv".to_string(),
563 attributes: HashMap::new(),
564 inputs: vec!["tensor1".to_string()],
565 outputs: vec!["tensor2".to_string()],
566 name: Some("conv".to_string()),
567 };
568
569 let id1 = graph.add_node(node1);
570 let id2 = graph.add_node(node2);
571
572 let edge = GraphEdge {
573 from_node: id1,
574 to_node: id2,
575 tensor_name: "tensor1".to_string(),
576 tensor_shape: Some(vec![1, 3, 224, 224]),
577 tensor_dtype: DataType::F32,
578 };
579
580 graph.add_edge(edge)?;
581 assert_eq!(graph.edges.len(), 1);
582
583 Ok(())
584 }
585
586 #[test]
587 fn test_topological_sort() -> Result<()> {
588 let mut graph = ModelGraph::new();
589
590 let node_a = GraphNode {
592 id: 0,
593 op_type: "Input".to_string(),
594 attributes: HashMap::new(),
595 inputs: vec![],
596 outputs: vec!["a_out".to_string()],
597 name: Some("A".to_string()),
598 };
599
600 let node_b = GraphNode {
601 id: 1,
602 op_type: "Conv".to_string(),
603 attributes: HashMap::new(),
604 inputs: vec!["a_out".to_string()],
605 outputs: vec!["b_out".to_string()],
606 name: Some("B".to_string()),
607 };
608
609 let node_c = GraphNode {
610 id: 2,
611 op_type: "ReLU".to_string(),
612 attributes: HashMap::new(),
613 inputs: vec!["b_out".to_string()],
614 outputs: vec!["c_out".to_string()],
615 name: Some("C".to_string()),
616 };
617
618 let id_a = graph.add_node(node_a);
619 let id_b = graph.add_node(node_b);
620 let id_c = graph.add_node(node_c);
621
622 graph.add_edge(GraphEdge {
623 from_node: id_a,
624 to_node: id_b,
625 tensor_name: "a_out".to_string(),
626 tensor_shape: None,
627 tensor_dtype: DataType::F32,
628 })?;
629
630 graph.add_edge(GraphEdge {
631 from_node: id_b,
632 to_node: id_c,
633 tensor_name: "b_out".to_string(),
634 tensor_shape: None,
635 tensor_dtype: DataType::F32,
636 })?;
637
638 let topo_order = graph.topological_sort()?;
639 assert_eq!(topo_order, vec![0, 1, 2]);
640
641 Ok(())
642 }
643
644 #[test]
645 fn test_graph_builder() -> Result<()> {
646 let mut builder = GraphBuilder::new();
647
648 let input_id = builder.add_op("Input", Some("input_layer".to_string()));
649 builder.add_output(input_id, "input_tensor");
650
651 let conv_id = builder.add_op("Conv", Some("conv_layer".to_string()));
652 builder
653 .add_input(conv_id, "input_tensor")
654 .add_output(conv_id, "conv_output")
655 .add_attribute(conv_id, "kernel_size", AttributeValue::IntArray(vec![3, 3]));
656
657 builder.connect(input_id, conv_id, "input_tensor")?;
658 builder
659 .set_inputs(vec!["input_tensor".to_string()])
660 .set_outputs(vec!["conv_output".to_string()]);
661
662 let graph = builder.build()?;
663 assert_eq!(graph.nodes.len(), 2);
664 assert_eq!(graph.edges.len(), 1);
665 assert_eq!(graph.inputs, vec!["input_tensor"]);
666 assert_eq!(graph.outputs, vec!["conv_output"]);
667
668 Ok(())
669 }
670
671 #[test]
672 fn test_cycle_detection() {
673 let mut graph = ModelGraph::new();
674
675 let node_a = GraphNode {
677 id: 0,
678 op_type: "A".to_string(),
679 attributes: HashMap::new(),
680 inputs: vec!["c_out".to_string()],
681 outputs: vec!["a_out".to_string()],
682 name: Some("A".to_string()),
683 };
684
685 let node_b = GraphNode {
686 id: 1,
687 op_type: "B".to_string(),
688 attributes: HashMap::new(),
689 inputs: vec!["a_out".to_string()],
690 outputs: vec!["b_out".to_string()],
691 name: Some("B".to_string()),
692 };
693
694 let node_c = GraphNode {
695 id: 2,
696 op_type: "C".to_string(),
697 attributes: HashMap::new(),
698 inputs: vec!["b_out".to_string()],
699 outputs: vec!["c_out".to_string()],
700 name: Some("C".to_string()),
701 };
702
703 let id_a = graph.add_node(node_a);
704 let id_b = graph.add_node(node_b);
705 let id_c = graph.add_node(node_c);
706
707 graph
709 .add_edge(GraphEdge {
710 from_node: id_a,
711 to_node: id_b,
712 tensor_name: "a_out".to_string(),
713 tensor_shape: None,
714 tensor_dtype: DataType::F32,
715 })
716 .unwrap();
717
718 graph
719 .add_edge(GraphEdge {
720 from_node: id_b,
721 to_node: id_c,
722 tensor_name: "b_out".to_string(),
723 tensor_shape: None,
724 tensor_dtype: DataType::F32,
725 })
726 .unwrap();
727
728 graph
729 .add_edge(GraphEdge {
730 from_node: id_c,
731 to_node: id_a,
732 tensor_name: "c_out".to_string(),
733 tensor_shape: None,
734 tensor_dtype: DataType::F32,
735 })
736 .unwrap();
737
738 assert!(graph.validate().is_err());
740 assert!(graph.has_cycles().unwrap());
741 }
742
743 #[test]
744 fn test_subgraph_extraction() -> Result<()> {
745 let mut graph = ModelGraph::new();
746
747 let input_id = graph.add_node(GraphNode {
749 id: 0,
750 op_type: "Input".to_string(),
751 attributes: HashMap::new(),
752 inputs: vec![],
753 outputs: vec!["input_out".to_string()],
754 name: Some("input".to_string()),
755 });
756
757 let conv1_id = graph.add_node(GraphNode {
758 id: 1,
759 op_type: "Conv".to_string(),
760 attributes: HashMap::new(),
761 inputs: vec!["input_out".to_string()],
762 outputs: vec!["conv1_out".to_string()],
763 name: Some("conv1".to_string()),
764 });
765
766 let conv2_id = graph.add_node(GraphNode {
767 id: 2,
768 op_type: "Conv".to_string(),
769 attributes: HashMap::new(),
770 inputs: vec!["conv1_out".to_string()],
771 outputs: vec!["conv2_out".to_string()],
772 name: Some("conv2".to_string()),
773 });
774
775 graph.add_edge(GraphEdge {
777 from_node: input_id,
778 to_node: conv1_id,
779 tensor_name: "input_out".to_string(),
780 tensor_shape: None,
781 tensor_dtype: DataType::F32,
782 })?;
783
784 graph.add_edge(GraphEdge {
785 from_node: conv1_id,
786 to_node: conv2_id,
787 tensor_name: "conv1_out".to_string(),
788 tensor_shape: None,
789 tensor_dtype: DataType::F32,
790 })?;
791
792 let subgraph = graph.extract_subgraph(&[conv1_id, conv2_id])?;
794
795 assert_eq!(subgraph.nodes.len(), 2);
796 assert_eq!(subgraph.edges.len(), 1);
797 assert_eq!(subgraph.inputs, vec!["input_out"]);
798 assert_eq!(subgraph.outputs, vec!["conv2_out"]);
799
800 Ok(())
801 }
802}