cifar_loop/
looping.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, network, objective, optimizer, plot, random, tensor};
4
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::{BufReader, Read, Write};
8use std::sync::Arc;
9use std::time;
10
11const IMAGE_SIZE: usize = 32;
12const NUM_CHANNELS: usize = 3;
13const IMAGE_BYTES: usize = IMAGE_SIZE * IMAGE_SIZE * NUM_CHANNELS;
14
15pub fn load_cifar10(file_path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
16    let file = File::open(file_path).unwrap();
17    let mut reader = BufReader::new(file);
18    let mut buffer = vec![0u8; 1 + IMAGE_BYTES];
19
20    let mut labels = Vec::new();
21    let mut images = Vec::new();
22
23    while reader.read_exact(&mut buffer).is_ok() {
24        let label = buffer[0];
25        let mut image = vec![vec![vec![0.0f32; IMAGE_SIZE]; IMAGE_SIZE]; NUM_CHANNELS];
26
27        for channel in 0..NUM_CHANNELS {
28            for row in 0..IMAGE_SIZE {
29                for col in 0..IMAGE_SIZE {
30                    let index = 1 + channel * IMAGE_SIZE * IMAGE_SIZE + row * IMAGE_SIZE + col;
31                    image[channel][row][col] = buffer[index] as f32 / 255.0;
32                }
33            }
34        }
35
36        labels.push(tensor::Tensor::one_hot(label as usize, 10));
37        images.push(tensor::Tensor::triple(image));
38    }
39
40    (images, labels)
41}
42
43pub fn shuffle(
44    x: Vec<tensor::Tensor>,
45    y: Vec<tensor::Tensor>,
46) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
47    let mut generator = random::Generator::create(
48        time::SystemTime::now()
49            .duration_since(time::UNIX_EPOCH)
50            .unwrap()
51            .subsec_micros() as u64,
52    );
53
54    let mut indices: Vec<usize> = (0..y.len()).collect();
55    generator.shuffle(&mut indices);
56
57    let a: Vec<tensor::Tensor> = indices.iter().map(|&i| x[i].clone()).collect();
58    let b: Vec<tensor::Tensor> = indices.iter().map(|&i| y[i].clone()).collect();
59
60    (a, b)
61}
62
63fn main() {
64    let _labels: HashMap<u8, &str> = [
65        (0, "airplane"),
66        (1, "automobile"),
67        (2, "bird"),
68        (3, "cat"),
69        (4, "deer"),
70        (5, "dog"),
71        (6, "frog"),
72        (7, "horse"),
73        (8, "ship"),
74        (9, "truck"),
75    ]
76    .iter()
77    .cloned()
78    .collect();
79    let mut x_train = Vec::new();
80    let mut y_train = Vec::new();
81    for i in 1..6 {
82        let (x_batch, y_batch) =
83            load_cifar10(&format!("./examples/datasets/cifar10/data_batch_{}.bin", i));
84        x_train.extend(x_batch);
85        y_train.extend(y_batch);
86    }
87    let (x_test, y_test) = load_cifar10("./examples/datasets/cifar10/test_batch.bin");
88    println!(
89        "Train: {} images, Test: {} images",
90        x_train.len(),
91        x_test.len()
92    );
93
94    // Shuffle the data.
95    // let (x_train, y_train) = shuffle(x_train, y_train);
96    // let (x_test, y_test) = shuffle(x_test, y_test);
97
98    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
99    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
100    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
101    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
102
103    // plot::heatmap(
104    //     &x_train[0],
105    //     &format!("{}", &labels[&(y_train[0].argmax() as u8)]),
106    //     "./output/cifar/input.png",
107    // );
108
109    let mut network = network::Network::new(tensor::Shape::Triple(3, 32, 32));
110
111    network.convolution(
112        32,
113        (3, 3),
114        (1, 1),
115        (1, 1),
116        (1, 1),
117        activation::Activation::ReLU,
118        None,
119    );
120    network.convolution(
121        32,
122        (3, 3),
123        (1, 1),
124        (1, 1),
125        (1, 1),
126        activation::Activation::ReLU,
127        None,
128    );
129    network.maxpool((2, 2), (2, 2));
130    network.convolution(
131        32,
132        (3, 3),
133        (1, 1),
134        (1, 1),
135        (1, 1),
136        activation::Activation::ReLU,
137        None,
138    );
139    network.convolution(
140        32,
141        (3, 3),
142        (1, 1),
143        (1, 1),
144        (1, 1),
145        activation::Activation::ReLU,
146        None,
147    );
148    network.maxpool((2, 2), (2, 2));
149    network.dense(512, activation::Activation::ReLU, true, None);
150    network.dense(10, activation::Activation::Softmax, true, None);
151
152    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
153    network.set_objective(objective::Objective::CrossEntropy, None);
154
155    network.loopback(1, 1, 1, Arc::new(|_loops| 1.0), false);
156    network.loopback(4, 3, 1, Arc::new(|_loops| 1.0), false);
157
158    println!("{}", network);
159
160    // Train the network
161    let (train_loss, val_loss, val_acc) = network.learn(
162        &x_train,
163        &y_train,
164        Some((&x_test, &y_test, 5)),
165        128,
166        50,
167        Some(1),
168    );
169    plot::loss(
170        &train_loss,
171        &val_loss,
172        &val_acc,
173        "LOOP : CIFAR-10",
174        "./output/cifar/loop.png",
175    );
176
177    // Store the training metrics
178    let mut writer = File::create("./output/cifar/loop.csv").unwrap();
179    writer.write_all(b"train_loss,val_loss,val_acc\n").unwrap();
180    for i in 0..train_loss.len() {
181        writer
182            .write_all(format!("{},{},{}\n", train_loss[i], val_loss[i], val_acc[i]).as_bytes())
183            .unwrap();
184    }
185
186    // Validate the network
187    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
188    println!(
189        "Final validation accuracy: {:.2} % and loss: {:.5}",
190        val_acc * 100.0,
191        val_loss
192    );
193
194    // Use the network
195    let prediction = network.predict(x_test.get(0).unwrap());
196    println!(
197        "Prediction Target: {}. Output: {}.",
198        y_test[0].argmax(),
199        prediction.argmax()
200    );
201}