1use super::tensor::Tensor;
4use super::model::*;
5use std::io::Read;
6use std::time::Instant;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub enum Device {
11 CPU,
12 GPUCompute,
13}
14
15pub struct InferenceEngine {
17 pub model: Model,
18 pub device: Device,
19 pub stats: InferenceStats,
20}
21
22#[derive(Debug, Clone, Default)]
24pub struct InferenceStats {
25 pub latency_ms: f64,
26 pub memory_bytes: usize,
27 pub flops: usize,
28}
29
30impl InferenceEngine {
31 pub fn new(model: Model, device: Device) -> Self {
32 Self {
33 model,
34 device,
35 stats: InferenceStats::default(),
36 }
37 }
38
39 pub fn infer(&mut self, input: &Tensor) -> Tensor {
41 let start = Instant::now();
42 let result = self.model.forward(input);
43 let elapsed = start.elapsed();
44 self.stats.latency_ms = elapsed.as_secs_f64() * 1000.0;
45 self.stats.memory_bytes = result.data.len() * 4 + input.data.len() * 4;
46 self.stats.flops = self.estimate_flops(input);
47 result
48 }
49
50 pub fn batch_infer(&mut self, inputs: &[Tensor]) -> Vec<Tensor> {
52 let start = Instant::now();
53 let results: Vec<Tensor> = inputs.iter().map(|inp| self.model.forward(inp)).collect();
54 let elapsed = start.elapsed();
55 self.stats.latency_ms = elapsed.as_secs_f64() * 1000.0;
56 self.stats.memory_bytes = results.iter().map(|r| r.data.len() * 4).sum::<usize>()
57 + inputs.iter().map(|i| i.data.len() * 4).sum::<usize>();
58 self.stats.flops = inputs.iter().map(|i| self.estimate_flops(i)).sum();
59 results
60 }
61
62 pub fn warm_up(&mut self, input_shape: Vec<usize>, runs: usize) {
64 let dummy = Tensor::zeros(input_shape);
65 for _ in 0..runs {
66 let _ = self.model.forward(&dummy);
67 }
68 }
69
70 fn estimate_flops(&self, input: &Tensor) -> usize {
72 let mut flops = 0usize;
73 let mut current_size: usize = input.data.len();
74 for layer in &self.model.layers {
75 match layer {
76 Layer::Dense(d) => {
77 let m = current_size / d.weights.shape[0];
78 let k = d.weights.shape[0];
79 let n = d.weights.shape[1];
80 flops += 2 * m * k * n;
81 current_size = m * n;
82 }
83 Layer::Conv2D(c) => {
84 let c_out = c.filters.shape[0];
85 let c_in = c.filters.shape[1];
86 let kh = c.filters.shape[2];
87 let kw = c.filters.shape[3];
88 flops += current_size * c_out * kh * kw * 2 / c_in.max(1);
90 }
91 Layer::Attention(a) => {
92 flops += 4 * a.d_model * a.d_model * 2;
94 }
95 _ => {
96 flops += current_size;
98 }
99 }
100 }
101 flops
102 }
103}
104
105#[derive(Debug, Clone)]
109enum OnnxOp {
110 Gemm { transA: bool, transB: bool, alpha: f32, beta: f32 },
111 Conv { strides: Vec<usize>, pads: Vec<usize> },
112 Relu,
113 MaxPool { kernel_shape: Vec<usize>, strides: Vec<usize> },
114 BatchNorm { eps: f32 },
115 Reshape,
116 Softmax { axis: i32 },
117 Add,
118 Mul,
119}
120
121#[derive(Debug, Clone)]
123struct OnnxNode {
124 op: OnnxOp,
125 inputs: Vec<String>,
126 outputs: Vec<String>,
127}
128
129pub struct OnnxLoader;
131
132impl OnnxLoader {
133 pub fn load_onnx(path: &str) -> Result<Model, String> {
144 let mut file = std::fs::File::open(path).map_err(|e| format!("cannot open {path}: {e}"))?;
145 let mut buf4 = [0u8; 4];
146 let mut buf1 = [0u8; 1];
147
148 file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
150 if &buf4 != b"ONNX" {
151 return Err("invalid ONNX magic".into());
152 }
153
154 file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
156 let num_nodes = u32::from_le_bytes(buf4) as usize;
157
158 let mut layers = Vec::new();
159
160 for _ in 0..num_nodes {
161 file.read_exact(&mut buf1).map_err(|e| e.to_string())?;
162 let op_type = buf1[0];
163
164 file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
165 let num_weights = u32::from_le_bytes(buf4) as usize;
166
167 let mut tensors = Vec::new();
168 for _ in 0..num_weights {
169 file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
170 let ndim = u32::from_le_bytes(buf4) as usize;
171 let mut shape = Vec::with_capacity(ndim);
172 for _ in 0..ndim {
173 file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
174 shape.push(u32::from_le_bytes(buf4) as usize);
175 }
176 let n: usize = shape.iter().product();
177 let mut data = Vec::with_capacity(n);
178 for _ in 0..n {
179 file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
180 data.push(f32::from_le_bytes(buf4));
181 }
182 tensors.push(Tensor { shape, data });
183 }
184
185 let layer = match op_type {
186 0 => {
187 if tensors.len() >= 2 {
189 Layer::Dense(DenseLayer {
190 weights: tensors[0].clone(),
191 bias: tensors[1].clone(),
192 })
193 } else {
194 return Err("Gemm requires 2 weight tensors".into());
195 }
196 }
197 1 => {
198 if tensors.len() >= 2 {
200 Layer::Conv2D(Conv2DLayer {
201 filters: tensors[0].clone(),
202 bias: tensors[1].clone(),
203 stride: 1,
204 padding: 0,
205 })
206 } else {
207 return Err("Conv requires 2 weight tensors".into());
208 }
209 }
210 2 => Layer::ReLU,
211 3 => Layer::MaxPool(MaxPoolLayer { kernel_size: 2, stride: 2 }),
212 4 => {
213 if tensors.len() >= 4 {
215 Layer::BatchNorm(BatchNormLayer {
216 gamma: tensors[0].clone(),
217 beta: tensors[1].clone(),
218 running_mean: tensors[2].clone(),
219 running_var: tensors[3].clone(),
220 eps: 1e-5,
221 })
222 } else {
223 return Err("BatchNorm requires 4 tensors".into());
224 }
225 }
226 5 => Layer::Flatten, 6 => Layer::Softmax(0),
228 7 | 8 => {
229 Layer::ReLU }
233 _ => return Err(format!("unknown op type {op_type}")),
234 };
235 layers.push(layer);
236 }
237
238 Ok(Model { layers, name: "onnx_model".to_string() })
239 }
240
241 pub fn save_onnx(model: &Model, path: &str) -> Result<(), String> {
243 use std::io::Write;
244 let mut file = std::fs::File::create(path).map_err(|e| e.to_string())?;
245 file.write_all(b"ONNX").map_err(|e| e.to_string())?;
246 let num_nodes = model.layers.len() as u32;
247 file.write_all(&num_nodes.to_le_bytes()).map_err(|e| e.to_string())?;
248
249 for layer in &model.layers {
250 let (op_type, tensors): (u8, Vec<&Tensor>) = match layer {
251 Layer::Dense(l) => (0, vec![&l.weights, &l.bias]),
252 Layer::Conv2D(l) => (1, vec![&l.filters, &l.bias]),
253 Layer::ReLU => (2, vec![]),
254 Layer::MaxPool(_) => (3, vec![]),
255 Layer::BatchNorm(l) => (4, vec![&l.gamma, &l.beta, &l.running_mean, &l.running_var]),
256 Layer::Flatten => (5, vec![]),
257 Layer::Softmax(_) => (6, vec![]),
258 _ => (2, vec![]), };
260 file.write_all(&[op_type]).map_err(|e| e.to_string())?;
261 let nw = tensors.len() as u32;
262 file.write_all(&nw.to_le_bytes()).map_err(|e| e.to_string())?;
263 for t in tensors {
264 let ndim = t.shape.len() as u32;
265 file.write_all(&ndim.to_le_bytes()).map_err(|e| e.to_string())?;
266 for &d in &t.shape {
267 file.write_all(&(d as u32).to_le_bytes()).map_err(|e| e.to_string())?;
268 }
269 for &v in &t.data {
270 file.write_all(&v.to_le_bytes()).map_err(|e| e.to_string())?;
271 }
272 }
273 }
274 Ok(())
275 }
276}
277
278pub fn quantize_model(model: &Model, bits: u32) -> Model {
283 let max_val = (1 << (bits - 1)) as f32 - 1.0;
284 let min_val = -max_val - 1.0;
285
286 let mut new_layers = Vec::new();
287 for layer in &model.layers {
288 let new_layer = match layer {
289 Layer::Dense(l) => {
290 let (qw, qb) = (quantize_tensor(&l.weights, min_val, max_val),
291 quantize_tensor(&l.bias, min_val, max_val));
292 Layer::Dense(DenseLayer { weights: qw, bias: qb })
293 }
294 Layer::Conv2D(l) => {
295 let qf = quantize_tensor(&l.filters, min_val, max_val);
296 let qb = quantize_tensor(&l.bias, min_val, max_val);
297 Layer::Conv2D(Conv2DLayer { filters: qf, bias: qb, stride: l.stride, padding: l.padding })
298 }
299 other => other.clone(),
300 };
301 new_layers.push(new_layer);
302 }
303 Model { layers: new_layers, name: format!("{}_q{}", model.name, bits) }
304}
305
306fn quantize_tensor(t: &Tensor, min_val: f32, max_val: f32) -> Tensor {
307 let abs_max = t.data.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
308 if abs_max == 0.0 {
309 return t.clone();
310 }
311 let scale = max_val / abs_max;
312 let inv_scale = abs_max / max_val;
313 let data: Vec<f32> = t.data.iter().map(|&v| {
314 let q = (v * scale).round().clamp(min_val, max_val);
315 q * inv_scale
316 }).collect();
317 Tensor { shape: t.shape.clone(), data }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_infer() {
326 let model = Sequential::new("test")
327 .dense(4, 3)
328 .relu()
329 .build();
330 let mut engine = InferenceEngine::new(model, Device::CPU);
331 let input = Tensor::ones(vec![1, 4]);
332 let out = engine.infer(&input);
333 assert_eq!(out.shape, vec![1, 3]);
334 assert!(engine.stats.latency_ms >= 0.0);
335 }
336
337 #[test]
338 fn test_batch_infer() {
339 let model = Sequential::new("test")
340 .dense(3, 2)
341 .build();
342 let mut engine = InferenceEngine::new(model, Device::CPU);
343 let inputs = vec![
344 Tensor::ones(vec![1, 3]),
345 Tensor::zeros(vec![1, 3]),
346 ];
347 let outputs = engine.batch_infer(&inputs);
348 assert_eq!(outputs.len(), 2);
349 assert_eq!(outputs[0].shape, vec![1, 2]);
350 assert_eq!(outputs[1].shape, vec![1, 2]);
351 }
352
353 #[test]
354 fn test_warm_up() {
355 let model = Sequential::new("test").dense(4, 2).build();
356 let mut engine = InferenceEngine::new(model, Device::CPU);
357 engine.warm_up(vec![1, 4], 5);
358 }
360
361 #[test]
362 fn test_quantize_model() {
363 let model = Sequential::new("test")
364 .dense(4, 3)
365 .relu()
366 .build();
367 let qmodel = quantize_model(&model, 8);
368 assert!(qmodel.name.contains("q8"));
369 let input = Tensor::ones(vec![1, 4]);
371 let out = qmodel.forward(&input);
372 assert_eq!(out.shape, vec![1, 3]);
373 }
374
375 #[test]
376 fn test_onnx_save_load_roundtrip() {
377 let model = Sequential::new("onnx_test")
378 .dense(4, 3)
379 .relu()
380 .dense(3, 2)
381 .softmax()
382 .build();
383
384 let path = std::env::temp_dir().join("proof_engine_test.onnx");
385 let path_str = path.to_str().unwrap();
386
387 OnnxLoader::save_onnx(&model, path_str).unwrap();
388 let loaded = OnnxLoader::load_onnx(path_str).unwrap();
389
390 assert_eq!(loaded.layers.len(), model.layers.len());
391
392 if let (Layer::Dense(orig), Layer::Dense(loaded_l)) = (&model.layers[0], &loaded.layers[0]) {
394 assert_eq!(orig.weights.data, loaded_l.weights.data);
395 }
396
397 let _ = std::fs::remove_file(path);
398 }
399
400 #[test]
401 fn test_onnx_load_bad_magic() {
402 let path = std::env::temp_dir().join("proof_engine_bad.onnx");
403 std::fs::write(&path, b"NOPE1234").unwrap();
404 let result = OnnxLoader::load_onnx(path.to_str().unwrap());
405 assert!(result.is_err());
406 let _ = std::fs::remove_file(path);
407 }
408
409 #[test]
410 fn test_inference_stats() {
411 let model = Sequential::new("s").dense(2, 2).build();
412 let mut engine = InferenceEngine::new(model, Device::CPU);
413 let _ = engine.infer(&Tensor::ones(vec![1, 2]));
414 assert!(engine.stats.flops > 0);
415 assert!(engine.stats.memory_bytes > 0);
416 }
417}