1use neurons::{activation, feedback, network, objective, optimizer, plot, random, tensor};
4
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::{BufReader, Read, Write};
8use std::time;
9
10const IMAGE_SIZE: usize = 32;
11const NUM_CHANNELS: usize = 3;
12const IMAGE_BYTES: usize = IMAGE_SIZE * IMAGE_SIZE * NUM_CHANNELS;
13
14pub fn load_cifar10(file_path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
15 let file = File::open(file_path).unwrap();
16 let mut reader = BufReader::new(file);
17 let mut buffer = vec![0u8; 1 + IMAGE_BYTES];
18
19 let mut labels = Vec::new();
20 let mut images = Vec::new();
21
22 while reader.read_exact(&mut buffer).is_ok() {
23 let label = buffer[0];
24 let mut image = vec![vec![vec![0.0f32; IMAGE_SIZE]; IMAGE_SIZE]; NUM_CHANNELS];
25
26 for channel in 0..NUM_CHANNELS {
27 for row in 0..IMAGE_SIZE {
28 for col in 0..IMAGE_SIZE {
29 let index = 1 + channel * IMAGE_SIZE * IMAGE_SIZE + row * IMAGE_SIZE + col;
30 image[channel][row][col] = buffer[index] as f32 / 255.0;
31 }
32 }
33 }
34
35 labels.push(tensor::Tensor::one_hot(label as usize, 10));
36 images.push(tensor::Tensor::triple(image));
37 }
38
39 (images, labels)
40}
41
42pub fn shuffle(
43 x: Vec<tensor::Tensor>,
44 y: Vec<tensor::Tensor>,
45) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
46 let mut generator = random::Generator::create(
47 time::SystemTime::now()
48 .duration_since(time::UNIX_EPOCH)
49 .unwrap()
50 .subsec_micros() as u64,
51 );
52
53 let mut indices: Vec<usize> = (0..y.len()).collect();
54 generator.shuffle(&mut indices);
55
56 let a: Vec<tensor::Tensor> = indices.iter().map(|&i| x[i].clone()).collect();
57 let b: Vec<tensor::Tensor> = indices.iter().map(|&i| y[i].clone()).collect();
58
59 (a, b)
60}
61
62fn main() {
63 let _labels: HashMap<u8, &str> = [
64 (0, "airplane"),
65 (1, "automobile"),
66 (2, "bird"),
67 (3, "cat"),
68 (4, "deer"),
69 (5, "dog"),
70 (6, "frog"),
71 (7, "horse"),
72 (8, "ship"),
73 (9, "truck"),
74 ]
75 .iter()
76 .cloned()
77 .collect();
78 let mut x_train = Vec::new();
79 let mut y_train = Vec::new();
80 for i in 1..6 {
81 let (x_batch, y_batch) =
82 load_cifar10(&format!("./examples/datasets/cifar10/data_batch_{}.bin", i));
83 x_train.extend(x_batch);
84 y_train.extend(y_batch);
85 }
86 let (x_test, y_test) = load_cifar10("./examples/datasets/cifar10/test_batch.bin");
87 println!(
88 "Train: {} images, Test: {} images",
89 x_train.len(),
90 x_test.len()
91 );
92
93 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
98 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
99 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
100 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
101
102 let mut network = network::Network::new(tensor::Shape::Triple(3, 32, 32));
109
110 network.convolution(
111 32,
112 (3, 3),
113 (1, 1),
114 (1, 1),
115 (1, 1),
116 activation::Activation::ReLU,
117 None,
118 );
119 network.feedback(
120 vec![feedback::Layer::Convolution(
121 32,
122 activation::Activation::ReLU,
123 (3, 3),
124 (1, 1),
125 (1, 1),
126 (1, 1),
127 None,
128 )],
129 2,
130 false,
131 false,
132 feedback::Accumulation::Mean,
133 );
134 network.maxpool((2, 2), (2, 2));
135 network.feedback(
136 vec![
137 feedback::Layer::Convolution(
138 32,
139 activation::Activation::ReLU,
140 (3, 3),
141 (1, 1),
142 (1, 1),
143 (1, 1),
144 None,
145 ),
146 feedback::Layer::Convolution(
147 32,
148 activation::Activation::ReLU,
149 (3, 3),
150 (1, 1),
151 (1, 1),
152 (1, 1),
153 None,
154 ),
155 ],
156 2,
157 false,
158 false,
159 feedback::Accumulation::Mean,
160 );
161 network.maxpool((2, 2), (2, 2));
162 network.dense(128, activation::Activation::ReLU, true, None);
163 network.dense(10, activation::Activation::Softmax, true, None);
164
165 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
166 network.set_objective(objective::Objective::CrossEntropy, None);
167
168 println!("{}", network);
169
170 let (train_loss, val_loss, val_acc) = network.learn(
172 &x_train,
173 &y_train,
174 Some((&x_test, &y_test, 5)),
175 128,
176 50,
177 Some(1),
178 );
179 plot::loss(
180 &train_loss,
181 &val_loss,
182 &val_acc,
183 "FEEDBACK : CIFAR-10",
184 "./output/cifar/feedback.png",
185 );
186
187 let mut writer = File::create("./output/cifar/feedback.csv").unwrap();
189 writer.write_all(b"train_loss,val_loss,val_acc\n").unwrap();
190 for i in 0..train_loss.len() {
191 writer
192 .write_all(format!("{},{},{}\n", train_loss[i], val_loss[i], val_acc[i]).as_bytes())
193 .unwrap();
194 }
195
196 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
198 println!(
199 "Final validation accuracy: {:.2} % and loss: {:.5}",
200 val_acc * 100.0,
201 val_loss
202 );
203
204 let prediction = network.predict(x_test.get(0).unwrap());
206 println!(
207 "Prediction Target: {}. Output: {}.",
208 y_test[0].argmax(),
209 prediction.argmax()
210 );
211}