phoenix_gui/data_sets/
mnist.rs1use crate::data_sets::TestSet;
2use crate::matrix::Matrix;
3use image::RgbImage;
4
5
6extern crate image;
7
8use crate::neural_network::{NNConfig, NeuralNetwork};
9use image::{ImageBuffer, Rgb};
10
11use serde::{Deserialize, Serialize};
12
13#[derive(Deserialize, Serialize, Debug)]
14pub struct MNist {
15 loaded: bool,
16 pub train_input: Vec<Matrix>,
17 pub train_target: Vec<Matrix>,
18 pub test_input: Vec<Matrix>,
19 pub test_target: Vec<Matrix>,
20}
21
22impl Default for MNist {
23 fn default() -> Self {
24 Self {
25 loaded: false,
26 train_input: vec![],
27 train_target: vec![],
28 test_input: vec![],
29 test_target: vec![],
30 }
31 }
32}
33
34#[cfg(feature = "mnist")]
35static TRAIN_DATA: &str = include_str!("..\\resources/mnist/mnist_train.csv");
36#[cfg(feature = "mnist")]
37static TEST_DATA: &str = include_str!("..\\resources/mnist/mnist_test.csv");
38
39#[cfg(not(feature = "mnist"))]
40static TRAIN_DATA: &str = "Not available";
41#[cfg(not(feature = "mnist"))]
42
43static TEST_DATA: &str = "Not available";
44#[cfg(feature = "mnist")]
45static TEST_SET: &'static [u8; 32240033] = include_bytes!("..\\resources/mnist/mnist_data.bin");
46
47#[cfg(not(feature = "mnist"))]
48static TEST_SET: &'static [u8; 32240033] = &[0; 32240033];
49
50impl TestSet for MNist {
51 fn read(&mut self) {
52 if self.loaded {
53 return;
54 }
55 let data_set: MNist = bincode::deserialize(TEST_SET).unwrap();
62 self.train_input = data_set.train_input;
63 self.train_target = data_set.train_target;
64 self.test_input = data_set.test_input;
65 self.test_target = data_set.test_target;
66 self.loaded = true;
67 }
68}
69
70impl MNist {
71 pub fn read_files(&mut self) {
72 println!("Reading train data...");
74 let mut lines = TRAIN_DATA.lines();
76 lines.next();
77 let (train_input, train_target) = Self::read_lines(lines);
78 self.train_input = train_input;
79 self.train_target = train_target;
80 }
81 fn read_lines(lines: core::str::Lines) -> (Vec<Matrix>, Vec<Matrix>) {
82 let mut input: Vec<Matrix> = vec![];
83 let mut target: Vec<Matrix> = vec![];
84 const LIMIT: usize = 1000000;
85 for (i, line) in lines.enumerate() {
86 let mut parts = line.split(",");
87 let mut target_matrix = Matrix::new(10, 1);
88 let mut input_matrix = Matrix::new(784, 1);
89 let value = parts.next().unwrap().parse::<usize>().unwrap();
90 target_matrix.set(value, 0, 1.0);
91 for i in 0..784 {
92 let value = parts.next().unwrap().parse::<f32>().unwrap();
93 input_matrix.set(i, 0, value / 255.0);
94 }
95 input.push(input_matrix);
96 target.push(target_matrix);
97 if i > LIMIT {
98 break;
99 }
100 }
101 (input, target)
102 }
103
104 pub fn print_data(data: Vec<Matrix>) {
105 println!("Getting image...");
106 let len = data.len() as f64;
107 let width: u32 = 28 * len.sqrt() as u32;
108 let height: u32 = 28 * len.sqrt() as u32;
109 let mut image: RgbImage = ImageBuffer::new(width, height);
110 let mut counter = 0;
111 for col in 0..len.sqrt() as u32 {
112 for row in 0..len.sqrt() as u32 {
113 for i in 0..28 {
114 for j in 0..28 {
115 if let Some(matrix) = data.get(counter) {
116 let value = matrix.get((i * 28 + j) as usize, 0);
117 let value = (value * 255.0) as u8;
118 *image.get_pixel_mut(i + col * 28, j + row * 28) =
120 Rgb([value, value, value]);
121 }
122 }
123 }
124 counter += 1;
125 }
126 }
127 image.save("output.png").unwrap();
129 }
130
131 pub fn run(&mut self, config: NNConfig) {
132 if config.epochs == 0 {
133 return;
134 }
135 if !self.loaded {
136 let mut data_set = MNist::default();
137 let t1 = std::time::Instant::now();
141 data_set.read_files();
142 println!("Time to read files: {:?}", t1.elapsed());
143 self.train_input = data_set.train_input;
144 self.train_target = data_set.train_target;
145 for matrix in self.train_input.iter() {
147 if matrix.contains_nan() {
148 panic!("Train input contains NaNs");
149 }
150 if matrix.max() > 1.0 {
151 panic!("Train input contains values greater than 1.0");
152 }
153 }
154 for matrix in self.train_target.iter() {
155 if matrix.contains_nan() {
156 panic!("Train target contains NaNs");
157 }
158 if matrix.max() > 1.0 {
159 panic!("Train input contains values greater than 1.0");
160 }
161 }
162 }
163 let mut nn = NeuralNetwork::new(config.layer_sizes);
164 nn.learning_rate = config.learning_rate;
165 nn.command_sender = config.command_sender;
167 nn.update_interval = config.update_interval;
168 nn.train_epochs_m(
170 self.train_input.clone(),
171 self.train_target.clone(),
172 config.batch_number,
173 config.epochs,
174 );
175 }
176}