ai_dataloader/indexable/sampler/
batch_sampler.rs

1use super::{Sampler, SequentialSampler};
2use crate::Len;
3
4/// Wraps another sampler to yield a mini-batch of indices.
5/// # Arguments
6///
7/// * `sampler` - Base sampler.
8/// * `batch_size` - Size of mini-batch.
9/// * `drop_last` - If `true`, the sampler will drop the last batch if its size would be less than `batch_size`.
10///
11///
12/// # Examples:
13/// ```
14/// use ai_dataloader::sampler::SequentialSampler;
15/// use ai_dataloader::sampler::BatchSampler;
16///
17/// let dataset = vec![0, 1, 2, 3];
18/// let batch_sampler = BatchSampler {
19///     sampler: SequentialSampler {
20///     data_source_len: dataset.len(),
21///     },
22///     batch_size: 2,
23///     drop_last: false,
24/// };
25/// let mut iter = batch_sampler.iter();
26/// assert_eq!(iter.next(), Some(vec![0, 1]));
27/// assert_eq!(iter.next(), Some(vec![2, 3]));
28/// ```
29#[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)]
30pub struct BatchSampler<S = SequentialSampler> {
31    /// Base sampler.
32    pub sampler: S,
33    /// Size of mini batch.
34    pub batch_size: usize,
35    /// If `true`, the sampler will drop the last batch if
36    /// its size were less than `batch_size`.
37    pub drop_last: bool,
38}
39
40impl<S: Sampler> Len for BatchSampler<S> {
41    /// Returns the number of batch.
42    ///
43    /// If `drop_last` is set to false, even an incomplete batch will be counted.
44    fn len(&self) -> usize {
45        if self.drop_last {
46            self.sampler.len() / self.batch_size
47        } else {
48            (self.sampler.len() + self.batch_size - 1) / self.batch_size
49        }
50    }
51}
52impl<S: Sampler> BatchSampler<S> {
53    /// Return an iterator over the [`BatchSampler`].
54    pub fn iter(&self) -> BatchIterator<S::IntoIter> {
55        BatchIterator {
56            sampler: self.sampler.into_iter(),
57            batch_size: self.batch_size,
58            drop_last: self.drop_last,
59        }
60    }
61}
62
63impl<S: Sampler> IntoIterator for &BatchSampler<S> {
64    type IntoIter = BatchIterator<<S as IntoIterator>::IntoIter>;
65    type Item = Vec<usize>;
66    fn into_iter(self) -> Self::IntoIter {
67        self.iter()
68    }
69}
70
71/// An iterator for the batch. Yield a sequence of index at each iteration.
72#[derive(Debug)]
73pub struct BatchIterator<I>
74where
75    I: Iterator<Item = usize>,
76{
77    /// The underlying sampler.
78    sampler: I,
79    /// The size of one batch.
80    batch_size: usize,
81    /// Whether to drop the laste elements or not.
82    drop_last: bool,
83}
84
85impl<I> Iterator for BatchIterator<I>
86where
87    I: Iterator<Item = usize>,
88{
89    type Item = Vec<usize>;
90    fn next(&mut self) -> Option<Self::Item> {
91        let mut batch = Vec::with_capacity(self.batch_size);
92
93        // We can't use a classic for loop here because it will
94        // try to move the `&mut`.
95        let mut current_idx = self.sampler.next();
96        while let Some(idx) = current_idx {
97            batch.push(idx);
98            if batch.len() == self.batch_size {
99                return Some(batch);
100            }
101            current_idx = self.sampler.next();
102        }
103        if !batch.is_empty() && !self.drop_last {
104            return Some(batch);
105        }
106        None
107    }
108    fn size_hint(&self) -> (usize, Option<usize>) {
109        let (lower, _) = self.sampler.size_hint();
110        let lower = if self.drop_last {
111            lower / self.batch_size
112        } else {
113            (lower + self.batch_size - 1) / self.batch_size
114        };
115        (lower, Some(lower))
116    }
117}
118
119impl<I> ExactSizeIterator for BatchIterator<I> where I: Iterator<Item = usize> + ExactSizeIterator {}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn basics() {
127        let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
128        let batch_sampler = BatchSampler {
129            sampler: SequentialSampler {
130                data_source_len: dataset.len(),
131            },
132            batch_size: 3,
133            drop_last: false,
134        };
135        for (i, batch_indices) in batch_sampler.iter().enumerate() {
136            println!("Batch #{i} indices: {batch_indices:?}");
137        }
138        let mut iter = batch_sampler.iter();
139        assert_eq!(iter.next(), Some(vec![0, 1, 2]));
140        assert_eq!(iter.next(), Some(vec![3, 4, 5]));
141        assert_eq!(iter.next(), Some(vec![6, 7, 8]));
142    }
143    #[test]
144    fn batch_sampler() {
145        // TODO : test from pytorch, need to support custom batch sampler
146        let mut batches = Vec::new();
147        for i in (0..20).step_by(5) {
148            batches.push([i..i + 2]);
149            batches.push([i + 2..i + 5]);
150        }
151    }
152    #[test]
153    fn len() {
154        let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
155        let batch_sampler = BatchSampler {
156            sampler: SequentialSampler {
157                data_source_len: dataset.len(),
158            },
159            batch_size: 2,
160            drop_last: false,
161        };
162        assert_eq!(batch_sampler.len(), 5);
163        assert_eq!(batch_sampler.iter().len(), 5);
164
165        let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
166        let batch_sampler = BatchSampler {
167            sampler: SequentialSampler {
168                data_source_len: dataset.len(),
169            },
170            batch_size: 2,
171            drop_last: false,
172        };
173        assert_eq!(batch_sampler.len(), 6);
174        assert_eq!(batch_sampler.iter().len(), 6);
175
176        let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
177        let batch_sampler = BatchSampler {
178            sampler: SequentialSampler {
179                data_source_len: dataset.len(),
180            },
181            batch_size: 2,
182            drop_last: true,
183        };
184        assert_eq!(batch_sampler.len(), 5);
185        let mut iter = batch_sampler.iter();
186        assert_eq!(iter.len(), 5);
187        iter.next();
188        iter.next();
189        assert_eq!(iter.len(), 3);
190    }
191}