1use super::tensor::Tensor;
4use std::io::{Read, Write};
5
6#[derive(Debug, Clone)]
8pub enum Layer {
9 Dense(DenseLayer),
10 Conv2D(Conv2DLayer),
11 MaxPool(MaxPoolLayer),
12 BatchNorm(BatchNormLayer),
13 Dropout(DropoutLayer),
14 Flatten,
15 ReLU,
16 Sigmoid,
17 Tanh,
18 Softmax(usize), GELU,
20 Residual(ResidualBlock),
21 Attention(MultiHeadAttention),
22}
23
24#[derive(Debug, Clone)]
25pub struct DenseLayer {
26 pub weights: Tensor, pub bias: Tensor, }
29
30#[derive(Debug, Clone)]
31pub struct Conv2DLayer {
32 pub filters: Tensor, pub bias: Tensor, pub stride: usize,
35 pub padding: usize,
36}
37
38#[derive(Debug, Clone)]
39pub struct MaxPoolLayer {
40 pub kernel_size: usize,
41 pub stride: usize,
42}
43
44#[derive(Debug, Clone)]
45pub struct BatchNormLayer {
46 pub gamma: Tensor,
47 pub beta: Tensor,
48 pub running_mean: Tensor,
49 pub running_var: Tensor,
50 pub eps: f32,
51}
52
53#[derive(Debug, Clone)]
54pub struct DropoutLayer {
55 pub p: f32,
56 pub training: bool,
57}
58
59#[derive(Debug, Clone)]
60pub struct ResidualBlock {
61 pub layers: Vec<Layer>,
62}
63
64#[derive(Debug, Clone)]
65pub struct MultiHeadAttention {
66 pub heads: usize,
67 pub d_model: usize,
68 pub d_k: usize,
69 pub w_q: Tensor, pub w_k: Tensor,
71 pub w_v: Tensor,
72 pub w_o: Tensor,
73}
74
75impl DenseLayer {
76 pub fn new(in_features: usize, out_features: usize) -> Self {
77 let scale = (2.0 / (in_features + out_features) as f32).sqrt();
79 let w = Tensor::rand(vec![in_features, out_features], (in_features * out_features) as u64);
80 let weights = Tensor {
81 shape: w.shape.clone(),
82 data: w.data.iter().map(|v| (v - 0.5) * 2.0 * scale).collect(),
83 };
84 let bias = Tensor::zeros(vec![out_features]);
85 Self { weights, bias }
86 }
87
88 pub fn forward(&self, input: &Tensor) -> Tensor {
89 let is_1d = input.shape.len() == 1;
91 let input_2d = if is_1d {
92 input.reshape(vec![1, input.shape[0]])
93 } else {
94 input.clone()
95 };
96 let mut out = Tensor::matmul(&input_2d, &self.weights);
98 let batch = out.shape[0];
100 let out_f = out.shape[1];
101 for b in 0..batch {
102 for j in 0..out_f {
103 out.data[b * out_f + j] += self.bias.data[j];
104 }
105 }
106 if is_1d { out.reshape(vec![out_f]) } else { out }
107 }
108
109 pub fn parameter_count(&self) -> usize {
110 self.weights.data.len() + self.bias.data.len()
111 }
112}
113
114impl Conv2DLayer {
115 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
116 let n = out_channels * in_channels * kernel_size * kernel_size;
117 let scale = (2.0 / (in_channels * kernel_size * kernel_size) as f32).sqrt();
118 let r = Tensor::rand(vec![out_channels, in_channels, kernel_size, kernel_size], n as u64);
119 let filters = Tensor {
120 shape: r.shape.clone(),
121 data: r.data.iter().map(|v| (v - 0.5) * 2.0 * scale).collect(),
122 };
123 let bias = Tensor::zeros(vec![out_channels]);
124 Self { filters, bias, stride: 1, padding: 0 }
125 }
126
127 pub fn forward(&self, input: &Tensor) -> Tensor {
128 let mut out = input.conv2d(&self.filters, self.stride, self.padding);
129 let c_out = out.shape[0];
131 let spatial: usize = out.shape[1..].iter().product();
132 for c in 0..c_out {
133 for s in 0..spatial {
134 out.data[c * spatial + s] += self.bias.data[c];
135 }
136 }
137 out
138 }
139
140 pub fn parameter_count(&self) -> usize {
141 self.filters.data.len() + self.bias.data.len()
142 }
143}
144
145impl MultiHeadAttention {
146 pub fn new(heads: usize, d_model: usize) -> Self {
147 let d_k = d_model / heads;
148 let init = |seed: u64| {
149 let r = Tensor::rand(vec![d_model, d_model], seed);
150 let scale = (1.0 / d_model as f32).sqrt();
151 Tensor {
152 shape: r.shape.clone(),
153 data: r.data.iter().map(|v| (v - 0.5) * 2.0 * scale).collect(),
154 }
155 };
156 Self {
157 heads,
158 d_model,
159 d_k,
160 w_q: init(1001),
161 w_k: init(2002),
162 w_v: init(3003),
163 w_o: init(4004),
164 }
165 }
166
167 pub fn forward(&self, input: &Tensor) -> Tensor {
169 assert_eq!(input.shape.len(), 2);
170 let seq_len = input.shape[0];
171 let d_model = input.shape[1];
172 assert_eq!(d_model, self.d_model);
173
174 let q = Tensor::matmul(input, &self.w_q);
175 let k = Tensor::matmul(input, &self.w_k);
176 let v = Tensor::matmul(input, &self.w_v);
177
178 let d_k = self.d_k;
179 let scale = 1.0 / (d_k as f32).sqrt();
180
181 let mut concat_heads = vec![0.0f32; seq_len * d_model];
183
184 for h in 0..self.heads {
185 let offset = h * d_k;
186 let mut qh = vec![0.0f32; seq_len * d_k];
188 let mut kh = vec![0.0f32; seq_len * d_k];
189 let mut vh = vec![0.0f32; seq_len * d_k];
190 for s in 0..seq_len {
191 for j in 0..d_k {
192 qh[s * d_k + j] = q.data[s * d_model + offset + j];
193 kh[s * d_k + j] = k.data[s * d_model + offset + j];
194 vh[s * d_k + j] = v.data[s * d_model + offset + j];
195 }
196 }
197 let qh = Tensor::from_vec(qh, vec![seq_len, d_k]);
198 let kh_t = Tensor::from_vec(kh, vec![seq_len, d_k]).transpose();
199 let vh = Tensor::from_vec(vh, vec![seq_len, d_k]);
200
201 let scores = Tensor::matmul(&qh, &kh_t).scale(scale);
203 let attn = scores.softmax(1);
205 let context = Tensor::matmul(&attn, &vh);
207
208 for s in 0..seq_len {
210 for j in 0..d_k {
211 concat_heads[s * d_model + offset + j] = context.data[s * d_k + j];
212 }
213 }
214 }
215
216 let concat = Tensor::from_vec(concat_heads, vec![seq_len, d_model]);
217 Tensor::matmul(&concat, &self.w_o)
218 }
219
220 pub fn parameter_count(&self) -> usize {
221 self.w_q.data.len() + self.w_k.data.len() + self.w_v.data.len() + self.w_o.data.len()
222 }
223}
224
225impl Layer {
226 pub fn forward(&self, input: &Tensor) -> Tensor {
227 match self {
228 Layer::Dense(l) => l.forward(input),
229 Layer::Conv2D(l) => l.forward(input),
230 Layer::MaxPool(l) => input.max_pool2d(l.kernel_size, l.stride),
231 Layer::BatchNorm(l) => {
232 input.batch_norm(&l.running_mean, &l.running_var, &l.gamma, &l.beta, l.eps)
233 }
234 Layer::Dropout(l) => input.dropout(l.p, 12345, l.training),
235 Layer::Flatten => input.flatten(),
236 Layer::ReLU => input.relu(),
237 Layer::Sigmoid => input.sigmoid(),
238 Layer::Tanh => input.tanh_act(),
239 Layer::Softmax(axis) => input.softmax(*axis),
240 Layer::GELU => input.gelu(),
241 Layer::Residual(block) => {
242 let mut out = input.clone();
243 for layer in &block.layers {
244 out = layer.forward(&out);
245 }
246 input.add(&out)
247 }
248 Layer::Attention(attn) => attn.forward(input),
249 }
250 }
251
252 pub fn parameter_count(&self) -> usize {
253 match self {
254 Layer::Dense(l) => l.parameter_count(),
255 Layer::Conv2D(l) => l.parameter_count(),
256 Layer::BatchNorm(l) => l.gamma.data.len() + l.beta.data.len(),
257 Layer::Attention(a) => a.parameter_count(),
258 Layer::Residual(block) => block.layers.iter().map(|l| l.parameter_count()).sum(),
259 _ => 0,
260 }
261 }
262
263 pub fn name(&self) -> &str {
264 match self {
265 Layer::Dense(_) => "Dense",
266 Layer::Conv2D(_) => "Conv2D",
267 Layer::MaxPool(_) => "MaxPool",
268 Layer::BatchNorm(_) => "BatchNorm",
269 Layer::Dropout(_) => "Dropout",
270 Layer::Flatten => "Flatten",
271 Layer::ReLU => "ReLU",
272 Layer::Sigmoid => "Sigmoid",
273 Layer::Tanh => "Tanh",
274 Layer::Softmax(_) => "Softmax",
275 Layer::GELU => "GELU",
276 Layer::Residual(_) => "Residual",
277 Layer::Attention(_) => "Attention",
278 }
279 }
280}
281
282#[derive(Debug, Clone)]
284pub struct Model {
285 pub layers: Vec<Layer>,
286 pub name: String,
287}
288
289impl Model {
290 pub fn new(name: &str) -> Self {
291 Self { layers: Vec::new(), name: name.to_string() }
292 }
293
294 pub fn forward(&self, input: &Tensor) -> Tensor {
295 let mut x = input.clone();
296 for layer in &self.layers {
297 x = layer.forward(&x);
298 }
299 x
300 }
301
302 pub fn parameter_count(&self) -> usize {
303 self.layers.iter().map(|l| l.parameter_count()).sum()
304 }
305
306 fn collect_weights(&self) -> Vec<&Tensor> {
308 let mut weights = Vec::new();
309 for layer in &self.layers {
310 match layer {
311 Layer::Dense(l) => { weights.push(&l.weights); weights.push(&l.bias); }
312 Layer::Conv2D(l) => { weights.push(&l.filters); weights.push(&l.bias); }
313 Layer::BatchNorm(l) => {
314 weights.push(&l.gamma); weights.push(&l.beta);
315 weights.push(&l.running_mean); weights.push(&l.running_var);
316 }
317 Layer::Attention(a) => {
318 weights.push(&a.w_q); weights.push(&a.w_k);
319 weights.push(&a.w_v); weights.push(&a.w_o);
320 }
321 Layer::Residual(block) => {
322 let m = Model { layers: block.layers.clone(), name: String::new() };
324 let _ = m;
327 }
328 _ => {}
329 }
330 }
331 weights
332 }
333
334 pub fn save_weights(&self, path: &str) -> Result<(), String> {
337 let mut file = std::fs::File::create(path).map_err(|e| e.to_string())?;
338 let weights = self.collect_weights();
339 let count = weights.len() as u32;
340 file.write_all(&count.to_le_bytes()).map_err(|e| e.to_string())?;
341 for w in weights {
342 let ndim = w.shape.len() as u32;
343 file.write_all(&ndim.to_le_bytes()).map_err(|e| e.to_string())?;
344 for &d in &w.shape {
345 file.write_all(&(d as u32).to_le_bytes()).map_err(|e| e.to_string())?;
346 }
347 for &v in &w.data {
348 file.write_all(&v.to_le_bytes()).map_err(|e| e.to_string())?;
349 }
350 }
351 Ok(())
352 }
353
354 pub fn load_weights(&mut self, path: &str) -> Result<(), String> {
356 let mut file = std::fs::File::open(path).map_err(|e| e.to_string())?;
357 let mut buf4 = [0u8; 4];
358
359 file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
360 let count = u32::from_le_bytes(buf4) as usize;
361
362 let mut tensors = Vec::with_capacity(count);
363 for _ in 0..count {
364 file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
365 let ndim = u32::from_le_bytes(buf4) as usize;
366 let mut shape = Vec::with_capacity(ndim);
367 for _ in 0..ndim {
368 file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
369 shape.push(u32::from_le_bytes(buf4) as usize);
370 }
371 let n: usize = shape.iter().product();
372 let mut data = Vec::with_capacity(n);
373 for _ in 0..n {
374 file.read_exact(&mut buf4).map_err(|e| e.to_string())?;
375 data.push(f32::from_le_bytes(buf4));
376 }
377 tensors.push(Tensor { shape, data });
378 }
379
380 let mut idx = 0;
382 for layer in &mut self.layers {
383 match layer {
384 Layer::Dense(l) => {
385 if idx + 1 < tensors.len() {
386 l.weights = tensors[idx].clone();
387 l.bias = tensors[idx + 1].clone();
388 idx += 2;
389 }
390 }
391 Layer::Conv2D(l) => {
392 if idx + 1 < tensors.len() {
393 l.filters = tensors[idx].clone();
394 l.bias = tensors[idx + 1].clone();
395 idx += 2;
396 }
397 }
398 Layer::BatchNorm(l) => {
399 if idx + 3 < tensors.len() {
400 l.gamma = tensors[idx].clone();
401 l.beta = tensors[idx + 1].clone();
402 l.running_mean = tensors[idx + 2].clone();
403 l.running_var = tensors[idx + 3].clone();
404 idx += 4;
405 }
406 }
407 Layer::Attention(a) => {
408 if idx + 3 < tensors.len() {
409 a.w_q = tensors[idx].clone();
410 a.w_k = tensors[idx + 1].clone();
411 a.w_v = tensors[idx + 2].clone();
412 a.w_o = tensors[idx + 3].clone();
413 idx += 4;
414 }
415 }
416 _ => {}
417 }
418 }
419 Ok(())
420 }
421}
422
423pub struct ModelSummary;
425
426impl ModelSummary {
427 pub fn print(model: &Model) -> String {
428 let mut lines = Vec::new();
429 lines.push(format!("Model: {}", model.name));
430 lines.push(format!("{:-<60}", ""));
431 lines.push(format!("{:<20} {:>20} {:>15}", "Layer", "Output Shape", "Params"));
432 lines.push(format!("{:-<60}", ""));
433 for (i, layer) in model.layers.iter().enumerate() {
434 let params = layer.parameter_count();
435 lines.push(format!("{:<20} {:>20} {:>15}", format!("{}_{}", layer.name(), i), "dynamic", params));
436 }
437 lines.push(format!("{:-<60}", ""));
438 lines.push(format!("Total parameters: {}", model.parameter_count()));
439 lines.join("\n")
440 }
441}
442
443pub struct Sequential {
445 layers: Vec<Layer>,
446 name: String,
447}
448
449impl Sequential {
450 pub fn new(name: &str) -> Self {
451 Self { layers: Vec::new(), name: name.to_string() }
452 }
453
454 pub fn dense(mut self, in_features: usize, out_features: usize) -> Self {
455 self.layers.push(Layer::Dense(DenseLayer::new(in_features, out_features)));
456 self
457 }
458
459 pub fn conv2d(mut self, in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
460 self.layers.push(Layer::Conv2D(Conv2DLayer::new(in_channels, out_channels, kernel_size)));
461 self
462 }
463
464 pub fn max_pool(mut self, kernel_size: usize, stride: usize) -> Self {
465 self.layers.push(Layer::MaxPool(MaxPoolLayer { kernel_size, stride }));
466 self
467 }
468
469 pub fn batch_norm(mut self, num_features: usize) -> Self {
470 self.layers.push(Layer::BatchNorm(BatchNormLayer {
471 gamma: Tensor::ones(vec![num_features]),
472 beta: Tensor::zeros(vec![num_features]),
473 running_mean: Tensor::zeros(vec![num_features]),
474 running_var: Tensor::ones(vec![num_features]),
475 eps: 1e-5,
476 }));
477 self
478 }
479
480 pub fn dropout(mut self, p: f32) -> Self {
481 self.layers.push(Layer::Dropout(DropoutLayer { p, training: true }));
482 self
483 }
484
485 pub fn flatten(mut self) -> Self {
486 self.layers.push(Layer::Flatten);
487 self
488 }
489
490 pub fn relu(mut self) -> Self {
491 self.layers.push(Layer::ReLU);
492 self
493 }
494
495 pub fn sigmoid(mut self) -> Self {
496 self.layers.push(Layer::Sigmoid);
497 self
498 }
499
500 pub fn tanh_act(mut self) -> Self {
501 self.layers.push(Layer::Tanh);
502 self
503 }
504
505 pub fn softmax(mut self) -> Self {
506 self.layers.push(Layer::Softmax(0));
508 self
509 }
510
511 pub fn softmax_axis(mut self, axis: usize) -> Self {
512 self.layers.push(Layer::Softmax(axis));
513 self
514 }
515
516 pub fn gelu(mut self) -> Self {
517 self.layers.push(Layer::GELU);
518 self
519 }
520
521 pub fn residual(mut self, layers: Vec<Layer>) -> Self {
522 self.layers.push(Layer::Residual(ResidualBlock { layers }));
523 self
524 }
525
526 pub fn attention(mut self, heads: usize, d_model: usize) -> Self {
527 self.layers.push(Layer::Attention(MultiHeadAttention::new(heads, d_model)));
528 self
529 }
530
531 pub fn layer(mut self, layer: Layer) -> Self {
532 self.layers.push(layer);
533 self
534 }
535
536 pub fn build(self) -> Model {
537 Model { layers: self.layers, name: self.name }
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544
545 #[test]
546 fn test_dense_forward_shape() {
547 let layer = DenseLayer::new(4, 3);
548 let input = Tensor::ones(vec![2, 4]);
549 let out = layer.forward(&input);
550 assert_eq!(out.shape, vec![2, 3]);
551 }
552
553 #[test]
554 fn test_dense_forward_1d() {
555 let layer = DenseLayer::new(3, 2);
556 let input = Tensor::ones(vec![3]);
557 let out = layer.forward(&input);
558 assert_eq!(out.shape, vec![2]);
559 }
560
561 #[test]
562 fn test_sequential_build() {
563 let model = Sequential::new("test")
564 .dense(10, 5)
565 .relu()
566 .dense(5, 2)
567 .softmax()
568 .build();
569 assert_eq!(model.layers.len(), 4);
570 assert_eq!(model.name, "test");
571 }
572
573 #[test]
574 fn test_model_forward_shape() {
575 let model = Sequential::new("mlp")
576 .dense(4, 8)
577 .relu()
578 .dense(8, 3)
579 .build();
580 let input = Tensor::ones(vec![2, 4]);
581 let out = model.forward(&input);
582 assert_eq!(out.shape, vec![2, 3]);
583 }
584
585 #[test]
586 fn test_parameter_count() {
587 let model = Sequential::new("mlp")
588 .dense(10, 5) .dense(5, 2) .build();
591 assert_eq!(model.parameter_count(), 55 + 12);
592 }
593
594 #[test]
595 fn test_residual_connection() {
596 let block_layers = vec![
598 Layer::Dense(DenseLayer {
599 weights: Tensor::zeros(vec![4, 4]),
600 bias: Tensor::zeros(vec![4]),
601 }),
602 Layer::ReLU,
603 ];
604 let model = Sequential::new("res")
605 .residual(block_layers)
606 .build();
607 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]);
608 let out = model.forward(&input);
609 assert_eq!(out.shape, vec![1, 4]);
611 assert_eq!(out.data, vec![1.0, 2.0, 3.0, 4.0]);
612 }
613
614 #[test]
615 fn test_attention_forward_shape() {
616 let attn = MultiHeadAttention::new(2, 4);
617 let input = Tensor::rand(vec![3, 4], 42); let out = attn.forward(&input);
619 assert_eq!(out.shape, vec![3, 4]);
620 }
621
622 #[test]
623 fn test_model_summary() {
624 let model = Sequential::new("demo")
625 .dense(10, 5)
626 .relu()
627 .build();
628 let summary = ModelSummary::print(&model);
629 assert!(summary.contains("demo"));
630 assert!(summary.contains("Dense"));
631 assert!(summary.contains("ReLU"));
632 }
633
634 #[test]
635 fn test_save_load_weights() {
636 let model = Sequential::new("test")
637 .dense(3, 2)
638 .build();
639 let path = std::env::temp_dir().join("proof_engine_test_weights.bin");
640 let path_str = path.to_str().unwrap();
641 model.save_weights(path_str).unwrap();
642
643 let mut model2 = Sequential::new("test")
644 .dense(3, 2)
645 .build();
646 model2.load_weights(path_str).unwrap();
647
648 if let (Layer::Dense(l1), Layer::Dense(l2)) = (&model.layers[0], &model2.layers[0]) {
650 assert_eq!(l1.weights.data, l2.weights.data);
651 assert_eq!(l1.bias.data, l2.bias.data);
652 }
653 let _ = std::fs::remove_file(path);
654 }
655
656 #[test]
657 fn test_conv2d_layer_forward() {
658 let layer = Conv2DLayer::new(1, 2, 3);
659 let input = Tensor::ones(vec![1, 5, 5]);
660 let out = layer.forward(&input);
661 assert_eq!(out.shape, vec![2, 3, 3]);
662 }
663}