1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
extern crate libflate;
extern crate ndarray;
extern crate reqwest;
use libflate::gzip::Decoder;
use ndarray::prelude::*;
use std::fs;
use std::io;
const URL: &str = "http://yann.lecun.com/exdb/mnist/";
const PATH: &str = "./data/";
const TE_LBL: &str = "t10k-labels-idx1-ubyte";
const TE_IMG: &str = "t10k-images-idx3-ubyte";
const TR_LBL: &str = "train-labels-idx1-ubyte";
const TR_IMG: &str = "train-images-idx3-ubyte";
pub fn get_all() -> (Array2<f64>, Array2<f64>, Array2<f64>, Array2<f64>) {
let te_lbl = read_deflated_labels(maybe_read(TE_LBL)).unwrap();
let te_img = read_deflated_images(maybe_read(TE_IMG)).unwrap();
let tr_lbl = read_deflated_labels(maybe_read(TR_LBL)).unwrap();
let tr_img = read_deflated_images(maybe_read(TR_IMG)).unwrap();
(te_lbl, te_img, tr_lbl, tr_img)
}
pub fn clean_all_extracted() -> io::Result<()> {
fs::remove_file(PATH.to_owned() + TE_LBL)?;
fs::remove_file(PATH.to_owned() + TE_IMG)?;
fs::remove_file(PATH.to_owned() + TR_LBL)?;
fs::remove_file(PATH.to_owned() + TR_IMG)?;
Ok(())
}
pub fn clean_everything() -> io::Result<()> {
clean_all_extracted().unwrap_or(());
fs::remove_file(PATH.to_owned() + TE_LBL + ".gz")?;
fs::remove_file(PATH.to_owned() + TE_IMG + ".gz")?;
fs::remove_file(PATH.to_owned() + TR_LBL + ".gz")?;
fs::remove_file(PATH.to_owned() + TR_IMG + ".gz")?;
Ok(())
}
fn maybe_read(name: &str) -> Vec<u8> {
match fs::read(PATH.to_owned() + name) {
Ok(f) => f,
Err(_) => maybe_download(name),
}
}
fn uncompress_file(file: &mut fs::File, name: &str) -> Vec<u8> {
println!("UNZIPPING {}", name);
let mut decoder = Decoder::new(file).unwrap();
let mut unzipped =
fs::File::create(PATH.to_owned() + name).expect("failed to create unzipped file");
io::copy(&mut decoder, &mut unzipped).expect("failed to copy data to unzipped file");
fs::read(PATH.to_owned() + name)
.expect("failed to read data from upzipped file we just created")
}
fn maybe_download(name: &str) -> Vec<u8> {
match fs::File::open(PATH.to_owned() + name + ".gz") {
Ok(mut f) => uncompress_file(&mut f, name),
Err(_) => {
download(name);
uncompress_file(
&mut fs::File::open(PATH.to_owned() + name + ".gz")
.expect("file downloaded but can't read it"),
name,
)
}
}
}
fn download(name: &str) {
println!("DOWNLOADING {}", name);
let url = URL.to_owned() + name + ".gz";
let mut resp = reqwest::get(url.as_str()).expect("request failed");
let mut out = fs::File::create(PATH.to_owned() + name + ".gz").expect("failed to create file");
io::copy(&mut resp, &mut out).expect("failed to copy content");
}
fn read_deflated_labels(file: Vec<u8>) -> Result<Array2<f64>, std::io::Error> {
println!("PASSED with labels file {:?}", file.len());
let file: Vec<u8> = Vec::from(&file[8..]);
let labels = hot_ones(file);
Ok(labels)
}
fn read_deflated_images(file: Vec<u8>) -> Result<Array2<f64>, std::io::Error> {
println!("PASSED with images file {:?}", file.len());
let file: Vec<f64> = file[16..].iter().map(|&e| e as f64 / 256.).collect();
let images: Array2<f64> = Array::from_shape_vec((file.len() / 784, 784), file).unwrap();
Ok(images)
}
fn hot_ones(data: Vec<u8>) -> Array2<f64> {
let mut hot_vec: Vec<f64> = Vec::new();
for element in data {
for i in 0..10 {
if element == i {
hot_vec.push(1.);
} else {
hot_vec.push(0.);
}
}
}
Array::from_shape_vec((hot_vec.len() / 10, 10), hot_vec).unwrap()
}