fashion_plain/
plain.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, network, objective, optimizer, plot, tensor};
4
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::{BufReader, Read, Result};
8
9fn read(reader: &mut dyn Read) -> Result<u32> {
10    let mut buffer = [0; 4];
11    reader.read_exact(&mut buffer)?;
12    Ok(u32::from_be_bytes(buffer))
13}
14
15fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
16    let mut reader = BufReader::new(File::open(path)?);
17    let mut images: Vec<tensor::Tensor> = Vec::new();
18
19    let _magic_number = read(&mut reader)?;
20    let num_images = read(&mut reader)?;
21    let num_rows = read(&mut reader)?;
22    let num_cols = read(&mut reader)?;
23
24    for _ in 0..num_images {
25        let mut image: Vec<Vec<f32>> = Vec::new();
26        for _ in 0..num_rows {
27            let mut row: Vec<f32> = Vec::new();
28            for _ in 0..num_cols {
29                let mut pixel = [0];
30                reader.read_exact(&mut pixel)?;
31                row.push(pixel[0] as f32 / 255.0);
32            }
33            image.push(row);
34        }
35        images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
36    }
37
38    Ok(images)
39}
40
41fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
42    let mut reader = BufReader::new(File::open(file_path)?);
43    let _magic_number = read(&mut reader)?;
44    let num_labels = read(&mut reader)?;
45
46    let mut _labels = vec![0; num_labels as usize];
47    reader.read_exact(&mut _labels)?;
48
49    Ok(_labels
50        .iter()
51        .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
52        .collect())
53}
54
55fn main() {
56    let _labels: HashMap<u8, &str> = [
57        (0, "top"),
58        (1, "trouser"),
59        (2, "pullover"),
60        (3, "dress"),
61        (4, "coat"),
62        (5, "sandal"),
63        (6, "shirt"),
64        (7, "sneaker"),
65        (8, "bag"),
66        (9, "ankle boot"),
67    ]
68    .iter()
69    .cloned()
70    .collect();
71    let x_train = load_mnist("./examples/datasets/mnist-fashion/train-images-idx3-ubyte").unwrap();
72    let y_train = load_labels(
73        "./examples/datasets/mnist-fashion/train-labels-idx1-ubyte",
74        10,
75    )
76    .unwrap();
77    let x_test = load_mnist("./examples/datasets/mnist-fashion/t10k-images-idx3-ubyte").unwrap();
78    let y_test = load_labels(
79        "./examples/datasets/mnist-fashion/t10k-labels-idx1-ubyte",
80        10,
81    )
82    .unwrap();
83    println!(
84        "Train: {} images, Test: {} images",
85        x_train.len(),
86        x_test.len()
87    );
88
89    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
90    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
91    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
92    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
93
94    let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
95
96    network.convolution(
97        1,
98        (3, 3),
99        (1, 1),
100        (1, 1),
101        (1, 1),
102        activation::Activation::ReLU,
103        None,
104    );
105    network.convolution(
106        1,
107        (3, 3),
108        (1, 1),
109        (1, 1),
110        (1, 1),
111        activation::Activation::ReLU,
112        None,
113    );
114    network.convolution(
115        1,
116        (3, 3),
117        (1, 1),
118        (1, 1),
119        (1, 1),
120        activation::Activation::ReLU,
121        None,
122    );
123    network.maxpool((2, 2), (2, 2));
124    network.dense(10, activation::Activation::Softmax, true, None);
125
126    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
127    network.set_objective(objective::Objective::CrossEntropy, None);
128
129    println!("{}", network);
130
131    // Train the network
132    let (train_loss, val_loss, val_acc) = network.learn(
133        &x_train,
134        &y_train,
135        Some((&x_test, &y_test, 10)),
136        32,
137        25,
138        Some(5),
139    );
140    plot::loss(
141        &train_loss,
142        &val_loss,
143        &val_acc,
144        "PLAIN : Fashion-MNIST",
145        "./output/mnist-fashion/plain.png",
146    );
147
148    // Validate the network
149    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
150    println!(
151        "Final validation accuracy: {:.2} % and loss: {:.5}",
152        val_acc * 100.0,
153        val_loss
154    );
155
156    // Use the network
157    let prediction = network.predict(x_test.get(0).unwrap());
158    println!(
159        "Prediction on input: Target: {}. Output: {}.",
160        y_test[0].argmax(),
161        prediction.argmax()
162    );
163
164    // let x = x_test.get(5).unwrap();
165    // let y = y_test.get(5).unwrap();
166    // plot::heatmap(
167    //     &x,
168    //     &format!("{}", &labels[&(y.argmax() as u8)]),
169    //     "./output/mnist-fashion/input.png",
170    // );
171
172    // Plot the pre- and post-activation heatmaps for each (image) layer.
173    // let (pre, post, _) = network.forward(x);
174    // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
175    //     let pre_title = format!("layer_{}_pre", i);
176    //     let post_title = format!("layer_{}_post", i);
177    //     let pre_file = format!("layer_{}_pre.png", i);
178    //     let post_file = format!("layer_{}_post.png", i);
179    //     plot::heatmap(&i_pre, &pre_title, &pre_file);
180    //     plot::heatmap(&i_post, &post_title, &post_file);
181    // }
182}