1use neurons::{activation, network, objective, optimizer, plot, random, tensor};
4
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::{BufReader, Read, Write};
8use std::sync::Arc;
9use std::time;
10
11const IMAGE_SIZE: usize = 32;
12const NUM_CHANNELS: usize = 3;
13const IMAGE_BYTES: usize = IMAGE_SIZE * IMAGE_SIZE * NUM_CHANNELS;
14
15pub fn load_cifar10(file_path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
16 let file = File::open(file_path).unwrap();
17 let mut reader = BufReader::new(file);
18 let mut buffer = vec![0u8; 1 + IMAGE_BYTES];
19
20 let mut labels = Vec::new();
21 let mut images = Vec::new();
22
23 while reader.read_exact(&mut buffer).is_ok() {
24 let label = buffer[0];
25 let mut image = vec![vec![vec![0.0f32; IMAGE_SIZE]; IMAGE_SIZE]; NUM_CHANNELS];
26
27 for channel in 0..NUM_CHANNELS {
28 for row in 0..IMAGE_SIZE {
29 for col in 0..IMAGE_SIZE {
30 let index = 1 + channel * IMAGE_SIZE * IMAGE_SIZE + row * IMAGE_SIZE + col;
31 image[channel][row][col] = buffer[index] as f32 / 255.0;
32 }
33 }
34 }
35
36 labels.push(tensor::Tensor::one_hot(label as usize, 10));
37 images.push(tensor::Tensor::triple(image));
38 }
39
40 (images, labels)
41}
42
43pub fn shuffle(
44 x: Vec<tensor::Tensor>,
45 y: Vec<tensor::Tensor>,
46) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
47 let mut generator = random::Generator::create(
48 time::SystemTime::now()
49 .duration_since(time::UNIX_EPOCH)
50 .unwrap()
51 .subsec_micros() as u64,
52 );
53
54 let mut indices: Vec<usize> = (0..y.len()).collect();
55 generator.shuffle(&mut indices);
56
57 let a: Vec<tensor::Tensor> = indices.iter().map(|&i| x[i].clone()).collect();
58 let b: Vec<tensor::Tensor> = indices.iter().map(|&i| y[i].clone()).collect();
59
60 (a, b)
61}
62
63fn main() {
64 let _labels: HashMap<u8, &str> = [
65 (0, "airplane"),
66 (1, "automobile"),
67 (2, "bird"),
68 (3, "cat"),
69 (4, "deer"),
70 (5, "dog"),
71 (6, "frog"),
72 (7, "horse"),
73 (8, "ship"),
74 (9, "truck"),
75 ]
76 .iter()
77 .cloned()
78 .collect();
79 let mut x_train = Vec::new();
80 let mut y_train = Vec::new();
81 for i in 1..6 {
82 let (x_batch, y_batch) =
83 load_cifar10(&format!("./examples/datasets/cifar10/data_batch_{}.bin", i));
84 x_train.extend(x_batch);
85 y_train.extend(y_batch);
86 }
87 let (x_test, y_test) = load_cifar10("./examples/datasets/cifar10/test_batch.bin");
88 println!(
89 "Train: {} images, Test: {} images",
90 x_train.len(),
91 x_test.len()
92 );
93
94 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
99 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
100 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
101 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
102
103 let mut network = network::Network::new(tensor::Shape::Triple(3, 32, 32));
110
111 network.convolution(
112 32,
113 (3, 3),
114 (1, 1),
115 (1, 1),
116 (1, 1),
117 activation::Activation::ReLU,
118 None,
119 );
120 network.convolution(
121 32,
122 (3, 3),
123 (1, 1),
124 (1, 1),
125 (1, 1),
126 activation::Activation::ReLU,
127 None,
128 );
129 network.maxpool((2, 2), (2, 2));
130 network.convolution(
131 32,
132 (3, 3),
133 (1, 1),
134 (1, 1),
135 (1, 1),
136 activation::Activation::ReLU,
137 None,
138 );
139 network.convolution(
140 32,
141 (3, 3),
142 (1, 1),
143 (1, 1),
144 (1, 1),
145 activation::Activation::ReLU,
146 None,
147 );
148 network.maxpool((2, 2), (2, 2));
149 network.dense(512, activation::Activation::ReLU, true, None);
150 network.dense(10, activation::Activation::Softmax, true, None);
151
152 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
153 network.set_objective(objective::Objective::CrossEntropy, None);
154
155 network.loopback(1, 1, 1, Arc::new(|_loops| 1.0), false);
156 network.loopback(4, 3, 1, Arc::new(|_loops| 1.0), false);
157
158 println!("{}", network);
159
160 let (train_loss, val_loss, val_acc) = network.learn(
162 &x_train,
163 &y_train,
164 Some((&x_test, &y_test, 5)),
165 128,
166 50,
167 Some(1),
168 );
169 plot::loss(
170 &train_loss,
171 &val_loss,
172 &val_acc,
173 "LOOP : CIFAR-10",
174 "./output/cifar/loop.png",
175 );
176
177 let mut writer = File::create("./output/cifar/loop.csv").unwrap();
179 writer.write_all(b"train_loss,val_loss,val_acc\n").unwrap();
180 for i in 0..train_loss.len() {
181 writer
182 .write_all(format!("{},{},{}\n", train_loss[i], val_loss[i], val_acc[i]).as_bytes())
183 .unwrap();
184 }
185
186 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
188 println!(
189 "Final validation accuracy: {:.2} % and loss: {:.5}",
190 val_acc * 100.0,
191 val_loss
192 );
193
194 let prediction = network.predict(x_test.get(0).unwrap());
196 println!(
197 "Prediction Target: {}. Output: {}.",
198 y_test[0].argmax(),
199 prediction.argmax()
200 );
201}