Skip to main content

mnist/
mnist.rs

1use meuron::NeuralNetwork;
2use meuron::activation::Sigmoid;
3use meuron::cost::MSE;
4use meuron::layer::DenseLayer;
5use ndarray::Array2;
6use std::fs::File;
7use std::io::{self, Read};
8use std::path::PathBuf;
9
10fn read_u32_from_file(file: &mut File) -> Result<u32, io::Error> {
11    let mut buf = [0u8; 4];
12    file.read_exact(&mut buf)?;
13    Ok(u32::from_be_bytes(buf))
14}
15
16fn load_mnist_data(
17    images_path: PathBuf,
18    labels_path: PathBuf,
19) -> Result<(Array2<f32>, Array2<f32>), io::Error> {
20    let mut image_file = File::open(images_path)?;
21    let mut label_file = File::open(labels_path)?;
22
23    let _magic_images = read_u32_from_file(&mut image_file)?;
24    let num_images = read_u32_from_file(&mut image_file)?;
25    let num_rows = read_u32_from_file(&mut image_file)?;
26    let num_cols = read_u32_from_file(&mut image_file)?;
27
28    let _magic_labels = read_u32_from_file(&mut label_file)?;
29    let num_labels = read_u32_from_file(&mut label_file)?;
30
31    assert_eq!(num_images, num_labels);
32
33    let mut image_data = vec![0u8; (num_images * num_rows * num_cols) as usize];
34    image_file.read_exact(&mut image_data)?;
35
36    let images = Array2::from_shape_vec(
37        (num_images as usize, (num_rows * num_cols) as usize),
38        image_data.into_iter().map(|x| x as f32 / 255.0).collect(),
39    )
40    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
41
42    let mut label_data = vec![0u8; num_labels as usize];
43    label_file.read_exact(&mut label_data)?;
44
45    let labels = Array2::from_shape_vec(
46        (num_labels as usize, 10),
47        label_data
48            .into_iter()
49            .flat_map(|label| {
50                let mut one_hot = vec![0.0; 10];
51                one_hot[label as usize] = 1.0;
52                one_hot
53            })
54            .collect(),
55    )
56    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
57
58    Ok((images, labels))
59}
60
61fn main() {
62    let model_path = "mnist_model.bin";
63
64    let mut nn = if PathBuf::from(model_path).exists() {
65        println!("Loading existing model...");
66        NeuralNetwork::load(model_path, MSE).expect("Failed to load model")
67    } else {
68        println!("Creating new model...");
69        let output_size = 10;
70        let input_size = 28 * 28;
71
72        let dense_layer_1 = DenseLayer::new(input_size, 128, Sigmoid);
73        let dense_layer_2 = DenseLayer::new(128, output_size, Sigmoid);
74
75        NeuralNetwork::new(vec![dense_layer_1, dense_layer_2], MSE)
76    };
77
78    let train_images_path = PathBuf::from("./train-images.idx3-ubyte");
79    let train_labels_path = PathBuf::from("./train-labels.idx1-ubyte");
80
81    let (images, labels) = match load_mnist_data(train_images_path, train_labels_path) {
82        Ok(data) => data,
83        Err(e) => {
84            eprintln!("Error loading MNIST data: {}", e);
85            return;
86        }
87    };
88
89    println!("Loaded {} training images", images.shape()[0]);
90
91    let learning_rate = 0.01;
92    let num_epochs = 10;
93    let batch_size = 32;
94
95    println!("\nTraining with batch size {}...", batch_size);
96    nn.train(&images, &labels, learning_rate, num_epochs, batch_size);
97
98    println!("\nSaving model to {}...", model_path);
99    nn.save(model_path).expect("Failed to save model");
100
101    let test_images_path = PathBuf::from("./t10k-images.idx3-ubyte");
102    let test_labels_path = PathBuf::from("./t10k-labels.idx1-ubyte");
103
104    let (test_images, test_labels) = match load_mnist_data(test_images_path, test_labels_path) {
105        Ok(data) => data,
106        Err(e) => {
107            eprintln!("Error loading test data: {}", e);
108            return;
109        }
110    };
111
112    let accuracy = nn.accuracy(&test_images, &test_labels);
113    println!("\nTest accuracy: {:.2}%", accuracy * 100.0);
114}