1extern crate neurons;
4
5use neurons::{activation, network, objective, optimizer, tensor};
6
7use std::fs::File;
8use std::io::{BufReader, Read, Result};
9use std::time;
10
11fn read(reader: &mut dyn Read) -> Result<u32> {
12 let mut buffer = [0; 4];
13 reader.read_exact(&mut buffer)?;
14 Ok(u32::from_be_bytes(buffer))
15}
16
17fn load_images(path: &str) -> Result<Vec<tensor::Tensor>> {
18 let mut reader = BufReader::new(File::open(path)?);
19 let mut images: Vec<tensor::Tensor> = Vec::new();
20
21 let _magic_number = read(&mut reader)?;
22 let num_images = read(&mut reader)?;
23 let num_rows = read(&mut reader)?;
24 let num_cols = read(&mut reader)?;
25
26 for _ in 0..num_images {
27 let mut image: Vec<Vec<f32>> = Vec::new();
28 for _ in 0..num_rows {
29 let mut row: Vec<f32> = Vec::new();
30 for _ in 0..num_cols {
31 let mut pixel = [0];
32 reader.read_exact(&mut pixel)?;
33 row.push(pixel[0] as f32 / 255.0);
34 }
35 image.push(row);
36 }
37 images.push(tensor::Tensor::triple(vec![image]).resize(tensor::Shape::Triple(1, 14, 14)));
38 }
39
40 Ok(images)
41}
42
43fn load_labels(file_path: &str, numbers: usize) -> Result<Vec<tensor::Tensor>> {
44 let mut reader = BufReader::new(File::open(file_path)?);
45 let _magic_number = read(&mut reader)?;
46 let num_labels = read(&mut reader)?;
47
48 let mut _labels = vec![0; num_labels as usize];
49 reader.read_exact(&mut _labels)?;
50
51 Ok(_labels
52 .iter()
53 .map(|&x| tensor::Tensor::one_hot(x as usize, numbers))
54 .collect())
55}
56
57fn main() {
58 let x_train = load_images("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
59 let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
60
61 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
62 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
63
64 let mut times: Vec<time::Duration> = Vec::new();
65
66 for iteration in 0..10 {
67 let start = time::Instant::now();
68
69 let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72 network.convolution(
73 8,
74 (3, 3),
75 (1, 1),
76 (0, 0),
77 (1, 1),
78 activation::Activation::ReLU,
79 Some(0.05),
80 );
81 network.maxpool((2, 2), (2, 2));
82 network.dense(10, activation::Activation::Softmax, true, None);
83
84 network.set_optimizer(optimizer::Adam::create(
85 0.001, 0.9, 0.999, 1e-8, Some(0.01), ));
91 network.set_objective(
92 objective::Objective::CrossEntropy, None, );
95
96 let (train_loss, _, _) = network.learn(&x_train, &y_train, None, 128, 10, None);
98
99 println!("Iteration: {}, Loss: {:?}", iteration, train_loss);
100
101 let duration = start.elapsed();
102 times.push(duration);
103 }
104
105 let sum: time::Duration = times.iter().sum();
106 let avg = sum / times.len() as u32;
107 println!("Average time: {:?}", avg);
108}