timing_ftir_cnn/
ftir-cnn.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// Code for comparison between the various architectures and their time differences.
4
5use neurons::{activation, feedback, network, objective, optimizer, tensor};
6
7use std::{
8    fs::File,
9    io::{BufRead, BufReader, Write},
10    sync::Arc,
11    time,
12};
13
14const RUNS: usize = 5;
15const EPOCHS: i32 = 1;
16
17fn data(
18    path: &str,
19) -> (
20    Vec<tensor::Tensor>,
21    Vec<tensor::Tensor>,
22    Vec<tensor::Tensor>,
23) {
24    let reader = BufReader::new(File::open(&path).unwrap());
25
26    let mut x: Vec<tensor::Tensor> = Vec::new();
27    let mut y: Vec<tensor::Tensor> = Vec::new();
28    let mut c: Vec<tensor::Tensor> = Vec::new();
29
30    for line in reader.lines().skip(1) {
31        let line = line.unwrap();
32        let record: Vec<&str> = line.split(',').collect();
33
34        let mut data: Vec<f32> = Vec::new();
35        for i in 0..571 {
36            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
37        }
38        let data: Vec<Vec<Vec<f32>>> = vec![vec![data]];
39        x.push(tensor::Tensor::triple(data));
40        y.push(tensor::Tensor::single(vec![record
41            .get(571)
42            .unwrap()
43            .parse::<f32>()
44            .unwrap()]));
45        c.push(tensor::Tensor::one_hot(
46            record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
47            28,
48        ));
49    }
50
51    (x, y, c)
52}
53
54fn main() {
55    // Load the ftir dataset
56    let (x, y, c) = data("./examples/datasets/ftir.csv");
57
58    let x: Vec<&tensor::Tensor> = x.iter().collect();
59    let y: Vec<&tensor::Tensor> = y.iter().collect();
60    let c: Vec<&tensor::Tensor> = c.iter().collect();
61
62    // Create the results file.
63    let mut file = File::create("./output/timing/ftir-cnn.json").unwrap();
64    writeln!(file, "[").unwrap();
65    writeln!(file, "  {{").unwrap();
66
67    vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
68        .iter()
69        .for_each(|method| {
70            println!("Method: {}", method);
71            vec![false, true].iter().for_each(|skip| {
72                println!("  Skip: {}", skip);
73                vec!["CLASSIFICATION", "REGRESSION"]
74                    .iter()
75                    .for_each(|problem| {
76                        println!("   Problem: {}", problem);
77
78                        let mut train_times: Vec<f64> = Vec::new();
79                        let mut valid_times: Vec<f64> = Vec::new();
80
81                        for _ in 0..RUNS {
82                            // Create the network based on the architecture.
83                            let mut network: network::Network;
84                            network = network::Network::new(tensor::Shape::Triple(1, 1, 571));
85
86                            // Check if the method is regular or feedback.
87                            if method == &"REGULAR" || method.contains(&"FB1") {
88                                network.convolution(
89                                    1,
90                                    (1, 9),
91                                    (1, 1),
92                                    (0, 4),
93                                    (1, 1),
94                                    activation::Activation::ReLU,
95                                    None,
96                                );
97                                network.convolution(
98                                    1,
99                                    (1, 9),
100                                    (1, 1),
101                                    (0, 4),
102                                    (1, 1),
103                                    activation::Activation::ReLU,
104                                    None,
105                                );
106                                network.dense(32, activation::Activation::ReLU, false, None);
107
108                                // Add the feedback loop if applicable.
109                                if method.contains(&"FB1") {
110                                    network.loopback(
111                                        1,
112                                        0,
113                                        method.chars().last().unwrap().to_digit(10).unwrap()
114                                            as usize
115                                            - 1,
116                                        Arc::new(|_loops| 1.0),
117                                        false,
118                                    );
119                                }
120                            } else {
121                                network.feedback(
122                                    vec![
123                                        feedback::Layer::Convolution(
124                                            1,
125                                            activation::Activation::ReLU,
126                                            (1, 9),
127                                            (1, 1),
128                                            (0, 4),
129                                            (1, 1),
130                                            None,
131                                        ),
132                                        feedback::Layer::Convolution(
133                                            1,
134                                            activation::Activation::ReLU,
135                                            (1, 9),
136                                            (1, 1),
137                                            (0, 4),
138                                            (1, 1),
139                                            None,
140                                        ),
141                                    ],
142                                    method.chars().last().unwrap().to_digit(10).unwrap() as usize,
143                                    false,
144                                    false,
145                                    feedback::Accumulation::Mean,
146                                );
147                                network.dense(32, activation::Activation::ReLU, false, None);
148                            }
149
150                            // Set the output layer based on the problem.
151                            if problem == &"REGRESSION" {
152                                network.dense(1, activation::Activation::Linear, false, None);
153                                network.set_objective(objective::Objective::RMSE, None);
154                            } else {
155                                network.dense(28, activation::Activation::Softmax, false, None);
156                                network.set_objective(
157                                    objective::Objective::CrossEntropy,
158                                    Some((-5.0, 5.0)),
159                                );
160                            }
161
162                            // Add the skip connection if applicable.
163                            if *skip {
164                                network.connect(0, network.layers.len() - 2);
165                            }
166
167                            network.set_optimizer(optimizer::Adam::create(
168                                0.001, 0.9, 0.999, 1e-8, None,
169                            ));
170
171                            let start = time::Instant::now();
172
173                            // Train the network
174                            if problem == &"REGRESSION" {
175                                (_, _, _) = network.learn(&x, &y, None, 32, EPOCHS, None);
176                            } else {
177                                (_, _, _) = network.learn(&x, &c, None, 32, EPOCHS, None);
178                            }
179
180                            let duration = start.elapsed().as_secs_f64();
181                            train_times.push(duration);
182
183                            let start = time::Instant::now();
184
185                            // Validate the network
186                            (_) = network.predict_batch(&x);
187
188                            let duration = start.elapsed().as_secs_f64();
189                            valid_times.push(duration);
190                        }
191
192                        if method == &"FB2x3" && *skip && problem == &"REGRESSION" {
193                            writeln!(
194                                file,
195                                "    \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}}",
196                                method, skip, problem, train_times, valid_times
197                            )
198                            .unwrap();
199                        } else {
200                            writeln!(
201                                file,
202                                "    \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}},",
203                                method, skip, problem, train_times, valid_times
204                            )
205                            .unwrap();
206                        }
207                    });
208            });
209        });
210    writeln!(file, "  }}").unwrap();
211    writeln!(file, "]").unwrap();
212}