1use neurons::{activation, feedback, network, objective, optimizer, random, tensor};
38
39use std::{
40 collections::HashMap,
41 fs::File,
42 io::{BufRead, BufReader, Write},
43 sync::Arc,
44};
45
46const RUNS: usize = 5;
47
48fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
49 let reader = BufReader::new(File::open(&path).unwrap());
50
51 let mut x: Vec<tensor::Tensor> = Vec::new();
52 let mut y: Vec<tensor::Tensor> = Vec::new();
53
54 for line in reader.lines().skip(1) {
55 let line = line.unwrap();
56 let record: Vec<&str> = line.split(',').collect();
57
58 let mut data: Vec<f32> = Vec::new();
59 for i in 2..14 {
60 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
61 }
62 x.push(tensor::Tensor::single(data));
63
64 y.push(tensor::Tensor::single(vec![record
65 .get(16)
66 .unwrap()
67 .parse::<f32>()
68 .unwrap()]));
69 }
70
71 let mut generator = random::Generator::create(12345);
72 let mut indices: Vec<usize> = (0..x.len()).collect();
73 generator.shuffle(&mut indices);
74
75 let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
76 let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
77
78 (x, y)
79}
80
81fn main() {
82 let (x, y) = data("./examples/datasets/bike/hour.csv");
84
85 let split = (x.len() as f32 * 0.8) as usize;
86 let x = x.split_at(split);
87 let y = y.split_at(split);
88
89 let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
90 let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
91 let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
92 let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
93
94 println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
95 println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
96
97 let mut file = File::create("./output/compare/bike.json").unwrap();
99 writeln!(file, "[").unwrap();
100 writeln!(file, " {{").unwrap();
101
102 vec![
103 "REGULAR", "FB1x2", "FB1x3", "FB1x4", "FB2x2", "FB2x3", "FB2x4",
104 ]
105 .iter()
106 .for_each(|method| {
107 println!("Method: {}", method);
108 vec![false, true].iter().for_each(|skip| {
109 println!(" Skip: {}", skip);
110 vec!["REGRESSION"].iter().for_each(|problem| {
111 println!(" Problem: {}", problem);
112 writeln!(file, " \"{}-{}-{}\": {{", method, skip, problem).unwrap();
113
114 for run in 1..RUNS + 1 {
115 println!(" Run: {}", run);
116 writeln!(file, " \"run-{}\": {{", run).unwrap();
117
118 let mut network: network::Network;
120 network = network::Network::new(tensor::Shape::Single(12));
121
122 if method == &"REGULAR" || method.contains(&"FB1") {
124 network.dense(24, activation::Activation::ReLU, false, None);
125 network.dense(24, activation::Activation::ReLU, false, None);
126 network.dense(24, activation::Activation::ReLU, false, None);
127
128 if method.contains(&"FB1") {
130 network.loopback(
131 2,
132 1,
133 method.chars().last().unwrap().to_digit(10).unwrap() as usize - 1,
134 Arc::new(|_loops| 1.0),
135 false,
136 );
137 }
138 } else {
139 network.dense(24, activation::Activation::ReLU, false, None);
140 network.feedback(
141 vec![
142 feedback::Layer::Dense(
143 24,
144 activation::Activation::ReLU,
145 false,
146 None,
147 ),
148 feedback::Layer::Dense(
149 24,
150 activation::Activation::ReLU,
151 false,
152 None,
153 ),
154 ],
155 method.chars().last().unwrap().to_digit(10).unwrap() as usize,
156 false,
157 false,
158 feedback::Accumulation::Mean,
159 );
160 }
161
162 if problem == &"REGRESSION" {
164 network.dense(1, activation::Activation::Linear, false, None);
165 network.set_objective(objective::Objective::RMSE, None);
166 } else {
167 panic!("Invalid problem type.");
168 }
169
170 if *skip {
172 network.connect(1, network.layers.len() - 1);
173 }
174
175 network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
176
177 let (train_loss, val_loss, val_acc);
179 if problem == &"REGRESSION" {
180 (train_loss, val_loss, val_acc) = network.learn(
181 &x_train,
182 &y_train,
183 Some((&x_test, &y_test, 25)),
184 64,
185 600,
186 None,
187 );
188 } else {
189 panic!("Invalid problem type.");
190 }
191
192 writeln!(file, " \"train\": {{").unwrap();
194 writeln!(file, " \"trn-loss\": {:?},", train_loss).unwrap();
195 writeln!(file, " \"val-loss\": {:?},", val_loss).unwrap();
196 writeln!(file, " \"val-acc\": {:?}", val_acc).unwrap();
197
198 if method != &"REGULAR" {
200 println!(" > Without feedback.");
201
202 let loopbacks = network.loopbacks.clone();
204 let layers = network.layers.clone();
205
206 if method.contains(&"FB1") {
208 network.loopbacks = HashMap::new();
209 } else {
210 match &mut network.layers.get_mut(1).unwrap() {
211 network::Layer::Feedback(fb) => {
212 fb.layers = fb.layers.drain(0..2).collect();
214 }
215 _ => panic!("Invalid layer."),
216 };
217 }
218
219 let (test_loss, test_acc);
220 if problem == &"REGRESSION" {
221 (test_loss, test_acc) = network.validate(&x_test, &y_test, 1e-6);
222 } else {
223 panic!("Invalid problem type.");
224 }
225
226 writeln!(file, " }},").unwrap();
227 writeln!(file, " \"no-feedback\": {{").unwrap();
228 writeln!(file, " \"tst-loss\": {},", test_loss).unwrap();
229 writeln!(file, " \"tst-acc\": {}", test_acc).unwrap();
230
231 network.loopbacks = loopbacks;
233 network.layers = layers;
234 }
235 if *skip {
236 println!(" > Without skip.");
237 network.connect = HashMap::new();
238
239 let (test_loss, test_acc);
240 if problem == &"REGRESSION" {
241 (test_loss, test_acc) = network.validate(&x_test, &y_test, 1e-6);
242 } else {
243 panic!("Invalid problem type.");
244 }
245
246 writeln!(file, " }},").unwrap();
247 writeln!(file, " \"no-skip\": {{").unwrap();
248 writeln!(file, " \"tst-loss\": {},", test_loss).unwrap();
249 writeln!(file, " \"tst-acc\": {}", test_acc).unwrap();
250 }
251 writeln!(file, " }}").unwrap();
252
253 if run == RUNS {
254 writeln!(file, " }}").unwrap();
255 if method == &"FB2x4" && *skip && problem == &"REGRESSION" {
256 writeln!(file, " }}").unwrap();
257 } else {
258 writeln!(file, " }},").unwrap();
259 }
260 } else {
261 writeln!(file, " }},").unwrap();
262 }
263 }
264 });
265 });
266 });
267 writeln!(file, " }}").unwrap();
268 writeln!(file, "]").unwrap();
269}