use crate::{
collate::{Collate, DefaultCollate},
Dataset,
};
#[cfg(feature = "rayon")]
use crate::THREAD_POOL;
#[cfg(feature = "rayon")]
use rayon::iter::ParallelIterator;
#[cfg(feature = "rayon")]
use rayon::prelude::IntoParallelIterator;
pub(crate) trait Fetcher<D, C = DefaultCollate>
where
D: Dataset,
C: Collate<D::Sample>,
{
fn fetch(&self, possibly_batched_index: Vec<usize>) -> C::Output;
}
#[derive(Debug)]
pub(crate) struct MapDatasetFetcher<'dataset, D, C = DefaultCollate>
where
D: Dataset + Sync,
C: Collate<D::Sample>,
{
pub(crate) dataset: &'dataset D,
pub(crate) collate_fn: &'dataset C,
}
impl<'dataset, D, C> Fetcher<D, C> for MapDatasetFetcher<'dataset, D, C>
where
D: Dataset + Sync,
C: Collate<D::Sample>,
D::Sample: Send,
{
fn fetch(&self, possibly_batched_index: Vec<usize>) -> C::Output {
#[cfg(feature = "rayon")]
let data = THREAD_POOL
.get()
.expect("thread pool is initialized")
.install(|| {
possibly_batched_index
.into_par_iter()
.map(|idx| self.dataset.get_sample(idx))
.collect()
});
#[cfg(not(feature = "rayon"))]
let data = possibly_batched_index
.into_iter()
.map(|idx| self.dataset.get_sample(idx))
.collect();
self.collate_fn.collate(data)
}
}