cifar_plain/
plain.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, network, objective, optimizer, plot, random, tensor};
4
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::{BufReader, Read, Write};
8use std::time;
9
10const IMAGE_SIZE: usize = 32;
11const NUM_CHANNELS: usize = 3;
12const IMAGE_BYTES: usize = IMAGE_SIZE * IMAGE_SIZE * NUM_CHANNELS;
13
14pub fn load_cifar10(file_path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
15    let file = File::open(file_path).unwrap();
16    let mut reader = BufReader::new(file);
17    let mut buffer = vec![0u8; 1 + IMAGE_BYTES];
18
19    let mut labels = Vec::new();
20    let mut images = Vec::new();
21
22    while reader.read_exact(&mut buffer).is_ok() {
23        let label = buffer[0];
24        let mut image = vec![vec![vec![0.0f32; IMAGE_SIZE]; IMAGE_SIZE]; NUM_CHANNELS];
25
26        for channel in 0..NUM_CHANNELS {
27            for row in 0..IMAGE_SIZE {
28                for col in 0..IMAGE_SIZE {
29                    let index = 1 + channel * IMAGE_SIZE * IMAGE_SIZE + row * IMAGE_SIZE + col;
30                    image[channel][row][col] = buffer[index] as f32 / 255.0;
31                }
32            }
33        }
34
35        labels.push(tensor::Tensor::one_hot(label as usize, 10));
36        images.push(tensor::Tensor::triple(image));
37    }
38
39    (images, labels)
40}
41
42pub fn shuffle(
43    x: Vec<tensor::Tensor>,
44    y: Vec<tensor::Tensor>,
45) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
46    let mut generator = random::Generator::create(
47        time::SystemTime::now()
48            .duration_since(time::UNIX_EPOCH)
49            .unwrap()
50            .subsec_micros() as u64,
51    );
52
53    let mut indices: Vec<usize> = (0..y.len()).collect();
54    generator.shuffle(&mut indices);
55
56    let a: Vec<tensor::Tensor> = indices.iter().map(|&i| x[i].clone()).collect();
57    let b: Vec<tensor::Tensor> = indices.iter().map(|&i| y[i].clone()).collect();
58
59    (a, b)
60}
61
62fn main() {
63    let _labels: HashMap<u8, &str> = [
64        (0, "airplane"),
65        (1, "automobile"),
66        (2, "bird"),
67        (3, "cat"),
68        (4, "deer"),
69        (5, "dog"),
70        (6, "frog"),
71        (7, "horse"),
72        (8, "ship"),
73        (9, "truck"),
74    ]
75    .iter()
76    .cloned()
77    .collect();
78    let mut x_train = Vec::new();
79    let mut y_train = Vec::new();
80    for i in 1..6 {
81        let (x_batch, y_batch) =
82            load_cifar10(&format!("./examples/datasets/cifar10/data_batch_{}.bin", i));
83        x_train.extend(x_batch);
84        y_train.extend(y_batch);
85    }
86    let (x_test, y_test) = load_cifar10("./examples/datasets/cifar10/test_batch.bin");
87    println!(
88        "Train: {} images, Test: {} images",
89        x_train.len(),
90        x_test.len()
91    );
92
93    // Shuffle the data.
94    // let (x_train, y_train) = shuffle(x_train, y_train);
95    // let (x_test, y_test) = shuffle(x_test, y_test);
96
97    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
98    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
99    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
100    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
101
102    // plot::heatmap(
103    //     &x_train[0],
104    //     &format!("{}", &labels[&(y_train[0].argmax() as u8)]),
105    //     "./output/cifar/input.png",
106    // );
107
108    let mut network = network::Network::new(tensor::Shape::Triple(3, 32, 32));
109
110    network.convolution(
111        32,
112        (3, 3),
113        (1, 1),
114        (1, 1),
115        (1, 1),
116        activation::Activation::ReLU,
117        None,
118    );
119    network.convolution(
120        32,
121        (3, 3),
122        (1, 1),
123        (1, 1),
124        (1, 1),
125        activation::Activation::ReLU,
126        None,
127    );
128    network.maxpool((2, 2), (2, 2));
129    network.convolution(
130        32,
131        (3, 3),
132        (1, 1),
133        (1, 1),
134        (1, 1),
135        activation::Activation::ReLU,
136        None,
137    );
138    network.convolution(
139        32,
140        (3, 3),
141        (1, 1),
142        (1, 1),
143        (1, 1),
144        activation::Activation::ReLU,
145        None,
146    );
147    network.maxpool((2, 2), (2, 2));
148    network.dense(512, activation::Activation::ReLU, true, None);
149    network.dense(10, activation::Activation::Softmax, true, None);
150
151    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
152    network.set_objective(objective::Objective::CrossEntropy, None);
153
154    println!("{}", network);
155
156    // Train the network
157    let (train_loss, val_loss, val_acc) = network.learn(
158        &x_train,
159        &y_train,
160        Some((&x_test, &y_test, 5)),
161        128,
162        50,
163        Some(1),
164    );
165    plot::loss(
166        &train_loss,
167        &val_loss,
168        &val_acc,
169        "PLAIN : CIFAR-10",
170        "./output/cifar/plain.png",
171    );
172
173    // Store the training metrics
174    let mut writer = File::create("./output/cifar/plain.csv").unwrap();
175    writer.write_all(b"train_loss,val_loss,val_acc\n").unwrap();
176    for i in 0..train_loss.len() {
177        writer
178            .write_all(format!("{},{},{}\n", train_loss[i], val_loss[i], val_acc[i]).as_bytes())
179            .unwrap();
180    }
181
182    // Validate the network
183    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
184    println!(
185        "Final validation accuracy: {:.2} % and loss: {:.5}",
186        val_acc * 100.0,
187        val_loss
188    );
189
190    // Use the network
191    let prediction = network.predict(x_test.get(0).unwrap());
192    println!(
193        "Prediction Target: {}. Output: {}.",
194        y_test[0].argmax(),
195        prediction.argmax()
196    );
197}