compare_iris/
iris.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// Code for comparison between the various architectures.
4// The respective loss and accuracies is stored to the file `~/output/compare/iris.json`.
5//
6// In addition, some simple probing of the networks are done.
7// Namely, validating the trained networks with and without feedback and skip connections.
8//
9// for (
10//   REGULAR,
11//   FEEDBACK[approach=1, loops=2],
12//   FEEDBACK[approach=1, loops=3],
13//   FEEDBACK[approach=1, loops=4],
14//   FEEDBACK[approach=2, loops=2],
15//   FEEDBACK[approach=2, loops=3],
16//   FEEDBACK[approach=2, loops=4]
17// ) do {
18//
19//   for (NOSKIP, SKIP) do {
20//
21//     for (CLASSIFICATION) do {
22//
23//       for (run in RUNS) do {
24//
25//         create the network
26//         train the network
27//         validate the network
28//         store the loss and accuracy
29//         probe the network
30//         store the probing results
31//
32//       }
33//     }
34//   }
35// }
36
37use neurons::{activation, feedback, network, objective, optimizer, random, tensor};
38
39use std::{
40    collections::HashMap,
41    fs::File,
42    io::{BufRead, BufReader, Write},
43    sync::Arc,
44};
45
46const RUNS: usize = 5;
47
48fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
49    let reader = BufReader::new(File::open(&path).unwrap());
50
51    let mut x: Vec<Vec<f32>> = Vec::new();
52    let mut y: Vec<usize> = Vec::new();
53
54    for line in reader.lines().skip(1) {
55        let line = line.unwrap();
56        let record: Vec<&str> = line.split(',').collect();
57        x.push(vec![
58            record.get(1).unwrap().parse::<f32>().unwrap(),
59            record.get(2).unwrap().parse::<f32>().unwrap(),
60            record.get(3).unwrap().parse::<f32>().unwrap(),
61            record.get(4).unwrap().parse::<f32>().unwrap(),
62        ]);
63        y.push(match record.get(5).unwrap() {
64            &"Iris-setosa" => 0,
65            &"Iris-versicolor" => 1,
66            &"Iris-virginica" => 2,
67            _ => panic!("> Unknown class."),
68        });
69    }
70
71    let mut generator = random::Generator::create(12345);
72    let mut indices: Vec<usize> = (0..x.len()).collect();
73    generator.shuffle(&mut indices);
74
75    let x: Vec<tensor::Tensor> = indices
76        .iter()
77        .map(|&i| tensor::Tensor::single(x[i].clone()))
78        .collect();
79    let y: Vec<tensor::Tensor> = indices
80        .iter()
81        .map(|&i| tensor::Tensor::one_hot(y[i], 3))
82        .collect();
83
84    (x, y)
85}
86
87fn main() {
88    // Load the iris dataset
89    let (x, y) = data("./examples/datasets/iris.csv");
90
91    let split = (x.len() as f32 * 0.8) as usize;
92    let x = x.split_at(split);
93    let y = y.split_at(split);
94
95    let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
96    let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
97    let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
98    let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
99
100    println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
101    println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
102
103    // Create the results file.
104    let mut file = File::create("./output/compare/iris.json").unwrap();
105    writeln!(file, "[").unwrap();
106    writeln!(file, "  {{").unwrap();
107
108    vec![
109        "REGULAR", "FB1x2", "FB1x3", "FB1x4", "FB2x2", "FB2x3", "FB2x4",
110    ]
111    .iter()
112    .for_each(|method| {
113        println!("Method: {}", method);
114        vec![false, true].iter().for_each(|skip| {
115            println!("  Skip: {}", skip);
116            vec!["CLASSIFICATION"].iter().for_each(|problem| {
117                println!("   Problem: {}", problem);
118                writeln!(file, "    \"{}-{}-{}\": {{", method, skip, problem).unwrap();
119
120                for run in 1..RUNS + 1 {
121                    println!("    Run: {}", run);
122                    writeln!(file, "      \"run-{}\": {{", run).unwrap();
123
124                    // Create the network based on the architecture.
125                    let mut network: network::Network;
126                    network = network::Network::new(tensor::Shape::Single(4));
127
128                    // Check if the method is regular or feedback.
129                    if method == &"REGULAR" || method.contains(&"FB1") {
130                        network.dense(25, activation::Activation::ReLU, false, None);
131                        network.dense(25, activation::Activation::ReLU, false, None);
132                        network.dense(25, activation::Activation::ReLU, false, None);
133
134                        // Add the feedback loop if applicable.
135                        if method.contains(&"FB1") {
136                            network.loopback(
137                                2,
138                                1,
139                                method.chars().last().unwrap().to_digit(10).unwrap() as usize - 1,
140                                Arc::new(|_loops| 1.0),
141                                false,
142                            );
143                        }
144                    } else {
145                        network.dense(25, activation::Activation::ReLU, false, None);
146                        network.feedback(
147                            vec![
148                                feedback::Layer::Dense(
149                                    25,
150                                    activation::Activation::ReLU,
151                                    false,
152                                    None,
153                                ),
154                                feedback::Layer::Dense(
155                                    25,
156                                    activation::Activation::ReLU,
157                                    false,
158                                    None,
159                                ),
160                            ],
161                            method.chars().last().unwrap().to_digit(10).unwrap() as usize,
162                            false,
163                            false,
164                            feedback::Accumulation::Mean,
165                        );
166                    }
167
168                    // Set the output layer based on the problem.
169                    if problem == &"REGRESSION" {
170                        panic!("Invalid problem type.");
171                    } else {
172                        network.dense(3, activation::Activation::Softmax, false, None);
173                        network.set_objective(objective::Objective::CrossEntropy, None);
174                    }
175
176                    // Add the skip connection if applicable.
177                    if *skip {
178                        network.connect(1, network.layers.len() - 1);
179                    }
180
181                    network.set_optimizer(optimizer::Adam::create(0.0001, 0.95, 0.999, 1e-7, None));
182
183                    // Train the network
184                    let (train_loss, val_loss, val_acc);
185                    if problem == &"REGRESSION" {
186                        panic!("Invalid problem type.");
187                    } else {
188                        (train_loss, val_loss, val_acc) = network.learn(
189                            &x_train,
190                            &y_train,
191                            Some((&x_test, &y_test, 10)),
192                            1,
193                            100,
194                            None,
195                        );
196                    }
197
198                    // Store the loss and accuracy.
199                    writeln!(file, "        \"train\": {{").unwrap();
200                    writeln!(file, "          \"trn-loss\": {:?},", train_loss).unwrap();
201                    writeln!(file, "          \"val-loss\": {:?},", val_loss).unwrap();
202                    writeln!(file, "          \"val-acc\": {:?}", val_acc).unwrap();
203
204                    // Probe the network (if applicable).
205                    if method != &"REGULAR" {
206                        println!("    > Without feedback.");
207
208                        // Store the network's loopbacks and layers to restore them later.
209                        let loopbacks = network.loopbacks.clone();
210                        let layers = network.layers.clone();
211
212                        // Remove the feedback loop.
213                        if method.contains(&"FB1") {
214                            network.loopbacks = HashMap::new();
215                        } else {
216                            match &mut network.layers.get_mut(1).unwrap() {
217                                network::Layer::Feedback(fb) => {
218                                    // Only keep the first layer.
219                                    fb.layers = fb.layers.drain(0..2).collect();
220                                }
221                                _ => panic!("Invalid layer."),
222                            };
223                        }
224
225                        let (test_loss, test_acc);
226                        if problem == &"REGRESSION" {
227                            panic!("Invalid problem type.");
228                        } else {
229                            (test_loss, test_acc) = network.validate(&x_test, &y_test, 1e-6);
230                        }
231
232                        writeln!(file, "        }},").unwrap();
233                        writeln!(file, "        \"no-feedback\": {{").unwrap();
234                        writeln!(file, "          \"tst-loss\": {},", test_loss).unwrap();
235                        writeln!(file, "          \"tst-acc\": {}", test_acc).unwrap();
236
237                        // Restore the feedback loop.
238                        network.loopbacks = loopbacks;
239                        network.layers = layers;
240                    }
241                    if *skip {
242                        println!("    > Without skip.");
243                        network.connect = HashMap::new();
244
245                        let (test_loss, test_acc);
246                        if problem == &"REGRESSION" {
247                            panic!("Invalid problem type.");
248                        } else {
249                            (test_loss, test_acc) = network.validate(&x_test, &y_test, 1e-6);
250                        }
251
252                        writeln!(file, "        }},").unwrap();
253                        writeln!(file, "        \"no-skip\": {{").unwrap();
254                        writeln!(file, "          \"tst-loss\": {},", test_loss).unwrap();
255                        writeln!(file, "          \"tst-acc\": {}", test_acc).unwrap();
256                    }
257                    writeln!(file, "        }}").unwrap();
258
259                    if run == RUNS {
260                        writeln!(file, "      }}").unwrap();
261                        if method == &"FB2x4" && *skip && problem == &"CLASSIFICATION" {
262                            writeln!(file, "    }}").unwrap();
263                        } else {
264                            writeln!(file, "    }},").unwrap();
265                        }
266                    } else {
267                        writeln!(file, "      }},").unwrap();
268                    }
269                }
270            });
271        });
272    });
273    writeln!(file, "  }}").unwrap();
274    writeln!(file, "]").unwrap();
275}