1#![allow(dead_code)]
2
3mod test;
32
33#[cfg(any(
34 feature = "to_ndarray_015",
35 feature = "to_ndarray_014",
36 feature = "to_ndarray_013"
37))]
38pub(self) use ndarray::prelude::*;
39
40#[cfg(feature = "to_ndarray_013")]
41use ndarray_013 as ndarray;
42#[cfg(feature = "to_ndarray_014")]
43use ndarray_014 as ndarray;
44#[cfg(feature = "to_ndarray_015")]
45use ndarray_015 as ndarray;
46
47use std::error::Error;
48use std::io::Read;
49use std::path::Path;
50
51#[cfg(feature = "download")]
52mod download;
53#[cfg(feature = "download")]
55use crate::download::download_and_extract;
56#[cfg(feature = "download")]
57use std::fs::File;
58#[cfg(feature = "download")]
59use tar::Archive;
60
61pub struct CifarResult(pub Vec<u8>, pub Vec<u8>, pub Vec<u8>, pub Vec<u8>);
63
64#[derive(Debug)]
66pub struct Cifar10 {
67 base_path: String,
68 cifar_data_path: String,
69 encode_one_hot: bool,
70 training_bin_paths: Vec<String>,
71 testing_bin_paths: Vec<String>,
72 num_records_train: usize,
73 num_records_test: usize,
74 as_f32: bool,
75 normalize: bool,
76 download_and_extract: bool,
77 download_url: String,
78}
79
80impl Cifar10 {
81 pub fn default() -> Self {
83 Cifar10 {
84 base_path: "data/".into(),
85 cifar_data_path: "cifar-10-batches-bin/".into(),
86 encode_one_hot: true,
87 training_bin_paths: vec![
88 "data_batch_1.bin".into(),
89 "data_batch_2.bin".into(),
90 "data_batch_3.bin".into(),
91 "data_batch_4.bin".into(),
92 "data_batch_5.bin".into(),
93 ],
94 testing_bin_paths: vec!["test_batch.bin".into()],
95 num_records_train: 50_000,
96 num_records_test: 10_000,
97 as_f32: false,
98 normalize: false,
99 download_and_extract: false,
100 download_url: "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz".to_string(),
101 }
102 }
103
104 pub fn base_path(mut self, base_path: impl Into<String>) -> Self {
106 self.base_path = base_path.into();
107 self
108 }
109
110 pub fn cifar_data_path(mut self, cifar_data_path: impl Into<String>) -> Self {
112 self.cifar_data_path = cifar_data_path.into();
113 self
114 }
115
116 pub fn download_and_extract(mut self, download_and_extract: bool) -> Self {
118 self.download_and_extract = download_and_extract;
119 self
120 }
121
122 pub fn download_url(mut self, download_url: impl Into<String>) -> Self {
124 self.download_url = download_url.into();
125 self
126 }
127
128 pub fn encode_one_hot(mut self, encode_one_hot: bool) -> Self {
130 self.encode_one_hot = encode_one_hot;
131 self
132 }
133
134 pub fn training_bin_paths(mut self, training_bin_paths: Vec<String>) -> Self {
136 self.training_bin_paths = training_bin_paths;
137 self
138 }
139
140 pub fn testing_bin_paths(mut self, testing_bin_paths: Vec<String>) -> Self {
142 self.testing_bin_paths = testing_bin_paths;
143 self
144 }
145
146 pub fn num_records_train(mut self, num_records_train: usize) -> Self {
148 self.num_records_train = num_records_train;
149 self
150 }
151
152 pub fn num_records_test(mut self, num_records_test: usize) -> Self {
154 self.num_records_test = num_records_test;
155 self
156 }
157
158 pub fn build(self) -> Result<CifarResult, Box<dyn Error>> {
160 #[cfg(feature = "download")]
161 match self.download_and_extract {
162 false => (),
163 true => {
164 download_and_extract(self.download_url.clone(), self.base_path.clone())?;
165 }
166 }
167
168 let (train_data, train_labels) = get_data(&self, "train")?;
169 let (test_data, test_labels) = get_data(&self, "test")?;
170 Ok(CifarResult(
171 train_data,
172 train_labels,
173 test_data,
174 test_labels,
175 ))
176 }
177}
178
179fn get_data(config: &Cifar10, dataset: &str) -> Result<(Vec<u8>, Vec<u8>), Box<dyn Error>> {
180 let mut buffer: Vec<u8> = Vec::new();
181
182 let (bin_paths, num_records) = match dataset {
183 "train" => (config.training_bin_paths.clone(), config.num_records_train),
184 "test" => (config.testing_bin_paths.clone(), config.num_records_test),
185 _ => panic!("An unexpected value was passed for which dataset should be parsed"),
186 };
187
188 for bin in &bin_paths {
189 let full_cifar_path = Path::new(&config.base_path)
191 .join(&config.cifar_data_path)
192 .join(bin);
193 let mut f = std::fs::File::open(full_cifar_path)?;
196
197 let mut temp_buffer: Vec<u8> = Vec::new();
199 f.read_to_end(&mut temp_buffer)?;
200 buffer.extend(&temp_buffer);
201 }
206
207 let mut labels: Vec<u8> = match config.encode_one_hot {
208 false => vec![0; num_records],
209 true => vec![0; num_records * 10],
210 };
211 let mut data: Vec<u8> = Vec::with_capacity(num_records * 3072);
212
213 for num in 0..num_records {
214 let base = num * (3073);
216
217 let label = buffer[base];
218 if label > 9 {
220 panic!(
221 "Image {}: Label is {}, which is inconsistent with the CIFAR-10 scheme",
222 num, label
223 );
224 }
225
226 data.extend(&buffer[base + 1..=base + 3072]);
227
228 match config.encode_one_hot {
229 false => labels[num] = label as u8,
230 true => labels[(num * 10) + (label as usize)] = 1u8,
231 };
232 }
233
234 Ok((data, labels))
235}
236
237impl CifarResult {
238 #[cfg(any(
239 feature = "to_ndarray_015",
240 feature = "to_ndarray_014",
241 feature = "to_ndarray_013"
242 ))]
243 pub fn to_ndarray<T: std::convert::From<u8>>(
244 self,
245 ) -> Result<(Array4<T>, Array2<T>, Array4<T>, Array2<T>), Box<dyn Error>> {
246 let train_data: Array4<T> =
247 Array::from_shape_vec((50_000, 3, 32, 32), self.0)?.mapv(|x| x.into());
248 let train_labels: Array2<T> =
249 Array::from_shape_vec((50_000, 10), self.1)?.mapv(|x| x.into());
250 let test_data: Array4<T> =
251 Array::from_shape_vec((10_000, 3, 32, 32), self.2)?.mapv(|x| x.into());
252 let test_labels: Array2<T> =
253 Array::from_shape_vec((10_000, 10), self.3)?.mapv(|x| x.into());
254
255 Ok((train_data, train_labels, test_data, test_labels))
256 }
257}
258
259#[cfg(any(
260 feature = "to_ndarray_015",
261 feature = "to_ndarray_014",
262 feature = "to_ndarray_013"
263))]
264pub fn return_label_from_one_hot(one_hot: Array1<u8>) -> String {
265 if one_hot == array![1, 0, 0, 0, 0, 0, 0, 0, 0, 0] {
266 "airplane".to_string()
267 } else if one_hot == array![0, 1, 0, 0, 0, 0, 0, 0, 0, 0] {
268 "automobile".to_string()
269 } else if one_hot == array![0, 0, 1, 0, 0, 0, 0, 0, 0, 0] {
270 "bird".to_string()
271 } else if one_hot == array![0, 0, 0, 1, 0, 0, 0, 0, 0, 0] {
272 "cat".to_string()
273 } else if one_hot == array![0, 0, 0, 0, 1, 0, 0, 0, 0, 0] {
274 "deer".to_string()
275 } else if one_hot == array![0, 0, 0, 0, 0, 1, 0, 0, 0, 0] {
276 "dog".to_string()
277 } else if one_hot == array![0, 0, 0, 0, 0, 0, 1, 0, 0, 0] {
278 "frog".to_string()
279 } else if one_hot == array![0, 0, 0, 0, 0, 0, 0, 1, 0, 0] {
280 "horse".to_string()
281 } else if one_hot == array![0, 0, 0, 0, 0, 0, 0, 0, 1, 0] {
282 "ship".to_string()
283 } else if one_hot == array![0, 0, 0, 0, 0, 0, 0, 0, 0, 1] {
284 "truck".to_string()
285 } else {
286 format!("Error: no valid label could be assigned to {}", one_hot)
287 }
288}