ai_dataloader/collate/torch_collate/
tuple.rs

1use super::super::Collate;
2use super::TorchCollate;
3use itertools::Itertools;
4
5// Maybe an implementation passing the length and the index of elements to the macro could be more efficient than with the
6// `Iterttols::multiunzip`.
7
8/// `tuple` implementation, up to 16 elements.
9macro_rules! tuple_impl {
10    ($($name:ident)+) => {
11        impl<$($name),+> Collate<($($name,)+)> for TorchCollate
12        where
13            $($name: Clone,)+
14            $(TorchCollate: Collate<$name>,)+
15
16        {
17            type Output = ($(<TorchCollate as Collate<$name>>::Output,)+);
18
19            #[allow(non_snake_case)]
20            fn collate(&self, batch: Vec<($($name,)+)>) -> Self::Output {
21                let copy = batch.to_vec();
22                let ($($name,)+) = copy.into_iter().multiunzip();
23                (
24                    $(TorchCollate::default().collate($name),)+
25                )
26
27            }
28        }
29    };
30}
31
32tuple_impl! { A }
33tuple_impl! { A B }
34tuple_impl! { A B C }
35tuple_impl! { A B C D }
36tuple_impl! { A B C D E }
37tuple_impl! { A B C D E F }
38tuple_impl! { A B C D E F G }
39tuple_impl! { A B C D E F G H }
40tuple_impl! { A B C D E F G H I }
41tuple_impl! { A B C D E F G H I J }
42tuple_impl! { A B C D E F G H I J K }
43tuple_impl! { A B C D E F G H I J K L }
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48
49    use tch::Tensor;
50
51    #[test]
52    fn vec_of_tuple() {
53        assert_eq!(
54            TorchCollate.collate(vec![(1, 2)]),
55            (Tensor::from_slice(&[1]), Tensor::from_slice(&[2]))
56        );
57        assert_eq!(
58            TorchCollate.collate(vec![(1.0, 2.0), (3.0, 4.0)]),
59            (
60                Tensor::from_slice(&[1.0, 3.0]),
61                Tensor::from_slice(&[2.0, 4.0])
62            )
63        );
64        assert_eq!(
65            TorchCollate.collate(vec![(1, 2), (3, 4)]),
66            (Tensor::from_slice(&[1, 3]), Tensor::from_slice(&[2, 4]))
67        );
68        assert_eq!(
69            TorchCollate.collate(vec![(-1, 2), (3, 4)]),
70            (Tensor::from_slice(&[-1, 3]), Tensor::from_slice(&[2, 4]))
71        );
72        assert_eq!(
73            TorchCollate.collate(vec![(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]),
74            (
75                Tensor::from_slice(&[1.0, 3.0, 5.0]),
76                Tensor::from_slice(&[2.0, 4.0, 6.0])
77            )
78        );
79    }
80    #[test]
81    fn vec_of_tuple_with_len_1() {
82        assert_eq!(
83            TorchCollate.collate(vec![(1,)]),
84            (Tensor::from_slice(&[1]),)
85        );
86    }
87
88    #[test]
89    fn vec_of_tuple_with_len_2() {
90        assert_eq!(
91            TorchCollate.collate(vec![(1, 2.0)]),
92            (Tensor::from_slice(&[1]), Tensor::from_slice(&[2.0]))
93        );
94        assert_eq!(
95            TorchCollate.collate(vec![(1, 2.0), (3, 4.0)]),
96            (Tensor::from_slice(&[1, 3]), Tensor::from_slice(&[2.0, 4.0]))
97        );
98        assert_eq!(
99            TorchCollate.collate(vec![(-1, true), (-3, false)]),
100            (
101                Tensor::from_slice(&[-1, -3]),
102                Tensor::from_slice(&[true, false])
103            )
104        );
105        assert_eq!(
106            TorchCollate.collate(vec![(-1, true), (3, false)]),
107            (
108                Tensor::from_slice(&[-1, 3]),
109                Tensor::from_slice(&[true, false])
110            )
111        );
112        assert_eq!(
113            TorchCollate.collate(vec![(1, 2.0), (3, 4.0), (5, 6.0)]),
114            (
115                Tensor::from_slice(&[1, 3, 5]),
116                Tensor::from_slice(&[2.0, 4.0, 6.0])
117            )
118        );
119    }
120    #[test]
121    fn vec_of_tuple_with_len_3() {
122        assert_eq!(
123            TorchCollate.collate(vec![(1, 2.0, true)]),
124            (
125                Tensor::from_slice(&[1]),
126                Tensor::from_slice(&[2.0]),
127                Tensor::from_slice(&[true])
128            )
129        );
130        assert_eq!(
131            TorchCollate.collate(vec![(1, 2.0, true), (3, 4.0, true)]),
132            (
133                Tensor::from_slice(&[1, 3]),
134                Tensor::from_slice(&[2.0, 4.0]),
135                Tensor::from_slice(&[true, true])
136            )
137        );
138        assert_eq!(
139            TorchCollate.collate(vec![(1, 2.0, true), (3, 4.0, false), (5, 6.0, true)]),
140            (
141                Tensor::from_slice(&[1, 3, 5]),
142                Tensor::from_slice(&[2.0, 4.0, 6.0]),
143                Tensor::from_slice(&[true, false, true])
144            )
145        );
146    }
147}