Skip to main content

proof_engine/ml/
inference.rs

1//! Inference engine: run models, batch inference, ONNX loading, quantization.
2
3use super::tensor::Tensor;
4use super::model::*;
5use std::io::Read;
6use std::time::Instant;
7
8/// Compute device target.
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub enum Device {
11    CPU,
12    GPUCompute,
13}
14
15/// Inference engine wrapping a model and target device.
16pub struct InferenceEngine {
17    pub model: Model,
18    pub device: Device,
19    pub stats: InferenceStats,
20}
21
22/// Statistics from a forward pass.
23#[derive(Debug, Clone, Default)]
24pub struct InferenceStats {
25    pub latency_ms: f64,
26    pub memory_bytes: usize,
27    pub flops: usize,
28}
29
30impl InferenceEngine {
31    pub fn new(model: Model, device: Device) -> Self {
32        Self {
33            model,
34            device,
35            stats: InferenceStats::default(),
36        }
37    }
38
39    /// Run a single inference pass.
40    pub fn infer(&mut self, input: &Tensor) -> Tensor {
41        let start = Instant::now();
42        let result = self.model.forward(input);
43        let elapsed = start.elapsed();
44        self.stats.latency_ms = elapsed.as_secs_f64() * 1000.0;
45        self.stats.memory_bytes = result.data.len() * 4 + input.data.len() * 4;
46        self.stats.flops = self.estimate_flops(input);
47        result
48    }
49
50    /// Batched inference: run each input through the model.
51    pub fn batch_infer(&mut self, inputs: &[Tensor]) -> Vec<Tensor> {
52        let start = Instant::now();
53        let results: Vec<Tensor> = inputs.iter().map(|inp| self.model.forward(inp)).collect();
54        let elapsed = start.elapsed();
55        self.stats.latency_ms = elapsed.as_secs_f64() * 1000.0;
56        self.stats.memory_bytes = results.iter().map(|r| r.data.len() * 4).sum::<usize>()
57            + inputs.iter().map(|i| i.data.len() * 4).sum::<usize>();
58        self.stats.flops = inputs.iter().map(|i| self.estimate_flops(i)).sum();
59        results
60    }
61
62    /// Warm up the inference pipeline by running dummy inputs.
63    pub fn warm_up(&mut self, input_shape: Vec<usize>, runs: usize) {
64        let dummy = Tensor::zeros(input_shape);
65        for _ in 0..runs {
66            let _ = self.model.forward(&dummy);
67        }
68    }
69
70    /// Rough FLOPs estimation based on layer types.
71    fn estimate_flops(&self, input: &Tensor) -> usize {
72        let mut flops = 0usize;
73        let mut current_size: usize = input.data.len();
74        for layer in &self.model.layers {
75            match layer {
76                Layer::Dense(d) => {
77                    let m = current_size / d.weights.shape[0];
78                    let k = d.weights.shape[0];
79                    let n = d.weights.shape[1];
80                    flops += 2 * m * k * n;
81                    current_size = m * n;
82                }
83                Layer::Conv2D(c) => {
84                    let c_out = c.filters.shape[0];
85                    let c_in = c.filters.shape[1];
86                    let kh = c.filters.shape[2];
87                    let kw = c.filters.shape[3];
88                    // rough: output_spatial * c_out * c_in * kh * kw * 2
89                    flops += current_size * c_out * kh * kw * 2 / c_in.max(1);
90                }
91                Layer::Attention(a) => {
92                    // Q,K,V projections + attention + output projection
93                    flops += 4 * a.d_model * a.d_model * 2;
94                }
95                _ => {
96                    // element-wise ops: ~N flops
97                    flops += current_size;
98                }
99            }
100        }
101        flops
102    }
103}
104
105// ── ONNX Loader ─────────────────────────────────────────────────────────
106
107/// Supported ONNX operation types (simplified).
108#[derive(Debug, Clone)]
109enum OnnxOp {
110    Gemm { transA: bool, transB: bool, alpha: f32, beta: f32 },
111    Conv { strides: Vec<usize>, pads: Vec<usize> },
112    Relu,
113    MaxPool { kernel_shape: Vec<usize>, strides: Vec<usize> },
114    BatchNorm { eps: f32 },
115    Reshape,
116    Softmax { axis: i32 },
117    Add,
118    Mul,
119}
120
121/// Minimal ONNX-like graph node.
122#[derive(Debug, Clone)]
123struct OnnxNode {
124    op: OnnxOp,
125    inputs: Vec<String>,
126    outputs: Vec<String>,
127}
128
129/// ONNX model loader.
130pub struct OnnxLoader;
131
132impl OnnxLoader {
133    /// Load an ONNX model from a binary file. This implements a minimal
134    /// subset of the protobuf format — enough for simple models.
135    ///
136    /// The format we support is our own simplified binary:
137    /// - magic: b"ONNX" (4 bytes)
138    /// - num_nodes: u32 LE
139    /// - For each node:
140    ///   - op_type: u8 (0=Gemm,1=Conv,2=Relu,3=MaxPool,4=BatchNorm,5=Reshape,6=Softmax,7=Add,8=Mul)
141    ///   - num_weights: u32 LE
142    ///   - For each weight tensor: [ndim: u32] [shape...] [data as f32 LE]
143    pub fn load_onnx(path: &str) -> Result<Model, String> {
144        let mut file = std::fs::File::open(path).map_err(|e| format!("cannot open {path}: {e}"))?;
145        let mut buf4 = [0u8; 4];
146        let mut buf1 = [0u8; 1];
147
148        // magic
149        file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
150        if &buf4 != b"ONNX" {
151            return Err("invalid ONNX magic".into());
152        }
153
154        // num_nodes
155        file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
156        let num_nodes = u32::from_le_bytes(buf4) as usize;
157
158        let mut layers = Vec::new();
159
160        for _ in 0..num_nodes {
161            file.read_exact(&mut buf1).map_err(|e| e.to_string())?;
162            let op_type = buf1[0];
163
164            file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
165            let num_weights = u32::from_le_bytes(buf4) as usize;
166
167            let mut tensors = Vec::new();
168            for _ in 0..num_weights {
169                file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
170                let ndim = u32::from_le_bytes(buf4) as usize;
171                let mut shape = Vec::with_capacity(ndim);
172                for _ in 0..ndim {
173                    file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
174                    shape.push(u32::from_le_bytes(buf4) as usize);
175                }
176                let n: usize = shape.iter().product();
177                let mut data = Vec::with_capacity(n);
178                for _ in 0..n {
179                    file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
180                    data.push(f32::from_le_bytes(buf4));
181                }
182                tensors.push(Tensor { shape, data });
183            }
184
185            let layer = match op_type {
186                0 => {
187                    // Gemm -> Dense
188                    if tensors.len() >= 2 {
189                        Layer::Dense(DenseLayer {
190                            weights: tensors[0].clone(),
191                            bias: tensors[1].clone(),
192                        })
193                    } else {
194                        return Err("Gemm requires 2 weight tensors".into());
195                    }
196                }
197                1 => {
198                    // Conv
199                    if tensors.len() >= 2 {
200                        Layer::Conv2D(Conv2DLayer {
201                            filters: tensors[0].clone(),
202                            bias: tensors[1].clone(),
203                            stride: 1,
204                            padding: 0,
205                        })
206                    } else {
207                        return Err("Conv requires 2 weight tensors".into());
208                    }
209                }
210                2 => Layer::ReLU,
211                3 => Layer::MaxPool(MaxPoolLayer { kernel_size: 2, stride: 2 }),
212                4 => {
213                    // BatchNorm
214                    if tensors.len() >= 4 {
215                        Layer::BatchNorm(BatchNormLayer {
216                            gamma: tensors[0].clone(),
217                            beta: tensors[1].clone(),
218                            running_mean: tensors[2].clone(),
219                            running_var: tensors[3].clone(),
220                            eps: 1e-5,
221                        })
222                    } else {
223                        return Err("BatchNorm requires 4 tensors".into());
224                    }
225                }
226                5 => Layer::Flatten, // Reshape treated as flatten
227                6 => Layer::Softmax(0),
228                7 | 8 => {
229                    // Add / Mul are skip layers (element-wise with weights handled
230                    // at a higher level; here we just store as identity)
231                    Layer::ReLU // placeholder: identity-ish
232                }
233                _ => return Err(format!("unknown op type {op_type}")),
234            };
235            layers.push(layer);
236        }
237
238        Ok(Model { layers, name: "onnx_model".to_string() })
239    }
240
241    /// Write a model in our simplified ONNX binary format.
242    pub fn save_onnx(model: &Model, path: &str) -> Result<(), String> {
243        use std::io::Write;
244        let mut file = std::fs::File::create(path).map_err(|e| e.to_string())?;
245        file.write_all(b"ONNX").map_err(|e| e.to_string())?;
246        let num_nodes = model.layers.len() as u32;
247        file.write_all(&num_nodes.to_le_bytes()).map_err(|e| e.to_string())?;
248
249        for layer in &model.layers {
250            let (op_type, tensors): (u8, Vec<&Tensor>) = match layer {
251                Layer::Dense(l) => (0, vec![&l.weights, &l.bias]),
252                Layer::Conv2D(l) => (1, vec![&l.filters, &l.bias]),
253                Layer::ReLU => (2, vec![]),
254                Layer::MaxPool(_) => (3, vec![]),
255                Layer::BatchNorm(l) => (4, vec![&l.gamma, &l.beta, &l.running_mean, &l.running_var]),
256                Layer::Flatten => (5, vec![]),
257                Layer::Softmax(_) => (6, vec![]),
258                _ => (2, vec![]), // default to relu-like
259            };
260            file.write_all(&[op_type]).map_err(|e| e.to_string())?;
261            let nw = tensors.len() as u32;
262            file.write_all(&nw.to_le_bytes()).map_err(|e| e.to_string())?;
263            for t in tensors {
264                let ndim = t.shape.len() as u32;
265                file.write_all(&ndim.to_le_bytes()).map_err(|e| e.to_string())?;
266                for &d in &t.shape {
267                    file.write_all(&(d as u32).to_le_bytes()).map_err(|e| e.to_string())?;
268                }
269                for &v in &t.data {
270                    file.write_all(&v.to_le_bytes()).map_err(|e| e.to_string())?;
271                }
272            }
273        }
274        Ok(())
275    }
276}
277
278// ── Quantization ────────────────────────────────────────────────────────
279
280/// Simple weight quantization: clamp weights to int8 range then dequantize.
281/// This simulates the effect of lower-precision storage.
282pub fn quantize_model(model: &Model, bits: u32) -> Model {
283    let max_val = (1 << (bits - 1)) as f32 - 1.0;
284    let min_val = -max_val - 1.0;
285
286    let mut new_layers = Vec::new();
287    for layer in &model.layers {
288        let new_layer = match layer {
289            Layer::Dense(l) => {
290                let (qw, qb) = (quantize_tensor(&l.weights, min_val, max_val),
291                                 quantize_tensor(&l.bias, min_val, max_val));
292                Layer::Dense(DenseLayer { weights: qw, bias: qb })
293            }
294            Layer::Conv2D(l) => {
295                let qf = quantize_tensor(&l.filters, min_val, max_val);
296                let qb = quantize_tensor(&l.bias, min_val, max_val);
297                Layer::Conv2D(Conv2DLayer { filters: qf, bias: qb, stride: l.stride, padding: l.padding })
298            }
299            other => other.clone(),
300        };
301        new_layers.push(new_layer);
302    }
303    Model { layers: new_layers, name: format!("{}_q{}", model.name, bits) }
304}
305
306fn quantize_tensor(t: &Tensor, min_val: f32, max_val: f32) -> Tensor {
307    let abs_max = t.data.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
308    if abs_max == 0.0 {
309        return t.clone();
310    }
311    let scale = max_val / abs_max;
312    let inv_scale = abs_max / max_val;
313    let data: Vec<f32> = t.data.iter().map(|&v| {
314        let q = (v * scale).round().clamp(min_val, max_val);
315        q * inv_scale
316    }).collect();
317    Tensor { shape: t.shape.clone(), data }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_infer() {
326        let model = Sequential::new("test")
327            .dense(4, 3)
328            .relu()
329            .build();
330        let mut engine = InferenceEngine::new(model, Device::CPU);
331        let input = Tensor::ones(vec![1, 4]);
332        let out = engine.infer(&input);
333        assert_eq!(out.shape, vec![1, 3]);
334        assert!(engine.stats.latency_ms >= 0.0);
335    }
336
337    #[test]
338    fn test_batch_infer() {
339        let model = Sequential::new("test")
340            .dense(3, 2)
341            .build();
342        let mut engine = InferenceEngine::new(model, Device::CPU);
343        let inputs = vec![
344            Tensor::ones(vec![1, 3]),
345            Tensor::zeros(vec![1, 3]),
346        ];
347        let outputs = engine.batch_infer(&inputs);
348        assert_eq!(outputs.len(), 2);
349        assert_eq!(outputs[0].shape, vec![1, 2]);
350        assert_eq!(outputs[1].shape, vec![1, 2]);
351    }
352
353    #[test]
354    fn test_warm_up() {
355        let model = Sequential::new("test").dense(4, 2).build();
356        let mut engine = InferenceEngine::new(model, Device::CPU);
357        engine.warm_up(vec![1, 4], 5);
358        // just verify it doesn't panic
359    }
360
361    #[test]
362    fn test_quantize_model() {
363        let model = Sequential::new("test")
364            .dense(4, 3)
365            .relu()
366            .build();
367        let qmodel = quantize_model(&model, 8);
368        assert!(qmodel.name.contains("q8"));
369        // forward still works
370        let input = Tensor::ones(vec![1, 4]);
371        let out = qmodel.forward(&input);
372        assert_eq!(out.shape, vec![1, 3]);
373    }
374
375    #[test]
376    fn test_onnx_save_load_roundtrip() {
377        let model = Sequential::new("onnx_test")
378            .dense(4, 3)
379            .relu()
380            .dense(3, 2)
381            .softmax()
382            .build();
383
384        let path = std::env::temp_dir().join("proof_engine_test.onnx");
385        let path_str = path.to_str().unwrap();
386
387        OnnxLoader::save_onnx(&model, path_str).unwrap();
388        let loaded = OnnxLoader::load_onnx(path_str).unwrap();
389
390        assert_eq!(loaded.layers.len(), model.layers.len());
391
392        // Verify dense weights match
393        if let (Layer::Dense(orig), Layer::Dense(loaded_l)) = (&model.layers[0], &loaded.layers[0]) {
394            assert_eq!(orig.weights.data, loaded_l.weights.data);
395        }
396
397        let _ = std::fs::remove_file(path);
398    }
399
400    #[test]
401    fn test_onnx_load_bad_magic() {
402        let path = std::env::temp_dir().join("proof_engine_bad.onnx");
403        std::fs::write(&path, b"NOPE1234").unwrap();
404        let result = OnnxLoader::load_onnx(path.to_str().unwrap());
405        assert!(result.is_err());
406        let _ = std::fs::remove_file(path);
407    }
408
409    #[test]
410    fn test_inference_stats() {
411        let model = Sequential::new("s").dense(2, 2).build();
412        let mut engine = InferenceEngine::new(model, Device::CPU);
413        let _ = engine.infer(&Tensor::ones(vec![1, 2]));
414        assert!(engine.stats.flops > 0);
415        assert!(engine.stats.memory_bytes > 0);
416    }
417}