compare_ftir_mlp/
ftir-mlp.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// NOTE: Gradient clipping is added for the classification task.
4//
5// Code for comparison between the various architectures.
6// The respective loss and accuracies is stored to the file `~/output/compare/ftir-mlp.json`.
7//
8// In addition, some simple probing of the networks are done.
9// Namely, validating the trained networks with and without feedback and skip connections.
10//
11// for (
12//   REGULAR,
13//   FEEDBACK[approach=1, loops=2],
14//   FEEDBACK[approach=1, loops=3],
15//   FEEDBACK[approach=2, loops=2],
16//   FEEDBACK[approach=2, loops=3]
17// ) do {
18//
19//   for (NOSKIP, SKIP) do {
20//
21//     for (CLASSIFICATION, REGRESSION) do {
22//
23//       for (run in RUNS) do {
24//
25//         create the network
26//         train the network
27//         validate the network
28//         store the loss and accuracy
29//         probe the network
30//         store the probing results
31//
32//       }
33//     }
34//   }
35// }
36
37use neurons::{activation, feedback, network, objective, optimizer, tensor};
38
39use std::{
40    collections::HashMap,
41    fs::File,
42    io::{BufRead, BufReader, Write},
43    sync::Arc,
44};
45
46const RUNS: usize = 5;
47
48fn data(
49    path: &str,
50) -> (
51    (
52        Vec<tensor::Tensor>,
53        Vec<tensor::Tensor>,
54        Vec<tensor::Tensor>,
55    ),
56    (
57        Vec<tensor::Tensor>,
58        Vec<tensor::Tensor>,
59        Vec<tensor::Tensor>,
60    ),
61    (
62        Vec<tensor::Tensor>,
63        Vec<tensor::Tensor>,
64        Vec<tensor::Tensor>,
65    ),
66) {
67    let reader = BufReader::new(File::open(&path).unwrap());
68
69    let mut x_train: Vec<tensor::Tensor> = Vec::new();
70    let mut y_train: Vec<tensor::Tensor> = Vec::new();
71    let mut class_train: Vec<tensor::Tensor> = Vec::new();
72
73    let mut x_test: Vec<tensor::Tensor> = Vec::new();
74    let mut y_test: Vec<tensor::Tensor> = Vec::new();
75    let mut class_test: Vec<tensor::Tensor> = Vec::new();
76
77    let mut x_val: Vec<tensor::Tensor> = Vec::new();
78    let mut y_val: Vec<tensor::Tensor> = Vec::new();
79    let mut class_val: Vec<tensor::Tensor> = Vec::new();
80
81    for line in reader.lines().skip(1) {
82        let line = line.unwrap();
83        let record: Vec<&str> = line.split(',').collect();
84
85        let mut data: Vec<f32> = Vec::new();
86        for i in 0..571 {
87            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
88        }
89        match record.get(573).unwrap() {
90            &"Train" => {
91                x_train.push(tensor::Tensor::single(data));
92                y_train.push(tensor::Tensor::single(vec![record
93                    .get(571)
94                    .unwrap()
95                    .parse::<f32>()
96                    .unwrap()]));
97                class_train.push(tensor::Tensor::one_hot(
98                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
99                    28,
100                ));
101            }
102            &"Test" => {
103                x_test.push(tensor::Tensor::single(data));
104                y_test.push(tensor::Tensor::single(vec![record
105                    .get(571)
106                    .unwrap()
107                    .parse::<f32>()
108                    .unwrap()]));
109                class_test.push(tensor::Tensor::one_hot(
110                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
111                    28,
112                ));
113            }
114            &"Val" => {
115                x_val.push(tensor::Tensor::single(data));
116                y_val.push(tensor::Tensor::single(vec![record
117                    .get(571)
118                    .unwrap()
119                    .parse::<f32>()
120                    .unwrap()]));
121                class_val.push(tensor::Tensor::one_hot(
122                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
123                    28,
124                ));
125            }
126            _ => panic!("> Unknown class."),
127        }
128    }
129
130    // let mut generator = random::Generator::create(12345);
131    // let mut indices: Vec<usize> = (0..x.len()).collect();
132    // generator.shuffle(&mut indices);
133
134    (
135        (x_train, y_train, class_train),
136        (x_test, y_test, class_test),
137        (x_val, y_val, class_val),
138    )
139}
140
141fn main() {
142    // Load the ftir dataset
143    let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
144        data("./examples/datasets/ftir.csv");
145
146    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
147    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
148    let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
149
150    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
151    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
152    let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
153
154    let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
155    let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
156    let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
157
158    println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
159    println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
160    println!("Validation data {}x{}\n", x_val.len(), x_val[0].shape,);
161
162    // Create the results file.
163    let mut file = File::create("./output/compare/ftir-mlp.json").unwrap();
164    writeln!(file, "[").unwrap();
165    writeln!(file, "  {{").unwrap();
166
167    vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
168        .iter()
169        .for_each(|method| {
170            println!("Method: {}", method);
171            vec![false, true].iter().for_each(|skip| {
172                println!("  Skip: {}", skip);
173                vec!["CLASSIFICATION", "REGRESSION"]
174                    .iter()
175                    .for_each(|problem| {
176                        println!("   Problem: {}", problem);
177                        writeln!(file, "    \"{}-{}-{}\": {{", method, skip, problem).unwrap();
178
179                        for run in 1..RUNS + 1 {
180                            println!("    Run: {}", run);
181                            writeln!(file, "      \"run-{}\": {{", run).unwrap();
182
183                            // Create the network based on the architecture.
184                            let mut network: network::Network;
185                            network = network::Network::new(tensor::Shape::Single(571));
186                            network.dense(128, activation::Activation::ReLU, false, None);
187
188                            // Check if the method is regular or feedback.
189                            if method == &"REGULAR" || method.contains(&"FB1") {
190                                network.dense(256, activation::Activation::ReLU, false, None);
191                                network.dense(128, activation::Activation::ReLU, false, None);
192
193                                // Add the feedback loop if applicable.
194                                if method.contains(&"FB1") {
195                                    network.loopback(
196                                        2,
197                                        1,
198                                        method.chars().last().unwrap().to_digit(10).unwrap()
199                                            as usize
200                                            - 1,
201                                        Arc::new(|_loops| 1.0),
202                                        false,
203                                    );
204                                }
205                            } else {
206                                network.feedback(
207                                    vec![
208                                        feedback::Layer::Dense(
209                                            256,
210                                            activation::Activation::ReLU,
211                                            false,
212                                            None,
213                                        ),
214                                        feedback::Layer::Dense(
215                                            128,
216                                            activation::Activation::ReLU,
217                                            false,
218                                            None,
219                                        ),
220                                    ],
221                                    method.chars().last().unwrap().to_digit(10).unwrap() as usize,
222                                    false,
223                                    false,
224                                    feedback::Accumulation::Mean,
225                                );
226                            }
227
228                            // Set the output layer based on the problem.
229                            if problem == &"REGRESSION" {
230                                network.dense(1, activation::Activation::Linear, false, None);
231                                network.set_objective(objective::Objective::RMSE, None);
232                            } else {
233                                network.dense(28, activation::Activation::Softmax, false, None);
234                                network.set_objective(
235                                    objective::Objective::CrossEntropy,
236                                    Some((-5.0, 5.0)),
237                                );
238                            }
239
240                            // Add the skip connection if applicable.
241                            if *skip {
242                                network.connect(1, network.layers.len() - 1);
243                            }
244
245                            network.set_optimizer(optimizer::Adam::create(
246                                0.001, 0.9, 0.999, 1e-8, None,
247                            ));
248
249                            // Train the network
250                            let (train_loss, val_loss, val_acc);
251                            if problem == &"REGRESSION" {
252                                (train_loss, val_loss, val_acc) = network.learn(
253                                    &x_train,
254                                    &y_train,
255                                    Some((&x_val, &y_val, 100)),
256                                    32,
257                                    1000,
258                                    None,
259                                );
260                            } else {
261                                (train_loss, val_loss, val_acc) = network.learn(
262                                    &x_train,
263                                    &class_train,
264                                    Some((&x_val, &class_val, 100)),
265                                    32,
266                                    1000,
267                                    None,
268                                );
269                            }
270
271                            // Store the loss and accuracy.
272                            writeln!(file, "        \"train\": {{").unwrap();
273                            writeln!(file, "          \"trn-loss\": {:?},", train_loss).unwrap();
274                            writeln!(file, "          \"val-loss\": {:?},", val_loss).unwrap();
275                            writeln!(file, "          \"val-acc\": {:?}", val_acc).unwrap();
276
277                            // Probe the network (if applicable).
278                            if method != &"REGULAR" {
279                                println!("    > Without feedback.");
280
281                                // Store the network's loopbacks and layers to restore them later.
282                                let loopbacks = network.loopbacks.clone();
283                                let layers = network.layers.clone();
284
285                                // Remove the feedback loop.
286                                if method.contains(&"FB1") {
287                                    network.loopbacks = HashMap::new();
288                                } else {
289                                    match &mut network.layers.get_mut(1).unwrap() {
290                                        network::Layer::Feedback(fb) => {
291                                            // Only keep the first two layers.
292                                            fb.layers = fb.layers.drain(0..2).collect();
293                                        }
294                                        _ => panic!("Invalid layer."),
295                                    };
296                                }
297
298                                let (val_loss, val_acc);
299                                if problem == &"REGRESSION" {
300                                    (val_loss, val_acc) = network.validate(&x_val, &y_val, 1e-6);
301                                } else {
302                                    (val_loss, val_acc) =
303                                        network.validate(&x_val, &class_val, 1e-6);
304                                }
305                                let (test_loss, test_acc);
306                                if problem == &"REGRESSION" {
307                                    (test_loss, test_acc) =
308                                        network.validate(&x_test, &y_test, 1e-6);
309                                } else {
310                                    (test_loss, test_acc) =
311                                        network.validate(&x_test, &class_test, 1e-6);
312                                }
313
314                                writeln!(file, "       }},").unwrap();
315                                writeln!(file, "        \"no-feedback\": {{").unwrap();
316                                writeln!(file, "          \"val-loss\": {},", val_loss).unwrap();
317                                writeln!(file, "          \"val-acc\": {},", val_acc).unwrap();
318                                writeln!(file, "          \"tst-loss\": {},", test_loss).unwrap();
319                                writeln!(file, "          \"tst-acc\": {}", test_acc).unwrap();
320
321                                // Restore the feedback loop.
322                                network.loopbacks = loopbacks;
323                                network.layers = layers;
324                            }
325                            if *skip {
326                                println!("    > Without skip.");
327                                network.connect = HashMap::new();
328
329                                let (val_loss, val_acc);
330                                if problem == &"REGRESSION" {
331                                    (val_loss, val_acc) = network.validate(&x_val, &y_val, 1e-6);
332                                } else {
333                                    (val_loss, val_acc) =
334                                        network.validate(&x_val, &class_val, 1e-6);
335                                }
336                                let (test_loss, test_acc);
337                                if problem == &"REGRESSION" {
338                                    (test_loss, test_acc) =
339                                        network.validate(&x_test, &y_test, 1e-6);
340                                } else {
341                                    (test_loss, test_acc) =
342                                        network.validate(&x_test, &class_test, 1e-6);
343                                }
344
345                                writeln!(file, "        }},").unwrap();
346                                writeln!(file, "        \"no-skip\": {{").unwrap();
347                                writeln!(file, "          \"val-loss\": {},", val_loss).unwrap();
348                                writeln!(file, "          \"val-acc\": {},", val_acc).unwrap();
349                                writeln!(file, "          \"tst-loss\": {},", test_loss).unwrap();
350                                writeln!(file, "          \"tst-acc\": {}", test_acc).unwrap();
351                            }
352                            writeln!(file, "        }}").unwrap();
353
354                            if run == RUNS {
355                                writeln!(file, "      }}").unwrap();
356                                if method == &"FB2x3" && *skip && problem == &"REGRESSION" {
357                                    writeln!(file, "    }}").unwrap();
358                                } else {
359                                    writeln!(file, "    }},").unwrap();
360                                }
361                            } else {
362                                writeln!(file, "      }},").unwrap();
363                            }
364                        }
365                    });
366            });
367        });
368    writeln!(file, "  }}").unwrap();
369    writeln!(file, "]").unwrap();
370}