timing_mnist/
mnist.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// Code for comparison between the various architectures and their time differences.
4
5use neurons::{activation, feedback, network, objective, optimizer, tensor};
6
7use std::{
8    fs::File,
9    io::{BufReader, Read, Result, Write},
10    sync::Arc,
11    time,
12};
13
14const RUNS: usize = 5;
15const EPOCHS: i32 = 1;
16
17fn read(reader: &mut dyn Read) -> Result<u32> {
18    let mut buffer = [0; 4];
19    reader.read_exact(&mut buffer)?;
20    Ok(u32::from_be_bytes(buffer))
21}
22
23fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
24    let mut reader = BufReader::new(File::open(path)?);
25    let mut images: Vec<tensor::Tensor> = Vec::new();
26
27    let _magic_number = read(&mut reader)?;
28    let num_images = read(&mut reader)?;
29    let num_rows = read(&mut reader)?;
30    let num_cols = read(&mut reader)?;
31
32    for _ in 0..num_images {
33        let mut image: Vec<Vec<f32>> = Vec::new();
34        for _ in 0..num_rows {
35            let mut row: Vec<f32> = Vec::new();
36            for _ in 0..num_cols {
37                let mut pixel = [0];
38                reader.read_exact(&mut pixel)?;
39                row.push(pixel[0] as f32 / 255.0);
40            }
41            image.push(row);
42        }
43        images.push(tensor::Tensor::triple(vec![image]));
44    }
45
46    Ok(images)
47}
48
49fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
50    let mut reader = BufReader::new(File::open(file_path)?);
51    let _magic_number = read(&mut reader)?;
52    let num_labels = read(&mut reader)?;
53
54    let mut _labels = vec![0; num_labels as usize];
55    reader.read_exact(&mut _labels)?;
56
57    Ok(_labels
58        .iter()
59        .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
60        .collect())
61}
62
63fn main() {
64    let mut x = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
65    let mut y = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
66    let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
67    let class_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
68
69    x.extend(x_test);
70    y.extend(class_test);
71
72    let x: Vec<&tensor::Tensor> = x.iter().collect();
73    let y: Vec<&tensor::Tensor> = y.iter().collect();
74
75    // Create the results file.
76    let mut file = File::create("./output/timing/mnist.json").unwrap();
77    writeln!(file, "[").unwrap();
78    writeln!(file, "  {{").unwrap();
79
80    vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
81        .iter()
82        .for_each(|method| {
83            println!("Method: {}", method);
84            vec![false, true].iter().for_each(|skip| {
85                println!("  Skip: {}", skip);
86                vec!["CLASSIFICATION"].iter().for_each(|problem| {
87                    println!("   Problem: {}", problem);
88
89                    let mut train_times: Vec<f64> = Vec::new();
90                    let mut valid_times: Vec<f64> = Vec::new();
91
92                    for _ in 0..RUNS {
93                        // Create the network based on the architecture.
94                        let mut network: network::Network;
95                        network = network::Network::new(tensor::Shape::Triple(1, 28, 28));
96                        network.convolution(
97                            1,
98                            (3, 3),
99                            (1, 1),
100                            (1, 1),
101                            (1, 1),
102                            activation::Activation::ReLU,
103                            None,
104                        );
105
106                        // Check if the method is regular or feedback.
107                        if method == &"REGULAR" || method.contains(&"FB1") {
108                            network.convolution(
109                                1,
110                                (3, 3),
111                                (1, 1),
112                                (1, 1),
113                                (1, 1),
114                                activation::Activation::ReLU,
115                                None,
116                            );
117                            network.convolution(
118                                1,
119                                (3, 3),
120                                (1, 1),
121                                (1, 1),
122                                (1, 1),
123                                activation::Activation::ReLU,
124                                None,
125                            );
126                            network.maxpool((2, 2), (2, 2));
127
128                            // Add the feedback loop if applicable.
129                            if method.contains(&"FB1") {
130                                network.loopback(
131                                    2,
132                                    0,
133                                    method.chars().last().unwrap().to_digit(10).unwrap() as usize
134                                        - 1,
135                                    Arc::new(|_loops| 1.0),
136                                    false,
137                                );
138                            }
139                        } else {
140                            network.feedback(
141                                vec![feedback::Layer::Convolution(
142                                    1,
143                                    activation::Activation::ReLU,
144                                    (3, 3),
145                                    (1, 1),
146                                    (1, 1),
147                                    (1, 1),
148                                    None,
149                                )],
150                                method.chars().last().unwrap().to_digit(10).unwrap() as usize,
151                                false,
152                                false,
153                                feedback::Accumulation::Mean,
154                            );
155                            network.convolution(
156                                1,
157                                (3, 3),
158                                (1, 1),
159                                (1, 1),
160                                (1, 1),
161                                activation::Activation::ReLU,
162                                None,
163                            );
164                            network.maxpool((2, 2), (2, 2));
165                        }
166
167                        // Set the output layer based on the problem.
168                        if problem == &"REGRESSION" {
169                            panic!("Invalid problem type.");
170                        } else {
171                            network.dense(10, activation::Activation::Softmax, true, None);
172                            network.set_objective(objective::Objective::CrossEntropy, None);
173                        }
174
175                        // Add the skip connection if applicable.
176                        if *skip {
177                            network.connect(1, network.layers.len() - 2);
178                        }
179
180                        network
181                            .set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
182
183                        let start = time::Instant::now();
184
185                        // Train the network
186                        if problem == &"REGRESSION" {
187                            panic!("Invalid problem type.");
188                        } else {
189                            (_, _, _) = network.learn(&x, &y, None, 32, EPOCHS, None);
190                        }
191
192                        let duration = start.elapsed().as_secs_f64();
193                        train_times.push(duration);
194
195                        let start = time::Instant::now();
196
197                        // Validate the network
198                        if problem == &"REGRESSION" {
199                            panic!("Invalid problem type.");
200                        } else {
201                            (_) = network.predict_batch(&x);
202                        }
203
204                        let duration = start.elapsed().as_secs_f64();
205                        valid_times.push(duration);
206                    }
207
208                    if method == &"FB2x3" && *skip && problem == &"CLASSIFICATION" {
209                        writeln!(
210                            file,
211                            "    \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}}",
212                            method, skip, problem, train_times, valid_times
213                        )
214                        .unwrap();
215                    } else {
216                        writeln!(
217                            file,
218                            "    \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}},",
219                            method, skip, problem, train_times, valid_times
220                        )
221                        .unwrap();
222                    }
223                });
224            });
225        });
226    writeln!(file, "  }}").unwrap();
227    writeln!(file, "]").unwrap();
228}