mlinrust/dataset/
dataloader.rs1use std::marker::PhantomData;
2
3use crate::{ndarray::NdArray, utils::RandGenerator};
4
5use super::{Dataset, TaskLabelType};
6
7
8pub struct Dataloader<T: TaskLabelType + Copy, E: DatasetBorrowTrait<T>> {
9 pub batch_size: usize,
10 pub shuffle: bool,
11 raw_dataset: E,
12 rng: RandGenerator,
13 phantom: PhantomData<T>,
14}
15
16pub struct BatchIterator<T: TaskLabelType + Copy> {
17 batch_features: Vec<NdArray>,
18 batch_labels: Vec<Vec<T>>,
19}
20
21impl<'a, T: TaskLabelType + Copy> Iterator for BatchIterator<T> {
22 type Item = (NdArray, Vec<T>);
23 fn next(&mut self) -> Option<Self::Item> {
24 if self.batch_features.is_empty() {
25 None
26 } else {
27 Some((self.batch_features.pop().unwrap(), self.batch_labels.pop().unwrap()))
28 }
29 }
30}
31
32pub trait DatasetBorrowTrait<T: TaskLabelType + Copy> {
34 fn total_len(&self) -> usize;
35
36 fn get_feature(&self, idx: usize) -> Vec<f32>;
37
38 fn get_label(&self, idx: usize) -> T;
39}
40
41macro_rules! write_dataset_to_dataloader_trait {
42 () => {
43 fn total_len(&self) -> usize {
44 self.len()
45 }
46
47 fn get_feature(&self, idx: usize) -> Vec<f32> {
48 self.features[idx].clone()
49 }
50
51 fn get_label(&self, idx: usize) -> T {
52 self.labels[idx]
53 }
54 }
55}
56
57impl<T: TaskLabelType + Copy> DatasetBorrowTrait<T> for Dataset<T> {
58 write_dataset_to_dataloader_trait!();
59}
60
61impl<T: TaskLabelType + Copy> DatasetBorrowTrait<T> for &Dataset<T> {
63 write_dataset_to_dataloader_trait!();
64}
65
66
67impl<T: TaskLabelType + Copy, E: DatasetBorrowTrait<T>> Dataloader<T, E> {
68 pub fn new(dataset: E, batch_size: usize, shuffle: bool, seed: Option<usize>) -> Self {
74 Self { batch_size: batch_size, shuffle: shuffle, raw_dataset: dataset, rng: RandGenerator::new(seed.unwrap_or(0)), phantom: PhantomData}
75 }
76
77 fn init_batches(&mut self) -> (Vec<NdArray>, Vec<Vec<T>>) {
78 let mut sampler: Vec<usize> = (0..self.raw_dataset.total_len()).collect();
79 if self.shuffle {
80 self.rng.shuffle(&mut sampler);
81 }
82
83 sampler.chunks(self.batch_size)
84 .fold((Vec::with_capacity(self.raw_dataset.total_len() / self.batch_size + 1), Vec::with_capacity(self.raw_dataset.total_len() / self.batch_size + 1)), |mut s, batch| {
85 let f: Vec<Vec<f32>> = batch.iter().map(|i| self.raw_dataset.get_feature(*i)).collect();
86 let f = NdArray::new(f);
87 let l: Vec<T> = batch.iter().map(|i| self.raw_dataset.get_label(*i)).collect();
88 s.0.push(f);
89 s.1.push(l);
90 s
91 })
92 }
93
94 pub fn iter_mut(&mut self) -> BatchIterator<T> {
96 let (features, labels) = self.init_batches();
97 BatchIterator {
98 batch_features: features,
99 batch_labels: labels,
100 }
101 }
102}
103
104
105#[cfg(test)]
106mod test {
107 use super::{Dataloader, Dataset};
108
109 #[test]
110 fn test_dataloader_iterator() {
111 let features = vec![vec![1.0, 2.0, 3.0]; 15];
112 let labels = (0..15).map(|i| i as f32).collect();
113 let dataset = Dataset::new(features, labels, None);
114 let mut dataloader = Dataloader::new(dataset, 4, true, None);
115 println!("epoch 1 ----------------------------");
116 for batch in dataloader.iter_mut() {
117 let (_, label) = batch;
118 println!("{label:?}");
119 }
120 println!("epoch 2 ----------------------------");
121 for batch in dataloader.iter_mut() {
122 let (_, label) = batch;
123 println!("{label:?}");
124 }
125 }
126}