npu_rs/
model.rs

1use serde::{Serialize, Deserialize};
2use crate::error::Result;
3
4/// Quantization format for models.
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum QuantFormat {
7    Float32,
8    Float16,
9    Int8,
10    Int4,
11}
12
13/// Optimization level.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum OptimizationLevel {
16    None,
17    O1,
18    O2,
19    O3,
20}
21
22/// Model configuration.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ModelConfig {
25    pub name: String,
26    pub input_shape: Vec<usize>,
27    pub output_shape: Vec<usize>,
28    pub quant_format: QuantFormat,
29    pub optimization_level: OptimizationLevel,
30    pub use_cache: bool,
31}
32
33impl Default for ModelConfig {
34    fn default() -> Self {
35        Self {
36            name: "default_model".to_string(),
37            input_shape: vec![1, 224, 224, 3],
38            output_shape: vec![1, 1000],
39            quant_format: QuantFormat::Float32,
40            optimization_level: OptimizationLevel::O2,
41            use_cache: true,
42        }
43    }
44}
45
46/// Model runtime for executing inference.
47pub struct ModelRuntime {
48    config: ModelConfig,
49}
50
51impl ModelRuntime {
52    /// Create a new model runtime.
53    pub fn new(config: ModelConfig) -> Self {
54        Self { config }
55    }
56
57    /// Load model from path.
58    pub fn load_from_path(_path: &str) -> Result<Self> {
59        let config = ModelConfig::default();
60        Ok(Self::new(config))
61    }
62
63    /// Get model configuration.
64    pub fn get_config(&self) -> &ModelConfig {
65        &self.config
66    }
67
68    /// Get input shape.
69    pub fn input_shape(&self) -> &[usize] {
70        &self.config.input_shape
71    }
72
73    /// Get output shape.
74    pub fn output_shape(&self) -> &[usize] {
75        &self.config.output_shape
76    }
77
78    /// Validate input dimensions.
79    pub fn validate_input(&self, shape: &[usize]) -> Result<()> {
80        if shape == self.config.input_shape {
81            Ok(())
82        } else {
83            Err(crate::error::NpuError::InvalidShape(
84                format!(
85                    "Input shape mismatch: {:?} != {:?}",
86                    shape, self.config.input_shape
87                ),
88            ))
89        }
90    }
91}
92
93/// Layer types supported by the NPU.
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum LayerType {
96    FullyConnected,
97    Convolution,
98    DepthwiseConvolution,
99    PointwiseConvolution,
100    Activation,
101    BatchNorm,
102    Pooling,
103    Concat,
104    Add,
105}
106
107/// Layer definition.
108#[derive(Debug, Clone)]
109pub struct Layer {
110    pub name: String,
111    pub layer_type: LayerType,
112    pub input_shape: Vec<usize>,
113    pub output_shape: Vec<usize>,
114}
115
116impl Layer {
117    /// Create a new layer.
118    pub fn new(name: String, layer_type: LayerType, input_shape: Vec<usize>, output_shape: Vec<usize>) -> Self {
119        Self {
120            name,
121            layer_type,
122            input_shape,
123            output_shape,
124        }
125    }
126
127    /// Estimate TOPS for this layer.
128    pub fn estimate_tops(&self) -> f32 {
129        match self.layer_type {
130            LayerType::FullyConnected => {
131                if self.input_shape.len() >= 2 && self.output_shape.len() >= 1 {
132                    let m = self.input_shape[0];
133                    let k = self.input_shape[1];
134                    let n = self.output_shape[1];
135                    (2 * m * k * n) as f32 / 1e12
136                } else {
137                    0.0
138                }
139            }
140            LayerType::Convolution => {
141                if self.input_shape.len() >= 3 && self.output_shape.len() >= 3 {
142                    let batch = self.input_shape[0];
143                    let h = self.input_shape[1];
144                    let w = self.input_shape[2];
145                    let c_in = self.input_shape[3];
146                    let c_out = self.output_shape[3];
147                    (2 * batch * h * w * c_in * c_out) as f32 / 1e12
148                } else {
149                    0.0
150                }
151            }
152            _ => 0.0,
153        }
154    }
155}
156
157/// Neural network model graph.
158pub struct NeuralNetwork {
159    name: String,
160    layers: Vec<Layer>,
161}
162
163impl NeuralNetwork {
164    /// Create a new neural network.
165    pub fn new(name: String) -> Self {
166        Self {
167            name,
168            layers: Vec::new(),
169        }
170    }
171
172    /// Add a layer to the network.
173    pub fn add_layer(&mut self, layer: Layer) {
174        self.layers.push(layer);
175    }
176
177    /// Get all layers.
178    pub fn get_layers(&self) -> &[Layer] {
179        &self.layers
180    }
181
182    /// Compute total estimated TOPS.
183    pub fn total_tops(&self) -> f32 {
184        self.layers.iter().map(|l| l.estimate_tops()).sum()
185    }
186
187    /// Get network name.
188    pub fn name(&self) -> &str {
189        &self.name
190    }
191
192    /// Get layer count.
193    pub fn layer_count(&self) -> usize {
194        self.layers.len()
195    }
196}