ai_dataloader/collate/torch_collate/
ndarray.rs1use super::super::Collate;
2use super::TorchCollate;
3use ndarray::{stack, Array, ArrayBase, ArrayView, Axis, Dimension, RemoveAxis};
4use tch::Tensor;
5
6impl<A, D> Collate<Array<A, D>> for TorchCollate
7where
8 A: Clone + tch::kind::Element,
9 D: Dimension,
10 D::Larger: RemoveAxis,
11{
12 type Output = Tensor;
13 fn collate(&self, batch: Vec<Array<A, D>>) -> Self::Output {
14 let vec_of_view: Vec<ArrayView<'_, A, D>> = batch.iter().map(ArrayBase::view).collect();
16 let array = stack(Axis(0), vec_of_view.as_slice())
18 .expect("Make sure you're items from the dataset have the same shape.");
19
20 let tensor = Tensor::from_slice(array.as_slice().unwrap());
21 #[allow(clippy::cast_possible_wrap)]
22 let shape = array
23 .shape()
24 .iter()
25 .map(|dim| *dim as i64)
26 .collect::<Vec<_>>();
27 tensor.reshape(shape)
28 }
29}
30
31#[cfg(test)]
32mod tests {
33 use super::*;
34 use ndarray::array;
35
36 #[test]
37 fn keep_dimension() {
38 let batch = TorchCollate.collate(vec![array![1, 2], array![3, 4]]);
39 assert_eq!(batch.dim(), 2);
40 batch.print();
41 }
42
43 #[test]
57 fn nested() {
58 }
61}