ai_dataloader/indexable/sampler/
batch_sampler.rs1use super::{Sampler, SequentialSampler};
2use crate::Len;
3
4#[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)]
30pub struct BatchSampler<S = SequentialSampler> {
31 pub sampler: S,
33 pub batch_size: usize,
35 pub drop_last: bool,
38}
39
40impl<S: Sampler> Len for BatchSampler<S> {
41 fn len(&self) -> usize {
45 if self.drop_last {
46 self.sampler.len() / self.batch_size
47 } else {
48 (self.sampler.len() + self.batch_size - 1) / self.batch_size
49 }
50 }
51}
52impl<S: Sampler> BatchSampler<S> {
53 pub fn iter(&self) -> BatchIterator<S::IntoIter> {
55 BatchIterator {
56 sampler: self.sampler.into_iter(),
57 batch_size: self.batch_size,
58 drop_last: self.drop_last,
59 }
60 }
61}
62
63impl<S: Sampler> IntoIterator for &BatchSampler<S> {
64 type IntoIter = BatchIterator<<S as IntoIterator>::IntoIter>;
65 type Item = Vec<usize>;
66 fn into_iter(self) -> Self::IntoIter {
67 self.iter()
68 }
69}
70
71#[derive(Debug)]
73pub struct BatchIterator<I>
74where
75 I: Iterator<Item = usize>,
76{
77 sampler: I,
79 batch_size: usize,
81 drop_last: bool,
83}
84
85impl<I> Iterator for BatchIterator<I>
86where
87 I: Iterator<Item = usize>,
88{
89 type Item = Vec<usize>;
90 fn next(&mut self) -> Option<Self::Item> {
91 let mut batch = Vec::with_capacity(self.batch_size);
92
93 let mut current_idx = self.sampler.next();
96 while let Some(idx) = current_idx {
97 batch.push(idx);
98 if batch.len() == self.batch_size {
99 return Some(batch);
100 }
101 current_idx = self.sampler.next();
102 }
103 if !batch.is_empty() && !self.drop_last {
104 return Some(batch);
105 }
106 None
107 }
108 fn size_hint(&self) -> (usize, Option<usize>) {
109 let (lower, _) = self.sampler.size_hint();
110 let lower = if self.drop_last {
111 lower / self.batch_size
112 } else {
113 (lower + self.batch_size - 1) / self.batch_size
114 };
115 (lower, Some(lower))
116 }
117}
118
119impl<I> ExactSizeIterator for BatchIterator<I> where I: Iterator<Item = usize> + ExactSizeIterator {}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[test]
126 fn basics() {
127 let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
128 let batch_sampler = BatchSampler {
129 sampler: SequentialSampler {
130 data_source_len: dataset.len(),
131 },
132 batch_size: 3,
133 drop_last: false,
134 };
135 for (i, batch_indices) in batch_sampler.iter().enumerate() {
136 println!("Batch #{i} indices: {batch_indices:?}");
137 }
138 let mut iter = batch_sampler.iter();
139 assert_eq!(iter.next(), Some(vec![0, 1, 2]));
140 assert_eq!(iter.next(), Some(vec![3, 4, 5]));
141 assert_eq!(iter.next(), Some(vec![6, 7, 8]));
142 }
143 #[test]
144 fn batch_sampler() {
145 let mut batches = Vec::new();
147 for i in (0..20).step_by(5) {
148 batches.push([i..i + 2]);
149 batches.push([i + 2..i + 5]);
150 }
151 }
152 #[test]
153 fn len() {
154 let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
155 let batch_sampler = BatchSampler {
156 sampler: SequentialSampler {
157 data_source_len: dataset.len(),
158 },
159 batch_size: 2,
160 drop_last: false,
161 };
162 assert_eq!(batch_sampler.len(), 5);
163 assert_eq!(batch_sampler.iter().len(), 5);
164
165 let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
166 let batch_sampler = BatchSampler {
167 sampler: SequentialSampler {
168 data_source_len: dataset.len(),
169 },
170 batch_size: 2,
171 drop_last: false,
172 };
173 assert_eq!(batch_sampler.len(), 6);
174 assert_eq!(batch_sampler.iter().len(), 6);
175
176 let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
177 let batch_sampler = BatchSampler {
178 sampler: SequentialSampler {
179 data_source_len: dataset.len(),
180 },
181 batch_size: 2,
182 drop_last: true,
183 };
184 assert_eq!(batch_sampler.len(), 5);
185 let mut iter = batch_sampler.iter();
186 assert_eq!(iter.len(), 5);
187 iter.next();
188 iter.next();
189 assert_eq!(iter.len(), 3);
190 }
191}