1use neurons::{activation, feedback, network, objective, optimizer, plot, tensor};
4
5use std::fs::File;
6use std::io::{BufReader, Read, Result};
7
8fn read(reader: &mut dyn Read) -> Result<u32> {
9 let mut buffer = [0; 4];
10 reader.read_exact(&mut buffer)?;
11 Ok(u32::from_be_bytes(buffer))
12}
13
14fn load_mnist(path: &str) -> Result<Vec<tensor::Tensor>> {
15 let mut reader = BufReader::new(File::open(path)?);
16 let mut images: Vec<tensor::Tensor> = Vec::new();
17
18 let _magic_number = read(&mut reader)?;
19 let num_images = read(&mut reader)?;
20 let num_rows = read(&mut reader)?;
21 let num_cols = read(&mut reader)?;
22
23 for _ in 0..num_images {
24 let mut image: Vec<Vec<f32>> = Vec::new();
25 for _ in 0..num_rows {
26 let mut row: Vec<f32> = Vec::new();
27 for _ in 0..num_cols {
28 let mut pixel = [0];
29 reader.read_exact(&mut pixel)?;
30 row.push(pixel[0] as f32 / 255.0);
31 }
32 image.push(row);
33 }
34 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
35 }
36
37 Ok(images)
38}
39
40fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
41 let mut reader = BufReader::new(File::open(file_path)?);
42 let _magic_number = read(&mut reader)?;
43 let num_labels = read(&mut reader)?;
44
45 let mut _labels = vec![0; num_labels as usize];
46 reader.read_exact(&mut _labels)?;
47
48 Ok(_labels
49 .iter()
50 .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
51 .collect())
52}
53
54fn main() {
55 let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56 let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57 let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58 let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59 println!(
60 "Train: {} images, Test: {} images",
61 x_train.len(),
62 x_test.len()
63 );
64
65 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70 let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72 network.convolution(
73 1,
74 (3, 3),
75 (1, 1),
76 (1, 1),
77 (1, 1),
78 activation::Activation::ReLU,
79 None,
80 );
81 network.feedback(
82 vec![feedback::Layer::Convolution(
83 1,
84 activation::Activation::ReLU,
85 (3, 3),
86 (1, 1),
87 (1, 1),
88 (1, 1),
89 None,
90 )],
91 3,
92 false,
93 false,
94 feedback::Accumulation::Mean,
95 );
96 network.convolution(
97 1,
98 (3, 3),
99 (1, 1),
100 (1, 1),
101 (1, 1),
102 activation::Activation::ReLU,
103 None,
104 );
105 network.maxpool((2, 2), (2, 2));
106 network.dense(10, activation::Activation::Softmax, true, None);
107
108 network.connect(1, 2);
110 network.set_accumulation(feedback::Accumulation::Add, feedback::Accumulation::Add);
111
112 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
113 network.set_objective(
114 objective::Objective::CrossEntropy, None, );
117
118 println!("{}", network);
119
120 let (train_loss, val_loss, val_acc) = network.learn(
122 &x_train,
123 &y_train,
124 Some((&x_test, &y_test, 10)),
125 32,
126 25,
127 Some(5),
128 );
129 plot::loss(
130 &train_loss,
131 &val_loss,
132 &val_acc,
133 "FEEDBACK : MNIST",
134 "./output/mnist/feedback.png",
135 );
136
137 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
139 println!(
140 "Final validation accuracy: {:.2} % and loss: {:.5}",
141 val_acc * 100.0,
142 val_loss
143 );
144
145 let prediction = network.predict(x_test.get(0).unwrap());
147 println!(
148 "Prediction on input: Target: {}. Output: {}.",
149 y_test[0].argmax(),
150 prediction.argmax()
151 );
152
153 }