compare_ftir_cnn/ftir-cnn.rs
1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// NOTE: Gradient clipping is added for the classification task.
4//
5// Code for comparison between the various architectures.
6// The respective loss and accuracies is stored to the file `~/output/compare/ftir-cnn.json`.
7//
8// In addition, some simple probing of the networks are done.
9// Namely, validating the trained networks with and without feedback and skip connections.
10//
11// for (
12// REGULAR,
13// FEEDBACK[approach=1, loops=2],
14// FEEDBACK[approach=1, loops=3],
15// FEEDBACK[approach=2, loops=2],
16// FEEDBACK[approach=2, loops=3]
17// ) do {
18//
19// for (NOSKIP, SKIP) do {
20//
21// for (CLASSIFICATION, REGRESSION) do {
22//
23// for (run in RUNS) do {
24//
25// create the network
26// train the network
27// validate the network
28// store the loss and accuracy
29// probe the network
30// store the probing results
31//
32// }
33// }
34// }
35// }
36
37use neurons::{activation, feedback, network, objective, optimizer, tensor};
38
39use std::{
40 collections::HashMap,
41 fs::File,
42 io::{BufRead, BufReader, Write},
43 sync::Arc,
44};
45
46const RUNS: usize = 5;
47
48fn data(
49 path: &str,
50) -> (
51 (
52 Vec<tensor::Tensor>,
53 Vec<tensor::Tensor>,
54 Vec<tensor::Tensor>,
55 ),
56 (
57 Vec<tensor::Tensor>,
58 Vec<tensor::Tensor>,
59 Vec<tensor::Tensor>,
60 ),
61 (
62 Vec<tensor::Tensor>,
63 Vec<tensor::Tensor>,
64 Vec<tensor::Tensor>,
65 ),
66) {
67 let reader = BufReader::new(File::open(&path).unwrap());
68
69 let mut x_train: Vec<tensor::Tensor> = Vec::new();
70 let mut y_train: Vec<tensor::Tensor> = Vec::new();
71 let mut class_train: Vec<tensor::Tensor> = Vec::new();
72
73 let mut x_test: Vec<tensor::Tensor> = Vec::new();
74 let mut y_test: Vec<tensor::Tensor> = Vec::new();
75 let mut class_test: Vec<tensor::Tensor> = Vec::new();
76
77 let mut x_val: Vec<tensor::Tensor> = Vec::new();
78 let mut y_val: Vec<tensor::Tensor> = Vec::new();
79 let mut class_val: Vec<tensor::Tensor> = Vec::new();
80
81 for line in reader.lines().skip(1) {
82 let line = line.unwrap();
83 let record: Vec<&str> = line.split(',').collect();
84
85 let mut data: Vec<f32> = Vec::new();
86 for i in 0..571 {
87 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
88 }
89 let data: Vec<Vec<Vec<f32>>> = vec![vec![data]];
90 match record.get(573).unwrap() {
91 &"Train" => {
92 x_train.push(tensor::Tensor::triple(data));
93 y_train.push(tensor::Tensor::single(vec![record
94 .get(571)
95 .unwrap()
96 .parse::<f32>()
97 .unwrap()]));
98 class_train.push(tensor::Tensor::one_hot(
99 record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
100 28,
101 ));
102 }
103 &"Test" => {
104 x_test.push(tensor::Tensor::triple(data));
105 y_test.push(tensor::Tensor::single(vec![record
106 .get(571)
107 .unwrap()
108 .parse::<f32>()
109 .unwrap()]));
110 class_test.push(tensor::Tensor::one_hot(
111 record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
112 28,
113 ));
114 }
115 &"Val" => {
116 x_val.push(tensor::Tensor::triple(data));
117 y_val.push(tensor::Tensor::single(vec![record
118 .get(571)
119 .unwrap()
120 .parse::<f32>()
121 .unwrap()]));
122 class_val.push(tensor::Tensor::one_hot(
123 record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
124 28,
125 ));
126 }
127 _ => panic!("> Unknown class."),
128 }
129 }
130
131 // let mut generator = random::Generator::create(12345);
132 // let mut indices: Vec<usize> = (0..x.len()).collect();
133 // generator.shuffle(&mut indices);
134
135 (
136 (x_train, y_train, class_train),
137 (x_test, y_test, class_test),
138 (x_val, y_val, class_val),
139 )
140}
141
142fn main() {
143 // Load the ftir dataset
144 let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
145 data("./examples/datasets/ftir.csv");
146
147 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
148 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
149 let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
150
151 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
152 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
153 let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
154
155 let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
156 let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
157 let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
158
159 println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
160 println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
161 println!("Validation data {}x{}\n", x_val.len(), x_val[0].shape,);
162
163 // Create the results file.
164 let mut file = File::create("./output/compare/ftir-cnn.json").unwrap();
165 writeln!(file, "[").unwrap();
166 writeln!(file, " {{").unwrap();
167
168 vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
169 .iter()
170 .for_each(|method| {
171 println!("Method: {}", method);
172 vec![false, true].iter().for_each(|skip| {
173 println!(" Skip: {}", skip);
174 vec!["CLASSIFICATION", "REGRESSION"]
175 .iter()
176 .for_each(|problem| {
177 println!(" Problem: {}", problem);
178 writeln!(file, " \"{}-{}-{}\": {{", method, skip, problem).unwrap();
179
180 for run in 1..RUNS + 1 {
181 println!(" Run: {}", run);
182 writeln!(file, " \"run-{}\": {{", run).unwrap();
183
184 // Create the network based on the architecture.
185 let mut network: network::Network;
186 network = network::Network::new(tensor::Shape::Triple(1, 1, 571));
187
188 // Check if the method is regular or feedback.
189 if method == &"REGULAR" || method.contains(&"FB1") {
190 network.convolution(
191 1,
192 (1, 9),
193 (1, 1),
194 (0, 4),
195 (1, 1),
196 activation::Activation::ReLU,
197 None,
198 );
199 network.convolution(
200 1,
201 (1, 9),
202 (1, 1),
203 (0, 4),
204 (1, 1),
205 activation::Activation::ReLU,
206 None,
207 );
208 network.dense(32, activation::Activation::ReLU, false, None);
209
210 // Add the feedback loop if applicable.
211 if method.contains(&"FB1") {
212 network.loopback(
213 1,
214 0,
215 method.chars().last().unwrap().to_digit(10).unwrap()
216 as usize
217 - 1,
218 Arc::new(|_loops| 1.0),
219 false,
220 );
221 }
222 } else {
223 network.feedback(
224 vec![
225 feedback::Layer::Convolution(
226 1,
227 activation::Activation::ReLU,
228 (1, 9),
229 (1, 1),
230 (0, 4),
231 (1, 1),
232 None,
233 ),
234 feedback::Layer::Convolution(
235 1,
236 activation::Activation::ReLU,
237 (1, 9),
238 (1, 1),
239 (0, 4),
240 (1, 1),
241 None,
242 ),
243 ],
244 method.chars().last().unwrap().to_digit(10).unwrap() as usize,
245 false,
246 false,
247 feedback::Accumulation::Mean,
248 );
249 network.dense(32, activation::Activation::ReLU, false, None);
250 }
251
252 // Set the output layer based on the problem.
253 if problem == &"REGRESSION" {
254 network.dense(1, activation::Activation::Linear, false, None);
255 network.set_objective(objective::Objective::RMSE, None);
256 } else {
257 network.dense(28, activation::Activation::Softmax, false, None);
258 network.set_objective(
259 objective::Objective::CrossEntropy,
260 Some((-5.0, 5.0)),
261 );
262 }
263
264 // Add the skip connection if applicable.
265 if *skip {
266 network.connect(0, network.layers.len() - 2);
267 }
268
269 network.set_optimizer(optimizer::Adam::create(
270 0.001, 0.9, 0.999, 1e-8, None,
271 ));
272
273 // Train the network
274 let (train_loss, val_loss, val_acc);
275 if problem == &"REGRESSION" {
276 (train_loss, val_loss, val_acc) = network.learn(
277 &x_train,
278 &y_train,
279 Some((&x_val, &y_val, 100)),
280 32,
281 2500,
282 None,
283 );
284 } else {
285 (train_loss, val_loss, val_acc) = network.learn(
286 &x_train,
287 &class_train,
288 Some((&x_val, &class_val, 100)),
289 32,
290 2500,
291 None,
292 );
293 }
294
295 // Store the loss and accuracy.
296 writeln!(file, " \"train\": {{").unwrap();
297 writeln!(file, " \"trn-loss\": {:?},", train_loss).unwrap();
298 writeln!(file, " \"val-loss\": {:?},", val_loss).unwrap();
299 writeln!(file, " \"val-acc\": {:?}", val_acc).unwrap();
300
301 // Probe the network (if applicable).
302 if method != &"REGULAR" {
303 println!(" > Without feedback.");
304
305 // Store the network's loopbacks and layers to restore them later.
306 let loopbacks = network.loopbacks.clone();
307 let layers = network.layers.clone();
308
309 // Remove the feedback loop.
310 if method.contains(&"FB1") {
311 network.loopbacks = HashMap::new();
312 } else {
313 match &mut network.layers.get_mut(0).unwrap() {
314 network::Layer::Feedback(fb) => {
315 // Only keep the first two layers.
316 fb.layers = fb.layers.drain(0..2).collect();
317 }
318 _ => panic!("Invalid layer."),
319 };
320 }
321
322 let (val_loss, val_acc);
323 if problem == &"REGRESSION" {
324 (val_loss, val_acc) = network.validate(&x_val, &y_val, 1e-6);
325 } else {
326 (val_loss, val_acc) =
327 network.validate(&x_val, &class_val, 1e-6);
328 }
329 let (test_loss, test_acc);
330 if problem == &"REGRESSION" {
331 (test_loss, test_acc) =
332 network.validate(&x_test, &y_test, 1e-6);
333 } else {
334 (test_loss, test_acc) =
335 network.validate(&x_test, &class_test, 1e-6);
336 }
337
338 writeln!(file, " }},").unwrap();
339 writeln!(file, " \"no-feedback\": {{").unwrap();
340 writeln!(file, " \"val-loss\": {},", val_loss).unwrap();
341 writeln!(file, " \"val-acc\": {},", val_acc).unwrap();
342 writeln!(file, " \"tst-loss\": {},", test_loss).unwrap();
343 writeln!(file, " \"tst-acc\": {}", test_acc).unwrap();
344
345 // Restore the feedback loop.
346 network.loopbacks = loopbacks;
347 network.layers = layers;
348 }
349 if *skip {
350 println!(" > Without skip.");
351 network.connect = HashMap::new();
352
353 let (val_loss, val_acc);
354 if problem == &"REGRESSION" {
355 (val_loss, val_acc) = network.validate(&x_val, &y_val, 1e-6);
356 } else {
357 (val_loss, val_acc) =
358 network.validate(&x_val, &class_val, 1e-6);
359 }
360 let (test_loss, test_acc);
361 if problem == &"REGRESSION" {
362 (test_loss, test_acc) =
363 network.validate(&x_test, &y_test, 1e-6);
364 } else {
365 (test_loss, test_acc) =
366 network.validate(&x_test, &class_test, 1e-6);
367 }
368
369 writeln!(file, " }},").unwrap();
370 writeln!(file, " \"no-skip\": {{").unwrap();
371 writeln!(file, " \"val-loss\": {},", val_loss).unwrap();
372 writeln!(file, " \"val-acc\": {},", val_acc).unwrap();
373 writeln!(file, " \"tst-loss\": {},", test_loss).unwrap();
374 writeln!(file, " \"tst-acc\": {}", test_acc).unwrap();
375 }
376 writeln!(file, " }}").unwrap();
377
378 if run == RUNS {
379 writeln!(file, " }}").unwrap();
380 if method == &"FB2x3" && *skip && problem == &"REGRESSION" {
381 writeln!(file, " }}").unwrap();
382 } else {
383 writeln!(file, " }},").unwrap();
384 }
385 } else {
386 writeln!(file, " }},").unwrap();
387 }
388 }
389 });
390 });
391 });
392 writeln!(file, " }}").unwrap();
393 writeln!(file, "]").unwrap();
394}