mnist_feedback/
feedback.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, feedback, network, objective, optimizer, plot, tensor};
4
5use std::fs::File;
6use std::io::{BufReader, Read, Result};
7
8fn read(reader: &mut dyn Read) -> Result<u32> {
9    let mut buffer = [0; 4];
10    reader.read_exact(&mut buffer)?;
11    Ok(u32::from_be_bytes(buffer))
12}
13
14fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
15    let mut reader = BufReader::new(File::open(path)?);
16    let mut images: Vec<tensor::Tensor> = Vec::new();
17
18    let _magic_number = read(&mut reader)?;
19    let num_images = read(&mut reader)?;
20    let num_rows = read(&mut reader)?;
21    let num_cols = read(&mut reader)?;
22
23    for _ in 0..num_images {
24        let mut image: Vec<Vec<f32>> = Vec::new();
25        for _ in 0..num_rows {
26            let mut row: Vec<f32> = Vec::new();
27            for _ in 0..num_cols {
28                let mut pixel = [0];
29                reader.read_exact(&mut pixel)?;
30                row.push(pixel[0] as f32 / 255.0);
31            }
32            image.push(row);
33        }
34        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
35    }
36
37    Ok(images)
38}
39
40fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
41    let mut reader = BufReader::new(File::open(file_path)?);
42    let _magic_number = read(&mut reader)?;
43    let num_labels = read(&mut reader)?;
44
45    let mut _labels = vec![0; num_labels as usize];
46    reader.read_exact(&mut _labels)?;
47
48    Ok(_labels
49        .iter()
50        .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
51        .collect())
52}
53
54fn main() {
55    let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56    let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57    let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58    let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59    println!(
60        "Train: {} images, Test: {} images",
61        x_train.len(),
62        x_test.len()
63    );
64
65    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70    let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72    network.convolution(
73        1,
74        (3, 3),
75        (1, 1),
76        (1, 1),
77        (1, 1),
78        activation::Activation::ReLU,
79        None,
80    );
81    network.feedback(
82        vec![feedback::Layer::Convolution(
83            1,
84            activation::Activation::ReLU,
85            (3, 3),
86            (1, 1),
87            (1, 1),
88            (1, 1),
89            None,
90        )],
91        3,
92        false,
93        false,
94        feedback::Accumulation::Mean,
95    );
96    network.convolution(
97        1,
98        (3, 3),
99        (1, 1),
100        (1, 1),
101        (1, 1),
102        activation::Activation::ReLU,
103        None,
104    );
105    network.maxpool((2, 2), (2, 2));
106    network.dense(10, activation::Activation::Softmax, true, None);
107
108    // Include skip connection bypassing the feedback block
109    network.connect(1, 2);
110    network.set_accumulation(feedback::Accumulation::Add, feedback::Accumulation::Add);
111
112    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
113    network.set_objective(
114        objective::Objective::CrossEntropy, // Objective function
115        None,                               // Gradient clipping
116    );
117
118    println!("{}", network);
119
120    // Train the network
121    let (train_loss, val_loss, val_acc) = network.learn(
122        &x_train,
123        &y_train,
124        Some((&x_test, &y_test, 10)),
125        32,
126        25,
127        Some(5),
128    );
129    plot::loss(
130        &train_loss,
131        &val_loss,
132        &val_acc,
133        "FEEDBACK : MNIST",
134        "./output/mnist/feedback.png",
135    );
136
137    // Validate the network
138    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
139    println!(
140        "Final validation accuracy: {:.2} % and loss: {:.5}",
141        val_acc * 100.0,
142        val_loss
143    );
144
145    // Use the network
146    let prediction = network.predict(x_test.get(0).unwrap());
147    println!(
148        "Prediction on input: Target: {}. Output: {}.",
149        y_test[0].argmax(),
150        prediction.argmax()
151    );
152
153    // let x = x_test.get(5).unwrap();
154    // let y = y_test.get(5).unwrap();
155    // plot::heatmap(
156    //     &x,
157    //     &format!("Target: {}", y.argmax()),
158    //     "./output/mnist/input.png",
159    // );
160
161    // Plot the pre- and post-activation heatmaps for each (image) layer.
162    // let (pre, post, _) = network.forward(x);
163    // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
164    //     let pre_title = format!("layer_{}_pre", i);
165    //     let post_title = format!("layer_{}_post", i);
166    //     let pre_file = format!("layer_{}_pre.png", i);
167    //     let post_file = format!("layer_{}_post.png", i);
168    //     plot::heatmap(&i_pre, &pre_title, &pre_file);
169    //     plot::heatmap(&i_post, &post_title, &post_file);
170    // }
171}