1use neurons::{activation, network, objective, optimizer, plot, tensor};
4
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::{BufReader, Read, Result};
8
9fn read(reader: &mut dyn Read) -> Result<u32> {
10 let mut buffer = [0; 4];
11 reader.read_exact(&mut buffer)?;
12 Ok(u32::from_be_bytes(buffer))
13}
14
15fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
16 let mut reader = BufReader::new(File::open(path)?);
17 let mut images: Vec<tensor::Tensor> = Vec::new();
18
19 let _magic_number = read(&mut reader)?;
20 let num_images = read(&mut reader)?;
21 let num_rows = read(&mut reader)?;
22 let num_cols = read(&mut reader)?;
23
24 for _ in 0..num_images {
25 let mut image: Vec<Vec<f32>> = Vec::new();
26 for _ in 0..num_rows {
27 let mut row: Vec<f32> = Vec::new();
28 for _ in 0..num_cols {
29 let mut pixel = [0];
30 reader.read_exact(&mut pixel)?;
31 row.push(pixel[0] as f32 / 255.0);
32 }
33 image.push(row);
34 }
35 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
36 }
37
38 Ok(images)
39}
40
41fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
42 let mut reader = BufReader::new(File::open(file_path)?);
43 let _magic_number = read(&mut reader)?;
44 let num_labels = read(&mut reader)?;
45
46 let mut _labels = vec![0; num_labels as usize];
47 reader.read_exact(&mut _labels)?;
48
49 Ok(_labels
50 .iter()
51 .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
52 .collect())
53}
54
55fn main() {
56 let _labels: HashMap<u8, &str> = [
57 (0, "top"),
58 (1, "trouser"),
59 (2, "pullover"),
60 (3, "dress"),
61 (4, "coat"),
62 (5, "sandal"),
63 (6, "shirt"),
64 (7, "sneaker"),
65 (8, "bag"),
66 (9, "ankle boot"),
67 ]
68 .iter()
69 .cloned()
70 .collect();
71 let x_train = load_mnist("./examples/datasets/mnist-fashion/train-images-idx3-ubyte").unwrap();
72 let y_train = load_labels(
73 "./examples/datasets/mnist-fashion/train-labels-idx1-ubyte",
74 10,
75 )
76 .unwrap();
77 let x_test = load_mnist("./examples/datasets/mnist-fashion/t10k-images-idx3-ubyte").unwrap();
78 let y_test = load_labels(
79 "./examples/datasets/mnist-fashion/t10k-labels-idx1-ubyte",
80 10,
81 )
82 .unwrap();
83 println!(
84 "Train: {} images, Test: {} images",
85 x_train.len(),
86 x_test.len()
87 );
88
89 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
90 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
91 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
92 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
93
94 let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
95
96 network.convolution(
97 1,
98 (3, 3),
99 (1, 1),
100 (1, 1),
101 (1, 1),
102 activation::Activation::ReLU,
103 None,
104 );
105 network.convolution(
106 1,
107 (3, 3),
108 (1, 1),
109 (1, 1),
110 (1, 1),
111 activation::Activation::ReLU,
112 None,
113 );
114 network.convolution(
115 1,
116 (3, 3),
117 (1, 1),
118 (1, 1),
119 (1, 1),
120 activation::Activation::ReLU,
121 None,
122 );
123 network.maxpool((2, 2), (2, 2));
124 network.dense(10, activation::Activation::Softmax, true, None);
125
126 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
127 network.set_objective(objective::Objective::CrossEntropy, None);
128
129 println!("{}", network);
130
131 let (train_loss, val_loss, val_acc) = network.learn(
133 &x_train,
134 &y_train,
135 Some((&x_test, &y_test, 10)),
136 32,
137 25,
138 Some(5),
139 );
140 plot::loss(
141 &train_loss,
142 &val_loss,
143 &val_acc,
144 "PLAIN : Fashion-MNIST",
145 "./output/mnist-fashion/plain.png",
146 );
147
148 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
150 println!(
151 "Final validation accuracy: {:.2} % and loss: {:.5}",
152 val_acc * 100.0,
153 val_loss
154 );
155
156 let prediction = network.predict(x_test.get(0).unwrap());
158 println!(
159 "Prediction on input: Target: {}. Output: {}.",
160 y_test[0].argmax(),
161 prediction.argmax()
162 );
163
164 }