use super::{Sampler, SequentialSampler};
use crate::Len;
#[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)]
pub struct BatchSampler<S = SequentialSampler> {
pub sampler: S,
pub batch_size: usize,
pub drop_last: bool,
}
impl<S: Sampler> Len for BatchSampler<S> {
fn len(&self) -> usize {
if self.drop_last {
self.sampler.len() / self.batch_size
} else {
(self.sampler.len() + self.batch_size - 1) / self.batch_size
}
}
}
impl<S: Sampler> BatchSampler<S> {
pub fn iter(&self) -> BatchIterator<S::IntoIter> {
BatchIterator {
sampler: self.sampler.into_iter(),
batch_size: self.batch_size,
drop_last: self.drop_last,
}
}
}
impl<S: Sampler> IntoIterator for &BatchSampler<S> {
type IntoIter = BatchIterator<<S as IntoIterator>::IntoIter>;
type Item = Vec<usize>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[derive(Debug)]
pub struct BatchIterator<I>
where
I: Iterator<Item = usize>,
{
sampler: I,
batch_size: usize,
drop_last: bool,
}
impl<I> Iterator for BatchIterator<I>
where
I: Iterator<Item = usize>,
{
type Item = Vec<usize>;
fn next(&mut self) -> Option<Self::Item> {
let mut batch = Vec::with_capacity(self.batch_size);
let mut current_idx = self.sampler.next();
while let Some(idx) = current_idx {
batch.push(idx);
if batch.len() == self.batch_size {
return Some(batch);
}
current_idx = self.sampler.next();
}
if !batch.is_empty() && !self.drop_last {
return Some(batch);
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (lower, _) = self.sampler.size_hint();
let lower = if self.drop_last {
lower / self.batch_size
} else {
(lower + self.batch_size - 1) / self.batch_size
};
(lower, Some(lower))
}
}
impl<I> ExactSizeIterator for BatchIterator<I> where I: Iterator<Item = usize> + ExactSizeIterator {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basics() {
let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let batch_sampler = BatchSampler {
sampler: SequentialSampler {
data_source_len: dataset.len(),
},
batch_size: 3,
drop_last: false,
};
for (i, batch_indices) in batch_sampler.iter().enumerate() {
println!("Batch #{i} indices: {batch_indices:?}");
}
let mut iter = batch_sampler.iter();
assert_eq!(iter.next(), Some(vec![0, 1, 2]));
assert_eq!(iter.next(), Some(vec![3, 4, 5]));
assert_eq!(iter.next(), Some(vec![6, 7, 8]));
}
#[test]
fn batch_sampler() {
let mut batches = Vec::new();
for i in (0..20).step_by(5) {
batches.push([i..i + 2]);
batches.push([i + 2..i + 5]);
}
}
#[test]
fn len() {
let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let batch_sampler = BatchSampler {
sampler: SequentialSampler {
data_source_len: dataset.len(),
},
batch_size: 2,
drop_last: false,
};
assert_eq!(batch_sampler.len(), 5);
assert_eq!(batch_sampler.iter().len(), 5);
let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
let batch_sampler = BatchSampler {
sampler: SequentialSampler {
data_source_len: dataset.len(),
},
batch_size: 2,
drop_last: false,
};
assert_eq!(batch_sampler.len(), 6);
assert_eq!(batch_sampler.iter().len(), 6);
let dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
let batch_sampler = BatchSampler {
sampler: SequentialSampler {
data_source_len: dataset.len(),
},
batch_size: 2,
drop_last: true,
};
assert_eq!(batch_sampler.len(), 5);
let mut iter = batch_sampler.iter();
assert_eq!(iter.len(), 5);
iter.next();
iter.next();
assert_eq!(iter.len(), 3);
}
}