mlinrust/dataset/
dataloader.rs

1use std::marker::PhantomData;
2
3use crate::{ndarray::NdArray, utils::RandGenerator};
4
5use super::{Dataset, TaskLabelType};
6
7
8pub struct Dataloader<T: TaskLabelType + Copy, E: DatasetBorrowTrait<T>> {
9    pub batch_size: usize,
10    pub shuffle: bool,
11    raw_dataset: E,
12    rng: RandGenerator,
13    phantom: PhantomData<T>,
14}
15
16pub struct BatchIterator<T: TaskLabelType + Copy> {
17    batch_features: Vec<NdArray>,
18    batch_labels: Vec<Vec<T>>,
19}
20
21impl<'a, T: TaskLabelType + Copy> Iterator for BatchIterator<T> {
22    type Item = (NdArray, Vec<T>);
23    fn next(&mut self) -> Option<Self::Item> {
24        if self.batch_features.is_empty() {
25            None
26        } else {
27            Some((self.batch_features.pop().unwrap(), self.batch_labels.pop().unwrap()))
28        }
29    }
30}
31
32/// dataset, &dataset or &mut dataset
33pub trait DatasetBorrowTrait<T: TaskLabelType + Copy> {
34    fn total_len(&self) -> usize;
35
36    fn get_feature(&self, idx: usize) -> Vec<f32>;
37
38    fn get_label(&self, idx: usize) -> T;
39}
40
41macro_rules! write_dataset_to_dataloader_trait {
42    () => {
43        fn total_len(&self) -> usize {
44            self.len()
45        }
46    
47        fn get_feature(&self, idx: usize) -> Vec<f32> {
48            self.features[idx].clone()
49        }
50    
51        fn get_label(&self, idx: usize) -> T {
52            self.labels[idx]
53        }
54    }
55}
56
57impl<T: TaskLabelType + Copy> DatasetBorrowTrait<T> for Dataset<T> {
58    write_dataset_to_dataloader_trait!();
59}
60
61/// note that you should not expect to shuffle a evaluation dataset since evaluate often accepts a non mut reference, it cannot be shuffled
62impl<T: TaskLabelType + Copy> DatasetBorrowTrait<T> for &Dataset<T> {
63    write_dataset_to_dataloader_trait!();
64}
65
66
67impl<T: TaskLabelType + Copy, E: DatasetBorrowTrait<T>> Dataloader<T, E> {
68    /// convert a dataset to Dataloader for batch iterations
69    /// * E: Dataset or &Dataset for both f32(regression) and usize(classification)
70    /// * shuffle: whether shuffle the dataset for each Iteration of Dataloader
71    ///     * **WARNING**: you should not shuffle a evaluation dataset since evaluate often accepts a non mut reference, it will cause panic
72    /// * seed: default is set to 0
73    pub fn new(dataset: E, batch_size: usize, shuffle: bool, seed: Option<usize>) -> Self {
74        Self { batch_size: batch_size, shuffle: shuffle, raw_dataset: dataset, rng: RandGenerator::new(seed.unwrap_or(0)), phantom: PhantomData}
75    }
76
77    fn init_batches(&mut self) -> (Vec<NdArray>, Vec<Vec<T>>) {
78        let mut sampler: Vec<usize> = (0..self.raw_dataset.total_len()).collect();
79        if self.shuffle {
80            self.rng.shuffle(&mut sampler);
81        }
82        
83        sampler.chunks(self.batch_size)
84        .fold((Vec::with_capacity(self.raw_dataset.total_len() / self.batch_size + 1), Vec::with_capacity(self.raw_dataset.total_len() / self.batch_size + 1)), |mut s, batch| {
85            let f: Vec<Vec<f32>> = batch.iter().map(|i| self.raw_dataset.get_feature(*i)).collect();
86            let f = NdArray::new(f);
87            let l: Vec<T> = batch.iter().map(|i| self.raw_dataset.get_label(*i)).collect();
88            s.0.push(f);
89            s.1.push(l);
90            s
91        })
92    }
93
94    /// expicitly create a batch iterator if you don't want it into iter
95    pub fn iter_mut(&mut self) -> BatchIterator<T> {
96        let (features, labels) = self.init_batches();
97        BatchIterator {
98            batch_features: features, 
99            batch_labels: labels,
100        }
101    }
102}
103
104
105#[cfg(test)]
106mod test {
107    use super::{Dataloader, Dataset};
108
109    #[test]
110    fn test_dataloader_iterator() {
111        let features = vec![vec![1.0, 2.0, 3.0]; 15];
112        let labels = (0..15).map(|i| i as f32).collect();
113        let dataset = Dataset::new(features, labels, None);
114        let mut dataloader = Dataloader::new(dataset, 4, true, None);
115        println!("epoch 1 ----------------------------");
116        for batch in dataloader.iter_mut() {
117            let (_, label) = batch;
118            println!("{label:?}");
119        }
120        println!("epoch 2 ----------------------------");
121        for batch in dataloader.iter_mut() {
122            let (_, label) = batch;
123            println!("{label:?}");
124        }
125    }
126}