use std::{path::PathBuf, sync::Mutex};
use flate2::read::GzDecoder;
use tar::Archive;
use crate::network::downloader;
use crate::vision::ImageFolderDataset;
const CIFAR10_URL: &str = "https://s3.amazonaws.com/fast-ai-sample/cifar10.tgz";
const CIFAR100_URL: &str = "https://s3.amazonaws.com/fast-ai-imageclas/cifar100.tgz";
#[derive(Debug, Clone, Copy)]
#[allow(dead_code)]
pub enum CifarType {
Cifar10,
Cifar100,
}
pub struct CifarDataset {
cifar_dir: PathBuf,
}
impl CifarDataset {
pub fn new(cifar_type: CifarType) -> Self {
Self {
cifar_dir: download(&cifar_type),
}
}
pub fn train(&self) -> ImageFolderDataset {
ImageFolderDataset::new_classification(self.cifar_dir.join("train")).unwrap()
}
pub fn test(&self) -> ImageFolderDataset {
ImageFolderDataset::new_classification(self.cifar_dir.join("test")).unwrap()
}
}
static DOWNLOAD_LOCK: Mutex<()> = Mutex::new(());
fn download(cifar_type: &CifarType) -> PathBuf {
let _lock = DOWNLOAD_LOCK.lock().unwrap();
let cache_dir = dirs::cache_dir()
.expect("Could not get cache directory")
.join("burn-dataset");
let cifar_dir = match cifar_type {
CifarType::Cifar10 => cache_dir.join("cifar10"),
CifarType::Cifar100 => cache_dir.join("cifar100"),
};
let url = match cifar_type {
CifarType::Cifar10 => CIFAR10_URL,
CifarType::Cifar100 => CIFAR100_URL,
};
let filename = match cifar_type {
CifarType::Cifar10 => "cifar10.tgz",
CifarType::Cifar100 => "cifar100.tgz",
};
if !cifar_dir.exists() {
let bytes = downloader::download_file_as_bytes(url, filename);
let gz_buffer = GzDecoder::new(&bytes[..]);
let mut archive = Archive::new(gz_buffer);
archive.unpack(cache_dir).unwrap();
}
cifar_dir
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Dataset, vision::Annotation};
const TRAINDATASET_LEN: usize = 50000;
const TESTDATASET_LEN: usize = 10000;
const CIFAR10_LABEL_MIN: usize = 0;
const CIFAR10_LABEL_MAX: usize = 9;
const CIFAR100_LABEL_MIN: usize = 0;
const CIFAR100_LABEL_MAX: usize = 99;
#[test]
fn test_cifar10_download() {
let cifar_dir = download(&CifarType::Cifar10);
assert!(cifar_dir.exists());
}
#[test]
fn test_cifar100_download() {
let cifar_dir = download(&CifarType::Cifar100);
assert!(cifar_dir.exists());
}
#[test]
fn test_cifar10_len() {
let dataset = CifarDataset::new(CifarType::Cifar10);
let train_dataset = dataset.train();
let test_dataset = dataset.test();
assert_eq!(train_dataset.len(), TRAINDATASET_LEN);
assert_eq!(test_dataset.len(), TESTDATASET_LEN);
}
#[test]
fn test_cifar100_len() {
let dataset = CifarDataset::new(CifarType::Cifar100);
let train_dataset = dataset.train();
let test_dataset = dataset.test();
assert_eq!(train_dataset.len(), TRAINDATASET_LEN);
assert_eq!(test_dataset.len(), TESTDATASET_LEN);
}
#[test]
fn test_cifar10_label_range() {
let dataset = CifarDataset::new(CifarType::Cifar10);
let test_dataset = dataset.test();
let (min, max) = get_label_range(&test_dataset);
assert_eq!(min, CIFAR10_LABEL_MIN);
assert_eq!(max, CIFAR10_LABEL_MAX);
}
#[test]
fn test_cifar100_label_range() {
let dataset = CifarDataset::new(CifarType::Cifar100);
let test_dataset = dataset.test();
let (min, max) = get_label_range(&test_dataset);
assert_eq!(min, CIFAR100_LABEL_MIN);
assert_eq!(max, CIFAR100_LABEL_MAX);
}
fn get_label_range(dataset: &ImageFolderDataset) -> (usize, usize) {
let labels: Vec<_> = dataset.iter().map(|item| item.annotation).collect();
let mut min = 128;
let mut max = 0;
for label in labels {
let index = match label {
Annotation::Label(index) => index,
_ => 0,
};
if index < min {
min = index;
}
if index > max {
max = index;
}
}
(min, max)
}
}