burn 0.3.0

BURN: Burn Unstoppable Rusty Neurons
Documentation
use super::{batcher::Batcher, BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy};
use burn_dataset::{transform::ShuffledDataset, Dataset};
use std::sync::Arc;

pub struct DataLoaderBuilder<I, O> {
    strategy: Option<Box<dyn BatchStrategy<I>>>,
    batcher: Arc<dyn Batcher<I, O>>,
    num_threads: Option<usize>,
    shuffle: Option<u64>,
}

impl<I, O> DataLoaderBuilder<I, O>
where
    I: Send + Sync + Clone + std::fmt::Debug + 'static,
    O: Send + Sync + Clone + std::fmt::Debug + 'static,
{
    pub fn new(batcher: Arc<dyn Batcher<I, O>>) -> Self {
        Self {
            batcher,
            strategy: None,
            num_threads: None,
            shuffle: None,
        }
    }

    pub fn batch_size(mut self, batch_size: usize) -> Self {
        self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
        self
    }

    pub fn shuffle(mut self, seed: u64) -> Self {
        self.shuffle = Some(seed);
        self
    }

    pub fn num_workers(mut self, num_workers: usize) -> Self {
        self.num_threads = Some(num_workers);
        self
    }

    pub fn build(self, dataset: Arc<dyn Dataset<I>>) -> Arc<dyn DataLoader<O>> {
        let dataset = match self.shuffle {
            Some(seed) => Arc::new(ShuffledDataset::with_seed(dataset, seed)),
            None => dataset,
        };
        let strategy = match self.strategy {
            Some(strategy) => strategy,
            None => Box::new(FixBatchStrategy::new(1)),
        };
        if let Some(num_threads) = self.num_threads {
            return Arc::new(BatchDataLoader::multi_thread(
                strategy,
                dataset,
                self.batcher,
                num_threads,
            ));
        }

        Arc::new(BatchDataLoader::new(strategy, dataset, self.batcher))
    }
}