1use neurons::{activation, feedback, network, objective, optimizer, random, tensor};
38
39use std::{
40 collections::HashMap,
41 fs::File,
42 io::{BufRead, BufReader, Write},
43 sync::Arc,
44};
45
46const RUNS: usize = 5;
47
48fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
49 let reader = BufReader::new(File::open(&path).unwrap());
50
51 let mut x: Vec<Vec<f32>> = Vec::new();
52 let mut y: Vec<usize> = Vec::new();
53
54 for line in reader.lines().skip(1) {
55 let line = line.unwrap();
56 let record: Vec<&str> = line.split(',').collect();
57 x.push(vec![
58 record.get(1).unwrap().parse::<f32>().unwrap(),
59 record.get(2).unwrap().parse::<f32>().unwrap(),
60 record.get(3).unwrap().parse::<f32>().unwrap(),
61 record.get(4).unwrap().parse::<f32>().unwrap(),
62 ]);
63 y.push(match record.get(5).unwrap() {
64 &"Iris-setosa" => 0,
65 &"Iris-versicolor" => 1,
66 &"Iris-virginica" => 2,
67 _ => panic!("> Unknown class."),
68 });
69 }
70
71 let mut generator = random::Generator::create(12345);
72 let mut indices: Vec<usize> = (0..x.len()).collect();
73 generator.shuffle(&mut indices);
74
75 let x: Vec<tensor::Tensor> = indices
76 .iter()
77 .map(|&i| tensor::Tensor::single(x[i].clone()))
78 .collect();
79 let y: Vec<tensor::Tensor> = indices
80 .iter()
81 .map(|&i| tensor::Tensor::one_hot(y[i], 3))
82 .collect();
83
84 (x, y)
85}
86
87fn main() {
88 let (x, y) = data("./examples/datasets/iris.csv");
90
91 let split = (x.len() as f32 * 0.8) as usize;
92 let x = x.split_at(split);
93 let y = y.split_at(split);
94
95 let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
96 let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
97 let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
98 let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
99
100 println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
101 println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
102
103 let mut file = File::create("./output/compare/iris.json").unwrap();
105 writeln!(file, "[").unwrap();
106 writeln!(file, " {{").unwrap();
107
108 vec![
109 "REGULAR", "FB1x2", "FB1x3", "FB1x4", "FB2x2", "FB2x3", "FB2x4",
110 ]
111 .iter()
112 .for_each(|method| {
113 println!("Method: {}", method);
114 vec![false, true].iter().for_each(|skip| {
115 println!(" Skip: {}", skip);
116 vec!["CLASSIFICATION"].iter().for_each(|problem| {
117 println!(" Problem: {}", problem);
118 writeln!(file, " \"{}-{}-{}\": {{", method, skip, problem).unwrap();
119
120 for run in 1..RUNS + 1 {
121 println!(" Run: {}", run);
122 writeln!(file, " \"run-{}\": {{", run).unwrap();
123
124 let mut network: network::Network;
126 network = network::Network::new(tensor::Shape::Single(4));
127
128 if method == &"REGULAR" || method.contains(&"FB1") {
130 network.dense(25, activation::Activation::ReLU, false, None);
131 network.dense(25, activation::Activation::ReLU, false, None);
132 network.dense(25, activation::Activation::ReLU, false, None);
133
134 if method.contains(&"FB1") {
136 network.loopback(
137 2,
138 1,
139 method.chars().last().unwrap().to_digit(10).unwrap() as usize - 1,
140 Arc::new(|_loops| 1.0),
141 false,
142 );
143 }
144 } else {
145 network.dense(25, activation::Activation::ReLU, false, None);
146 network.feedback(
147 vec![
148 feedback::Layer::Dense(
149 25,
150 activation::Activation::ReLU,
151 false,
152 None,
153 ),
154 feedback::Layer::Dense(
155 25,
156 activation::Activation::ReLU,
157 false,
158 None,
159 ),
160 ],
161 method.chars().last().unwrap().to_digit(10).unwrap() as usize,
162 false,
163 false,
164 feedback::Accumulation::Mean,
165 );
166 }
167
168 if problem == &"REGRESSION" {
170 panic!("Invalid problem type.");
171 } else {
172 network.dense(3, activation::Activation::Softmax, false, None);
173 network.set_objective(objective::Objective::CrossEntropy, None);
174 }
175
176 if *skip {
178 network.connect(1, network.layers.len() - 1);
179 }
180
181 network.set_optimizer(optimizer::Adam::create(0.0001, 0.95, 0.999, 1e-7, None));
182
183 let (train_loss, val_loss, val_acc);
185 if problem == &"REGRESSION" {
186 panic!("Invalid problem type.");
187 } else {
188 (train_loss, val_loss, val_acc) = network.learn(
189 &x_train,
190 &y_train,
191 Some((&x_test, &y_test, 10)),
192 1,
193 100,
194 None,
195 );
196 }
197
198 writeln!(file, " \"train\": {{").unwrap();
200 writeln!(file, " \"trn-loss\": {:?},", train_loss).unwrap();
201 writeln!(file, " \"val-loss\": {:?},", val_loss).unwrap();
202 writeln!(file, " \"val-acc\": {:?}", val_acc).unwrap();
203
204 if method != &"REGULAR" {
206 println!(" > Without feedback.");
207
208 let loopbacks = network.loopbacks.clone();
210 let layers = network.layers.clone();
211
212 if method.contains(&"FB1") {
214 network.loopbacks = HashMap::new();
215 } else {
216 match &mut network.layers.get_mut(1).unwrap() {
217 network::Layer::Feedback(fb) => {
218 fb.layers = fb.layers.drain(0..2).collect();
220 }
221 _ => panic!("Invalid layer."),
222 };
223 }
224
225 let (test_loss, test_acc);
226 if problem == &"REGRESSION" {
227 panic!("Invalid problem type.");
228 } else {
229 (test_loss, test_acc) = network.validate(&x_test, &y_test, 1e-6);
230 }
231
232 writeln!(file, " }},").unwrap();
233 writeln!(file, " \"no-feedback\": {{").unwrap();
234 writeln!(file, " \"tst-loss\": {},", test_loss).unwrap();
235 writeln!(file, " \"tst-acc\": {}", test_acc).unwrap();
236
237 network.loopbacks = loopbacks;
239 network.layers = layers;
240 }
241 if *skip {
242 println!(" > Without skip.");
243 network.connect = HashMap::new();
244
245 let (test_loss, test_acc);
246 if problem == &"REGRESSION" {
247 panic!("Invalid problem type.");
248 } else {
249 (test_loss, test_acc) = network.validate(&x_test, &y_test, 1e-6);
250 }
251
252 writeln!(file, " }},").unwrap();
253 writeln!(file, " \"no-skip\": {{").unwrap();
254 writeln!(file, " \"tst-loss\": {},", test_loss).unwrap();
255 writeln!(file, " \"tst-acc\": {}", test_acc).unwrap();
256 }
257 writeln!(file, " }}").unwrap();
258
259 if run == RUNS {
260 writeln!(file, " }}").unwrap();
261 if method == &"FB2x4" && *skip && problem == &"CLASSIFICATION" {
262 writeln!(file, " }}").unwrap();
263 } else {
264 writeln!(file, " }},").unwrap();
265 }
266 } else {
267 writeln!(file, " }},").unwrap();
268 }
269 }
270 });
271 });
272 });
273 writeln!(file, " }}").unwrap();
274 writeln!(file, "]").unwrap();
275}