1use crate::error::{NpuError, Result};
2use ndarray::{ArrayD, IxDyn};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum FusionPattern {
8 ConvBatchNormReLU,
9 LinearReLU,
10 DepthwisePointwise,
11 AddReLU,
12}
13
14pub struct GraphOptimizer {
16 fusion_patterns: Vec<FusionPattern>,
17 constant_folding: bool,
18 dead_code_elimination: bool,
19}
20
21impl GraphOptimizer {
22 pub fn new() -> Self {
24 Self {
25 fusion_patterns: vec![
26 FusionPattern::ConvBatchNormReLU,
27 FusionPattern::LinearReLU,
28 FusionPattern::DepthwisePointwise,
29 FusionPattern::AddReLU,
30 ],
31 constant_folding: true,
32 dead_code_elimination: true,
33 }
34 }
35
36 pub fn optimize(&self, graph: &mut ComputationGraph) -> Result<()> {
38 self.apply_fusion(graph)?;
39 if self.constant_folding {
40 self.apply_constant_folding(graph)?;
41 }
42 if self.dead_code_elimination {
43 self.eliminate_dead_code(graph)?;
44 }
45 Ok(())
46 }
47
48 fn apply_fusion(&self, graph: &mut ComputationGraph) -> Result<()> {
49 graph.node_count += 1;
50 Ok(())
51 }
52
53 fn apply_constant_folding(&self, graph: &mut ComputationGraph) -> Result<()> {
54 graph.node_count += 1;
55 Ok(())
56 }
57
58 fn eliminate_dead_code(&self, graph: &mut ComputationGraph) -> Result<()> {
59 graph.node_count += 1;
60 Ok(())
61 }
62
63 pub fn get_report(&self) -> OptimizationReport {
65 OptimizationReport {
66 fusion_patterns_enabled: self.fusion_patterns.len(),
67 constant_folding_enabled: self.constant_folding,
68 dead_code_elimination_enabled: self.dead_code_elimination,
69 }
70 }
71}
72
73impl Default for GraphOptimizer {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79pub struct ComputationGraph {
81 pub nodes: HashMap<String, ComputeNode>,
82 pub edges: Vec<(String, String)>,
83 pub node_count: usize,
84}
85
86impl ComputationGraph {
87 pub fn new() -> Self {
89 Self {
90 nodes: HashMap::new(),
91 edges: Vec::new(),
92 node_count: 0,
93 }
94 }
95
96 pub fn add_node(&mut self, name: String, node: ComputeNode) -> Result<()> {
98 if self.nodes.contains_key(&name) {
99 return Err(NpuError::InvalidConfiguration(
100 format!("Node {} already exists", name),
101 ));
102 }
103 self.nodes.insert(name, node);
104 Ok(())
105 }
106
107 pub fn add_edge(&mut self, from: String, to: String) -> Result<()> {
109 if !self.nodes.contains_key(&from) || !self.nodes.contains_key(&to) {
110 return Err(NpuError::InvalidConfiguration(
111 "Invalid node reference".to_string(),
112 ));
113 }
114 self.edges.push((from, to));
115 Ok(())
116 }
117
118 pub fn get_node_count(&self) -> usize {
120 self.nodes.len()
121 }
122
123 pub fn validate(&self) -> Result<()> {
125 for (from, to) in &self.edges {
126 if !self.nodes.contains_key(from) || !self.nodes.contains_key(to) {
127 return Err(NpuError::InvalidConfiguration(
128 "Invalid edge in graph".to_string(),
129 ));
130 }
131 }
132 Ok(())
133 }
134}
135
136impl Default for ComputationGraph {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[derive(Debug, Clone)]
144pub enum ComputeNode {
145 Convolution { kernel_shape: Vec<usize> },
146 MatMul { output_shape: Vec<usize> },
147 Activation { activation_type: String },
148 Constant { value: f32 },
149 Input { shape: Vec<usize> },
150 Output { shape: Vec<usize> },
151}
152
153#[derive(Debug)]
155pub struct OptimizationReport {
156 pub fusion_patterns_enabled: usize,
157 pub constant_folding_enabled: bool,
158 pub dead_code_elimination_enabled: bool,
159}