use neurons::{activation, feedback, network, objective, optimizer, plot, random, tensor};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read};
pub fn load_cifar10(file_path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
let file = File::open(file_path).unwrap();
let mut reader = BufReader::new(file);
let mut buffer = vec![0u8; 1 + 3072];
let mut labels = Vec::new();
let mut images = Vec::new();
while reader.read_exact(&mut buffer).is_ok() {
let label = buffer[0];
let mut image = vec![vec![vec![0.0f32; 32]; 32]; 3];
for channel in 0..3 {
for row in 0..32 {
for col in 0..32 {
let index = 1 + channel * 1024 + row * 32 + col;
image[channel][row][col] = buffer[index] as f32 / 255.0;
}
}
}
labels.push(label as usize);
images.push(tensor::Tensor::triple(image));
}
let mut generator = random::Generator::create(12345);
let mut indices: Vec<usize> = (0..labels.len()).collect();
generator.shuffle(&mut indices);
let images: Vec<tensor::Tensor> = indices.iter().map(|&i| images[i].clone()).collect();
let labels: Vec<tensor::Tensor> = indices
.iter()
.map(|&i| tensor::Tensor::one_hot(labels[i], 10))
.collect();
(images, labels)
}
fn main() {
let labels: HashMap<u8, &str> = [
(0, "airplane"),
(1, "automobile"),
(2, "bird"),
(3, "cat"),
(4, "deer"),
(5, "dog"),
(6, "frog"),
(7, "horse"),
(8, "ship"),
(9, "truck"),
]
.iter()
.cloned()
.collect();
let mut x_train = Vec::new();
let mut y_train = Vec::new();
for i in 1..6 {
let (x_batch, y_batch) =
load_cifar10(&format!("./examples/datasets/cifar10/data_batch_{}.bin", i));
x_train.extend(x_batch);
y_train.extend(y_batch);
}
let (x_test, y_test) = load_cifar10("./examples/datasets/cifar10/test_batch.bin");
println!(
"Train: {} images, Test: {} images",
x_train.len(),
x_test.len()
);
let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
plot::heatmap(
&x_train[0],
&format!("{}", &labels[&(y_train[0].argmax() as u8)]),
"./static/input.png",
);
let mut network = network::Network::new(tensor::Shape::Triple(3, 32, 32));
network.convolution(
32,
(5, 5),
(1, 1),
(1, 1),
activation::Activation::ReLU,
None,
);
for _ in 0..3 {
network.convolution(
32,
(3, 3),
(1, 1),
(1, 1),
activation::Activation::ReLU,
None,
);
}
network.maxpool((2, 2), (2, 2));
network.dense(128, activation::Activation::ReLU, true, None);
network.dense(10, activation::Activation::Softmax, true, None);
network.connect(1, 4);
network.set_accumulation(feedback::Accumulation::Add);
network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
network.set_objective(
objective::Objective::CrossEntropy, None, );
println!("{}", network);
let (train_loss, val_loss, val_acc) = network.learn(
&x_train,
&y_train,
Some((&x_test, &y_test, 5)),
128,
10,
Some(1),
);
plot::loss(
&train_loss,
&val_loss,
&val_acc,
"SKIP : CIFAR-10",
"./static/cifar10-skip.png",
);
let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
println!(
"Final validation accuracy: {:.2} % and loss: {:.5}",
val_acc * 100.0,
val_loss
);
let prediction = network.predict(x_test.get(0).unwrap());
println!(
"Prediction Target: {}. Output: {}.",
y_test[0].argmax(),
prediction.argmax()
);
}