1use neurons::{activation, feedback, network, objective, optimizer, tensor};
6
7use std::{
8 fs::File,
9 io::{BufReader, Read, Result, Write},
10 sync::Arc,
11 time,
12};
13
14const RUNS: usize = 5;
15const EPOCHS: i32 = 1;
16
17fn read(reader: &mut dyn Read) -> Result<u32> {
18 let mut buffer = [0; 4];
19 reader.read_exact(&mut buffer)?;
20 Ok(u32::from_be_bytes(buffer))
21}
22
23fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
24 let mut reader = BufReader::new(File::open(path)?);
25 let mut images: Vec<tensor::Tensor> = Vec::new();
26
27 let _magic_number = read(&mut reader)?;
28 let num_images = read(&mut reader)?;
29 let num_rows = read(&mut reader)?;
30 let num_cols = read(&mut reader)?;
31
32 for _ in 0..num_images {
33 let mut image: Vec<Vec<f32>> = Vec::new();
34 for _ in 0..num_rows {
35 let mut row: Vec<f32> = Vec::new();
36 for _ in 0..num_cols {
37 let mut pixel = [0];
38 reader.read_exact(&mut pixel)?;
39 row.push(pixel[0] as f32 / 255.0);
40 }
41 image.push(row);
42 }
43 images.push(tensor::Tensor::triple(vec![image]));
44 }
45
46 Ok(images)
47}
48
49fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
50 let mut reader = BufReader::new(File::open(file_path)?);
51 let _magic_number = read(&mut reader)?;
52 let num_labels = read(&mut reader)?;
53
54 let mut _labels = vec![0; num_labels as usize];
55 reader.read_exact(&mut _labels)?;
56
57 Ok(_labels
58 .iter()
59 .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
60 .collect())
61}
62
63fn main() {
64 let mut x = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
65 let mut y = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
66 let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
67 let class_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
68
69 x.extend(x_test);
70 y.extend(class_test);
71
72 let x: Vec<&tensor::Tensor> = x.iter().collect();
73 let y: Vec<&tensor::Tensor> = y.iter().collect();
74
75 let mut file = File::create("./output/timing/mnist.json").unwrap();
77 writeln!(file, "[").unwrap();
78 writeln!(file, " {{").unwrap();
79
80 vec!["REGULAR", "FB1x2", "FB1x3", "FB2x2", "FB2x3"]
81 .iter()
82 .for_each(|method| {
83 println!("Method: {}", method);
84 vec![false, true].iter().for_each(|skip| {
85 println!(" Skip: {}", skip);
86 vec!["CLASSIFICATION"].iter().for_each(|problem| {
87 println!(" Problem: {}", problem);
88
89 let mut train_times: Vec<f64> = Vec::new();
90 let mut valid_times: Vec<f64> = Vec::new();
91
92 for _ in 0..RUNS {
93 let mut network: network::Network;
95 network = network::Network::new(tensor::Shape::Triple(1, 28, 28));
96 network.convolution(
97 1,
98 (3, 3),
99 (1, 1),
100 (1, 1),
101 (1, 1),
102 activation::Activation::ReLU,
103 None,
104 );
105
106 if method == &"REGULAR" || method.contains(&"FB1") {
108 network.convolution(
109 1,
110 (3, 3),
111 (1, 1),
112 (1, 1),
113 (1, 1),
114 activation::Activation::ReLU,
115 None,
116 );
117 network.convolution(
118 1,
119 (3, 3),
120 (1, 1),
121 (1, 1),
122 (1, 1),
123 activation::Activation::ReLU,
124 None,
125 );
126 network.maxpool((2, 2), (2, 2));
127
128 if method.contains(&"FB1") {
130 network.loopback(
131 2,
132 0,
133 method.chars().last().unwrap().to_digit(10).unwrap() as usize
134 - 1,
135 Arc::new(|_loops| 1.0),
136 false,
137 );
138 }
139 } else {
140 network.feedback(
141 vec![feedback::Layer::Convolution(
142 1,
143 activation::Activation::ReLU,
144 (3, 3),
145 (1, 1),
146 (1, 1),
147 (1, 1),
148 None,
149 )],
150 method.chars().last().unwrap().to_digit(10).unwrap() as usize,
151 false,
152 false,
153 feedback::Accumulation::Mean,
154 );
155 network.convolution(
156 1,
157 (3, 3),
158 (1, 1),
159 (1, 1),
160 (1, 1),
161 activation::Activation::ReLU,
162 None,
163 );
164 network.maxpool((2, 2), (2, 2));
165 }
166
167 if problem == &"REGRESSION" {
169 panic!("Invalid problem type.");
170 } else {
171 network.dense(10, activation::Activation::Softmax, true, None);
172 network.set_objective(objective::Objective::CrossEntropy, None);
173 }
174
175 if *skip {
177 network.connect(1, network.layers.len() - 2);
178 }
179
180 network
181 .set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
182
183 let start = time::Instant::now();
184
185 if problem == &"REGRESSION" {
187 panic!("Invalid problem type.");
188 } else {
189 (_, _, _) = network.learn(&x, &y, None, 32, EPOCHS, None);
190 }
191
192 let duration = start.elapsed().as_secs_f64();
193 train_times.push(duration);
194
195 let start = time::Instant::now();
196
197 if problem == &"REGRESSION" {
199 panic!("Invalid problem type.");
200 } else {
201 (_) = network.predict_batch(&x);
202 }
203
204 let duration = start.elapsed().as_secs_f64();
205 valid_times.push(duration);
206 }
207
208 if method == &"FB2x3" && *skip && problem == &"CLASSIFICATION" {
209 writeln!(
210 file,
211 " \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}}",
212 method, skip, problem, train_times, valid_times
213 )
214 .unwrap();
215 } else {
216 writeln!(
217 file,
218 " \"{}-{}-{}\": {{\"train\": {:?}, \"validate\": {:?}}},",
219 method, skip, problem, train_times, valid_times
220 )
221 .unwrap();
222 }
223 });
224 });
225 });
226 writeln!(file, " }}").unwrap();
227 writeln!(file, "]").unwrap();
228}