burn_dataset/transform/
partial.rs1use crate::Dataset;
2use std::{marker::PhantomData, sync::Arc};
3
4#[derive(new)]
6pub struct PartialDataset<D, I> {
7 dataset: D,
8 start_index: usize,
9 end_index: usize,
10 input: PhantomData<I>,
11}
12
13impl<D, I> PartialDataset<D, I>
14where
15 D: Dataset<I>,
16{
17 pub fn split(dataset: D, num: usize) -> Vec<PartialDataset<Arc<D>, I>> {
19 let dataset = Arc::new(dataset); let mut current = 0;
22 let mut datasets = Vec::with_capacity(num);
23
24 let batch_size = dataset.len() / num;
25
26 for i in 0..num {
27 let start = current;
28 let mut end = current + batch_size;
29
30 if i == (num - 1) {
31 end = dataset.len();
32 }
33
34 let dataset = PartialDataset::new(dataset.clone(), start, end);
35
36 current += batch_size;
37 datasets.push(dataset);
38 }
39
40 datasets
41 }
42}
43
44impl<D, I> Dataset<I> for PartialDataset<D, I>
45where
46 D: Dataset<I>,
47 I: Clone + Send + Sync,
48{
49 fn get(&self, index: usize) -> Option<I> {
50 let index = index + self.start_index;
51 if index < self.start_index || index >= self.end_index {
52 return None;
53 }
54 self.dataset.get(index)
55 }
56
57 fn len(&self) -> usize {
58 usize::min(self.end_index - self.start_index, self.dataset.len())
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65 use crate::FakeDataset;
66 use std::collections::HashSet;
67
68 #[test]
69 fn test_start_from_beginning() {
70 let dataset_original = FakeDataset::<String>::new(27);
71 let mut items_original_1 = HashSet::new();
72 let mut items_original_2 = HashSet::new();
73 let mut items_partial = HashSet::new();
74 dataset_original.iter().enumerate().for_each(|(i, item)| {
75 match i >= 10 {
76 true => items_original_2.insert(item),
77 false => items_original_1.insert(item),
78 };
79 });
80
81 let dataset_partial = PartialDataset::new(dataset_original, 0, 10);
82
83 for item in dataset_partial.iter() {
84 items_partial.insert(item);
85 }
86
87 assert_eq!(dataset_partial.len(), 10);
88 assert_eq!(items_original_1, items_partial);
89 for item in items_original_2 {
90 assert!(!items_partial.contains(&item));
91 }
92 }
93
94 #[test]
95 fn test_start_inside() {
96 let dataset_original = FakeDataset::<String>::new(27);
97 let mut items_original_1 = HashSet::new();
98 let mut items_original_2 = HashSet::new();
99 let mut items_partial = HashSet::new();
100
101 dataset_original.iter().enumerate().for_each(|(i, item)| {
102 match !(10..20).contains(&i) {
103 true => items_original_2.insert(item),
104 false => items_original_1.insert(item),
105 };
106 });
107
108 let dataset_partial = PartialDataset::new(dataset_original, 10, 20);
109 for item in dataset_partial.iter() {
110 items_partial.insert(item);
111 }
112
113 assert_eq!(dataset_partial.len(), 10);
114 assert_eq!(items_original_1, items_partial);
115 for item in items_original_2 {
116 assert!(!items_partial.contains(&item));
117 }
118 }
119
120 #[test]
121 fn test_split_contains_all_items_without_duplicates() {
122 let dataset_original = FakeDataset::<String>::new(27);
123 let mut items_original = Vec::new();
124 let mut items_partial = Vec::new();
125 for item in dataset_original.iter() {
126 items_original.push(item);
127 }
128
129 let dataset_partials = PartialDataset::split(dataset_original, 4);
130
131 for dataset in dataset_partials {
132 for item in dataset.iter() {
133 items_partial.push(item);
134 }
135 }
136
137 assert_eq!(items_original, items_partial);
138 }
139}