ai_dataloader/collate/torch_collate/
ndarray.rs

1use super::super::Collate;
2use super::TorchCollate;
3use ndarray::{stack, Array, ArrayBase, ArrayView, Axis, Dimension, RemoveAxis};
4use tch::Tensor;
5
6impl<A, D> Collate<Array<A, D>> for TorchCollate
7where
8    A: Clone + tch::kind::Element,
9    D: Dimension,
10    D::Larger: RemoveAxis,
11{
12    type Output = Tensor;
13    fn collate(&self, batch: Vec<Array<A, D>>) -> Self::Output {
14        // Convert it to a `Vec` of view.
15        let vec_of_view: Vec<ArrayView<'_, A, D>> = batch.iter().map(ArrayBase::view).collect();
16        // TODO: maybe use tensor stack here
17        let array = stack(Axis(0), vec_of_view.as_slice())
18            .expect("Make sure you're items from the dataset have the same shape.");
19
20        let tensor = Tensor::from_slice(array.as_slice().unwrap());
21        #[allow(clippy::cast_possible_wrap)]
22        let shape = array
23            .shape()
24            .iter()
25            .map(|dim| *dim as i64)
26            .collect::<Vec<_>>();
27        tensor.reshape(shape)
28    }
29}
30
31#[cfg(test)]
32mod tests {
33    use super::*;
34    use ndarray::array;
35
36    #[test]
37    fn keep_dimension() {
38        let batch = TorchCollate.collate(vec![array![1, 2], array![3, 4]]);
39        assert_eq!(batch.dim(), 2);
40        batch.print();
41    }
42
43    // #[test]
44    // fn foo() {
45    //     println!("has_cuda: {}", tch::utils::has_cuda());
46
47    //     let array = vec![0; 1_000_000];
48    //     let array = Array::from_vec(array);
49    //     for i in 1..1_000_000 {
50    //         let t = TorchCollate::default().collate(vec![&array]);
51    //         println!("{} {:?}", i, t.size())
52    //     }
53    //     assert!(false);
54    // }
55
56    #[test]
57    fn nested() {
58        // If a type is an ndarray it get converted into tensor. But what if this tensor needs to be collated again?.
59        // Look at the supported type if this case can happen.
60    }
61}