cifar_feedback/
feedback.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, feedback, network, objective, optimizer, plot, random, tensor};
4
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::{BufReader, Read, Write};
8use std::time;
9
10const IMAGE_SIZE: usize = 32;
11const NUM_CHANNELS: usize = 3;
12const IMAGE_BYTES: usize = IMAGE_SIZE * IMAGE_SIZE * NUM_CHANNELS;
13
14pub fn load_cifar10(file_path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
15    let file = File::open(file_path).unwrap();
16    let mut reader = BufReader::new(file);
17    let mut buffer = vec![0u8; 1 + IMAGE_BYTES];
18
19    let mut labels = Vec::new();
20    let mut images = Vec::new();
21
22    while reader.read_exact(&mut buffer).is_ok() {
23        let label = buffer[0];
24        let mut image = vec![vec![vec![0.0f32; IMAGE_SIZE]; IMAGE_SIZE]; NUM_CHANNELS];
25
26        for channel in 0..NUM_CHANNELS {
27            for row in 0..IMAGE_SIZE {
28                for col in 0..IMAGE_SIZE {
29                    let index = 1 + channel * IMAGE_SIZE * IMAGE_SIZE + row * IMAGE_SIZE + col;
30                    image[channel][row][col] = buffer[index] as f32 / 255.0;
31                }
32            }
33        }
34
35        labels.push(tensor::Tensor::one_hot(label as usize, 10));
36        images.push(tensor::Tensor::triple(image));
37    }
38
39    (images, labels)
40}
41
42pub fn shuffle(
43    x: Vec<tensor::Tensor>,
44    y: Vec<tensor::Tensor>,
45) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
46    let mut generator = random::Generator::create(
47        time::SystemTime::now()
48            .duration_since(time::UNIX_EPOCH)
49            .unwrap()
50            .subsec_micros() as u64,
51    );
52
53    let mut indices: Vec<usize> = (0..y.len()).collect();
54    generator.shuffle(&mut indices);
55
56    let a: Vec<tensor::Tensor> = indices.iter().map(|&i| x[i].clone()).collect();
57    let b: Vec<tensor::Tensor> = indices.iter().map(|&i| y[i].clone()).collect();
58
59    (a, b)
60}
61
62fn main() {
63    let _labels: HashMap<u8, &str> = [
64        (0, "airplane"),
65        (1, "automobile"),
66        (2, "bird"),
67        (3, "cat"),
68        (4, "deer"),
69        (5, "dog"),
70        (6, "frog"),
71        (7, "horse"),
72        (8, "ship"),
73        (9, "truck"),
74    ]
75    .iter()
76    .cloned()
77    .collect();
78    let mut x_train = Vec::new();
79    let mut y_train = Vec::new();
80    for i in 1..6 {
81        let (x_batch, y_batch) =
82            load_cifar10(&format!("./examples/datasets/cifar10/data_batch_{}.bin", i));
83        x_train.extend(x_batch);
84        y_train.extend(y_batch);
85    }
86    let (x_test, y_test) = load_cifar10("./examples/datasets/cifar10/test_batch.bin");
87    println!(
88        "Train: {} images, Test: {} images",
89        x_train.len(),
90        x_test.len()
91    );
92
93    // Shuffle the data.
94    // let (x_train, y_train) = shuffle(x_train, y_train);
95    // let (x_test, y_test) = shuffle(x_test, y_test);
96
97    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
98    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
99    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
100    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
101
102    // plot::heatmap(
103    //     &x_train[0],
104    //     &format!("{}", &labels[&(y_train[0].argmax() as u8)]),
105    //     "./output/cifar/input.png",
106    // );
107
108    let mut network = network::Network::new(tensor::Shape::Triple(3, 32, 32));
109
110    network.convolution(
111        32,
112        (3, 3),
113        (1, 1),
114        (1, 1),
115        (1, 1),
116        activation::Activation::ReLU,
117        None,
118    );
119    network.feedback(
120        vec![feedback::Layer::Convolution(
121            32,
122            activation::Activation::ReLU,
123            (3, 3),
124            (1, 1),
125            (1, 1),
126            (1, 1),
127            None,
128        )],
129        2,
130        false,
131        false,
132        feedback::Accumulation::Mean,
133    );
134    network.maxpool((2, 2), (2, 2));
135    network.feedback(
136        vec![
137            feedback::Layer::Convolution(
138                32,
139                activation::Activation::ReLU,
140                (3, 3),
141                (1, 1),
142                (1, 1),
143                (1, 1),
144                None,
145            ),
146            feedback::Layer::Convolution(
147                32,
148                activation::Activation::ReLU,
149                (3, 3),
150                (1, 1),
151                (1, 1),
152                (1, 1),
153                None,
154            ),
155        ],
156        2,
157        false,
158        false,
159        feedback::Accumulation::Mean,
160    );
161    network.maxpool((2, 2), (2, 2));
162    network.dense(128, activation::Activation::ReLU, true, None);
163    network.dense(10, activation::Activation::Softmax, true, None);
164
165    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
166    network.set_objective(objective::Objective::CrossEntropy, None);
167
168    println!("{}", network);
169
170    // Train the network
171    let (train_loss, val_loss, val_acc) = network.learn(
172        &x_train,
173        &y_train,
174        Some((&x_test, &y_test, 5)),
175        128,
176        50,
177        Some(1),
178    );
179    plot::loss(
180        &train_loss,
181        &val_loss,
182        &val_acc,
183        "FEEDBACK : CIFAR-10",
184        "./output/cifar/feedback.png",
185    );
186
187    // Store the training metrics
188    let mut writer = File::create("./output/cifar/feedback.csv").unwrap();
189    writer.write_all(b"train_loss,val_loss,val_acc\n").unwrap();
190    for i in 0..train_loss.len() {
191        writer
192            .write_all(format!("{},{},{}\n", train_loss[i], val_loss[i], val_acc[i]).as_bytes())
193            .unwrap();
194    }
195
196    // Validate the network
197    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
198    println!(
199        "Final validation accuracy: {:.2} % and loss: {:.5}",
200        val_acc * 100.0,
201        val_loss
202    );
203
204    // Use the network
205    let prediction = network.predict(x_test.get(0).unwrap());
206    println!(
207        "Prediction Target: {}. Output: {}.",
208        y_test[0].argmax(),
209        prediction.argmax()
210    );
211}