compare_bike/
bike.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// Code for comparison between the various architectures.
4// The respective loss and accuracies is stored to the file `~/output/compare/bike.json`.
5//
6// In addition, some simple probing of the networks are done.
7// Namely, validating the trained networks with and without feedback and skip connections.
8//
9// for (
10//   REGULAR,
11//   FEEDBACK[approach=1, loops=2],
12//   FEEDBACK[approach=1, loops=3],
13//   FEEDBACK[approach=1, loops=4],
14//   FEEDBACK[approach=2, loops=2],
15//   FEEDBACK[approach=2, loops=3],
16//   FEEDBACK[approach=2, loops=4]
17// ) do {
18//
19//   for (NOSKIP, SKIP) do {
20//
21//     for (CLASSIFICATION) do {
22//
23//       for (run in RUNS) do {
24//
25//         create the network
26//         train the network
27//         validate the network
28//         store the loss and accuracy
29//         probe the network
30//         store the probing results
31//
32//       }
33//     }
34//   }
35// }
36
37use neurons::{activation, feedback, network, objective, optimizer, random, tensor};
38
39use std::{
40    collections::HashMap,
41    fs::File,
42    io::{BufRead, BufReader, Write},
43    sync::Arc,
44};
45
46const RUNS: usize = 5;
47
48fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
49    let reader = BufReader::new(File::open(&path).unwrap());
50
51    let mut x: Vec<tensor::Tensor> = Vec::new();
52    let mut y: Vec<tensor::Tensor> = Vec::new();
53
54    for line in reader.lines().skip(1) {
55        let line = line.unwrap();
56        let record: Vec<&str> = line.split(',').collect();
57
58        let mut data: Vec<f32> = Vec::new();
59        for i in 2..14 {
60            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
61        }
62        x.push(tensor::Tensor::single(data));
63
64        y.push(tensor::Tensor::single(vec![record
65            .get(16)
66            .unwrap()
67            .parse::<f32>()
68            .unwrap()]));
69    }
70
71    let mut generator = random::Generator::create(12345);
72    let mut indices: Vec<usize> = (0..x.len()).collect();
73    generator.shuffle(&mut indices);
74
75    let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
76    let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
77
78    (x, y)
79}
80
81fn main() {
82    // Load the bike dataset
83    let (x, y) = data("./examples/datasets/bike/hour.csv");
84
85    let split = (x.len() as f32 * 0.8) as usize;
86    let x = x.split_at(split);
87    let y = y.split_at(split);
88
89    let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
90    let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
91    let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
92    let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
93
94    println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
95    println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
96
97    // Create the results file.
98    let mut file = File::create("./output/compare/bike.json").unwrap();
99    writeln!(file, "[").unwrap();
100    writeln!(file, "  {{").unwrap();
101
102    vec![
103        "REGULAR", "FB1x2", "FB1x3", "FB1x4", "FB2x2", "FB2x3", "FB2x4",
104    ]
105    .iter()
106    .for_each(|method| {
107        println!("Method: {}", method);
108        vec![false, true].iter().for_each(|skip| {
109            println!("  Skip: {}", skip);
110            vec!["REGRESSION"].iter().for_each(|problem| {
111                println!("   Problem: {}", problem);
112                writeln!(file, "    \"{}-{}-{}\": {{", method, skip, problem).unwrap();
113
114                for run in 1..RUNS + 1 {
115                    println!("    Run: {}", run);
116                    writeln!(file, "      \"run-{}\": {{", run).unwrap();
117
118                    // Create the network based on the architecture.
119                    let mut network: network::Network;
120                    network = network::Network::new(tensor::Shape::Single(12));
121
122                    // Check if the method is regular or feedback.
123                    if method == &"REGULAR" || method.contains(&"FB1") {
124                        network.dense(24, activation::Activation::ReLU, false, None);
125                        network.dense(24, activation::Activation::ReLU, false, None);
126                        network.dense(24, activation::Activation::ReLU, false, None);
127
128                        // Add the feedback loop if applicable.
129                        if method.contains(&"FB1") {
130                            network.loopback(
131                                2,
132                                1,
133                                method.chars().last().unwrap().to_digit(10).unwrap() as usize - 1,
134                                Arc::new(|_loops| 1.0),
135                                false,
136                            );
137                        }
138                    } else {
139                        network.dense(24, activation::Activation::ReLU, false, None);
140                        network.feedback(
141                            vec![
142                                feedback::Layer::Dense(
143                                    24,
144                                    activation::Activation::ReLU,
145                                    false,
146                                    None,
147                                ),
148                                feedback::Layer::Dense(
149                                    24,
150                                    activation::Activation::ReLU,
151                                    false,
152                                    None,
153                                ),
154                            ],
155                            method.chars().last().unwrap().to_digit(10).unwrap() as usize,
156                            false,
157                            false,
158                            feedback::Accumulation::Mean,
159                        );
160                    }
161
162                    // Set the output layer based on the problem.
163                    if problem == &"REGRESSION" {
164                        network.dense(1, activation::Activation::Linear, false, None);
165                        network.set_objective(objective::Objective::RMSE, None);
166                    } else {
167                        panic!("Invalid problem type.");
168                    }
169
170                    // Add the skip connection if applicable.
171                    if *skip {
172                        network.connect(1, network.layers.len() - 1);
173                    }
174
175                    network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
176
177                    // Train the network
178                    let (train_loss, val_loss, val_acc);
179                    if problem == &"REGRESSION" {
180                        (train_loss, val_loss, val_acc) = network.learn(
181                            &x_train,
182                            &y_train,
183                            Some((&x_test, &y_test, 25)),
184                            64,
185                            600,
186                            None,
187                        );
188                    } else {
189                        panic!("Invalid problem type.");
190                    }
191
192                    // Store the loss and accuracy.
193                    writeln!(file, "        \"train\": {{").unwrap();
194                    writeln!(file, "          \"trn-loss\": {:?},", train_loss).unwrap();
195                    writeln!(file, "          \"val-loss\": {:?},", val_loss).unwrap();
196                    writeln!(file, "          \"val-acc\": {:?}", val_acc).unwrap();
197
198                    // Probe the network (if applicable).
199                    if method != &"REGULAR" {
200                        println!("    > Without feedback.");
201
202                        // Store the network's loopbacks and layers to restore them later.
203                        let loopbacks = network.loopbacks.clone();
204                        let layers = network.layers.clone();
205
206                        // Remove the feedback loop.
207                        if method.contains(&"FB1") {
208                            network.loopbacks = HashMap::new();
209                        } else {
210                            match &mut network.layers.get_mut(1).unwrap() {
211                                network::Layer::Feedback(fb) => {
212                                    // Only keep the first layer.
213                                    fb.layers = fb.layers.drain(0..2).collect();
214                                }
215                                _ => panic!("Invalid layer."),
216                            };
217                        }
218
219                        let (test_loss, test_acc);
220                        if problem == &"REGRESSION" {
221                            (test_loss, test_acc) = network.validate(&x_test, &y_test, 1e-6);
222                        } else {
223                            panic!("Invalid problem type.");
224                        }
225
226                        writeln!(file, "        }},").unwrap();
227                        writeln!(file, "        \"no-feedback\": {{").unwrap();
228                        writeln!(file, "          \"tst-loss\": {},", test_loss).unwrap();
229                        writeln!(file, "          \"tst-acc\": {}", test_acc).unwrap();
230
231                        // Restore the feedback loop.
232                        network.loopbacks = loopbacks;
233                        network.layers = layers;
234                    }
235                    if *skip {
236                        println!("    > Without skip.");
237                        network.connect = HashMap::new();
238
239                        let (test_loss, test_acc);
240                        if problem == &"REGRESSION" {
241                            (test_loss, test_acc) = network.validate(&x_test, &y_test, 1e-6);
242                        } else {
243                            panic!("Invalid problem type.");
244                        }
245
246                        writeln!(file, "        }},").unwrap();
247                        writeln!(file, "        \"no-skip\": {{").unwrap();
248                        writeln!(file, "          \"tst-loss\": {},", test_loss).unwrap();
249                        writeln!(file, "          \"tst-acc\": {}", test_acc).unwrap();
250                    }
251                    writeln!(file, "        }}").unwrap();
252
253                    if run == RUNS {
254                        writeln!(file, "      }}").unwrap();
255                        if method == &"FB2x4" && *skip && problem == &"REGRESSION" {
256                            writeln!(file, "    }}").unwrap();
257                        } else {
258                            writeln!(file, "    }},").unwrap();
259                        }
260                    } else {
261                        writeln!(file, "      }},").unwrap();
262                    }
263                }
264            });
265        });
266    });
267    writeln!(file, "  }}").unwrap();
268    writeln!(file, "]").unwrap();
269}