timing_bike/
bike.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// Code for comparison between the various architectures and their time differences.
4
5use neurons::{activation, feedback, network, objective, optimizer, random, tensor};
6
7use std::{
8    fs::File,
9    io::{BufRead, BufReader, Write},
10    sync::Arc,
11    time,
12};
13
14const RUNS: usize = 5;
15const EPOCHS: i32 = 1;
16
17fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
18    let reader = BufReader::new(File::open(&path).unwrap());
19
20    let mut x: Vec<tensor::Tensor> = Vec::new();
21    let mut y: Vec<tensor::Tensor> = Vec::new();
22
23    for line in reader.lines().skip(1) {
24        let line = line.unwrap();
25        let record: Vec<&str> = line.split(',').collect();
26
27        let mut data: Vec<f32> = Vec::new();
28        for i in 2..14 {
29            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
30        }
31        x.push(tensor::Tensor::single(data));
32
33        y.push(tensor::Tensor::single(vec![record
34            .get(16)
35            .unwrap()
36            .parse::<f32>()
37            .unwrap()]));
38    }
39
40    let mut generator = random::Generator::create(12345);
41    let mut indices: Vec<usize> = (0..x.len()).collect();
42    generator.shuffle(&mut indices);
43
44    let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
45    let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
46
47    (x, y)
48}
49
50fn main() {
51    // Load the bike dataset
52    let (x, y) = data("./examples/datasets/bike/hour.csv");
53    let x: Vec<&tensor::Tensor> = x.iter().collect();
54    let y: Vec<&tensor::Tensor> = y.iter().collect();
55
56    // Create the results file.
57    let mut file = File::create("./output/timing/bike.json").unwrap();
58    writeln!(file, "[").unwrap();
59    writeln!(file, "  {{").unwrap();
60
61    vec![
62        "REGULAR", "FB1x2", "FB1x3", "FB1x4", "FB2x2", "FB2x3", "FB2x4",
63    ]
64    .iter()
65    .for_each(|method| {
66        println!("Method: {}", method);
67        vec![false, true].iter().for_each(|skip| {
68            println!("  Skip: {}", skip);
69            vec!["REGRESSION"].iter().for_each(|problem| {
70                println!("   Problem: {}", problem);
71
72                let mut train_times: Vec<f64> = Vec::new();
73                let mut valid_times: Vec<f64> = Vec::new();
74
75                for _ in 0..RUNS {
76                    // Create the network based on the architecture.
77                    let mut network: network::Network;
78                    network = network::Network::new(tensor::Shape::Single(12));
79
80                    // Check if the method is regular or feedback.
81                    if method == &"REGULAR" || method.contains(&"FB1") {
82                        network.dense(24, activation::Activation::ReLU, false, None);
83                        network.dense(24, activation::Activation::ReLU, false, None);
84                        network.dense(24, activation::Activation::ReLU, false, None);
85
86                        // Add the feedback loop if applicable.
87                        if method.contains(&"FB1") {
88                            network.loopback(
89                                2,
90                                1,
91                                method.chars().last().unwrap().to_digit(10).unwrap() as usize - 1,
92                                Arc::new(|_loops| 1.0),
93                                false,
94                            );
95                        }
96                    } else {
97                        network.dense(24, activation::Activation::ReLU, false, None);
98                        network.feedback(
99                            vec![
100                                feedback::Layer::Dense(
101                                    24,
102                                    activation::Activation::ReLU,
103                                    false,
104                                    None,
105                                ),
106                                feedback::Layer::Dense(
107                                    24,
108                                    activation::Activation::ReLU,
109                                    false,
110                                    None,
111                                ),
112                            ],
113                            method.chars().last().unwrap().to_digit(10).unwrap() as usize,
114                            false,
115                            false,
116                            feedback::Accumulation::Mean,
117                        );
118                    }
119
120                    // Set the output layer based on the problem.
121                    if problem == &"REGRESSION" {
122                        network.dense(1, activation::Activation::Linear, false, None);
123                        network.set_objective(objective::Objective::RMSE, None);
124                    } else {
125                        panic!("Invalid problem type.");
126                    }
127
128                    // Add the skip connection if applicable.
129                    if *skip {
130                        network.connect(1, network.layers.len() - 1);
131                    }
132
133                    network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
134
135                    let start = time::Instant::now();
136
137                    // Train the network
138                    if problem == &"CLASSIFICATION" {
139                        panic!("Invalid problem type.");
140                    } else {
141                        (_, _, _) = network.learn(&x, &y, None, 64, EPOCHS, None);
142                    }
143
144                    let duration = start.elapsed().as_secs_f64();
145                    train_times.push(duration);
146
147                    let start = time::Instant::now();
148
149                    // Validate the network
150                    if problem == &"CLASSIFICATION" {
151                        panic!("Invalid problem type.");
152                    } else {
153                        (_) = network.predict_batch(&x);
154                    }
155
156                    let duration = start.elapsed().as_secs_f64();
157                    valid_times.push(duration);
158                }
159
160                if method == &"FB2x4" && *skip && problem == &"REGRESSION" {
161                    writeln!(
162                        file,
163                        "    \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}}",
164                        method, skip, problem, train_times, valid_times
165                    )
166                    .unwrap();
167                } else {
168                    writeln!(
169                        file,
170                        "    \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}},",
171                        method, skip, problem, train_times, valid_times
172                    )
173                    .unwrap();
174                }
175            });
176        });
177    });
178    writeln!(file, "  }}").unwrap();
179    writeln!(file, "]").unwrap();
180}