ai_dataloader/collate/torch_collate/
tuple.rs1use super::super::Collate;
2use super::TorchCollate;
3use itertools::Itertools;
4
5macro_rules! tuple_impl {
10 ($($name:ident)+) => {
11 impl<$($name),+> Collate<($($name,)+)> for TorchCollate
12 where
13 $($name: Clone,)+
14 $(TorchCollate: Collate<$name>,)+
15
16 {
17 type Output = ($(<TorchCollate as Collate<$name>>::Output,)+);
18
19 #[allow(non_snake_case)]
20 fn collate(&self, batch: Vec<($($name,)+)>) -> Self::Output {
21 let copy = batch.to_vec();
22 let ($($name,)+) = copy.into_iter().multiunzip();
23 (
24 $(TorchCollate::default().collate($name),)+
25 )
26
27 }
28 }
29 };
30}
31
32tuple_impl! { A }
33tuple_impl! { A B }
34tuple_impl! { A B C }
35tuple_impl! { A B C D }
36tuple_impl! { A B C D E }
37tuple_impl! { A B C D E F }
38tuple_impl! { A B C D E F G }
39tuple_impl! { A B C D E F G H }
40tuple_impl! { A B C D E F G H I }
41tuple_impl! { A B C D E F G H I J }
42tuple_impl! { A B C D E F G H I J K }
43tuple_impl! { A B C D E F G H I J K L }
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48
49 use tch::Tensor;
50
51 #[test]
52 fn vec_of_tuple() {
53 assert_eq!(
54 TorchCollate.collate(vec![(1, 2)]),
55 (Tensor::from_slice(&[1]), Tensor::from_slice(&[2]))
56 );
57 assert_eq!(
58 TorchCollate.collate(vec![(1.0, 2.0), (3.0, 4.0)]),
59 (
60 Tensor::from_slice(&[1.0, 3.0]),
61 Tensor::from_slice(&[2.0, 4.0])
62 )
63 );
64 assert_eq!(
65 TorchCollate.collate(vec![(1, 2), (3, 4)]),
66 (Tensor::from_slice(&[1, 3]), Tensor::from_slice(&[2, 4]))
67 );
68 assert_eq!(
69 TorchCollate.collate(vec![(-1, 2), (3, 4)]),
70 (Tensor::from_slice(&[-1, 3]), Tensor::from_slice(&[2, 4]))
71 );
72 assert_eq!(
73 TorchCollate.collate(vec![(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]),
74 (
75 Tensor::from_slice(&[1.0, 3.0, 5.0]),
76 Tensor::from_slice(&[2.0, 4.0, 6.0])
77 )
78 );
79 }
80 #[test]
81 fn vec_of_tuple_with_len_1() {
82 assert_eq!(
83 TorchCollate.collate(vec![(1,)]),
84 (Tensor::from_slice(&[1]),)
85 );
86 }
87
88 #[test]
89 fn vec_of_tuple_with_len_2() {
90 assert_eq!(
91 TorchCollate.collate(vec![(1, 2.0)]),
92 (Tensor::from_slice(&[1]), Tensor::from_slice(&[2.0]))
93 );
94 assert_eq!(
95 TorchCollate.collate(vec![(1, 2.0), (3, 4.0)]),
96 (Tensor::from_slice(&[1, 3]), Tensor::from_slice(&[2.0, 4.0]))
97 );
98 assert_eq!(
99 TorchCollate.collate(vec![(-1, true), (-3, false)]),
100 (
101 Tensor::from_slice(&[-1, -3]),
102 Tensor::from_slice(&[true, false])
103 )
104 );
105 assert_eq!(
106 TorchCollate.collate(vec![(-1, true), (3, false)]),
107 (
108 Tensor::from_slice(&[-1, 3]),
109 Tensor::from_slice(&[true, false])
110 )
111 );
112 assert_eq!(
113 TorchCollate.collate(vec![(1, 2.0), (3, 4.0), (5, 6.0)]),
114 (
115 Tensor::from_slice(&[1, 3, 5]),
116 Tensor::from_slice(&[2.0, 4.0, 6.0])
117 )
118 );
119 }
120 #[test]
121 fn vec_of_tuple_with_len_3() {
122 assert_eq!(
123 TorchCollate.collate(vec![(1, 2.0, true)]),
124 (
125 Tensor::from_slice(&[1]),
126 Tensor::from_slice(&[2.0]),
127 Tensor::from_slice(&[true])
128 )
129 );
130 assert_eq!(
131 TorchCollate.collate(vec![(1, 2.0, true), (3, 4.0, true)]),
132 (
133 Tensor::from_slice(&[1, 3]),
134 Tensor::from_slice(&[2.0, 4.0]),
135 Tensor::from_slice(&[true, true])
136 )
137 );
138 assert_eq!(
139 TorchCollate.collate(vec![(1, 2.0, true), (3, 4.0, false), (5, 6.0, true)]),
140 (
141 Tensor::from_slice(&[1, 3, 5]),
142 Tensor::from_slice(&[2.0, 4.0, 6.0]),
143 Tensor::from_slice(&[true, false, true])
144 )
145 );
146 }
147}