use crate::collate::{Collate, DefaultCollate};
use super::DataLoader;
#[must_use]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, Ord)]
pub struct Builder<D, C = DefaultCollate>
where
D: IntoIterator,
C: Collate<D::Item>,
{
dataset: D,
batch_size: usize,
drop_last: bool,
collate_fn: C,
shuffle: bool,
}
impl<D> Builder<D, DefaultCollate>
where
D: IntoIterator,
DefaultCollate: Collate<D::Item>,
{
pub fn new(dataset: D) -> Self {
Self {
dataset,
batch_size: 1,
drop_last: false,
collate_fn: DefaultCollate,
shuffle: false,
}
}
}
impl<D, C> Builder<D, C>
where
D: IntoIterator,
C: Collate<D::Item>,
{
pub fn shuffle(mut self) -> Builder<D, C> {
self.shuffle = true;
self
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn drop_last(mut self) -> Self {
self.drop_last = true;
self
}
pub fn collate_fn<CF>(self, collate_fn: CF) -> Builder<D, CF>
where
CF: Collate<D::Item>,
{
Builder {
dataset: self.dataset,
batch_size: self.batch_size,
drop_last: self.drop_last,
collate_fn,
shuffle: self.shuffle,
}
}
pub fn build(self) -> DataLoader<D, C> {
DataLoader {
dataset: self.dataset,
batch_size: self.batch_size,
drop_last: self.drop_last,
collate_fn: self.collate_fn,
shuffle: self.shuffle,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::collate::NoOpCollate;
#[test]
fn api() {
let _loader = Builder::new(vec![1, 2, 3, 4]).build();
let _loader = Builder::new(vec![1, 2, 3, 4]).shuffle().build();
let _loader = Builder::new(vec![1, 2, 3, 4]).batch_size(2).build();
let _loader = Builder::new(vec![1, 2, 3, 4])
.batch_size(2)
.drop_last()
.build();
let _loader = Builder::new(vec![1, 2, 3, 4])
.batch_size(2)
.drop_last()
.collate_fn(NoOpCollate)
.build();
let _loader = Builder::new(vec![1, 2, 3, 4])
.shuffle()
.batch_size(2)
.drop_last()
.collate_fn(NoOpCollate)
.build();
let _loader = Builder::new(vec![1, 2, 3, 4])
.shuffle()
.batch_size(2)
.drop_last()
.collate_fn(|x| x)
.build();
}
}