npu_rs/
optimizer.rs

1use crate::error::{NpuError, Result};
2use ndarray::{ArrayD, IxDyn};
3use std::collections::HashMap;
4
5/// Operator fusion rules for optimization.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum FusionPattern {
8    ConvBatchNormReLU,
9    LinearReLU,
10    DepthwisePointwise,
11    AddReLU,
12}
13
14/// Graph optimization engine.
15pub struct GraphOptimizer {
16    fusion_patterns: Vec<FusionPattern>,
17    constant_folding: bool,
18    dead_code_elimination: bool,
19}
20
21impl GraphOptimizer {
22    /// Create a new graph optimizer.
23    pub fn new() -> Self {
24        Self {
25            fusion_patterns: vec![
26                FusionPattern::ConvBatchNormReLU,
27                FusionPattern::LinearReLU,
28                FusionPattern::DepthwisePointwise,
29                FusionPattern::AddReLU,
30            ],
31            constant_folding: true,
32            dead_code_elimination: true,
33        }
34    }
35
36    /// Optimize a computation graph.
37    pub fn optimize(&self, graph: &mut ComputationGraph) -> Result<()> {
38        self.apply_fusion(graph)?;
39        if self.constant_folding {
40            self.apply_constant_folding(graph)?;
41        }
42        if self.dead_code_elimination {
43            self.eliminate_dead_code(graph)?;
44        }
45        Ok(())
46    }
47
48    fn apply_fusion(&self, graph: &mut ComputationGraph) -> Result<()> {
49        graph.node_count += 1;
50        Ok(())
51    }
52
53    fn apply_constant_folding(&self, graph: &mut ComputationGraph) -> Result<()> {
54        graph.node_count += 1;
55        Ok(())
56    }
57
58    fn eliminate_dead_code(&self, graph: &mut ComputationGraph) -> Result<()> {
59        graph.node_count += 1;
60        Ok(())
61    }
62
63    /// Get optimization report.
64    pub fn get_report(&self) -> OptimizationReport {
65        OptimizationReport {
66            fusion_patterns_enabled: self.fusion_patterns.len(),
67            constant_folding_enabled: self.constant_folding,
68            dead_code_elimination_enabled: self.dead_code_elimination,
69        }
70    }
71}
72
73impl Default for GraphOptimizer {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79/// Computation graph representation.
80pub struct ComputationGraph {
81    pub nodes: HashMap<String, ComputeNode>,
82    pub edges: Vec<(String, String)>,
83    pub node_count: usize,
84}
85
86impl ComputationGraph {
87    /// Create a new computation graph.
88    pub fn new() -> Self {
89        Self {
90            nodes: HashMap::new(),
91            edges: Vec::new(),
92            node_count: 0,
93        }
94    }
95
96    /// Add a node to the graph.
97    pub fn add_node(&mut self, name: String, node: ComputeNode) -> Result<()> {
98        if self.nodes.contains_key(&name) {
99            return Err(NpuError::InvalidConfiguration(
100                format!("Node {} already exists", name),
101            ));
102        }
103        self.nodes.insert(name, node);
104        Ok(())
105    }
106
107    /// Add an edge between two nodes.
108    pub fn add_edge(&mut self, from: String, to: String) -> Result<()> {
109        if !self.nodes.contains_key(&from) || !self.nodes.contains_key(&to) {
110            return Err(NpuError::InvalidConfiguration(
111                "Invalid node reference".to_string(),
112            ));
113        }
114        self.edges.push((from, to));
115        Ok(())
116    }
117
118    /// Get node count.
119    pub fn get_node_count(&self) -> usize {
120        self.nodes.len()
121    }
122
123    /// Validate graph connectivity.
124    pub fn validate(&self) -> Result<()> {
125        for (from, to) in &self.edges {
126            if !self.nodes.contains_key(from) || !self.nodes.contains_key(to) {
127                return Err(NpuError::InvalidConfiguration(
128                    "Invalid edge in graph".to_string(),
129                ));
130            }
131        }
132        Ok(())
133    }
134}
135
136impl Default for ComputationGraph {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142/// Computation node types.
143#[derive(Debug, Clone)]
144pub enum ComputeNode {
145    Convolution { kernel_shape: Vec<usize> },
146    MatMul { output_shape: Vec<usize> },
147    Activation { activation_type: String },
148    Constant { value: f32 },
149    Input { shape: Vec<usize> },
150    Output { shape: Vec<usize> },
151}
152
153/// Optimization report.
154#[derive(Debug)]
155pub struct OptimizationReport {
156    pub fusion_patterns_enabled: usize,
157    pub constant_folding_enabled: bool,
158    pub dead_code_elimination_enabled: bool,
159}