cifar_ten/
lib.rs

1#![allow(dead_code)]
2
3//! This library parses the binary files of the CIFAR-10 data set and returns them as a tuple struct
4//! - `CifarResult`: `(Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>)` which is organized as `(train_data, train_labels, test_data, test_labels)`
5//!
6//! Convenience methods for converting these to the Rust `ndarray` numeric arrays are provided using the `to_ndarray` feature flag, as
7//! well as for automatically downloading binary training data from a remote url.  
8//!
9//! ```rust
10//! // $ cargo build --features=download,to_ndarray
11//! use cifar_ten::*;
12//!
13//! fn main() {
14//!     let (train_data, train_labels, test_data, test_labels) = Cifar10::default()
15//!         .download_and_extract(true)
16//!         .encode_one_hot(true)
17//!         .build()
18//!         .unwrap()
19//!         .to_ndarray::<f32>()
20//!         .expect("Failed to build CIFAR-10 data");
21//! }
22//! ```
23//!
24//! A `tar.gz` file with the original binaries can be found [here](https://www.cs.toronto.edu/~kriz/cifar.html). The crate's author also
25//! provides several ML data mirrors [here](https://cmoran.xyz/data/) which are used for running tests on this library. Please feel free to use,
26//! but should you expect to make heavy use of these files, please consider creating your own mirror.   
27//!
28//! If you'd like to verify that the correct images and labels are being provided, the `examples/preview_images.rs` file using `show-image` to
29//! preview a RGB representation of a given image with the corresponding one-hot formatted label.
30
31mod test;
32
33#[cfg(any(
34    feature = "to_ndarray_015",
35    feature = "to_ndarray_014",
36    feature = "to_ndarray_013"
37))]
38pub(self) use ndarray::prelude::*;
39
40#[cfg(feature = "to_ndarray_013")]
41use ndarray_013 as ndarray;
42#[cfg(feature = "to_ndarray_014")]
43use ndarray_014 as ndarray;
44#[cfg(feature = "to_ndarray_015")]
45use ndarray_015 as ndarray;
46
47use std::error::Error;
48use std::io::Read;
49use std::path::Path;
50
51#[cfg(feature = "download")]
52mod download;
53// Dependencies for download feature
54#[cfg(feature = "download")]
55use crate::download::download_and_extract;
56#[cfg(feature = "download")]
57use std::fs::File;
58#[cfg(feature = "download")]
59use tar::Archive;
60
61/// Primary data return, wrapper around tuple `(Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>)`
62pub struct CifarResult(pub Vec<u8>, pub Vec<u8>, pub Vec<u8>, pub Vec<u8>);
63
64/// Data structure used to specify where/how the CIFAR-10 binary data is parsed
65#[derive(Debug)]
66pub struct Cifar10 {
67    base_path: String,
68    cifar_data_path: String,
69    encode_one_hot: bool,
70    training_bin_paths: Vec<String>,
71    testing_bin_paths: Vec<String>,
72    num_records_train: usize,
73    num_records_test: usize,
74    as_f32: bool,
75    normalize: bool,
76    download_and_extract: bool,
77    download_url: String,
78}
79
80impl Cifar10 {
81    /// Returns the default struct, looking in the "./data/" directory with default binary names
82    pub fn default() -> Self {
83        Cifar10 {
84            base_path: "data/".into(),
85            cifar_data_path: "cifar-10-batches-bin/".into(),
86            encode_one_hot: true,
87            training_bin_paths: vec![
88                "data_batch_1.bin".into(),
89                "data_batch_2.bin".into(),
90                "data_batch_3.bin".into(),
91                "data_batch_4.bin".into(),
92                "data_batch_5.bin".into(),
93            ],
94            testing_bin_paths: vec!["test_batch.bin".into()],
95            num_records_train: 50_000,
96            num_records_test: 10_000,
97            as_f32: false,
98            normalize: false,
99            download_and_extract: false,
100            download_url: "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz".to_string(),
101        }
102    }
103
104    /// Manually set the base path
105    pub fn base_path(mut self, base_path: impl Into<String>) -> Self {
106        self.base_path = base_path.into();
107        self
108    }
109
110    /// Manually set the path for the CIFAR-10 data
111    pub fn cifar_data_path(mut self, cifar_data_path: impl Into<String>) -> Self {
112        self.cifar_data_path = cifar_data_path.into();
113        self
114    }
115
116    /// Download CIFAR-10 dataset and extract from compressed tarball
117    pub fn download_and_extract(mut self, download_and_extract: bool) -> Self {
118        self.download_and_extract = download_and_extract;
119        self
120    }
121
122    /// Choose a custom url from which to download the CIFAR-10 dataset
123    pub fn download_url(mut self, download_url: impl Into<String>) -> Self {
124        self.download_url = download_url.into();
125        self
126    }
127
128    /// Choose if the `labels` return is in one-hot format or not (default yes)
129    pub fn encode_one_hot(mut self, encode_one_hot: bool) -> Self {
130        self.encode_one_hot = encode_one_hot;
131        self
132    }
133
134    /// Manually set the path to the training data binaries
135    pub fn training_bin_paths(mut self, training_bin_paths: Vec<String>) -> Self {
136        self.training_bin_paths = training_bin_paths;
137        self
138    }
139
140    /// Manually set the path to the testing data binaries
141    pub fn testing_bin_paths(mut self, testing_bin_paths: Vec<String>) -> Self {
142        self.testing_bin_paths = testing_bin_paths;
143        self
144    }
145
146    /// Set the number of records in the training set (default 50_000)
147    pub fn num_records_train(mut self, num_records_train: usize) -> Self {
148        self.num_records_train = num_records_train;
149        self
150    }
151
152    /// Set the number of records in the training set (default 10_000)
153    pub fn num_records_test(mut self, num_records_test: usize) -> Self {
154        self.num_records_test = num_records_test;
155        self
156    }
157
158    /// Returns the array tuple using the specified options in Array4<T> form
159    pub fn build(self) -> Result<CifarResult, Box<dyn Error>> {
160        #[cfg(feature = "download")]
161        match self.download_and_extract {
162            false => (),
163            true => {
164                download_and_extract(self.download_url.clone(), self.base_path.clone())?;
165            }
166        }
167
168        let (train_data, train_labels) = get_data(&self, "train")?;
169        let (test_data, test_labels) = get_data(&self, "test")?;
170        Ok(CifarResult(
171            train_data,
172            train_labels,
173            test_data,
174            test_labels,
175        ))
176    }
177}
178
179fn get_data(config: &Cifar10, dataset: &str) -> Result<(Vec<u8>, Vec<u8>), Box<dyn Error>> {
180    let mut buffer: Vec<u8> = Vec::new();
181
182    let (bin_paths, num_records) = match dataset {
183        "train" => (config.training_bin_paths.clone(), config.num_records_train),
184        "test" => (config.testing_bin_paths.clone(), config.num_records_test),
185        _ => panic!("An unexpected value was passed for which dataset should be parsed"),
186    };
187
188    for bin in &bin_paths {
189        // let full_cifar_path = [config.base_path, config.cifar_data_path, bin.into()].join("");
190        let full_cifar_path = Path::new(&config.base_path)
191            .join(&config.cifar_data_path)
192            .join(bin);
193        // println!("{}", full_cifar_path);
194
195        let mut f = std::fs::File::open(full_cifar_path)?;
196
197        // read the whole file
198        let mut temp_buffer: Vec<u8> = Vec::new();
199        f.read_to_end(&mut temp_buffer)?;
200        buffer.extend(&temp_buffer);
201        //println!(
202        //    "{}",
203        //    format!("- Done parsing binary file {} to Vec<u8>", bin).as_str()
204        //);
205    }
206
207    let mut labels: Vec<u8> = match config.encode_one_hot {
208        false => vec![0; num_records],
209        true => vec![0; num_records * 10],
210    };
211    let mut data: Vec<u8> = Vec::with_capacity(num_records * 3072);
212
213    for num in 0..num_records {
214        // println!("Through image #{}/{}", num, num_records);
215        let base = num * (3073);
216
217        let label = buffer[base];
218        // dbg!(buffer[base]);
219        if label > 9 {
220            panic!(
221                "Image {}: Label is {}, which is inconsistent with the CIFAR-10 scheme",
222                num, label
223            );
224        }
225
226        data.extend(&buffer[base + 1..=base + 3072]);
227
228        match config.encode_one_hot {
229            false => labels[num] = label as u8,
230            true => labels[(num * 10) + (label as usize)] = 1u8,
231        };
232    }
233
234    Ok((data, labels))
235}
236
237impl CifarResult {
238    #[cfg(any(
239        feature = "to_ndarray_015",
240        feature = "to_ndarray_014",
241        feature = "to_ndarray_013"
242    ))]
243    pub fn to_ndarray<T: std::convert::From<u8>>(
244        self,
245    ) -> Result<(Array4<T>, Array2<T>, Array4<T>, Array2<T>), Box<dyn Error>> {
246        let train_data: Array4<T> =
247            Array::from_shape_vec((50_000, 3, 32, 32), self.0)?.mapv(|x| x.into());
248        let train_labels: Array2<T> =
249            Array::from_shape_vec((50_000, 10), self.1)?.mapv(|x| x.into());
250        let test_data: Array4<T> =
251            Array::from_shape_vec((10_000, 3, 32, 32), self.2)?.mapv(|x| x.into());
252        let test_labels: Array2<T> =
253            Array::from_shape_vec((10_000, 10), self.3)?.mapv(|x| x.into());
254
255        Ok((train_data, train_labels, test_data, test_labels))
256    }
257}
258
259#[cfg(any(
260    feature = "to_ndarray_015",
261    feature = "to_ndarray_014",
262    feature = "to_ndarray_013"
263))]
264pub fn return_label_from_one_hot(one_hot: Array1<u8>) -> String {
265    if one_hot == array![1, 0, 0, 0, 0, 0, 0, 0, 0, 0] {
266        "airplane".to_string()
267    } else if one_hot == array![0, 1, 0, 0, 0, 0, 0, 0, 0, 0] {
268        "automobile".to_string()
269    } else if one_hot == array![0, 0, 1, 0, 0, 0, 0, 0, 0, 0] {
270        "bird".to_string()
271    } else if one_hot == array![0, 0, 0, 1, 0, 0, 0, 0, 0, 0] {
272        "cat".to_string()
273    } else if one_hot == array![0, 0, 0, 0, 1, 0, 0, 0, 0, 0] {
274        "deer".to_string()
275    } else if one_hot == array![0, 0, 0, 0, 0, 1, 0, 0, 0, 0] {
276        "dog".to_string()
277    } else if one_hot == array![0, 0, 0, 0, 0, 0, 1, 0, 0, 0] {
278        "frog".to_string()
279    } else if one_hot == array![0, 0, 0, 0, 0, 0, 0, 1, 0, 0] {
280        "horse".to_string()
281    } else if one_hot == array![0, 0, 0, 0, 0, 0, 0, 0, 1, 0] {
282        "ship".to_string()
283    } else if one_hot == array![0, 0, 0, 0, 0, 0, 0, 0, 0, 1] {
284        "truck".to_string()
285    } else {
286        format!("Error: no valid label could be assigned to {}", one_hot)
287    }
288}