Skip to main content

burn_dataset/transform/
composed.rs

1use crate::Dataset;
2
3/// Compose multiple datasets together to create a bigger one.
4#[derive(new)]
5pub struct ComposedDataset<D> {
6    datasets: Vec<D>,
7}
8
9impl<D, I> Dataset<I> for ComposedDataset<D>
10where
11    D: Dataset<I>,
12    I: Clone,
13{
14    fn get(&self, index: usize) -> Option<I> {
15        let mut current_index = 0;
16        for dataset in self.datasets.iter() {
17            if index < dataset.len() + current_index {
18                return dataset.get(index - current_index);
19            }
20            current_index += dataset.len();
21        }
22        None
23    }
24    fn len(&self) -> usize {
25        let mut total = 0;
26        for dataset in self.datasets.iter() {
27            total += dataset.len();
28        }
29        total
30    }
31}
32
33#[cfg(test)]
34mod tests {
35    use super::*;
36    use crate::FakeDataset;
37
38    #[test]
39    fn test_composed_dataset() {
40        let dataset1 = FakeDataset::<String>::new(10);
41        let dataset2 = FakeDataset::<String>::new(5);
42
43        let items1 = dataset1.iter().collect::<Vec<_>>();
44        let items2 = dataset2.iter().collect::<Vec<_>>();
45
46        let composed = ComposedDataset::new(vec![dataset1, dataset2]);
47
48        assert_eq!(composed.len(), 15);
49
50        let expected_items: Vec<String> = items1.iter().chain(items2.iter()).cloned().collect();
51
52        let items = composed.iter().collect::<Vec<_>>();
53
54        assert_eq!(items, expected_items);
55    }
56}