convolution/
convolution.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, network, plot, tensor};
4
5fn main() {
6    let mut network = network::Network::new(tensor::Shape::Triple(1, 24, 24));
7
8    network.convolution(
9        5,
10        (3, 3),
11        (1, 1),
12        (0, 0),
13        (1, 1),
14        activation::Activation::ReLU,
15        Some(0.1),
16    );
17    network.convolution(
18        1,
19        (3, 3),
20        (1, 1),
21        (0, 0),
22        (1, 1),
23        activation::Activation::ReLU,
24        Some(0.1),
25    );
26
27    println!("{}", network);
28
29    let x = tensor::Tensor::random(tensor::Shape::Triple(1, 24, 24), 0.0, 1.0);
30    println!("x: {}", &x.shape);
31
32    let (pre, post, _, _) = network.forward(&x);
33    println!("pre-activation: {}", &pre[pre.len() - 1].shape);
34    println!("post-activation: {}", &post[post.len() - 1].shape);
35
36    plot::heatmap(&x, "Input", "./output/convolution-input.png");
37    plot::heatmap(
38        &pre[pre.len() - 1],
39        "Pre-activation",
40        "./output/convolution-pre.png",
41    );
42    plot::heatmap(
43        &post[post.len() - 1],
44        "Post-activation",
45        "./output/convolution-post.png",
46    );
47}