1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
use crate::source::huggingface::downloader::HuggingfaceDatasetLoader;
use crate::{Dataset, InMemDataset};
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct MNISTItem {
    pub image: [[f32; 28]; 28],
    pub label: usize,
}

pub struct MNISTDataset {
    dataset: InMemDataset<MNISTItem>,
}

impl Dataset<MNISTItem> for MNISTDataset {
    fn get(&self, index: usize) -> Option<MNISTItem> {
        self.dataset.get(index)
    }

    fn len(&self) -> usize {
        self.dataset.len()
    }
}

impl MNISTDataset {
    pub fn train() -> Self {
        Self::new("train")
    }

    pub fn test() -> Self {
        Self::new("test")
    }

    fn new(split: &str) -> Self {
        let dataset = HuggingfaceDatasetLoader::new("mnist", split)
            .extract_image("image")
            .extract_number("label")
            .deps(&["pillow", "numpy"])
            .load_in_memory()
            .unwrap();

        Self { dataset }
    }
}