ai_dataloader/indexable/sampler/
batch_sampler.rs1use super::{Sampler, SequentialSampler};
2use crate::Len;
3
4#[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)]
30pub struct BatchSampler<S = SequentialSampler> {
31    pub sampler: S,
33    pub batch_size: usize,
35    pub drop_last: bool,
38}
39
40impl<S: Sampler> Len for BatchSampler<S> {
41    fn len(&self) -> usize {
45        if self.drop_last {
46            self.sampler.len() / self.batch_size
47        } else {
48            (self.sampler.len() + self.batch_size - 1) / self.batch_size
49        }
50    }
51}
52impl<S: Sampler> BatchSampler<S> {
53    pub fn iter(&self) -> BatchIterator<S::IntoIter> {
55        BatchIterator {
56            sampler: self.sampler.into_iter(),
57            batch_size: self.batch_size,
58            drop_last: self.drop_last,
59        }
60    }
61}
62
63impl<S: Sampler> IntoIterator for &BatchSampler<S> {
64    type IntoIter = BatchIterator<<S as IntoIterator>::IntoIter>;
65    type Item = Vec<usize>;
66    fn into_iter(self) -> Self::IntoIter {
67        self.iter()
68    }
69}
70
71#[derive(Debug)]
73pub struct BatchIterator<I>
74where
75    I: Iterator<Item = usize>,
76{
77    sampler: I,
79    batch_size: usize,
81    drop_last: bool,
83}
84
85impl<I> Iterator for BatchIterator<I>
86where
87    I: Iterator<Item = usize>,
88{
89    type Item = Vec<usize>;
90    fn next(&mut self) -> Option<Self::Item> {
91        let mut batch = Vec::with_capacity(self.batch_size);
92
93        let mut current_idx = self.sampler.next();
96        while let Some(idx) = current_idx {
97            batch.push(idx);
98            if batch.len() == self.batch_size {
99                return Some(batch);
100            }
101            current_idx = self.sampler.next();
102        }
103        if !batch.is_empty() && !self.drop_last {
104            return Some(batch);
105        }
106        None
107    }
108    fn size_hint(&self) -> (usize, Option<usize>) {
109        let (lower, _) = self.sampler.size_hint();
110        let lower = if self.drop_last {
111            lower / self.batch_size
112        } else {
113            (lower + self.batch_size - 1) / self.batch_size
114        };
115        (lower, Some(lower))
116    }
117}
118
119impl<I> ExactSizeIterator for BatchIterator<I> where I: Iterator<Item = usize> + ExactSizeIterator {}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn basics() {
127        let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
128        let batch_sampler = BatchSampler {
129            sampler: SequentialSampler {
130                data_source_len: dataset.len(),
131            },
132            batch_size: 3,
133            drop_last: false,
134        };
135        for (i, batch_indices) in batch_sampler.iter().enumerate() {
136            println!("Batch #{i} indices: {batch_indices:?}");
137        }
138        let mut iter = batch_sampler.iter();
139        assert_eq!(iter.next(), Some(vec![0, 1, 2]));
140        assert_eq!(iter.next(), Some(vec![3, 4, 5]));
141        assert_eq!(iter.next(), Some(vec![6, 7, 8]));
142    }
143    #[test]
144    fn batch_sampler() {
145        let mut batches = Vec::new();
147        for i in (0..20).step_by(5) {
148            batches.push([i..i + 2]);
149            batches.push([i + 2..i + 5]);
150        }
151    }
152    #[test]
153    fn len() {
154        let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
155        let batch_sampler = BatchSampler {
156            sampler: SequentialSampler {
157                data_source_len: dataset.len(),
158            },
159            batch_size: 2,
160            drop_last: false,
161        };
162        assert_eq!(batch_sampler.len(), 5);
163        assert_eq!(batch_sampler.iter().len(), 5);
164
165        let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
166        let batch_sampler = BatchSampler {
167            sampler: SequentialSampler {
168                data_source_len: dataset.len(),
169            },
170            batch_size: 2,
171            drop_last: false,
172        };
173        assert_eq!(batch_sampler.len(), 6);
174        assert_eq!(batch_sampler.iter().len(), 6);
175
176        let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
177        let batch_sampler = BatchSampler {
178            sampler: SequentialSampler {
179                data_source_len: dataset.len(),
180            },
181            batch_size: 2,
182            drop_last: true,
183        };
184        assert_eq!(batch_sampler.len(), 5);
185        let mut iter = batch_sampler.iter();
186        assert_eq!(iter.len(), 5);
187        iter.next();
188        iter.next();
189        assert_eq!(iter.len(), 3);
190    }
191}