1use neurons::{activation, feedback, network, objective, optimizer, random, tensor};
6
7use std::{
8 fs::File,
9 io::{BufRead, BufReader, Write},
10 sync::Arc,
11 time,
12};
13
14const RUNS: usize = 5;
15const EPOCHS: i32 = 1;
16
17fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
18 let reader = BufReader::new(File::open(&path).unwrap());
19
20 let mut x: Vec<tensor::Tensor> = Vec::new();
21 let mut y: Vec<tensor::Tensor> = Vec::new();
22
23 for line in reader.lines().skip(1) {
24 let line = line.unwrap();
25 let record: Vec<&str> = line.split(',').collect();
26
27 let mut data: Vec<f32> = Vec::new();
28 for i in 2..14 {
29 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
30 }
31 x.push(tensor::Tensor::single(data));
32
33 y.push(tensor::Tensor::single(vec![record
34 .get(16)
35 .unwrap()
36 .parse::<f32>()
37 .unwrap()]));
38 }
39
40 let mut generator = random::Generator::create(12345);
41 let mut indices: Vec<usize> = (0..x.len()).collect();
42 generator.shuffle(&mut indices);
43
44 let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
45 let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
46
47 (x, y)
48}
49
50fn main() {
51 let (x, y) = data("./examples/datasets/bike/hour.csv");
53 let x: Vec<&tensor::Tensor> = x.iter().collect();
54 let y: Vec<&tensor::Tensor> = y.iter().collect();
55
56 let mut file = File::create("./output/timing/bike.json").unwrap();
58 writeln!(file, "[").unwrap();
59 writeln!(file, " {{").unwrap();
60
61 vec![
62 "REGULAR", "FB1x2", "FB1x3", "FB1x4", "FB2x2", "FB2x3", "FB2x4",
63 ]
64 .iter()
65 .for_each(|method| {
66 println!("Method: {}", method);
67 vec![false, true].iter().for_each(|skip| {
68 println!(" Skip: {}", skip);
69 vec!["REGRESSION"].iter().for_each(|problem| {
70 println!(" Problem: {}", problem);
71
72 let mut train_times: Vec<f64> = Vec::new();
73 let mut valid_times: Vec<f64> = Vec::new();
74
75 for _ in 0..RUNS {
76 let mut network: network::Network;
78 network = network::Network::new(tensor::Shape::Single(12));
79
80 if method == &"REGULAR" || method.contains(&"FB1") {
82 network.dense(24, activation::Activation::ReLU, false, None);
83 network.dense(24, activation::Activation::ReLU, false, None);
84 network.dense(24, activation::Activation::ReLU, false, None);
85
86 if method.contains(&"FB1") {
88 network.loopback(
89 2,
90 1,
91 method.chars().last().unwrap().to_digit(10).unwrap() as usize - 1,
92 Arc::new(|_loops| 1.0),
93 false,
94 );
95 }
96 } else {
97 network.dense(24, activation::Activation::ReLU, false, None);
98 network.feedback(
99 vec![
100 feedback::Layer::Dense(
101 24,
102 activation::Activation::ReLU,
103 false,
104 None,
105 ),
106 feedback::Layer::Dense(
107 24,
108 activation::Activation::ReLU,
109 false,
110 None,
111 ),
112 ],
113 method.chars().last().unwrap().to_digit(10).unwrap() as usize,
114 false,
115 false,
116 feedback::Accumulation::Mean,
117 );
118 }
119
120 if problem == &"REGRESSION" {
122 network.dense(1, activation::Activation::Linear, false, None);
123 network.set_objective(objective::Objective::RMSE, None);
124 } else {
125 panic!("Invalid problem type.");
126 }
127
128 if *skip {
130 network.connect(1, network.layers.len() - 1);
131 }
132
133 network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
134
135 let start = time::Instant::now();
136
137 if problem == &"CLASSIFICATION" {
139 panic!("Invalid problem type.");
140 } else {
141 (_, _, _) = network.learn(&x, &y, None, 64, EPOCHS, None);
142 }
143
144 let duration = start.elapsed().as_secs_f64();
145 train_times.push(duration);
146
147 let start = time::Instant::now();
148
149 if problem == &"CLASSIFICATION" {
151 panic!("Invalid problem type.");
152 } else {
153 (_) = network.predict_batch(&x);
154 }
155
156 let duration = start.elapsed().as_secs_f64();
157 valid_times.push(duration);
158 }
159
160 if method == &"FB2x4" && *skip && problem == &"REGRESSION" {
161 writeln!(
162 file,
163 " \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}}",
164 method, skip, problem, train_times, valid_times
165 )
166 .unwrap();
167 } else {
168 writeln!(
169 file,
170 " \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}},",
171 method, skip, problem, train_times, valid_times
172 )
173 .unwrap();
174 }
175 });
176 });
177 });
178 writeln!(file, " }}").unwrap();
179 writeln!(file, "]").unwrap();
180}