Skip to main content

mnist/
mnist.rs

1use std::path::Path;
2
3use newron::dataset::Dataset;
4use newron::layers::LayerEnum::*;
5use newron::loss::{categorical_entropy::CategoricalEntropy};
6use newron::metrics::Metric;
7use newron::sequential::Sequential;
8use newron::optimizers::sgd::SGD;
9
10fn main() {
11    // Path to a folder containing the 4 files :
12    // 1/ train-images-idx3-ubyte
13    // 2/ train-labels-idx1-ubyte
14    // 3/ t10k-images-idx3-ubyte
15    // 4/ t10k-labels-idx1-ubyte
16    let path = Path::new("datasets/fashion_mnist/");
17
18    let dataset = Dataset::from_ubyte(path).unwrap();
19    println!("{:?}", dataset);
20
21    let mut model = Sequential::new();
22    model.set_seed(99);
23
24    model.add(Dense {
25        input_units: dataset.get_number_features(),
26        output_units: 256
27    });
28
29    model.add(Dropout {prob: 0.2});
30    
31    model.add(ReLU);
32
33    model.add(Dense {
34        input_units: 256,
35        output_units: dataset.get_number_targets()
36    });
37
38    model.compile(CategoricalEntropy{},
39              SGD::new(0.2),
40              vec![Metric::Accuracy]);
41
42    model.summary();
43
44    model.fit(&dataset, 10, true);
45}