use crate::{
collate::{Collate, DefaultCollate},
sampler::{BatchSampler, RandomSampler, Sampler, SequentialSampler},
Dataset,
};
#[cfg(feature = "rayon")]
use crate::THREAD_POOL;
use super::DataLoader;
#[must_use]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, Ord)]
pub struct Builder<D, S = SequentialSampler, C = DefaultCollate>
where
D: Dataset,
S: Sampler,
C: Collate<D::Sample>,
{
dataset: D,
batch_sampler: BatchSampler<S>,
collate_fn: C,
#[cfg(feature = "rayon")]
num_threads: usize,
}
impl<D> Builder<D, SequentialSampler, DefaultCollate>
where
D: Dataset,
DefaultCollate: Collate<D::Sample>,
{
pub fn new(dataset: D) -> Self {
#[cfg(feature = "rayon")]
let num_threads = std::thread::available_parallelism()
.unwrap_or(std::num::NonZeroUsize::new(1).unwrap())
.get();
let dataset_len = dataset.len();
Self {
dataset,
batch_sampler: BatchSampler {
sampler: SequentialSampler::new(dataset_len),
batch_size: 1,
drop_last: false,
},
collate_fn: DefaultCollate,
#[cfg(feature = "rayon")]
num_threads,
}
}
}
impl<D, S, C> Builder<D, S, C>
where
D: Dataset,
S: Sampler,
C: Collate<D::Sample>,
{
pub fn shuffle(self) -> Builder<D, RandomSampler, C> {
self.sampler::<RandomSampler>()
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_sampler.batch_size = batch_size;
self
}
#[cfg(feature = "rayon")]
pub fn num_threads(mut self, num_threads: usize) -> Self {
self.num_threads = num_threads;
self
}
pub fn drop_last(mut self) -> Self {
self.batch_sampler.drop_last = true;
self
}
pub fn collate_fn<CF>(self, collate_fn: CF) -> Builder<D, S, CF>
where
CF: Collate<D::Sample>,
{
Builder {
dataset: self.dataset,
batch_sampler: self.batch_sampler,
collate_fn,
#[cfg(feature = "rayon")]
num_threads: self.num_threads,
}
}
pub fn sampler<SA>(self) -> Builder<D, SA, C>
where
SA: Sampler,
{
let sampler: SA = SA::new(self.dataset.len());
Builder {
dataset: self.dataset,
batch_sampler: BatchSampler {
sampler,
batch_size: self.batch_sampler.batch_size,
drop_last: self.batch_sampler.drop_last,
},
collate_fn: self.collate_fn,
#[cfg(feature = "rayon")]
num_threads: self.num_threads,
}
}
pub fn build(self) -> DataLoader<D, S, C> {
#[cfg(feature = "rayon")]
if let Some(pool) = THREAD_POOL.get() {
if pool.current_num_threads() != self.num_threads {
#[cfg(feature = "rayon")]
THREAD_POOL
.set(
rayon::ThreadPoolBuilder::new()
.num_threads(self.num_threads)
.build()
.expect("could not spawn threads"),
)
.ok();
}
} else {
THREAD_POOL
.set(
rayon::ThreadPoolBuilder::new()
.num_threads(self.num_threads)
.build()
.expect("could not spawn threads"),
)
.ok();
}
DataLoader {
dataset: self.dataset,
batch_sampler: self.batch_sampler,
collate_fn: self.collate_fn,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::collate::NoOpCollate;
#[test]
fn api() {
let _loader = Builder::new(vec![1, 2, 3, 4]).build();
let _loader = Builder::new(vec![1, 2, 3, 4]).shuffle().build();
let _loader = Builder::new(vec![1, 2, 3, 4]).batch_size(2).build();
let _loader = Builder::new(vec![1, 2, 3, 4])
.batch_size(2)
.drop_last()
.build();
let _loader = Builder::new(vec![1, 2, 3, 4])
.batch_size(2)
.drop_last()
.collate_fn(NoOpCollate)
.build();
let _loader = Builder::new(vec![1, 2, 3, 4])
.batch_size(2)
.drop_last()
.sampler::<RandomSampler>()
.build();
let _loader = Builder::new(vec![1, 2, 3, 4])
.batch_size(2)
.drop_last()
.sampler::<RandomSampler>()
.collate_fn(NoOpCollate)
.build();
let _loader = Builder::new(vec![1, 2, 3, 4])
.shuffle()
.batch_size(2)
.drop_last()
.collate_fn(NoOpCollate)
.build();
let _loader = Builder::new(vec![1, 2, 3, 4])
.collate_fn(NoOpCollate)
.batch_size(2)
.build();
let _loader = Builder::new(vec![1, 2, 3, 4])
.collate_fn(|x| x)
.batch_size(2)
.build();
}
}