1use meuron::NeuralNetwork;
2use meuron::activation::Sigmoid;
3use meuron::cost::MSE;
4use meuron::layer::DenseLayer;
5use ndarray::Array2;
6use std::fs::File;
7use std::io::{self, Read};
8use std::path::PathBuf;
9
10fn read_u32_from_file(file: &mut File) -> Result<u32, io::Error> {
11 let mut buf = [0u8; 4];
12 file.read_exact(&mut buf)?;
13 Ok(u32::from_be_bytes(buf))
14}
15
16fn load_mnist_data(
17 images_path: PathBuf,
18 labels_path: PathBuf,
19) -> Result<(Array2<f32>, Array2<f32>), io::Error> {
20 let mut image_file = File::open(images_path)?;
21 let mut label_file = File::open(labels_path)?;
22
23 let _magic_images = read_u32_from_file(&mut image_file)?;
24 let num_images = read_u32_from_file(&mut image_file)?;
25 let num_rows = read_u32_from_file(&mut image_file)?;
26 let num_cols = read_u32_from_file(&mut image_file)?;
27
28 let _magic_labels = read_u32_from_file(&mut label_file)?;
29 let num_labels = read_u32_from_file(&mut label_file)?;
30
31 assert_eq!(num_images, num_labels);
32
33 let mut image_data = vec![0u8; (num_images * num_rows * num_cols) as usize];
34 image_file.read_exact(&mut image_data)?;
35
36 let images = Array2::from_shape_vec(
37 (num_images as usize, (num_rows * num_cols) as usize),
38 image_data.into_iter().map(|x| x as f32 / 255.0).collect(),
39 )
40 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
41
42 let mut label_data = vec![0u8; num_labels as usize];
43 label_file.read_exact(&mut label_data)?;
44
45 let labels = Array2::from_shape_vec(
46 (num_labels as usize, 10),
47 label_data
48 .into_iter()
49 .flat_map(|label| {
50 let mut one_hot = vec![0.0; 10];
51 one_hot[label as usize] = 1.0;
52 one_hot
53 })
54 .collect(),
55 )
56 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
57
58 Ok((images, labels))
59}
60
61fn main() {
62 let model_path = "mnist_model.bin";
63
64 let mut nn = if PathBuf::from(model_path).exists() {
65 println!("Loading existing model...");
66 NeuralNetwork::load(model_path, MSE).expect("Failed to load model")
67 } else {
68 println!("Creating new model...");
69 let output_size = 10;
70 let input_size = 28 * 28;
71
72 let dense_layer_1 = DenseLayer::new(input_size, 128, Sigmoid);
73 let dense_layer_2 = DenseLayer::new(128, output_size, Sigmoid);
74
75 NeuralNetwork::new(vec![dense_layer_1, dense_layer_2], MSE)
76 };
77
78 let train_images_path = PathBuf::from("./train-images.idx3-ubyte");
79 let train_labels_path = PathBuf::from("./train-labels.idx1-ubyte");
80
81 let (images, labels) = match load_mnist_data(train_images_path, train_labels_path) {
82 Ok(data) => data,
83 Err(e) => {
84 eprintln!("Error loading MNIST data: {}", e);
85 return;
86 }
87 };
88
89 println!("Loaded {} training images", images.shape()[0]);
90
91 let learning_rate = 0.01;
92 let num_epochs = 10;
93 let batch_size = 32;
94
95 println!("\nTraining with batch size {}...", batch_size);
96 nn.train(&images, &labels, learning_rate, num_epochs, batch_size);
97
98 println!("\nSaving model to {}...", model_path);
99 nn.save(model_path).expect("Failed to save model");
100
101 let test_images_path = PathBuf::from("./t10k-images.idx3-ubyte");
102 let test_labels_path = PathBuf::from("./t10k-labels.idx1-ubyte");
103
104 let (test_images, test_labels) = match load_mnist_data(test_images_path, test_labels_path) {
105 Ok(data) => data,
106 Err(e) => {
107 eprintln!("Error loading test data: {}", e);
108 return;
109 }
110 };
111
112 let accuracy = nn.accuracy(&test_images, &test_labels);
113 println!("\nTest accuracy: {:.2}%", accuracy * 100.0);
114}