phoenix_gui/data_sets/
mnist.rs

1use crate::data_sets::TestSet;
2use crate::matrix::Matrix;
3use image::RgbImage;
4
5
6extern crate image;
7
8use crate::neural_network::{NNConfig, NeuralNetwork};
9use image::{ImageBuffer, Rgb};
10
11use serde::{Deserialize, Serialize};
12
13#[derive(Deserialize, Serialize, Debug)]
14pub struct MNist {
15    loaded: bool,
16    pub train_input: Vec<Matrix>,
17    pub train_target: Vec<Matrix>,
18    pub test_input: Vec<Matrix>,
19    pub test_target: Vec<Matrix>,
20}
21
22impl Default for MNist {
23    fn default() -> Self {
24        Self {
25            loaded: false,
26            train_input: vec![],
27            train_target: vec![],
28            test_input: vec![],
29            test_target: vec![],
30        }
31    }
32}
33
34#[cfg(feature = "mnist")]
35static TRAIN_DATA: &str = include_str!("..\\resources/mnist/mnist_train.csv");
36#[cfg(feature = "mnist")]
37static TEST_DATA: &str = include_str!("..\\resources/mnist/mnist_test.csv");
38
39#[cfg(not(feature = "mnist"))]
40static TRAIN_DATA: &str = "Not available";
41#[cfg(not(feature = "mnist"))]
42
43static TEST_DATA: &str = "Not available";
44#[cfg(feature = "mnist")]
45static TEST_SET: &'static [u8; 32240033] = include_bytes!("..\\resources/mnist/mnist_data.bin");
46
47#[cfg(not(feature = "mnist"))]
48static TEST_SET: &'static [u8; 32240033] = &[0; 32240033];
49
50impl TestSet for MNist {
51    fn read(&mut self) {
52        if self.loaded {
53            return;
54        }
55        // let decompressed_data = match decompress(&TEST_SET.to_vec()) {
56        //     Ok(s) => s,
57        //     Err(e) => {
58        //         panic!("Error decompressing file: {:?}", e);
59        //     }
60        // };
61        let data_set: MNist = bincode::deserialize(TEST_SET).unwrap();
62        self.train_input = data_set.train_input;
63        self.train_target = data_set.train_target;
64        self.test_input = data_set.test_input;
65        self.test_target = data_set.test_target;
66        self.loaded = true;
67    }
68}
69
70impl MNist {
71    pub fn read_files(&mut self) {
72        // read data from train_data
73        println!("Reading train data...");
74        // skip first line
75        let mut lines = TRAIN_DATA.lines();
76        lines.next();
77        let (train_input, train_target) = Self::read_lines(lines);
78        self.train_input = train_input;
79        self.train_target = train_target;
80    }
81    fn read_lines(lines: core::str::Lines) -> (Vec<Matrix>, Vec<Matrix>) {
82        let mut input: Vec<Matrix> = vec![];
83        let mut target: Vec<Matrix> = vec![];
84        const LIMIT: usize = 1000000;
85        for (i, line) in lines.enumerate() {
86            let mut parts = line.split(",");
87            let mut target_matrix = Matrix::new(10, 1);
88            let mut input_matrix = Matrix::new(784, 1);
89            let value = parts.next().unwrap().parse::<usize>().unwrap();
90            target_matrix.set(value, 0, 1.0);
91            for i in 0..784 {
92                let value = parts.next().unwrap().parse::<f32>().unwrap();
93                input_matrix.set(i, 0, value / 255.0);
94            }
95            input.push(input_matrix);
96            target.push(target_matrix);
97            if i > LIMIT {
98                break;
99            }
100        }
101        (input, target)
102    }
103
104    pub fn print_data(data: Vec<Matrix>) {
105        println!("Getting image...");
106        let len = data.len() as f64;
107        let width: u32 = 28 * len.sqrt() as u32;
108        let height: u32 = 28 * len.sqrt() as u32;
109        let mut image: RgbImage = ImageBuffer::new(width, height);
110        let mut counter = 0;
111        for col in 0..len.sqrt() as u32 {
112            for row in 0..len.sqrt() as u32 {
113                for i in 0..28 {
114                    for j in 0..28 {
115                        if let Some(matrix) = data.get(counter) {
116                            let value = matrix.get((i * 28 + j) as usize, 0);
117                            let value = (value * 255.0) as u8;
118                            // add col and row to the i j variable
119                            *image.get_pixel_mut(i + col * 28, j + row * 28) =
120                                Rgb([value, value, value]);
121                        }
122                    }
123                }
124                counter += 1;
125            }
126        }
127        // *image.get_pixel_mut(5, 5) = image::Rgb([255, 255, 255]);
128        image.save("output.png").unwrap();
129    }
130
131    pub fn run(&mut self, config: NNConfig) {
132        if config.epochs == 0 {
133            return;
134        }
135        if !self.loaded {
136            let mut data_set = MNist::default();
137            // let t1 = std::time::Instant::now();
138            // data_set.read();
139            // println!("Time to read bin: {:?}", t1.elapsed());
140            let t1 = std::time::Instant::now();
141            data_set.read_files();
142            println!("Time to read files: {:?}", t1.elapsed());
143            self.train_input = data_set.train_input;
144            self.train_target = data_set.train_target;
145            // check if they contain NaNs
146            for matrix in self.train_input.iter() {
147                if matrix.contains_nan() {
148                    panic!("Train input contains NaNs");
149                }
150                if matrix.max() > 1.0 {
151                    panic!("Train input contains values greater than 1.0");
152                }
153            }
154            for matrix in self.train_target.iter() {
155                if matrix.contains_nan() {
156                    panic!("Train target contains NaNs");
157                }
158                if matrix.max() > 1.0 {
159                    panic!("Train input contains values greater than 1.0");
160                }
161            }
162        }
163        let mut nn = NeuralNetwork::new(config.layer_sizes);
164        nn.learning_rate = config.learning_rate;
165        // nn.command_receiver = config.command_receiver;
166        nn.command_sender = config.command_sender;
167        nn.update_interval = config.update_interval;
168        // todo: dont clone
169        nn.train_epochs_m(
170            self.train_input.clone(),
171            self.train_target.clone(),
172            config.batch_number,
173            config.epochs,
174        );
175    }
176}