ai_dataloader/collate/default_collate/
ndarray.rs

1use super::super::Collate;
2use super::DefaultCollate;
3use ndarray::{stack, Array, ArrayBase, ArrayView, Axis, Dimension, RemoveAxis};
4
5impl<A, D> Collate<Array<A, D>> for DefaultCollate
6where
7    A: Clone,
8    D: Dimension,
9    D::Larger: RemoveAxis,
10{
11    type Output = Array<A, <D as Dimension>::Larger>;
12    fn collate(&self, batch: Vec<Array<A, D>>) -> Self::Output {
13        // Convert it to a `Vec` of view.
14        let vec_of_view: Vec<ArrayView<'_, A, D>> = batch.iter().map(ArrayBase::view).collect();
15        stack(Axis(0), vec_of_view.as_slice())
16            .expect("Make sure you're items from the dataset have the same shape.")
17    }
18}