1use neurons::{activation, feedback, network, objective, optimizer, tensor};
36
37use std::{
38 collections::HashMap,
39 fs::File,
40 io::{BufReader, Read, Result, Write},
41 sync::Arc,
42};
43
44const RUNS: usize = 5;
45
46fn read(reader: &mut dyn Read) -> Result<u32> {
47 let mut buffer = [0; 4];
48 reader.read_exact(&mut buffer)?;
49 Ok(u32::from_be_bytes(buffer))
50}
51
52fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
53 let mut reader = BufReader::new(File::open(path)?);
54 let mut images: Vec<tensor::Tensor> = Vec::new();
55
56 let _magic_number = read(&mut reader)?;
57 let num_images = read(&mut reader)?;
58 let num_rows = read(&mut reader)?;
59 let num_cols = read(&mut reader)?;
60
61 for _ in 0..num_images {
62 let mut image: Vec<Vec<f32>> = Vec::new();
63 for _ in 0..num_rows {
64 let mut row: Vec<f32> = Vec::new();
65 for _ in 0..num_cols {
66 let mut pixel = [0];
67 reader.read_exact(&mut pixel)?;
68 row.push(pixel[0] as f32 / 255.0);
69 }
70 image.push(row);
71 }
72 images.push(tensor::Tensor::triple(vec![image]));
73 }
74
75 Ok(images)
76}
77
78fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
79 let mut reader = BufReader::new(File::open(file_path)?);
80 let _magic_number = read(&mut reader)?;
81 let num_labels = read(&mut reader)?;
82
83 let mut _labels = vec![0; num_labels as usize];
84 reader.read_exact(&mut _labels)?;
85
86 Ok(_labels
87 .iter()
88 .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
89 .collect())
90}
91
92fn main() {
93 let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
94 let class_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
95 let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
96 let class_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
97
98 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
99 let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
100 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
101 let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
102
103 println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
104 println!("Test data {}x{}\n", x_test.len(), x_test[0].shape,);
105
106 let mut file = File::create("./output/compare/mnist.json").unwrap();
108 writeln!(file, "[").unwrap();
109 writeln!(file, " {{").unwrap();
110
111 vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
112 .iter()
113 .for_each(|method| {
114 println!("Method: {}", method);
115 vec![false, true].iter().for_each(|skip| {
116 println!(" Skip: {}", skip);
117 vec!["CLASSIFICATION"].iter().for_each(|problem| {
118 println!(" Problem: {}", problem);
119 writeln!(file, " \"{}-{}-{}\": {{", method, skip, problem).unwrap();
120
121 for run in 1..RUNS + 1 {
122 println!(" Run: {}", run);
123 writeln!(file, " \"run-{}\": {{", run).unwrap();
124
125 let mut network: network::Network;
127 network = network::Network::new(tensor::Shape::Triple(1, 28, 28));
128 network.convolution(
129 1,
130 (3, 3),
131 (1, 1),
132 (1, 1),
133 (1, 1),
134 activation::Activation::ReLU,
135 None,
136 );
137
138 if method == &"REGULAR" || method.contains(&"FB1") {
140 network.convolution(
141 1,
142 (3, 3),
143 (1, 1),
144 (1, 1),
145 (1, 1),
146 activation::Activation::ReLU,
147 None,
148 );
149 network.convolution(
150 1,
151 (3, 3),
152 (1, 1),
153 (1, 1),
154 (1, 1),
155 activation::Activation::ReLU,
156 None,
157 );
158 network.maxpool((2, 2), (2, 2));
159
160 if method.contains(&"FB1") {
162 network.loopback(
163 2,
164 0,
165 method.chars().last().unwrap().to_digit(10).unwrap() as usize
166 - 1,
167 Arc::new(|_loops| 1.0),
168 false,
169 );
170 }
171 } else {
172 network.feedback(
173 vec![feedback::Layer::Convolution(
174 1,
175 activation::Activation::ReLU,
176 (3, 3),
177 (1, 1),
178 (1, 1),
179 (1, 1),
180 None,
181 )],
182 method.chars().last().unwrap().to_digit(10).unwrap() as usize,
183 false,
184 false,
185 feedback::Accumulation::Mean,
186 );
187 network.convolution(
188 1,
189 (3, 3),
190 (1, 1),
191 (1, 1),
192 (1, 1),
193 activation::Activation::ReLU,
194 None,
195 );
196 network.maxpool((2, 2), (2, 2));
197 }
198
199 if problem == &"REGRESSION" {
201 panic!("Invalid problem type.");
202 } else {
203 network.dense(10, activation::Activation::Softmax, true, None);
204 network.set_objective(objective::Objective::CrossEntropy, None);
205 }
206
207 if *skip {
209 network.connect(1, network.layers.len() - 2);
210 }
211
212 network
213 .set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
214
215 let (train_loss, val_loss, val_acc);
217 if problem == &"REGRESSION" {
218 unimplemented!("Regression not implemented.");
219 } else {
220 (train_loss, val_loss, val_acc) = network.learn(
221 &x_train,
222 &class_train,
223 Some((&x_test, &class_test, 10)),
224 32,
225 40,
226 None,
227 );
228 }
229
230 writeln!(file, " \"train\": {{").unwrap();
232 writeln!(file, " \"trn-loss\": {:?},", train_loss).unwrap();
233 writeln!(file, " \"val-loss\": {:?},", val_loss).unwrap();
234 writeln!(file, " \"val-acc\": {:?}", val_acc).unwrap();
235
236 if method != &"REGULAR" {
238 println!(" > Without feedback.");
239
240 let loopbacks = network.loopbacks.clone();
242 let layers = network.layers.clone();
243
244 if method.contains(&"FB1") {
246 network.loopbacks = HashMap::new();
247 } else {
248 match &mut network.layers.get_mut(1).unwrap() {
249 network::Layer::Feedback(fb) => {
250 fb.layers = fb.layers.drain(0..2).collect();
252 }
253 _ => panic!("Invalid layer."),
254 };
255 }
256
257 let (test_loss, test_acc);
258 if problem == &"REGRESSION" {
259 unimplemented!("Regression not implemented.");
260 } else {
261 (test_loss, test_acc) =
262 network.validate(&x_test, &class_test, 1e-6);
263 }
264
265 writeln!(file, " }},").unwrap();
266 writeln!(file, " \"no-feedback\": {{").unwrap();
267 writeln!(file, " \"tst-loss\": {},", test_loss).unwrap();
268 writeln!(file, " \"tst-acc\": {}", test_acc).unwrap();
269
270 network.loopbacks = loopbacks;
272 network.layers = layers;
273 }
274 if *skip {
275 println!(" > Without skip.");
276 network.connect = HashMap::new();
277
278 let (test_loss, test_acc);
279 if problem == &"REGRESSION" {
280 unimplemented!("Regression not implemented.");
281 } else {
282 (test_loss, test_acc) =
283 network.validate(&x_test, &class_test, 1e-6);
284 }
285
286 writeln!(file, " }},").unwrap();
287 writeln!(file, " \"no-skip\": {{").unwrap();
288 writeln!(file, " \"tst-loss\": {},", test_loss).unwrap();
289 writeln!(file, " \"tst-acc\": {}", test_acc).unwrap();
290 }
291 writeln!(file, " }}").unwrap();
292
293 if run == RUNS {
294 writeln!(file, " }}").unwrap();
295 if method == &"FB2x3" && *skip && problem == &"CLASSIFICATION" {
296 writeln!(file, " }}").unwrap();
297 } else {
298 writeln!(file, " }},").unwrap();
299 }
300 } else {
301 writeln!(file, " }},").unwrap();
302 }
303 }
304 });
305 });
306 });
307 writeln!(file, " }}").unwrap();
308 writeln!(file, "]").unwrap();
309}