use super::{
BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy, MultiThreadDataLoader,
batcher::Batcher,
};
use burn_dataset::Dataset;
use burn_tensor::backend::Backend;
use rand::{SeedableRng, rngs::StdRng};
use std::sync::Arc;
pub struct DataLoaderBuilder<B: Backend, I, O> {
strategy: Option<Box<dyn BatchStrategy<I>>>,
batcher: Arc<dyn Batcher<B, I, O>>,
num_threads: Option<usize>,
shuffle: Option<u64>,
device: Option<B::Device>,
}
impl<B, I, O> DataLoaderBuilder<B, I, O>
where
B: Backend,
I: Send + Sync + Clone + std::fmt::Debug + 'static,
O: Send + Clone + std::fmt::Debug + 'static,
{
pub fn new<Bt>(batcher: Bt) -> Self
where
Bt: Batcher<B, I, O> + 'static,
{
Self {
batcher: Arc::new(batcher),
strategy: None,
num_threads: None,
shuffle: None,
device: None,
}
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
self
}
pub fn shuffle(mut self, seed: u64) -> Self {
self.shuffle = Some(seed);
self
}
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_threads = Some(num_workers);
self
}
pub fn set_device(mut self, device: B::Device) -> Self {
self.device = Some(device);
self
}
pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<B, O>>
where
D: Dataset<I> + 'static,
{
let dataset = Arc::new(dataset);
let device = self.device.unwrap_or_default();
let rng = self.shuffle.map(StdRng::seed_from_u64);
let strategy = match self.strategy {
Some(strategy) => strategy,
None => Box::new(FixBatchStrategy::new(1)),
};
if let Some(num_threads) = self.num_threads
&& num_threads > 0
{
return Arc::new(MultiThreadDataLoader::new(
strategy,
dataset,
self.batcher,
num_threads,
device,
rng,
));
}
Arc::new(BatchDataLoader::new(
strategy,
dataset,
self.batcher,
device,
rng,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use crate::data::dataset::FakeDataset;
#[derive(new, Clone)]
struct TestBatcherDevice;
#[cfg(test)]
impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcherDevice {
fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {
*device
}
}
type TestDevice = <TestBackend as Backend>::Device;
#[test]
fn test_dataloader_no_workers() {
type TestDevice = <TestBackend as Backend>::Device;
let default_device = TestDevice::default();
let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
.batch_size(1)
.build(FakeDataset::<String>::new(9));
assert_eq!(dataloader.num_items(), 9);
for device in dataloader.iter() {
assert_eq!(device, default_device)
}
}
#[test]
fn test_dataloader_default_device() {
let default_device = TestDevice::default();
let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
.batch_size(1)
.num_workers(1)
.build(FakeDataset::<String>::new(9));
assert_eq!(dataloader.num_items(), 9);
for device in dataloader.iter() {
assert_eq!(device, default_device)
}
}
#[test]
fn test_dataloader_slice_multi_device() {
let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
.batch_size(1)
.num_workers(1)
.build(FakeDataset::<String>::new(11));
#[cfg(all(
test,
not(feature = "test-tch"),
not(feature = "test-wgpu"),
not(feature = "test-cuda")
))]
let (device1, device2) = (
burn_ndarray::NdArrayDevice::Cpu,
burn_ndarray::NdArrayDevice::Cpu,
);
#[cfg(all(test, feature = "test-tch"))]
let (device1, device2) = (
burn_tch::LibTorchDevice::Cuda(0),
burn_tch::LibTorchDevice::Cuda(1),
);
#[cfg(all(test, feature = "test-wgpu"))]
let (device1, device2) = (
burn_wgpu::WgpuDevice::DiscreteGpu(0),
burn_wgpu::WgpuDevice::DiscreteGpu(1),
);
#[cfg(all(test, feature = "test-cuda"))]
let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
assert_eq!(dataloader.num_items(), 11);
let dataloader_1 = dataloader.slice(0, 5).to_device(&device1);
let dataloader_2 = dataloader.slice(5, 11).to_device(&device2);
assert_eq!(dataloader_1.num_items(), 5);
assert_eq!(dataloader_2.num_items(), 6);
let (mut iterator_1, mut iterator_2) = (dataloader_1.iter(), dataloader_2.iter());
for _ in 0..5 {
assert_eq!(iterator_1.next(), Some(device1));
assert_eq!(iterator_2.next(), Some(device2));
}
assert_eq!(iterator_1.next(), None);
assert_eq!(iterator_2.next(), Some(device2));
assert_eq!(iterator_2.next(), None);
}
}