ai_dataloader/collate/default_collate/
ndarray.rs1use super::super::Collate;
2use super::DefaultCollate;
3use ndarray::{stack, Array, ArrayBase, ArrayView, Axis, Dimension, RemoveAxis};
4
5impl<A, D> Collate<Array<A, D>> for DefaultCollate
6where
7 A: Clone,
8 D: Dimension,
9 D::Larger: RemoveAxis,
10{
11 type Output = Array<A, <D as Dimension>::Larger>;
12 fn collate(&self, batch: Vec<Array<A, D>>) -> Self::Output {
13 let vec_of_view: Vec<ArrayView<'_, A, D>> = batch.iter().map(ArrayBase::view).collect();
15 stack(Axis(0), vec_of_view.as_slice())
16 .expect("Make sure you're items from the dataset have the same shape.")
17 }
18}