1use neurons::{activation, feedback, network, objective, optimizer, tensor};
6
7use std::{
8 fs::File,
9 io::{BufRead, BufReader, Write},
10 sync::Arc,
11 time,
12};
13
14const RUNS: usize = 5;
15const EPOCHS: i32 = 1;
16
17fn data(
18 path: &str,
19) -> (
20 Vec<tensor::Tensor>,
21 Vec<tensor::Tensor>,
22 Vec<tensor::Tensor>,
23) {
24 let reader = BufReader::new(File::open(&path).unwrap());
25
26 let mut x: Vec<tensor::Tensor> = Vec::new();
27 let mut y: Vec<tensor::Tensor> = Vec::new();
28 let mut c: Vec<tensor::Tensor> = Vec::new();
29
30 for line in reader.lines().skip(1) {
31 let line = line.unwrap();
32 let record: Vec<&str> = line.split(',').collect();
33
34 let mut data: Vec<f32> = Vec::new();
35 for i in 0..571 {
36 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
37 }
38
39 x.push(tensor::Tensor::single(data));
40 y.push(tensor::Tensor::single(vec![record
41 .get(571)
42 .unwrap()
43 .parse::<f32>()
44 .unwrap()]));
45 c.push(tensor::Tensor::one_hot(
46 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
48 ));
49 }
50 (x, y, c)
51}
52
53fn main() {
54 let (x, y, c) = data("./examples/datasets/ftir.csv");
56
57 let x: Vec<&tensor::Tensor> = x.iter().collect();
58 let y: Vec<&tensor::Tensor> = y.iter().collect();
59 let c: Vec<&tensor::Tensor> = c.iter().collect();
60
61 let mut file = File::create("./output/timing/ftir-mlp.json").unwrap();
63 writeln!(file, "[").unwrap();
64 writeln!(file, " {{").unwrap();
65
66 vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
67 .iter()
68 .for_each(|method| {
69 println!("Method: {}", method);
70 vec![false, true].iter().for_each(|skip| {
71 println!(" Skip: {}", skip);
72 vec!["CLASSIFICATION", "REGRESSION"]
73 .iter()
74 .for_each(|problem| {
75 println!(" Problem: {}", problem);
76
77 let mut train_times: Vec<f64> = Vec::new();
78 let mut valid_times: Vec<f64> = Vec::new();
79
80 for _ in 0..RUNS {
81 let mut network: network::Network;
83 network = network::Network::new(tensor::Shape::Single(571));
84 network.dense(128, activation::Activation::ReLU, false, None);
85
86 if method == &"REGULAR" || method.contains(&"FB1") {
88 network.dense(256, activation::Activation::ReLU, false, None);
89 network.dense(128, activation::Activation::ReLU, false, None);
90
91 if method.contains(&"FB1") {
93 network.loopback(
94 2,
95 1,
96 method.chars().last().unwrap().to_digit(10).unwrap()
97 as usize
98 - 1,
99 Arc::new(|_loops| 1.0),
100 false,
101 );
102 }
103 } else {
104 network.feedback(
105 vec![
106 feedback::Layer::Dense(
107 256,
108 activation::Activation::ReLU,
109 false,
110 None,
111 ),
112 feedback::Layer::Dense(
113 128,
114 activation::Activation::ReLU,
115 false,
116 None,
117 ),
118 ],
119 method.chars().last().unwrap().to_digit(10).unwrap() as usize,
120 false,
121 false,
122 feedback::Accumulation::Mean,
123 );
124 }
125
126 if problem == &"REGRESSION" {
128 network.dense(1, activation::Activation::Linear, false, None);
129 network.set_objective(objective::Objective::RMSE, None);
130 } else {
131 network.dense(28, activation::Activation::Softmax, false, None);
132 network.set_objective(
133 objective::Objective::CrossEntropy,
134 Some((-5.0, 5.0)),
135 );
136 }
137
138 if *skip {
140 network.connect(1, network.layers.len() - 1);
141 }
142
143 network.set_optimizer(optimizer::Adam::create(
144 0.001, 0.9, 0.999, 1e-8, None,
145 ));
146
147 let start = time::Instant::now();
148
149 if problem == &"REGRESSION" {
151 (_, _, _) = network.learn(&x, &y, None, 32, EPOCHS, None);
152 } else {
153 (_, _, _) = network.learn(&x, &c, None, 32, EPOCHS, None);
154 }
155
156 let duration = start.elapsed().as_secs_f64();
157 train_times.push(duration);
158
159 let start = time::Instant::now();
160
161 (_) = network.predict_batch(&x);
163
164 let duration = start.elapsed().as_secs_f64();
165 valid_times.push(duration);
166 }
167
168 if method == &"FB2x3" && *skip && problem == &"REGRESSION" {
169 writeln!(
170 file,
171 " \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}}",
172 method, skip, problem, train_times, valid_times
173 )
174 .unwrap();
175 } else {
176 writeln!(
177 file,
178 " \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}},",
179 method, skip, problem, train_times, valid_times
180 )
181 .unwrap();
182 }
183 });
184 });
185 });
186 writeln!(file, " }}").unwrap();
187 writeln!(file, "]").unwrap();
188}