ai_dataloader/collate/default_collate/
tuple.rs1use super::super::Collate;
2use super::DefaultCollate;
3use itertools::Itertools;
4
5macro_rules! tuple_impl {
10 ($($name:ident)+) => {
11 impl<$($name),+> Collate<($($name,)+)> for DefaultCollate
12 where
13 $($name: Clone,)+
14 $(DefaultCollate: Collate<$name>,)+
15
16 {
17 type Output = ($(<DefaultCollate as Collate<$name>>::Output,)+);
18
19 #[allow(non_snake_case)]
20 fn collate(&self, batch: Vec<($($name,)+)>) -> Self::Output {
21 let copy = batch.to_vec();
22 let ($($name,)+) = copy.into_iter().multiunzip();
23 (
24 $(DefaultCollate::default().collate($name),)+
25 )
26
27 }
28 }
29 };
30}
31
32tuple_impl! { A }
33tuple_impl! { A B }
34tuple_impl! { A B C }
35tuple_impl! { A B C D }
36tuple_impl! { A B C D E }
37tuple_impl! { A B C D E F }
38tuple_impl! { A B C D E F G }
39tuple_impl! { A B C D E F G H }
40tuple_impl! { A B C D E F G H I }
41tuple_impl! { A B C D E F G H I J }
42tuple_impl! { A B C D E F G H I J K }
43tuple_impl! { A B C D E F G H I J K L }
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48 use ndarray::array;
49
50 #[test]
51 fn vec_of_tuple() {
52 assert_eq!(DefaultCollate.collate(vec![(1, 2)]), (array![1], array![2]));
53 assert_eq!(
54 DefaultCollate.collate(vec![(1.0, 2.0), (3.0, 4.0)]),
55 (array![1.0, 3.0], array![2.0, 4.0])
56 );
57 assert_eq!(
58 DefaultCollate.collate(vec![(1, 2), (3, 4)]),
59 (array![1, 3], array![2, 4])
60 );
61 assert_eq!(
62 DefaultCollate.collate(vec![(-1, 2), (3, 4)]),
63 (array![-1, 3], array![2, 4])
64 );
65 assert_eq!(
66 DefaultCollate.collate(vec![(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]),
67 (array![1.0, 3.0, 5.0], array![2.0, 4.0, 6.0])
68 );
69 }
70 #[test]
71 fn vec_of_tuple_with_len_1() {
72 assert_eq!(DefaultCollate.collate(vec![(1,)]), (array![1],));
73 }
74
75 #[test]
76 fn vec_of_tuple_with_len_2() {
77 assert_eq!(
78 DefaultCollate.collate(vec![(1, 2.0)]),
79 (array![1], array![2.0])
80 );
81 assert_eq!(
82 DefaultCollate.collate(vec![(1, 2.0), (3, 4.0)]),
83 (array![1, 3], array![2.0, 4.0])
84 );
85 assert_eq!(
86 DefaultCollate.collate(vec![(-1, true), (-3, false)]),
87 (array![-1, -3], array![true, false])
88 );
89 assert_eq!(
90 DefaultCollate.collate(vec![(-1, true), (3, false)]),
91 (array![-1, 3], array![true, false])
92 );
93 assert_eq!(
94 DefaultCollate.collate(vec![(1, 2.0), (3, 4.0), (5, 6.0)]),
95 (array![1, 3, 5], array![2.0, 4.0, 6.0])
96 );
97 }
98 #[test]
99 fn vec_of_tuple_with_len_3() {
100 assert_eq!(
101 DefaultCollate.collate(vec![(1, 2.0, true)]),
102 (array![1], array![2.0], array![true])
103 );
104 assert_eq!(
105 DefaultCollate.collate(vec![(1, 2.0, true), (3, 4.0, true)]),
106 (array![1, 3], array![2.0, 4.0], array![true, true])
107 );
108 assert_eq!(
109 DefaultCollate.collate(vec![(1, 2.0, true), (3, 4.0, false), (5, 6.0, true)]),
110 (
111 array![1, 3, 5],
112 array![2.0, 4.0, 6.0],
113 array![true, false, true]
114 )
115 );
116 }
117}