Skip to main content

proof_engine/ml/
model.rs

1//! Neural network model construction and execution.
2
3use super::tensor::Tensor;
4use std::io::{Read, Write};
5
6/// A single layer in a neural network.
7#[derive(Debug, Clone)]
8pub enum Layer {
9    Dense(DenseLayer),
10    Conv2D(Conv2DLayer),
11    MaxPool(MaxPoolLayer),
12    BatchNorm(BatchNormLayer),
13    Dropout(DropoutLayer),
14    Flatten,
15    ReLU,
16    Sigmoid,
17    Tanh,
18    Softmax(usize), // axis
19    GELU,
20    Residual(ResidualBlock),
21    Attention(MultiHeadAttention),
22}
23
24#[derive(Debug, Clone)]
25pub struct DenseLayer {
26    pub weights: Tensor, // (in_features, out_features)
27    pub bias: Tensor,    // (out_features,)
28}
29
30#[derive(Debug, Clone)]
31pub struct Conv2DLayer {
32    pub filters: Tensor, // (c_out, c_in, kh, kw)
33    pub bias: Tensor,    // (c_out,)
34    pub stride: usize,
35    pub padding: usize,
36}
37
38#[derive(Debug, Clone)]
39pub struct MaxPoolLayer {
40    pub kernel_size: usize,
41    pub stride: usize,
42}
43
44#[derive(Debug, Clone)]
45pub struct BatchNormLayer {
46    pub gamma: Tensor,
47    pub beta: Tensor,
48    pub running_mean: Tensor,
49    pub running_var: Tensor,
50    pub eps: f32,
51}
52
53#[derive(Debug, Clone)]
54pub struct DropoutLayer {
55    pub p: f32,
56    pub training: bool,
57}
58
59#[derive(Debug, Clone)]
60pub struct ResidualBlock {
61    pub layers: Vec<Layer>,
62}
63
64#[derive(Debug, Clone)]
65pub struct MultiHeadAttention {
66    pub heads: usize,
67    pub d_model: usize,
68    pub d_k: usize,
69    pub w_q: Tensor, // (d_model, d_model)
70    pub w_k: Tensor,
71    pub w_v: Tensor,
72    pub w_o: Tensor,
73}
74
75impl DenseLayer {
76    pub fn new(in_features: usize, out_features: usize) -> Self {
77        // Xavier initialization
78        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
79        let w = Tensor::rand(vec![in_features, out_features], (in_features * out_features) as u64);
80        let weights = Tensor {
81            shape: w.shape.clone(),
82            data: w.data.iter().map(|v| (v - 0.5) * 2.0 * scale).collect(),
83        };
84        let bias = Tensor::zeros(vec![out_features]);
85        Self { weights, bias }
86    }
87
88    pub fn forward(&self, input: &Tensor) -> Tensor {
89        // input: (batch, in_features) or (in_features,)
90        let is_1d = input.shape.len() == 1;
91        let input_2d = if is_1d {
92            input.reshape(vec![1, input.shape[0]])
93        } else {
94            input.clone()
95        };
96        // out = input @ weights + bias
97        let mut out = Tensor::matmul(&input_2d, &self.weights);
98        // add bias to each row
99        let batch = out.shape[0];
100        let out_f = out.shape[1];
101        for b in 0..batch {
102            for j in 0..out_f {
103                out.data[b * out_f + j] += self.bias.data[j];
104            }
105        }
106        if is_1d { out.reshape(vec![out_f]) } else { out }
107    }
108
109    pub fn parameter_count(&self) -> usize {
110        self.weights.data.len() + self.bias.data.len()
111    }
112}
113
114impl Conv2DLayer {
115    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
116        let n = out_channels * in_channels * kernel_size * kernel_size;
117        let scale = (2.0 / (in_channels * kernel_size * kernel_size) as f32).sqrt();
118        let r = Tensor::rand(vec![out_channels, in_channels, kernel_size, kernel_size], n as u64);
119        let filters = Tensor {
120            shape: r.shape.clone(),
121            data: r.data.iter().map(|v| (v - 0.5) * 2.0 * scale).collect(),
122        };
123        let bias = Tensor::zeros(vec![out_channels]);
124        Self { filters, bias, stride: 1, padding: 0 }
125    }
126
127    pub fn forward(&self, input: &Tensor) -> Tensor {
128        let mut out = input.conv2d(&self.filters, self.stride, self.padding);
129        // add bias per channel
130        let c_out = out.shape[0];
131        let spatial: usize = out.shape[1..].iter().product();
132        for c in 0..c_out {
133            for s in 0..spatial {
134                out.data[c * spatial + s] += self.bias.data[c];
135            }
136        }
137        out
138    }
139
140    pub fn parameter_count(&self) -> usize {
141        self.filters.data.len() + self.bias.data.len()
142    }
143}
144
145impl MultiHeadAttention {
146    pub fn new(heads: usize, d_model: usize) -> Self {
147        let d_k = d_model / heads;
148        let init = |seed: u64| {
149            let r = Tensor::rand(vec![d_model, d_model], seed);
150            let scale = (1.0 / d_model as f32).sqrt();
151            Tensor {
152                shape: r.shape.clone(),
153                data: r.data.iter().map(|v| (v - 0.5) * 2.0 * scale).collect(),
154            }
155        };
156        Self {
157            heads,
158            d_model,
159            d_k,
160            w_q: init(1001),
161            w_k: init(2002),
162            w_v: init(3003),
163            w_o: init(4004),
164        }
165    }
166
167    /// Forward pass. Input shape: (seq_len, d_model). Returns same shape.
168    pub fn forward(&self, input: &Tensor) -> Tensor {
169        assert_eq!(input.shape.len(), 2);
170        let seq_len = input.shape[0];
171        let d_model = input.shape[1];
172        assert_eq!(d_model, self.d_model);
173
174        let q = Tensor::matmul(input, &self.w_q);
175        let k = Tensor::matmul(input, &self.w_k);
176        let v = Tensor::matmul(input, &self.w_v);
177
178        let d_k = self.d_k;
179        let scale = 1.0 / (d_k as f32).sqrt();
180
181        // accumulate multi-head output
182        let mut concat_heads = vec![0.0f32; seq_len * d_model];
183
184        for h in 0..self.heads {
185            let offset = h * d_k;
186            // extract head slices (seq_len, d_k) for Q, K, V
187            let mut qh = vec![0.0f32; seq_len * d_k];
188            let mut kh = vec![0.0f32; seq_len * d_k];
189            let mut vh = vec![0.0f32; seq_len * d_k];
190            for s in 0..seq_len {
191                for j in 0..d_k {
192                    qh[s * d_k + j] = q.data[s * d_model + offset + j];
193                    kh[s * d_k + j] = k.data[s * d_model + offset + j];
194                    vh[s * d_k + j] = v.data[s * d_model + offset + j];
195                }
196            }
197            let qh = Tensor::from_vec(qh, vec![seq_len, d_k]);
198            let kh_t = Tensor::from_vec(kh, vec![seq_len, d_k]).transpose();
199            let vh = Tensor::from_vec(vh, vec![seq_len, d_k]);
200
201            // scores = Q @ K^T / sqrt(d_k)
202            let scores = Tensor::matmul(&qh, &kh_t).scale(scale);
203            // attention weights via softmax over last axis
204            let attn = scores.softmax(1);
205            // context = attn @ V
206            let context = Tensor::matmul(&attn, &vh);
207
208            // write into concat buffer
209            for s in 0..seq_len {
210                for j in 0..d_k {
211                    concat_heads[s * d_model + offset + j] = context.data[s * d_k + j];
212                }
213            }
214        }
215
216        let concat = Tensor::from_vec(concat_heads, vec![seq_len, d_model]);
217        Tensor::matmul(&concat, &self.w_o)
218    }
219
220    pub fn parameter_count(&self) -> usize {
221        self.w_q.data.len() + self.w_k.data.len() + self.w_v.data.len() + self.w_o.data.len()
222    }
223}
224
225impl Layer {
226    pub fn forward(&self, input: &Tensor) -> Tensor {
227        match self {
228            Layer::Dense(l) => l.forward(input),
229            Layer::Conv2D(l) => l.forward(input),
230            Layer::MaxPool(l) => input.max_pool2d(l.kernel_size, l.stride),
231            Layer::BatchNorm(l) => {
232                input.batch_norm(&l.running_mean, &l.running_var, &l.gamma, &l.beta, l.eps)
233            }
234            Layer::Dropout(l) => input.dropout(l.p, 12345, l.training),
235            Layer::Flatten => input.flatten(),
236            Layer::ReLU => input.relu(),
237            Layer::Sigmoid => input.sigmoid(),
238            Layer::Tanh => input.tanh_act(),
239            Layer::Softmax(axis) => input.softmax(*axis),
240            Layer::GELU => input.gelu(),
241            Layer::Residual(block) => {
242                let mut out = input.clone();
243                for layer in &block.layers {
244                    out = layer.forward(&out);
245                }
246                input.add(&out)
247            }
248            Layer::Attention(attn) => attn.forward(input),
249        }
250    }
251
252    pub fn parameter_count(&self) -> usize {
253        match self {
254            Layer::Dense(l) => l.parameter_count(),
255            Layer::Conv2D(l) => l.parameter_count(),
256            Layer::BatchNorm(l) => l.gamma.data.len() + l.beta.data.len(),
257            Layer::Attention(a) => a.parameter_count(),
258            Layer::Residual(block) => block.layers.iter().map(|l| l.parameter_count()).sum(),
259            _ => 0,
260        }
261    }
262
263    pub fn name(&self) -> &str {
264        match self {
265            Layer::Dense(_) => "Dense",
266            Layer::Conv2D(_) => "Conv2D",
267            Layer::MaxPool(_) => "MaxPool",
268            Layer::BatchNorm(_) => "BatchNorm",
269            Layer::Dropout(_) => "Dropout",
270            Layer::Flatten => "Flatten",
271            Layer::ReLU => "ReLU",
272            Layer::Sigmoid => "Sigmoid",
273            Layer::Tanh => "Tanh",
274            Layer::Softmax(_) => "Softmax",
275            Layer::GELU => "GELU",
276            Layer::Residual(_) => "Residual",
277            Layer::Attention(_) => "Attention",
278        }
279    }
280}
281
282/// A sequential neural network model.
283#[derive(Debug, Clone)]
284pub struct Model {
285    pub layers: Vec<Layer>,
286    pub name: String,
287}
288
289impl Model {
290    pub fn new(name: &str) -> Self {
291        Self { layers: Vec::new(), name: name.to_string() }
292    }
293
294    pub fn forward(&self, input: &Tensor) -> Tensor {
295        let mut x = input.clone();
296        for layer in &self.layers {
297            x = layer.forward(&x);
298        }
299        x
300    }
301
302    pub fn parameter_count(&self) -> usize {
303        self.layers.iter().map(|l| l.parameter_count()).sum()
304    }
305
306    /// Collect all weight tensors from the model in order.
307    fn collect_weights(&self) -> Vec<&Tensor> {
308        let mut weights = Vec::new();
309        for layer in &self.layers {
310            match layer {
311                Layer::Dense(l) => { weights.push(&l.weights); weights.push(&l.bias); }
312                Layer::Conv2D(l) => { weights.push(&l.filters); weights.push(&l.bias); }
313                Layer::BatchNorm(l) => {
314                    weights.push(&l.gamma); weights.push(&l.beta);
315                    weights.push(&l.running_mean); weights.push(&l.running_var);
316                }
317                Layer::Attention(a) => {
318                    weights.push(&a.w_q); weights.push(&a.w_k);
319                    weights.push(&a.w_v); weights.push(&a.w_o);
320                }
321                Layer::Residual(block) => {
322                    // For simplicity, build a temp Model to reuse logic
323                    let m = Model { layers: block.layers.clone(), name: String::new() };
324                    // We can't easily return refs into block here without
325                    // restructuring, so we skip residual sub-weights in save/load.
326                    let _ = m;
327                }
328                _ => {}
329            }
330        }
331        weights
332    }
333
334    /// Save weights to a simple binary format:
335    /// For each tensor: [ndim: u32] [shape[0]: u32] ... [shape[n-1]: u32] [data as f32 LE bytes]
336    pub fn save_weights(&self, path: &str) -> Result<(), String> {
337        let mut file = std::fs::File::create(path).map_err(|e| e.to_string())?;
338        let weights = self.collect_weights();
339        let count = weights.len() as u32;
340        file.write_all(&count.to_le_bytes()).map_err(|e| e.to_string())?;
341        for w in weights {
342            let ndim = w.shape.len() as u32;
343            file.write_all(&ndim.to_le_bytes()).map_err(|e| e.to_string())?;
344            for &d in &w.shape {
345                file.write_all(&(d as u32).to_le_bytes()).map_err(|e| e.to_string())?;
346            }
347            for &v in &w.data {
348                file.write_all(&v.to_le_bytes()).map_err(|e| e.to_string())?;
349            }
350        }
351        Ok(())
352    }
353
354    /// Load weights from the binary format written by `save_weights`.
355    pub fn load_weights(&mut self, path: &str) -> Result<(), String> {
356        let mut file = std::fs::File::open(path).map_err(|e| e.to_string())?;
357        let mut buf4 = [0u8; 4];
358
359        file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
360        let count = u32::from_le_bytes(buf4) as usize;
361
362        let mut tensors = Vec::with_capacity(count);
363        for _ in 0..count {
364            file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
365            let ndim = u32::from_le_bytes(buf4) as usize;
366            let mut shape = Vec::with_capacity(ndim);
367            for _ in 0..ndim {
368                file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
369                shape.push(u32::from_le_bytes(buf4) as usize);
370            }
371            let n: usize = shape.iter().product();
372            let mut data = Vec::with_capacity(n);
373            for _ in 0..n {
374                file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
375                data.push(f32::from_le_bytes(buf4));
376            }
377            tensors.push(Tensor { shape, data });
378        }
379
380        // Assign weights back to layers
381        let mut idx = 0;
382        for layer in &mut self.layers {
383            match layer {
384                Layer::Dense(l) => {
385                    if idx + 1 < tensors.len() {
386                        l.weights = tensors[idx].clone();
387                        l.bias = tensors[idx + 1].clone();
388                        idx += 2;
389                    }
390                }
391                Layer::Conv2D(l) => {
392                    if idx + 1 < tensors.len() {
393                        l.filters = tensors[idx].clone();
394                        l.bias = tensors[idx + 1].clone();
395                        idx += 2;
396                    }
397                }
398                Layer::BatchNorm(l) => {
399                    if idx + 3 < tensors.len() {
400                        l.gamma = tensors[idx].clone();
401                        l.beta = tensors[idx + 1].clone();
402                        l.running_mean = tensors[idx + 2].clone();
403                        l.running_var = tensors[idx + 3].clone();
404                        idx += 4;
405                    }
406                }
407                Layer::Attention(a) => {
408                    if idx + 3 < tensors.len() {
409                        a.w_q = tensors[idx].clone();
410                        a.w_k = tensors[idx + 1].clone();
411                        a.w_v = tensors[idx + 2].clone();
412                        a.w_o = tensors[idx + 3].clone();
413                        idx += 4;
414                    }
415                }
416                _ => {}
417            }
418        }
419        Ok(())
420    }
421}
422
423/// Human-readable model summary.
424pub struct ModelSummary;
425
426impl ModelSummary {
427    pub fn print(model: &Model) -> String {
428        let mut lines = Vec::new();
429        lines.push(format!("Model: {}", model.name));
430        lines.push(format!("{:-<60}", ""));
431        lines.push(format!("{:<20} {:>20} {:>15}", "Layer", "Output Shape", "Params"));
432        lines.push(format!("{:-<60}", ""));
433        for (i, layer) in model.layers.iter().enumerate() {
434            let params = layer.parameter_count();
435            lines.push(format!("{:<20} {:>20} {:>15}", format!("{}_{}", layer.name(), i), "dynamic", params));
436        }
437        lines.push(format!("{:-<60}", ""));
438        lines.push(format!("Total parameters: {}", model.parameter_count()));
439        lines.join("\n")
440    }
441}
442
443/// Builder for constructing models layer by layer.
444pub struct Sequential {
445    layers: Vec<Layer>,
446    name: String,
447}
448
449impl Sequential {
450    pub fn new(name: &str) -> Self {
451        Self { layers: Vec::new(), name: name.to_string() }
452    }
453
454    pub fn dense(mut self, in_features: usize, out_features: usize) -> Self {
455        self.layers.push(Layer::Dense(DenseLayer::new(in_features, out_features)));
456        self
457    }
458
459    pub fn conv2d(mut self, in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
460        self.layers.push(Layer::Conv2D(Conv2DLayer::new(in_channels, out_channels, kernel_size)));
461        self
462    }
463
464    pub fn max_pool(mut self, kernel_size: usize, stride: usize) -> Self {
465        self.layers.push(Layer::MaxPool(MaxPoolLayer { kernel_size, stride }));
466        self
467    }
468
469    pub fn batch_norm(mut self, num_features: usize) -> Self {
470        self.layers.push(Layer::BatchNorm(BatchNormLayer {
471            gamma: Tensor::ones(vec![num_features]),
472            beta: Tensor::zeros(vec![num_features]),
473            running_mean: Tensor::zeros(vec![num_features]),
474            running_var: Tensor::ones(vec![num_features]),
475            eps: 1e-5,
476        }));
477        self
478    }
479
480    pub fn dropout(mut self, p: f32) -> Self {
481        self.layers.push(Layer::Dropout(DropoutLayer { p, training: true }));
482        self
483    }
484
485    pub fn flatten(mut self) -> Self {
486        self.layers.push(Layer::Flatten);
487        self
488    }
489
490    pub fn relu(mut self) -> Self {
491        self.layers.push(Layer::ReLU);
492        self
493    }
494
495    pub fn sigmoid(mut self) -> Self {
496        self.layers.push(Layer::Sigmoid);
497        self
498    }
499
500    pub fn tanh_act(mut self) -> Self {
501        self.layers.push(Layer::Tanh);
502        self
503    }
504
505    pub fn softmax(mut self) -> Self {
506        // default: softmax over the last axis, represented as axis=0 for 1-D
507        self.layers.push(Layer::Softmax(0));
508        self
509    }
510
511    pub fn softmax_axis(mut self, axis: usize) -> Self {
512        self.layers.push(Layer::Softmax(axis));
513        self
514    }
515
516    pub fn gelu(mut self) -> Self {
517        self.layers.push(Layer::GELU);
518        self
519    }
520
521    pub fn residual(mut self, layers: Vec<Layer>) -> Self {
522        self.layers.push(Layer::Residual(ResidualBlock { layers }));
523        self
524    }
525
526    pub fn attention(mut self, heads: usize, d_model: usize) -> Self {
527        self.layers.push(Layer::Attention(MultiHeadAttention::new(heads, d_model)));
528        self
529    }
530
531    pub fn layer(mut self, layer: Layer) -> Self {
532        self.layers.push(layer);
533        self
534    }
535
536    pub fn build(self) -> Model {
537        Model { layers: self.layers, name: self.name }
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544
545    #[test]
546    fn test_dense_forward_shape() {
547        let layer = DenseLayer::new(4, 3);
548        let input = Tensor::ones(vec![2, 4]);
549        let out = layer.forward(&input);
550        assert_eq!(out.shape, vec![2, 3]);
551    }
552
553    #[test]
554    fn test_dense_forward_1d() {
555        let layer = DenseLayer::new(3, 2);
556        let input = Tensor::ones(vec![3]);
557        let out = layer.forward(&input);
558        assert_eq!(out.shape, vec![2]);
559    }
560
561    #[test]
562    fn test_sequential_build() {
563        let model = Sequential::new("test")
564            .dense(10, 5)
565            .relu()
566            .dense(5, 2)
567            .softmax()
568            .build();
569        assert_eq!(model.layers.len(), 4);
570        assert_eq!(model.name, "test");
571    }
572
573    #[test]
574    fn test_model_forward_shape() {
575        let model = Sequential::new("mlp")
576            .dense(4, 8)
577            .relu()
578            .dense(8, 3)
579            .build();
580        let input = Tensor::ones(vec![2, 4]);
581        let out = model.forward(&input);
582        assert_eq!(out.shape, vec![2, 3]);
583    }
584
585    #[test]
586    fn test_parameter_count() {
587        let model = Sequential::new("mlp")
588            .dense(10, 5) // 10*5 + 5 = 55
589            .dense(5, 2)  // 5*2 + 2 = 12
590            .build();
591        assert_eq!(model.parameter_count(), 55 + 12);
592    }
593
594    #[test]
595    fn test_residual_connection() {
596        // residual block: dense(4,4) + relu, then added to input
597        let block_layers = vec![
598            Layer::Dense(DenseLayer {
599                weights: Tensor::zeros(vec![4, 4]),
600                bias: Tensor::zeros(vec![4]),
601            }),
602            Layer::ReLU,
603        ];
604        let model = Sequential::new("res")
605            .residual(block_layers)
606            .build();
607        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]);
608        let out = model.forward(&input);
609        // With zero weights, dense outputs zeros, relu(zeros)=zeros, residual = input + 0 = input
610        assert_eq!(out.shape, vec![1, 4]);
611        assert_eq!(out.data, vec![1.0, 2.0, 3.0, 4.0]);
612    }
613
614    #[test]
615    fn test_attention_forward_shape() {
616        let attn = MultiHeadAttention::new(2, 4);
617        let input = Tensor::rand(vec![3, 4], 42); // seq_len=3, d_model=4
618        let out = attn.forward(&input);
619        assert_eq!(out.shape, vec![3, 4]);
620    }
621
622    #[test]
623    fn test_model_summary() {
624        let model = Sequential::new("demo")
625            .dense(10, 5)
626            .relu()
627            .build();
628        let summary = ModelSummary::print(&model);
629        assert!(summary.contains("demo"));
630        assert!(summary.contains("Dense"));
631        assert!(summary.contains("ReLU"));
632    }
633
634    #[test]
635    fn test_save_load_weights() {
636        let model = Sequential::new("test")
637            .dense(3, 2)
638            .build();
639        let path = std::env::temp_dir().join("proof_engine_test_weights.bin");
640        let path_str = path.to_str().unwrap();
641        model.save_weights(path_str).unwrap();
642
643        let mut model2 = Sequential::new("test")
644            .dense(3, 2)
645            .build();
646        model2.load_weights(path_str).unwrap();
647
648        // weights should match
649        if let (Layer::Dense(l1), Layer::Dense(l2)) = (&model.layers[0], &model2.layers[0]) {
650            assert_eq!(l1.weights.data, l2.weights.data);
651            assert_eq!(l1.bias.data, l2.bias.data);
652        }
653        let _ = std::fs::remove_file(path);
654    }
655
656    #[test]
657    fn test_conv2d_layer_forward() {
658        let layer = Conv2DLayer::new(1, 2, 3);
659        let input = Tensor::ones(vec![1, 5, 5]);
660        let out = layer.forward(&input);
661        assert_eq!(out.shape, vec![2, 3, 3]);
662    }
663}