1use serde::{Serialize, Deserialize};
2use crate::error::Result;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum QuantFormat {
7 Float32,
8 Float16,
9 Int8,
10 Int4,
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum OptimizationLevel {
16 None,
17 O1,
18 O2,
19 O3,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ModelConfig {
25 pub name: String,
26 pub input_shape: Vec<usize>,
27 pub output_shape: Vec<usize>,
28 pub quant_format: QuantFormat,
29 pub optimization_level: OptimizationLevel,
30 pub use_cache: bool,
31}
32
33impl Default for ModelConfig {
34 fn default() -> Self {
35 Self {
36 name: "default_model".to_string(),
37 input_shape: vec![1, 224, 224, 3],
38 output_shape: vec![1, 1000],
39 quant_format: QuantFormat::Float32,
40 optimization_level: OptimizationLevel::O2,
41 use_cache: true,
42 }
43 }
44}
45
46pub struct ModelRuntime {
48 config: ModelConfig,
49}
50
51impl ModelRuntime {
52 pub fn new(config: ModelConfig) -> Self {
54 Self { config }
55 }
56
57 pub fn load_from_path(_path: &str) -> Result<Self> {
59 let config = ModelConfig::default();
60 Ok(Self::new(config))
61 }
62
63 pub fn get_config(&self) -> &ModelConfig {
65 &self.config
66 }
67
68 pub fn input_shape(&self) -> &[usize] {
70 &self.config.input_shape
71 }
72
73 pub fn output_shape(&self) -> &[usize] {
75 &self.config.output_shape
76 }
77
78 pub fn validate_input(&self, shape: &[usize]) -> Result<()> {
80 if shape == self.config.input_shape {
81 Ok(())
82 } else {
83 Err(crate::error::NpuError::InvalidShape(
84 format!(
85 "Input shape mismatch: {:?} != {:?}",
86 shape, self.config.input_shape
87 ),
88 ))
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum LayerType {
96 FullyConnected,
97 Convolution,
98 DepthwiseConvolution,
99 PointwiseConvolution,
100 Activation,
101 BatchNorm,
102 Pooling,
103 Concat,
104 Add,
105}
106
107#[derive(Debug, Clone)]
109pub struct Layer {
110 pub name: String,
111 pub layer_type: LayerType,
112 pub input_shape: Vec<usize>,
113 pub output_shape: Vec<usize>,
114}
115
116impl Layer {
117 pub fn new(name: String, layer_type: LayerType, input_shape: Vec<usize>, output_shape: Vec<usize>) -> Self {
119 Self {
120 name,
121 layer_type,
122 input_shape,
123 output_shape,
124 }
125 }
126
127 pub fn estimate_tops(&self) -> f32 {
129 match self.layer_type {
130 LayerType::FullyConnected => {
131 if self.input_shape.len() >= 2 && self.output_shape.len() >= 1 {
132 let m = self.input_shape[0];
133 let k = self.input_shape[1];
134 let n = self.output_shape[1];
135 (2 * m * k * n) as f32 / 1e12
136 } else {
137 0.0
138 }
139 }
140 LayerType::Convolution => {
141 if self.input_shape.len() >= 3 && self.output_shape.len() >= 3 {
142 let batch = self.input_shape[0];
143 let h = self.input_shape[1];
144 let w = self.input_shape[2];
145 let c_in = self.input_shape[3];
146 let c_out = self.output_shape[3];
147 (2 * batch * h * w * c_in * c_out) as f32 / 1e12
148 } else {
149 0.0
150 }
151 }
152 _ => 0.0,
153 }
154 }
155}
156
157pub struct NeuralNetwork {
159 name: String,
160 layers: Vec<Layer>,
161}
162
163impl NeuralNetwork {
164 pub fn new(name: String) -> Self {
166 Self {
167 name,
168 layers: Vec::new(),
169 }
170 }
171
172 pub fn add_layer(&mut self, layer: Layer) {
174 self.layers.push(layer);
175 }
176
177 pub fn get_layers(&self) -> &[Layer] {
179 &self.layers
180 }
181
182 pub fn total_tops(&self) -> f32 {
184 self.layers.iter().map(|l| l.estimate_tops()).sum()
185 }
186
187 pub fn name(&self) -> &str {
189 &self.name
190 }
191
192 pub fn layer_count(&self) -> usize {
194 self.layers.len()
195 }
196}