timing_ftir_mlp/
ftir-mlp.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// Code for comparison between the various architectures and their time differences.
4
5use neurons::{activation, feedback, network, objective, optimizer, tensor};
6
7use std::{
8    fs::File,
9    io::{BufRead, BufReader, Write},
10    sync::Arc,
11    time,
12};
13
14const RUNS: usize = 5;
15const EPOCHS: i32 = 1;
16
17fn data(
18    path: &str,
19) -> (
20    Vec<tensor::Tensor>,
21    Vec<tensor::Tensor>,
22    Vec<tensor::Tensor>,
23) {
24    let reader = BufReader::new(File::open(&path).unwrap());
25
26    let mut x: Vec<tensor::Tensor> = Vec::new();
27    let mut y: Vec<tensor::Tensor> = Vec::new();
28    let mut c: Vec<tensor::Tensor> = Vec::new();
29
30    for line in reader.lines().skip(1) {
31        let line = line.unwrap();
32        let record: Vec<&str> = line.split(',').collect();
33
34        let mut data: Vec<f32> = Vec::new();
35        for i in 0..571 {
36            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
37        }
38
39        x.push(tensor::Tensor::single(data));
40        y.push(tensor::Tensor::single(vec![record
41            .get(571)
42            .unwrap()
43            .parse::<f32>()
44            .unwrap()]));
45        c.push(tensor::Tensor::one_hot(
46            record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
47            28,
48        ));
49    }
50    (x, y, c)
51}
52
53fn main() {
54    // Load the ftir dataset
55    let (x, y, c) = data("./examples/datasets/ftir.csv");
56
57    let x: Vec<&tensor::Tensor> = x.iter().collect();
58    let y: Vec<&tensor::Tensor> = y.iter().collect();
59    let c: Vec<&tensor::Tensor> = c.iter().collect();
60
61    // Create the results file.
62    let mut file = File::create("./output/timing/ftir-mlp.json").unwrap();
63    writeln!(file, "[").unwrap();
64    writeln!(file, "  {{").unwrap();
65
66    vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
67        .iter()
68        .for_each(|method| {
69            println!("Method: {}", method);
70            vec![false, true].iter().for_each(|skip| {
71                println!("  Skip: {}", skip);
72                vec!["CLASSIFICATION", "REGRESSION"]
73                    .iter()
74                    .for_each(|problem| {
75                        println!("   Problem: {}", problem);
76
77                        let mut train_times: Vec<f64> = Vec::new();
78                        let mut valid_times: Vec<f64> = Vec::new();
79
80                        for _ in 0..RUNS {
81                            // Create the network based on the architecture.
82                            let mut network: network::Network;
83                            network = network::Network::new(tensor::Shape::Single(571));
84                            network.dense(128, activation::Activation::ReLU, false, None);
85
86                            // Check if the method is regular or feedback.
87                            if method == &"REGULAR" || method.contains(&"FB1") {
88                                network.dense(256, activation::Activation::ReLU, false, None);
89                                network.dense(128, activation::Activation::ReLU, false, None);
90
91                                // Add the feedback loop if applicable.
92                                if method.contains(&"FB1") {
93                                    network.loopback(
94                                        2,
95                                        1,
96                                        method.chars().last().unwrap().to_digit(10).unwrap()
97                                            as usize
98                                            - 1,
99                                        Arc::new(|_loops| 1.0),
100                                        false,
101                                    );
102                                }
103                            } else {
104                                network.feedback(
105                                    vec![
106                                        feedback::Layer::Dense(
107                                            256,
108                                            activation::Activation::ReLU,
109                                            false,
110                                            None,
111                                        ),
112                                        feedback::Layer::Dense(
113                                            128,
114                                            activation::Activation::ReLU,
115                                            false,
116                                            None,
117                                        ),
118                                    ],
119                                    method.chars().last().unwrap().to_digit(10).unwrap() as usize,
120                                    false,
121                                    false,
122                                    feedback::Accumulation::Mean,
123                                );
124                            }
125
126                            // Set the output layer based on the problem.
127                            if problem == &"REGRESSION" {
128                                network.dense(1, activation::Activation::Linear, false, None);
129                                network.set_objective(objective::Objective::RMSE, None);
130                            } else {
131                                network.dense(28, activation::Activation::Softmax, false, None);
132                                network.set_objective(
133                                    objective::Objective::CrossEntropy,
134                                    Some((-5.0, 5.0)),
135                                );
136                            }
137
138                            // Add the skip connection if applicable.
139                            if *skip {
140                                network.connect(1, network.layers.len() - 1);
141                            }
142
143                            network.set_optimizer(optimizer::Adam::create(
144                                0.001, 0.9, 0.999, 1e-8, None,
145                            ));
146
147                            let start = time::Instant::now();
148
149                            // Train the network
150                            if problem == &"REGRESSION" {
151                                (_, _, _) = network.learn(&x, &y, None, 32, EPOCHS, None);
152                            } else {
153                                (_, _, _) = network.learn(&x, &c, None, 32, EPOCHS, None);
154                            }
155
156                            let duration = start.elapsed().as_secs_f64();
157                            train_times.push(duration);
158
159                            let start = time::Instant::now();
160
161                            // Validate the network
162                            (_) = network.predict_batch(&x);
163
164                            let duration = start.elapsed().as_secs_f64();
165                            valid_times.push(duration);
166                        }
167
168                        if method == &"FB2x3" && *skip && problem == &"REGRESSION" {
169                            writeln!(
170                                file,
171                                "    \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}}",
172                                method, skip, problem, train_times, valid_times
173                            )
174                            .unwrap();
175                        } else {
176                            writeln!(
177                                file,
178                                "    \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}},",
179                                method, skip, problem, train_times, valid_times
180                            )
181                            .unwrap();
182                        }
183                    });
184            });
185        });
186    writeln!(file, "  }}").unwrap();
187    writeln!(file, "]").unwrap();
188}