timing_ftir_cnn/ftir-cnn.rs
1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// Code for comparison between the various architectures and their time differences.
4
5use neurons::{activation, feedback, network, objective, optimizer, tensor};
6
7use std::{
8 fs::File,
9 io::{BufRead, BufReader, Write},
10 sync::Arc,
11 time,
12};
13
14const RUNS: usize = 5;
15const EPOCHS: i32 = 1;
16
17fn data(
18 path: &str,
19) -> (
20 Vec<tensor::Tensor>,
21 Vec<tensor::Tensor>,
22 Vec<tensor::Tensor>,
23) {
24 let reader = BufReader::new(File::open(&path).unwrap());
25
26 let mut x: Vec<tensor::Tensor> = Vec::new();
27 let mut y: Vec<tensor::Tensor> = Vec::new();
28 let mut c: Vec<tensor::Tensor> = Vec::new();
29
30 for line in reader.lines().skip(1) {
31 let line = line.unwrap();
32 let record: Vec<&str> = line.split(',').collect();
33
34 let mut data: Vec<f32> = Vec::new();
35 for i in 0..571 {
36 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
37 }
38 let data: Vec<Vec<Vec<f32>>> = vec![vec![data]];
39 x.push(tensor::Tensor::triple(data));
40 y.push(tensor::Tensor::single(vec![record
41 .get(571)
42 .unwrap()
43 .parse::<f32>()
44 .unwrap()]));
45 c.push(tensor::Tensor::one_hot(
46 record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
47 28,
48 ));
49 }
50
51 (x, y, c)
52}
53
54fn main() {
55 // Load the ftir dataset
56 let (x, y, c) = data("./examples/datasets/ftir.csv");
57
58 let x: Vec<&tensor::Tensor> = x.iter().collect();
59 let y: Vec<&tensor::Tensor> = y.iter().collect();
60 let c: Vec<&tensor::Tensor> = c.iter().collect();
61
62 // Create the results file.
63 let mut file = File::create("./output/timing/ftir-cnn.json").unwrap();
64 writeln!(file, "[").unwrap();
65 writeln!(file, " {{").unwrap();
66
67 vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
68 .iter()
69 .for_each(|method| {
70 println!("Method: {}", method);
71 vec![false, true].iter().for_each(|skip| {
72 println!(" Skip: {}", skip);
73 vec!["CLASSIFICATION", "REGRESSION"]
74 .iter()
75 .for_each(|problem| {
76 println!(" Problem: {}", problem);
77
78 let mut train_times: Vec<f64> = Vec::new();
79 let mut valid_times: Vec<f64> = Vec::new();
80
81 for _ in 0..RUNS {
82 // Create the network based on the architecture.
83 let mut network: network::Network;
84 network = network::Network::new(tensor::Shape::Triple(1, 1, 571));
85
86 // Check if the method is regular or feedback.
87 if method == &"REGULAR" || method.contains(&"FB1") {
88 network.convolution(
89 1,
90 (1, 9),
91 (1, 1),
92 (0, 4),
93 (1, 1),
94 activation::Activation::ReLU,
95 None,
96 );
97 network.convolution(
98 1,
99 (1, 9),
100 (1, 1),
101 (0, 4),
102 (1, 1),
103 activation::Activation::ReLU,
104 None,
105 );
106 network.dense(32, activation::Activation::ReLU, false, None);
107
108 // Add the feedback loop if applicable.
109 if method.contains(&"FB1") {
110 network.loopback(
111 1,
112 0,
113 method.chars().last().unwrap().to_digit(10).unwrap()
114 as usize
115 - 1,
116 Arc::new(|_loops| 1.0),
117 false,
118 );
119 }
120 } else {
121 network.feedback(
122 vec![
123 feedback::Layer::Convolution(
124 1,
125 activation::Activation::ReLU,
126 (1, 9),
127 (1, 1),
128 (0, 4),
129 (1, 1),
130 None,
131 ),
132 feedback::Layer::Convolution(
133 1,
134 activation::Activation::ReLU,
135 (1, 9),
136 (1, 1),
137 (0, 4),
138 (1, 1),
139 None,
140 ),
141 ],
142 method.chars().last().unwrap().to_digit(10).unwrap() as usize,
143 false,
144 false,
145 feedback::Accumulation::Mean,
146 );
147 network.dense(32, activation::Activation::ReLU, false, None);
148 }
149
150 // Set the output layer based on the problem.
151 if problem == &"REGRESSION" {
152 network.dense(1, activation::Activation::Linear, false, None);
153 network.set_objective(objective::Objective::RMSE, None);
154 } else {
155 network.dense(28, activation::Activation::Softmax, false, None);
156 network.set_objective(
157 objective::Objective::CrossEntropy,
158 Some((-5.0, 5.0)),
159 );
160 }
161
162 // Add the skip connection if applicable.
163 if *skip {
164 network.connect(0, network.layers.len() - 2);
165 }
166
167 network.set_optimizer(optimizer::Adam::create(
168 0.001, 0.9, 0.999, 1e-8, None,
169 ));
170
171 let start = time::Instant::now();
172
173 // Train the network
174 if problem == &"REGRESSION" {
175 (_, _, _) = network.learn(&x, &y, None, 32, EPOCHS, None);
176 } else {
177 (_, _, _) = network.learn(&x, &c, None, 32, EPOCHS, None);
178 }
179
180 let duration = start.elapsed().as_secs_f64();
181 train_times.push(duration);
182
183 let start = time::Instant::now();
184
185 // Validate the network
186 (_) = network.predict_batch(&x);
187
188 let duration = start.elapsed().as_secs_f64();
189 valid_times.push(duration);
190 }
191
192 if method == &"FB2x3" && *skip && problem == &"REGRESSION" {
193 writeln!(
194 file,
195 " \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}}",
196 method, skip, problem, train_times, valid_times
197 )
198 .unwrap();
199 } else {
200 writeln!(
201 file,
202 " \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}},",
203 method, skip, problem, train_times, valid_times
204 )
205 .unwrap();
206 }
207 });
208 });
209 });
210 writeln!(file, " }}").unwrap();
211 writeln!(file, "]").unwrap();
212}