1use brique::checkpoint::Checkpoint;
2use brique::layers::*;
3use brique::matrix::*;
4use brique::model::*;
5use brique::model_builder::ModelBuilder;
6use brique::optimizer::Optimizer;
7use brique::save_load::*;
8use brique::utils::*;
9
10fn main() {
11 training();
12}
13
14pub fn testing() {
15 println!("extracting mnist data...");
16 let labels: Matrix = extract_labels("t10k-labels.idx1-ubyte");
17 let mut images: Matrix = extract_images("t10k-images.idx3-ubyte");
18 println!("extraction done");
19
20 images.normalize();
21 println!("number of images {}", images.height);
22 println!("number of pixels in each image {}", images.width);
23
24 println!("loading pre-trained model...");
25 let mut model: Model = load_model("mnist_128x128".to_string()).unwrap();
26
27 println!("evaluating...");
28 let score = model.evaluate(&images, false);
29 let acc = model.accuracy(&score, &labels);
30
31 println!("acc : {}", acc);
32}
33
34pub fn training() {
35 println!("extracting mnist data...");
36 let labels: Matrix = extract_labels("train-labels.idx1-ubyte");
37 let mut images: Matrix = extract_images("train-images.idx3-ubyte");
38 println!("extraction done");
39
40 images.normalize();
41 println!("number of images {}", images.height);
42 println!("number of pixels in each image {}", images.width);
43
44 ModelBuilder::new()
45 .add_layer(Layer::init(28 * 28, 128, true))
46 .add_layer(Layer::init(128, 128, true))
47 .add_layer(Layer::init(128, 10, false))
48 .optimizer(Optimizer::Adam {
49 learning_step: 0.001,
50 beta1: 0.9,
51 beta2: 0.999,
52 })
53 .l2_reg(0.001)
54 .checkpoint(Checkpoint::ValAcc {
55 save_path: "mnist_128x128".to_string(),
56 })
57 .verbose(10, false)
58 .build_and_train(&images, &labels, 128, 10, 2000);
59}
60
61fn _print_a_number(labels: Matrix, images: Matrix, v: usize) {
62 println!("{}", labels.get(0, v));
63
64 for i in 0..28 * 28 {
65 if images.get(v, i) > 0.5 && images.get(v, i) < 1.0 {
66 if i % 28 == 0 {
67 print!("\n");
68 }
69
70 if images.get(v, i) > 0.5 && images.get(v, i) < 0.75 {
71 print!("-");
72 } else {
73 print!("*");
74 }
75 } else {
76 if i % 28 == 0 {
77 print!("\n");
78 }
79 print!("_");
80 }
81 }
82}