compare_mnist/
mnist.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// Code for comparison between the various architectures.
4// The respective loss and accuracies is stored to the file `~/output/compare/mnist.json`.
5//
6// In addition, some simple probing of the networks are done.
7// Namely, validating the trained networks with and without feedback and skip connections.
8//
9// for (
10//   REGULAR,
11//   FEEDBACK[approach=1, loops=2],
12//   FEEDBACK[approach=1, loops=3],
13//   FEEDBACK[approach=2, loops=2],
14//   FEEDBACK[approach=2, loops=3]
15// ) do {
16//
17//   for (NOSKIP, SKIP) do {
18//
19//     for (CLASSIFICATION, REGRESSION) do {
20//
21//       for (run in RUNS) do {
22//
23//         create the network
24//         train the network
25//         validate the network
26//         store the loss and accuracy
27//         probe the network
28//         store the probing results
29//
30//       }
31//     }
32//   }
33// }
34
35use neurons::{activation, feedback, network, objective, optimizer, tensor};
36
37use std::{
38    collections::HashMap,
39    fs::File,
40    io::{BufReader, Read, Result, Write},
41    sync::Arc,
42};
43
44const RUNS: usize = 5;
45
46fn read(reader: &mut dyn Read) -> Result<u32> {
47    let mut buffer = [0; 4];
48    reader.read_exact(&mut buffer)?;
49    Ok(u32::from_be_bytes(buffer))
50}
51
52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53    let mut reader = BufReader::new(File::open(path)?);
54    let mut images: Vec<tensor::Tensor> = Vec::new();
55
56    let _magic_number = read(&mut reader)?;
57    let num_images = read(&mut reader)?;
58    let num_rows = read(&mut reader)?;
59    let num_cols = read(&mut reader)?;
60
61    for _ in 0..num_images {
62        let mut image: Vec<Vec<f32>> = Vec::new();
63        for _ in 0..num_rows {
64            let mut row: Vec<f32> = Vec::new();
65            for _ in 0..num_cols {
66                let mut pixel = [0];
67                reader.read_exact(&mut pixel)?;
68                row.push(pixel[0] as f32 / 255.0);
69            }
70            image.push(row);
71        }
72        images.push(tensor::Tensor::triple(vec![image]));
73    }
74
75    Ok(images)
76}
77
78fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
79    let mut reader = BufReader::new(File::open(file_path)?);
80    let _magic_number = read(&mut reader)?;
81    let num_labels = read(&mut reader)?;
82
83    let mut _labels = vec![0; num_labels as usize];
84    reader.read_exact(&mut _labels)?;
85
86    Ok(_labels
87        .iter()
88        .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
89        .collect())
90}
91
92fn main() {
93    let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
94    let class_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
95    let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
96    let class_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
97
98    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
99    let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
100    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
101    let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
102
103    println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
104    println!("Test data {}x{}\n", x_test.len(), x_test[0].shape,);
105
106    // Create the results file.
107    let mut file = File::create("./output/compare/mnist.json").unwrap();
108    writeln!(file, "[").unwrap();
109    writeln!(file, "  {{").unwrap();
110
111    vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
112        .iter()
113        .for_each(|method| {
114            println!("Method: {}", method);
115            vec![false, true].iter().for_each(|skip| {
116                println!("  Skip: {}", skip);
117                vec!["CLASSIFICATION"].iter().for_each(|problem| {
118                    println!("   Problem: {}", problem);
119                    writeln!(file, "    \"{}-{}-{}\": {{", method, skip, problem).unwrap();
120
121                    for run in 1..RUNS + 1 {
122                        println!("    Run: {}", run);
123                        writeln!(file, "      \"run-{}\": {{", run).unwrap();
124
125                        // Create the network based on the architecture.
126                        let mut network: network::Network;
127                        network = network::Network::new(tensor::Shape::Triple(1, 28, 28));
128                        network.convolution(
129                            1,
130                            (3, 3),
131                            (1, 1),
132                            (1, 1),
133                            (1, 1),
134                            activation::Activation::ReLU,
135                            None,
136                        );
137
138                        // Check if the method is regular or feedback.
139                        if method == &"REGULAR" || method.contains(&"FB1") {
140                            network.convolution(
141                                1,
142                                (3, 3),
143                                (1, 1),
144                                (1, 1),
145                                (1, 1),
146                                activation::Activation::ReLU,
147                                None,
148                            );
149                            network.convolution(
150                                1,
151                                (3, 3),
152                                (1, 1),
153                                (1, 1),
154                                (1, 1),
155                                activation::Activation::ReLU,
156                                None,
157                            );
158                            network.maxpool((2, 2), (2, 2));
159
160                            // Add the feedback loop if applicable.
161                            if method.contains(&"FB1") {
162                                network.loopback(
163                                    2,
164                                    0,
165                                    method.chars().last().unwrap().to_digit(10).unwrap() as usize
166                                        - 1,
167                                    Arc::new(|_loops| 1.0),
168                                    false,
169                                );
170                            }
171                        } else {
172                            network.feedback(
173                                vec![feedback::Layer::Convolution(
174                                    1,
175                                    activation::Activation::ReLU,
176                                    (3, 3),
177                                    (1, 1),
178                                    (1, 1),
179                                    (1, 1),
180                                    None,
181                                )],
182                                method.chars().last().unwrap().to_digit(10).unwrap() as usize,
183                                false,
184                                false,
185                                feedback::Accumulation::Mean,
186                            );
187                            network.convolution(
188                                1,
189                                (3, 3),
190                                (1, 1),
191                                (1, 1),
192                                (1, 1),
193                                activation::Activation::ReLU,
194                                None,
195                            );
196                            network.maxpool((2, 2), (2, 2));
197                        }
198
199                        // Set the output layer based on the problem.
200                        if problem == &"REGRESSION" {
201                            panic!("Invalid problem type.");
202                        } else {
203                            network.dense(10, activation::Activation::Softmax, true, None);
204                            network.set_objective(objective::Objective::CrossEntropy, None);
205                        }
206
207                        // Add the skip connection if applicable.
208                        if *skip {
209                            network.connect(1, network.layers.len() - 2);
210                        }
211
212                        network
213                            .set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
214
215                        // Train the network
216                        let (train_loss, val_loss, val_acc);
217                        if problem == &"REGRESSION" {
218                            unimplemented!("Regression not implemented.");
219                        } else {
220                            (train_loss, val_loss, val_acc) = network.learn(
221                                &x_train,
222                                &class_train,
223                                Some((&x_test, &class_test, 10)),
224                                32,
225                                40,
226                                None,
227                            );
228                        }
229
230                        // Store the loss and accuracy.
231                        writeln!(file, "        \"train\": {{").unwrap();
232                        writeln!(file, "          \"trn-loss\": {:?},", train_loss).unwrap();
233                        writeln!(file, "          \"val-loss\": {:?},", val_loss).unwrap();
234                        writeln!(file, "          \"val-acc\": {:?}", val_acc).unwrap();
235
236                        // Probe the network (if applicable).
237                        if method != &"REGULAR" {
238                            println!("    > Without feedback.");
239
240                            // Store the network's loopbacks and layers to restore them later.
241                            let loopbacks = network.loopbacks.clone();
242                            let layers = network.layers.clone();
243
244                            // Remove the feedback loop.
245                            if method.contains(&"FB1") {
246                                network.loopbacks = HashMap::new();
247                            } else {
248                                match &mut network.layers.get_mut(1).unwrap() {
249                                    network::Layer::Feedback(fb) => {
250                                        // Only keep the first two layers.
251                                        fb.layers = fb.layers.drain(0..2).collect();
252                                    }
253                                    _ => panic!("Invalid layer."),
254                                };
255                            }
256
257                            let (test_loss, test_acc);
258                            if problem == &"REGRESSION" {
259                                unimplemented!("Regression not implemented.");
260                            } else {
261                                (test_loss, test_acc) =
262                                    network.validate(&x_test, &class_test, 1e-6);
263                            }
264
265                            writeln!(file, "        }},").unwrap();
266                            writeln!(file, "        \"no-feedback\": {{").unwrap();
267                            writeln!(file, "          \"tst-loss\": {},", test_loss).unwrap();
268                            writeln!(file, "          \"tst-acc\": {}", test_acc).unwrap();
269
270                            // Restore the feedback loop.
271                            network.loopbacks = loopbacks;
272                            network.layers = layers;
273                        }
274                        if *skip {
275                            println!("    > Without skip.");
276                            network.connect = HashMap::new();
277
278                            let (test_loss, test_acc);
279                            if problem == &"REGRESSION" {
280                                unimplemented!("Regression not implemented.");
281                            } else {
282                                (test_loss, test_acc) =
283                                    network.validate(&x_test, &class_test, 1e-6);
284                            }
285
286                            writeln!(file, "        }},").unwrap();
287                            writeln!(file, "        \"no-skip\": {{").unwrap();
288                            writeln!(file, "          \"tst-loss\": {},", test_loss).unwrap();
289                            writeln!(file, "          \"tst-acc\": {}", test_acc).unwrap();
290                        }
291                        writeln!(file, "        }}").unwrap();
292
293                        if run == RUNS {
294                            writeln!(file, "      }}").unwrap();
295                            if method == &"FB2x3" && *skip && problem == &"CLASSIFICATION" {
296                                writeln!(file, "      }}").unwrap();
297                            } else {
298                                writeln!(file, "      }},").unwrap();
299                            }
300                        } else {
301                            writeln!(file, "      }},").unwrap();
302                        }
303                    }
304                });
305            });
306        });
307    writeln!(file, "  }}").unwrap();
308    writeln!(file, "]").unwrap();
309}