1use neurons::{activation, feedback, network, objective, optimizer, tensor};
38
39use std::{
40 collections::HashMap,
41 fs::File,
42 io::{BufRead, BufReader, Write},
43 sync::Arc,
44};
45
46const RUNS: usize = 5;
47
48fn data(
49 path: &str,
50) -> (
51 (
52 Vec<tensor::Tensor>,
53 Vec<tensor::Tensor>,
54 Vec<tensor::Tensor>,
55 ),
56 (
57 Vec<tensor::Tensor>,
58 Vec<tensor::Tensor>,
59 Vec<tensor::Tensor>,
60 ),
61 (
62 Vec<tensor::Tensor>,
63 Vec<tensor::Tensor>,
64 Vec<tensor::Tensor>,
65 ),
66) {
67 let reader = BufReader::new(File::open(&path).unwrap());
68
69 let mut x_train: Vec<tensor::Tensor> = Vec::new();
70 let mut y_train: Vec<tensor::Tensor> = Vec::new();
71 let mut class_train: Vec<tensor::Tensor> = Vec::new();
72
73 let mut x_test: Vec<tensor::Tensor> = Vec::new();
74 let mut y_test: Vec<tensor::Tensor> = Vec::new();
75 let mut class_test: Vec<tensor::Tensor> = Vec::new();
76
77 let mut x_val: Vec<tensor::Tensor> = Vec::new();
78 let mut y_val: Vec<tensor::Tensor> = Vec::new();
79 let mut class_val: Vec<tensor::Tensor> = Vec::new();
80
81 for line in reader.lines().skip(1) {
82 let line = line.unwrap();
83 let record: Vec<&str> = line.split(',').collect();
84
85 let mut data: Vec<f32> = Vec::new();
86 for i in 0..571 {
87 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
88 }
89 match record.get(573).unwrap() {
90 &"Train" => {
91 x_train.push(tensor::Tensor::single(data));
92 y_train.push(tensor::Tensor::single(vec![record
93 .get(571)
94 .unwrap()
95 .parse::<f32>()
96 .unwrap()]));
97 class_train.push(tensor::Tensor::one_hot(
98 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
100 ));
101 }
102 &"Test" => {
103 x_test.push(tensor::Tensor::single(data));
104 y_test.push(tensor::Tensor::single(vec![record
105 .get(571)
106 .unwrap()
107 .parse::<f32>()
108 .unwrap()]));
109 class_test.push(tensor::Tensor::one_hot(
110 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
112 ));
113 }
114 &"Val" => {
115 x_val.push(tensor::Tensor::single(data));
116 y_val.push(tensor::Tensor::single(vec![record
117 .get(571)
118 .unwrap()
119 .parse::<f32>()
120 .unwrap()]));
121 class_val.push(tensor::Tensor::one_hot(
122 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
124 ));
125 }
126 _ => panic!("> Unknown class."),
127 }
128 }
129
130 (
135 (x_train, y_train, class_train),
136 (x_test, y_test, class_test),
137 (x_val, y_val, class_val),
138 )
139}
140
141fn main() {
142 let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
144 data("./examples/datasets/ftir.csv");
145
146 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
147 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
148 let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
149
150 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
151 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
152 let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
153
154 let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
155 let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
156 let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
157
158 println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
159 println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
160 println!("Validation data {}x{}\n", x_val.len(), x_val[0].shape,);
161
162 let mut file = File::create("./output/compare/ftir-mlp.json").unwrap();
164 writeln!(file, "[").unwrap();
165 writeln!(file, " {{").unwrap();
166
167 vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
168 .iter()
169 .for_each(|method| {
170 println!("Method: {}", method);
171 vec![false, true].iter().for_each(|skip| {
172 println!(" Skip: {}", skip);
173 vec!["CLASSIFICATION", "REGRESSION"]
174 .iter()
175 .for_each(|problem| {
176 println!(" Problem: {}", problem);
177 writeln!(file, " \"{}-{}-{}\": {{", method, skip, problem).unwrap();
178
179 for run in 1..RUNS + 1 {
180 println!(" Run: {}", run);
181 writeln!(file, " \"run-{}\": {{", run).unwrap();
182
183 let mut network: network::Network;
185 network = network::Network::new(tensor::Shape::Single(571));
186 network.dense(128, activation::Activation::ReLU, false, None);
187
188 if method == &"REGULAR" || method.contains(&"FB1") {
190 network.dense(256, activation::Activation::ReLU, false, None);
191 network.dense(128, activation::Activation::ReLU, false, None);
192
193 if method.contains(&"FB1") {
195 network.loopback(
196 2,
197 1,
198 method.chars().last().unwrap().to_digit(10).unwrap()
199 as usize
200 - 1,
201 Arc::new(|_loops| 1.0),
202 false,
203 );
204 }
205 } else {
206 network.feedback(
207 vec![
208 feedback::Layer::Dense(
209 256,
210 activation::Activation::ReLU,
211 false,
212 None,
213 ),
214 feedback::Layer::Dense(
215 128,
216 activation::Activation::ReLU,
217 false,
218 None,
219 ),
220 ],
221 method.chars().last().unwrap().to_digit(10).unwrap() as usize,
222 false,
223 false,
224 feedback::Accumulation::Mean,
225 );
226 }
227
228 if problem == &"REGRESSION" {
230 network.dense(1, activation::Activation::Linear, false, None);
231 network.set_objective(objective::Objective::RMSE, None);
232 } else {
233 network.dense(28, activation::Activation::Softmax, false, None);
234 network.set_objective(
235 objective::Objective::CrossEntropy,
236 Some((-5.0, 5.0)),
237 );
238 }
239
240 if *skip {
242 network.connect(1, network.layers.len() - 1);
243 }
244
245 network.set_optimizer(optimizer::Adam::create(
246 0.001, 0.9, 0.999, 1e-8, None,
247 ));
248
249 let (train_loss, val_loss, val_acc);
251 if problem == &"REGRESSION" {
252 (train_loss, val_loss, val_acc) = network.learn(
253 &x_train,
254 &y_train,
255 Some((&x_val, &y_val, 100)),
256 32,
257 1000,
258 None,
259 );
260 } else {
261 (train_loss, val_loss, val_acc) = network.learn(
262 &x_train,
263 &class_train,
264 Some((&x_val, &class_val, 100)),
265 32,
266 1000,
267 None,
268 );
269 }
270
271 writeln!(file, " \"train\": {{").unwrap();
273 writeln!(file, " \"trn-loss\": {:?},", train_loss).unwrap();
274 writeln!(file, " \"val-loss\": {:?},", val_loss).unwrap();
275 writeln!(file, " \"val-acc\": {:?}", val_acc).unwrap();
276
277 if method != &"REGULAR" {
279 println!(" > Without feedback.");
280
281 let loopbacks = network.loopbacks.clone();
283 let layers = network.layers.clone();
284
285 if method.contains(&"FB1") {
287 network.loopbacks = HashMap::new();
288 } else {
289 match &mut network.layers.get_mut(1).unwrap() {
290 network::Layer::Feedback(fb) => {
291 fb.layers = fb.layers.drain(0..2).collect();
293 }
294 _ => panic!("Invalid layer."),
295 };
296 }
297
298 let (val_loss, val_acc);
299 if problem == &"REGRESSION" {
300 (val_loss, val_acc) = network.validate(&x_val, &y_val, 1e-6);
301 } else {
302 (val_loss, val_acc) =
303 network.validate(&x_val, &class_val, 1e-6);
304 }
305 let (test_loss, test_acc);
306 if problem == &"REGRESSION" {
307 (test_loss, test_acc) =
308 network.validate(&x_test, &y_test, 1e-6);
309 } else {
310 (test_loss, test_acc) =
311 network.validate(&x_test, &class_test, 1e-6);
312 }
313
314 writeln!(file, " }},").unwrap();
315 writeln!(file, " \"no-feedback\": {{").unwrap();
316 writeln!(file, " \"val-loss\": {},", val_loss).unwrap();
317 writeln!(file, " \"val-acc\": {},", val_acc).unwrap();
318 writeln!(file, " \"tst-loss\": {},", test_loss).unwrap();
319 writeln!(file, " \"tst-acc\": {}", test_acc).unwrap();
320
321 network.loopbacks = loopbacks;
323 network.layers = layers;
324 }
325 if *skip {
326 println!(" > Without skip.");
327 network.connect = HashMap::new();
328
329 let (val_loss, val_acc);
330 if problem == &"REGRESSION" {
331 (val_loss, val_acc) = network.validate(&x_val, &y_val, 1e-6);
332 } else {
333 (val_loss, val_acc) =
334 network.validate(&x_val, &class_val, 1e-6);
335 }
336 let (test_loss, test_acc);
337 if problem == &"REGRESSION" {
338 (test_loss, test_acc) =
339 network.validate(&x_test, &y_test, 1e-6);
340 } else {
341 (test_loss, test_acc) =
342 network.validate(&x_test, &class_test, 1e-6);
343 }
344
345 writeln!(file, " }},").unwrap();
346 writeln!(file, " \"no-skip\": {{").unwrap();
347 writeln!(file, " \"val-loss\": {},", val_loss).unwrap();
348 writeln!(file, " \"val-acc\": {},", val_acc).unwrap();
349 writeln!(file, " \"tst-loss\": {},", test_loss).unwrap();
350 writeln!(file, " \"tst-acc\": {}", test_acc).unwrap();
351 }
352 writeln!(file, " }}").unwrap();
353
354 if run == RUNS {
355 writeln!(file, " }}").unwrap();
356 if method == &"FB2x3" && *skip && problem == &"REGRESSION" {
357 writeln!(file, " }}").unwrap();
358 } else {
359 writeln!(file, " }},").unwrap();
360 }
361 } else {
362 writeln!(file, " }},").unwrap();
363 }
364 }
365 });
366 });
367 });
368 writeln!(file, " }}").unwrap();
369 writeln!(file, "]").unwrap();
370}