timing_iris/
iris.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2//
3// Code for comparison between the various architectures and their time differences.
4
5use neurons::{activation, feedback, network, objective, optimizer, random, tensor};
6
7use std::{
8    fs::File,
9    io::{BufRead, BufReader, Write},
10    sync::Arc,
11    time,
12};
13
14const RUNS: usize = 5;
15const EPOCHS: i32 = 1;
16
17fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
18    let reader = BufReader::new(File::open(&path).unwrap());
19
20    let mut x: Vec<Vec<f32>> = Vec::new();
21    let mut y: Vec<usize> = Vec::new();
22
23    for line in reader.lines().skip(1) {
24        let line = line.unwrap();
25        let record: Vec<&str> = line.split(',').collect();
26        x.push(vec![
27            record.get(1).unwrap().parse::<f32>().unwrap(),
28            record.get(2).unwrap().parse::<f32>().unwrap(),
29            record.get(3).unwrap().parse::<f32>().unwrap(),
30            record.get(4).unwrap().parse::<f32>().unwrap(),
31        ]);
32        y.push(match record.get(5).unwrap() {
33            &"Iris-setosa" => 0,
34            &"Iris-versicolor" => 1,
35            &"Iris-virginica" => 2,
36            _ => panic!("> Unknown class."),
37        });
38    }
39
40    let mut generator = random::Generator::create(12345);
41    let mut indices: Vec<usize> = (0..x.len()).collect();
42    generator.shuffle(&mut indices);
43
44    let x: Vec<tensor::Tensor> = indices
45        .iter()
46        .map(|&i| tensor::Tensor::single(x[i].clone()))
47        .collect();
48    let y: Vec<tensor::Tensor> = indices
49        .iter()
50        .map(|&i| tensor::Tensor::one_hot(y[i], 3))
51        .collect();
52
53    (x, y)
54}
55
56fn main() {
57    // Load the iris dataset
58    let (x, y) = data("./examples/datasets/iris.csv");
59    let x: Vec<&tensor::Tensor> = x.iter().collect();
60    let y: Vec<&tensor::Tensor> = y.iter().collect();
61
62    // Create the results file.
63    let mut file = File::create("./output/timing/iris.json").unwrap();
64    writeln!(file, "[").unwrap();
65    writeln!(file, "  {{").unwrap();
66
67    vec![
68        "REGULAR", "FB1x2", "FB1x3", "FB1x4", "FB2x2", "FB2x3", "FB2x4",
69    ]
70    .iter()
71    .for_each(|method| {
72        println!("Method: {}", method);
73        vec![false, true].iter().for_each(|skip| {
74            println!("  Skip: {}", skip);
75            vec!["CLASSIFICATION"].iter().for_each(|problem| {
76                println!("   Problem: {}", problem);
77
78                let mut train_times: Vec<f64> = Vec::new();
79                let mut valid_times: Vec<f64> = Vec::new();
80
81                for _ in 0..RUNS {
82                    // Create the network based on the architecture.
83                    let mut network: network::Network;
84                    network = network::Network::new(tensor::Shape::Single(4));
85
86                    // Check if the method is regular or feedback.
87                    if method == &"REGULAR" || method.contains(&"FB1") {
88                        network.dense(25, activation::Activation::ReLU, false, None);
89                        network.dense(25, activation::Activation::ReLU, false, None);
90                        network.dense(25, activation::Activation::ReLU, false, None);
91
92                        // Add the feedback loop if applicable.
93                        if method.contains(&"FB1") {
94                            network.loopback(
95                                2,
96                                1,
97                                method.chars().last().unwrap().to_digit(10).unwrap() as usize - 1,
98                                Arc::new(|_loops| 1.0),
99                                false,
100                            );
101                        }
102                    } else {
103                        network.dense(25, activation::Activation::ReLU, false, None);
104                        network.feedback(
105                            vec![
106                                feedback::Layer::Dense(
107                                    25,
108                                    activation::Activation::ReLU,
109                                    false,
110                                    None,
111                                ),
112                                feedback::Layer::Dense(
113                                    25,
114                                    activation::Activation::ReLU,
115                                    false,
116                                    None,
117                                ),
118                            ],
119                            method.chars().last().unwrap().to_digit(10).unwrap() as usize,
120                            false,
121                            false,
122                            feedback::Accumulation::Mean,
123                        );
124                    }
125
126                    // Set the output layer based on the problem.
127                    if problem == &"REGRESSION" {
128                        panic!("Invalid problem type.");
129                    } else {
130                        network.dense(3, activation::Activation::Softmax, false, None);
131                        network.set_objective(objective::Objective::CrossEntropy, None);
132                    }
133
134                    // Add the skip connection if applicable.
135                    if *skip {
136                        network.connect(1, network.layers.len() - 1);
137                    }
138
139                    network.set_optimizer(optimizer::Adam::create(0.0001, 0.95, 0.999, 1e-7, None));
140
141                    let start = time::Instant::now();
142
143                    // Train the network
144                    if problem == &"REGRESSION" {
145                        panic!("Invalid problem type.");
146                    } else {
147                        (_, _, _) = network.learn(&x, &y, None, 1, EPOCHS, None);
148                    }
149
150                    let duration = start.elapsed().as_secs_f64();
151                    train_times.push(duration);
152
153                    let start = time::Instant::now();
154
155                    // Validate the network
156                    if problem == &"REGRESSION" {
157                        panic!("Invalid problem type.");
158                    } else {
159                        (_) = network.predict_batch(&x);
160                    }
161
162                    let duration = start.elapsed().as_secs_f64();
163                    valid_times.push(duration);
164                }
165
166                if method == &"FB2x4" && *skip && problem == &"CLASSIFICATION" {
167                    writeln!(
168                        file,
169                        "    \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}}",
170                        method, skip, problem, train_times, valid_times
171                    )
172                    .unwrap();
173                } else {
174                    writeln!(
175                        file,
176                        "    \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}},",
177                        method, skip, problem, train_times, valid_times
178                    )
179                    .unwrap();
180                }
181            });
182        });
183    });
184    writeln!(file, "  }}").unwrap();
185    writeln!(file, "]").unwrap();
186}