mnist/
mnist.rs

1use brique::checkpoint::Checkpoint;
2use brique::layers::*;
3use brique::matrix::*;
4use brique::model::*;
5use brique::model_builder::ModelBuilder;
6use brique::optimizer::Optimizer;
7use brique::save_load::*;
8use brique::utils::*;
9
10fn main() {
11    training();
12}
13
14pub fn testing() {
15    println!("extracting mnist data...");
16    let labels: Matrix = extract_labels("t10k-labels.idx1-ubyte");
17    let mut images: Matrix = extract_images("t10k-images.idx3-ubyte");
18    println!("extraction done");
19
20    images.normalize();
21    println!("number of images {}", images.height);
22    println!("number of pixels in each image {}", images.width);
23
24    println!("loading pre-trained model...");
25    let mut model: Model = load_model("mnist_128x128".to_string()).unwrap();
26
27    println!("evaluating...");
28    let score = model.evaluate(&images, false);
29    let acc = model.accuracy(&score, &labels);
30
31    println!("acc : {}", acc);
32}
33
34pub fn training() {
35    println!("extracting mnist data...");
36    let labels: Matrix = extract_labels("train-labels.idx1-ubyte");
37    let mut images: Matrix = extract_images("train-images.idx3-ubyte");
38    println!("extraction done");
39
40    images.normalize();
41    println!("number of images {}", images.height);
42    println!("number of pixels in each image {}", images.width);
43
44    ModelBuilder::new()
45        .add_layer(Layer::init(28 * 28, 128, true))
46        .add_layer(Layer::init(128, 128, true))
47        .add_layer(Layer::init(128, 10, false))
48        .optimizer(Optimizer::Adam {
49            learning_step: 0.001,
50            beta1: 0.9,
51            beta2: 0.999,
52        })
53        .l2_reg(0.001)
54        .checkpoint(Checkpoint::ValAcc {
55            save_path: "mnist_128x128".to_string(),
56        })
57        .verbose(10, false)
58        .build_and_train(&images, &labels, 128, 10, 2000);
59}
60
61fn _print_a_number(labels: Matrix, images: Matrix, v: usize) {
62    println!("{}", labels.get(0, v));
63
64    for i in 0..28 * 28 {
65        if images.get(v, i) > 0.5 && images.get(v, i) < 1.0 {
66            if i % 28 == 0 {
67                print!("\n");
68            }
69
70            if images.get(v, i) > 0.5 && images.get(v, i) < 0.75 {
71                print!("-");
72            } else {
73                print!("*");
74            }
75        } else {
76            if i % 28 == 0 {
77                print!("\n");
78            }
79            print!("_");
80        }
81    }
82}