compare_ftir_cnn/
ftir-cnn.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// NOTE: Gradient clipping is added for the classification task.
4//
5// Code for comparison between the various architectures.
6// The respective loss and accuracies is stored to the file `~/output/compare/ftir-cnn.json`.
7//
8// In addition, some simple probing of the networks are done.
9// Namely, validating the trained networks with and without feedback and skip connections.
10//
11// for (
12//   REGULAR,
13//   FEEDBACK[approach=1, loops=2],
14//   FEEDBACK[approach=1, loops=3],
15//   FEEDBACK[approach=2, loops=2],
16//   FEEDBACK[approach=2, loops=3]
17// ) do {
18//
19//   for (NOSKIP, SKIP) do {
20//
21//     for (CLASSIFICATION, REGRESSION) do {
22//
23//       for (run in RUNS) do {
24//
25//         create the network
26//         train the network
27//         validate the network
28//         store the loss and accuracy
29//         probe the network
30//         store the probing results
31//
32//       }
33//     }
34//   }
35// }
36
37use neurons::{activation, feedback, network, objective, optimizer, tensor};
38
39use std::{
40    collections::HashMap,
41    fs::File,
42    io::{BufRead, BufReader, Write},
43    sync::Arc,
44};
45
46const RUNS: usize = 5;
47
48fn data(
49    path: &str,
50) -> (
51    (
52        Vec<tensor::Tensor>,
53        Vec<tensor::Tensor>,
54        Vec<tensor::Tensor>,
55    ),
56    (
57        Vec<tensor::Tensor>,
58        Vec<tensor::Tensor>,
59        Vec<tensor::Tensor>,
60    ),
61    (
62        Vec<tensor::Tensor>,
63        Vec<tensor::Tensor>,
64        Vec<tensor::Tensor>,
65    ),
66) {
67    let reader = BufReader::new(File::open(&path).unwrap());
68
69    let mut x_train: Vec<tensor::Tensor> = Vec::new();
70    let mut y_train: Vec<tensor::Tensor> = Vec::new();
71    let mut class_train: Vec<tensor::Tensor> = Vec::new();
72
73    let mut x_test: Vec<tensor::Tensor> = Vec::new();
74    let mut y_test: Vec<tensor::Tensor> = Vec::new();
75    let mut class_test: Vec<tensor::Tensor> = Vec::new();
76
77    let mut x_val: Vec<tensor::Tensor> = Vec::new();
78    let mut y_val: Vec<tensor::Tensor> = Vec::new();
79    let mut class_val: Vec<tensor::Tensor> = Vec::new();
80
81    for line in reader.lines().skip(1) {
82        let line = line.unwrap();
83        let record: Vec<&str> = line.split(',').collect();
84
85        let mut data: Vec<f32> = Vec::new();
86        for i in 0..571 {
87            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
88        }
89        let data: Vec<Vec<Vec<f32>>> = vec![vec![data]];
90        match record.get(573).unwrap() {
91            &"Train" => {
92                x_train.push(tensor::Tensor::triple(data));
93                y_train.push(tensor::Tensor::single(vec![record
94                    .get(571)
95                    .unwrap()
96                    .parse::<f32>()
97                    .unwrap()]));
98                class_train.push(tensor::Tensor::one_hot(
99                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
100                    28,
101                ));
102            }
103            &"Test" => {
104                x_test.push(tensor::Tensor::triple(data));
105                y_test.push(tensor::Tensor::single(vec![record
106                    .get(571)
107                    .unwrap()
108                    .parse::<f32>()
109                    .unwrap()]));
110                class_test.push(tensor::Tensor::one_hot(
111                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
112                    28,
113                ));
114            }
115            &"Val" => {
116                x_val.push(tensor::Tensor::triple(data));
117                y_val.push(tensor::Tensor::single(vec![record
118                    .get(571)
119                    .unwrap()
120                    .parse::<f32>()
121                    .unwrap()]));
122                class_val.push(tensor::Tensor::one_hot(
123                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
124                    28,
125                ));
126            }
127            _ => panic!("> Unknown class."),
128        }
129    }
130
131    // let mut generator = random::Generator::create(12345);
132    // let mut indices: Vec<usize> = (0..x.len()).collect();
133    // generator.shuffle(&mut indices);
134
135    (
136        (x_train, y_train, class_train),
137        (x_test, y_test, class_test),
138        (x_val, y_val, class_val),
139    )
140}
141
142fn main() {
143    // Load the ftir dataset
144    let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
145        data("./examples/datasets/ftir.csv");
146
147    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
148    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
149    let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
150
151    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
152    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
153    let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
154
155    let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
156    let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
157    let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
158
159    println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
160    println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
161    println!("Validation data {}x{}\n", x_val.len(), x_val[0].shape,);
162
163    // Create the results file.
164    let mut file = File::create("./output/compare/ftir-cnn.json").unwrap();
165    writeln!(file, "[").unwrap();
166    writeln!(file, "  {{").unwrap();
167
168    vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
169        .iter()
170        .for_each(|method| {
171            println!("Method: {}", method);
172            vec![false, true].iter().for_each(|skip| {
173                println!("  Skip: {}", skip);
174                vec!["CLASSIFICATION", "REGRESSION"]
175                    .iter()
176                    .for_each(|problem| {
177                        println!("   Problem: {}", problem);
178                        writeln!(file, "    \"{}-{}-{}\": {{", method, skip, problem).unwrap();
179
180                        for run in 1..RUNS + 1 {
181                            println!("    Run: {}", run);
182                            writeln!(file, "      \"run-{}\": {{", run).unwrap();
183
184                            // Create the network based on the architecture.
185                            let mut network: network::Network;
186                            network = network::Network::new(tensor::Shape::Triple(1, 1, 571));
187
188                            // Check if the method is regular or feedback.
189                            if method == &"REGULAR" || method.contains(&"FB1") {
190                                network.convolution(
191                                    1,
192                                    (1, 9),
193                                    (1, 1),
194                                    (0, 4),
195                                    (1, 1),
196                                    activation::Activation::ReLU,
197                                    None,
198                                );
199                                network.convolution(
200                                    1,
201                                    (1, 9),
202                                    (1, 1),
203                                    (0, 4),
204                                    (1, 1),
205                                    activation::Activation::ReLU,
206                                    None,
207                                );
208                                network.dense(32, activation::Activation::ReLU, false, None);
209
210                                // Add the feedback loop if applicable.
211                                if method.contains(&"FB1") {
212                                    network.loopback(
213                                        1,
214                                        0,
215                                        method.chars().last().unwrap().to_digit(10).unwrap()
216                                            as usize
217                                            - 1,
218                                        Arc::new(|_loops| 1.0),
219                                        false,
220                                    );
221                                }
222                            } else {
223                                network.feedback(
224                                    vec![
225                                        feedback::Layer::Convolution(
226                                            1,
227                                            activation::Activation::ReLU,
228                                            (1, 9),
229                                            (1, 1),
230                                            (0, 4),
231                                            (1, 1),
232                                            None,
233                                        ),
234                                        feedback::Layer::Convolution(
235                                            1,
236                                            activation::Activation::ReLU,
237                                            (1, 9),
238                                            (1, 1),
239                                            (0, 4),
240                                            (1, 1),
241                                            None,
242                                        ),
243                                    ],
244                                    method.chars().last().unwrap().to_digit(10).unwrap() as usize,
245                                    false,
246                                    false,
247                                    feedback::Accumulation::Mean,
248                                );
249                                network.dense(32, activation::Activation::ReLU, false, None);
250                            }
251
252                            // Set the output layer based on the problem.
253                            if problem == &"REGRESSION" {
254                                network.dense(1, activation::Activation::Linear, false, None);
255                                network.set_objective(objective::Objective::RMSE, None);
256                            } else {
257                                network.dense(28, activation::Activation::Softmax, false, None);
258                                network.set_objective(
259                                    objective::Objective::CrossEntropy,
260                                    Some((-5.0, 5.0)),
261                                );
262                            }
263
264                            // Add the skip connection if applicable.
265                            if *skip {
266                                network.connect(0, network.layers.len() - 2);
267                            }
268
269                            network.set_optimizer(optimizer::Adam::create(
270                                0.001, 0.9, 0.999, 1e-8, None,
271                            ));
272
273                            // Train the network
274                            let (train_loss, val_loss, val_acc);
275                            if problem == &"REGRESSION" {
276                                (train_loss, val_loss, val_acc) = network.learn(
277                                    &x_train,
278                                    &y_train,
279                                    Some((&x_val, &y_val, 100)),
280                                    32,
281                                    2500,
282                                    None,
283                                );
284                            } else {
285                                (train_loss, val_loss, val_acc) = network.learn(
286                                    &x_train,
287                                    &class_train,
288                                    Some((&x_val, &class_val, 100)),
289                                    32,
290                                    2500,
291                                    None,
292                                );
293                            }
294
295                            // Store the loss and accuracy.
296                            writeln!(file, "        \"train\": {{").unwrap();
297                            writeln!(file, "          \"trn-loss\": {:?},", train_loss).unwrap();
298                            writeln!(file, "          \"val-loss\": {:?},", val_loss).unwrap();
299                            writeln!(file, "          \"val-acc\": {:?}", val_acc).unwrap();
300
301                            // Probe the network (if applicable).
302                            if method != &"REGULAR" {
303                                println!("    > Without feedback.");
304
305                                // Store the network's loopbacks and layers to restore them later.
306                                let loopbacks = network.loopbacks.clone();
307                                let layers = network.layers.clone();
308
309                                // Remove the feedback loop.
310                                if method.contains(&"FB1") {
311                                    network.loopbacks = HashMap::new();
312                                } else {
313                                    match &mut network.layers.get_mut(0).unwrap() {
314                                        network::Layer::Feedback(fb) => {
315                                            // Only keep the first two layers.
316                                            fb.layers = fb.layers.drain(0..2).collect();
317                                        }
318                                        _ => panic!("Invalid layer."),
319                                    };
320                                }
321
322                                let (val_loss, val_acc);
323                                if problem == &"REGRESSION" {
324                                    (val_loss, val_acc) = network.validate(&x_val, &y_val, 1e-6);
325                                } else {
326                                    (val_loss, val_acc) =
327                                        network.validate(&x_val, &class_val, 1e-6);
328                                }
329                                let (test_loss, test_acc);
330                                if problem == &"REGRESSION" {
331                                    (test_loss, test_acc) =
332                                        network.validate(&x_test, &y_test, 1e-6);
333                                } else {
334                                    (test_loss, test_acc) =
335                                        network.validate(&x_test, &class_test, 1e-6);
336                                }
337
338                                writeln!(file, "        }},").unwrap();
339                                writeln!(file, "        \"no-feedback\": {{").unwrap();
340                                writeln!(file, "          \"val-loss\": {},", val_loss).unwrap();
341                                writeln!(file, "          \"val-acc\": {},", val_acc).unwrap();
342                                writeln!(file, "          \"tst-loss\": {},", test_loss).unwrap();
343                                writeln!(file, "          \"tst-acc\": {}", test_acc).unwrap();
344
345                                // Restore the feedback loop.
346                                network.loopbacks = loopbacks;
347                                network.layers = layers;
348                            }
349                            if *skip {
350                                println!("    > Without skip.");
351                                network.connect = HashMap::new();
352
353                                let (val_loss, val_acc);
354                                if problem == &"REGRESSION" {
355                                    (val_loss, val_acc) = network.validate(&x_val, &y_val, 1e-6);
356                                } else {
357                                    (val_loss, val_acc) =
358                                        network.validate(&x_val, &class_val, 1e-6);
359                                }
360                                let (test_loss, test_acc);
361                                if problem == &"REGRESSION" {
362                                    (test_loss, test_acc) =
363                                        network.validate(&x_test, &y_test, 1e-6);
364                                } else {
365                                    (test_loss, test_acc) =
366                                        network.validate(&x_test, &class_test, 1e-6);
367                                }
368
369                                writeln!(file, "        }},").unwrap();
370                                writeln!(file, "        \"no-skip\": {{").unwrap();
371                                writeln!(file, "          \"val-loss\": {},", val_loss).unwrap();
372                                writeln!(file, "          \"val-acc\": {},", val_acc).unwrap();
373                                writeln!(file, "          \"tst-loss\": {},", test_loss).unwrap();
374                                writeln!(file, "          \"tst-acc\": {}", test_acc).unwrap();
375                            }
376                            writeln!(file, "        }}").unwrap();
377
378                            if run == RUNS {
379                                writeln!(file, "      }}").unwrap();
380                                if method == &"FB2x3" && *skip && problem == &"REGRESSION" {
381                                    writeln!(file, "    }}").unwrap();
382                                } else {
383                                    writeln!(file, "    }},").unwrap();
384                                }
385                            } else {
386                                writeln!(file, "      }},").unwrap();
387                            }
388                        }
389                    });
390            });
391        });
392    writeln!(file, "  }}").unwrap();
393    writeln!(file, "]").unwrap();
394}