1use neurons::{activation, feedback, network, objective, optimizer, random, tensor};
6
7use std::{
8 fs::File,
9 io::{BufRead, BufReader, Write},
10 sync::Arc,
11 time,
12};
13
14const RUNS: usize = 5;
15const EPOCHS: i32 = 1;
16
17fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
18 let reader = BufReader::new(File::open(&path).unwrap());
19
20 let mut x: Vec<Vec<f32>> = Vec::new();
21 let mut y: Vec<usize> = Vec::new();
22
23 for line in reader.lines().skip(1) {
24 let line = line.unwrap();
25 let record: Vec<&str> = line.split(',').collect();
26 x.push(vec![
27 record.get(1).unwrap().parse::<f32>().unwrap(),
28 record.get(2).unwrap().parse::<f32>().unwrap(),
29 record.get(3).unwrap().parse::<f32>().unwrap(),
30 record.get(4).unwrap().parse::<f32>().unwrap(),
31 ]);
32 y.push(match record.get(5).unwrap() {
33 &"Iris-setosa" => 0,
34 &"Iris-versicolor" => 1,
35 &"Iris-virginica" => 2,
36 _ => panic!("> Unknown class."),
37 });
38 }
39
40 let mut generator = random::Generator::create(12345);
41 let mut indices: Vec<usize> = (0..x.len()).collect();
42 generator.shuffle(&mut indices);
43
44 let x: Vec<tensor::Tensor> = indices
45 .iter()
46 .map(|&i| tensor::Tensor::single(x[i].clone()))
47 .collect();
48 let y: Vec<tensor::Tensor> = indices
49 .iter()
50 .map(|&i| tensor::Tensor::one_hot(y[i], 3))
51 .collect();
52
53 (x, y)
54}
55
56fn main() {
57 let (x, y) = data("./examples/datasets/iris.csv");
59 let x: Vec<&tensor::Tensor> = x.iter().collect();
60 let y: Vec<&tensor::Tensor> = y.iter().collect();
61
62 let mut file = File::create("./output/timing/iris.json").unwrap();
64 writeln!(file, "[").unwrap();
65 writeln!(file, " {{").unwrap();
66
67 vec![
68 "REGULAR", "FB1x2", "FB1x3", "FB1x4", "FB2x2", "FB2x3", "FB2x4",
69 ]
70 .iter()
71 .for_each(|method| {
72 println!("Method: {}", method);
73 vec![false, true].iter().for_each(|skip| {
74 println!(" Skip: {}", skip);
75 vec!["CLASSIFICATION"].iter().for_each(|problem| {
76 println!(" Problem: {}", problem);
77
78 let mut train_times: Vec<f64> = Vec::new();
79 let mut valid_times: Vec<f64> = Vec::new();
80
81 for _ in 0..RUNS {
82 let mut network: network::Network;
84 network = network::Network::new(tensor::Shape::Single(4));
85
86 if method == &"REGULAR" || method.contains(&"FB1") {
88 network.dense(25, activation::Activation::ReLU, false, None);
89 network.dense(25, activation::Activation::ReLU, false, None);
90 network.dense(25, activation::Activation::ReLU, false, None);
91
92 if method.contains(&"FB1") {
94 network.loopback(
95 2,
96 1,
97 method.chars().last().unwrap().to_digit(10).unwrap() as usize - 1,
98 Arc::new(|_loops| 1.0),
99 false,
100 );
101 }
102 } else {
103 network.dense(25, activation::Activation::ReLU, false, None);
104 network.feedback(
105 vec![
106 feedback::Layer::Dense(
107 25,
108 activation::Activation::ReLU,
109 false,
110 None,
111 ),
112 feedback::Layer::Dense(
113 25,
114 activation::Activation::ReLU,
115 false,
116 None,
117 ),
118 ],
119 method.chars().last().unwrap().to_digit(10).unwrap() as usize,
120 false,
121 false,
122 feedback::Accumulation::Mean,
123 );
124 }
125
126 if problem == &"REGRESSION" {
128 panic!("Invalid problem type.");
129 } else {
130 network.dense(3, activation::Activation::Softmax, false, None);
131 network.set_objective(objective::Objective::CrossEntropy, None);
132 }
133
134 if *skip {
136 network.connect(1, network.layers.len() - 1);
137 }
138
139 network.set_optimizer(optimizer::Adam::create(0.0001, 0.95, 0.999, 1e-7, None));
140
141 let start = time::Instant::now();
142
143 if problem == &"REGRESSION" {
145 panic!("Invalid problem type.");
146 } else {
147 (_, _, _) = network.learn(&x, &y, None, 1, EPOCHS, None);
148 }
149
150 let duration = start.elapsed().as_secs_f64();
151 train_times.push(duration);
152
153 let start = time::Instant::now();
154
155 if problem == &"REGRESSION" {
157 panic!("Invalid problem type.");
158 } else {
159 (_) = network.predict_batch(&x);
160 }
161
162 let duration = start.elapsed().as_secs_f64();
163 valid_times.push(duration);
164 }
165
166 if method == &"FB2x4" && *skip && problem == &"CLASSIFICATION" {
167 writeln!(
168 file,
169 " \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}}",
170 method, skip, problem, train_times, valid_times
171 )
172 .unwrap();
173 } else {
174 writeln!(
175 file,
176 " \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}},",
177 method, skip, problem, train_times, valid_times
178 )
179 .unwrap();
180 }
181 });
182 });
183 });
184 writeln!(file, " }}").unwrap();
185 writeln!(file, "]").unwrap();
186}