ai_dataloader/collate/default_collate/
tuple.rs

1use super::super::Collate;
2use super::DefaultCollate;
3use itertools::Itertools;
4
5// Maybe an implementation passing the length and the index of elements to the macro could be more efficient than with the
6// `Iterttols::multiunzip`.
7
8/// `tuple` implementation, up to 16 elements.
9macro_rules! tuple_impl {
10    ($($name:ident)+) => {
11        impl<$($name),+> Collate<($($name,)+)> for DefaultCollate
12        where
13            $($name: Clone,)+
14            $(DefaultCollate: Collate<$name>,)+
15
16        {
17            type Output = ($(<DefaultCollate as Collate<$name>>::Output,)+);
18
19            #[allow(non_snake_case)]
20            fn collate(&self, batch: Vec<($($name,)+)>) -> Self::Output {
21                let copy = batch.to_vec();
22                let ($($name,)+) = copy.into_iter().multiunzip();
23                (
24                    $(DefaultCollate::default().collate($name),)+
25                )
26
27            }
28        }
29    };
30}
31
32tuple_impl! { A }
33tuple_impl! { A B }
34tuple_impl! { A B C }
35tuple_impl! { A B C D }
36tuple_impl! { A B C D E }
37tuple_impl! { A B C D E F }
38tuple_impl! { A B C D E F G }
39tuple_impl! { A B C D E F G H }
40tuple_impl! { A B C D E F G H I }
41tuple_impl! { A B C D E F G H I J }
42tuple_impl! { A B C D E F G H I J K }
43tuple_impl! { A B C D E F G H I J K L }
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use ndarray::array;
49
50    #[test]
51    fn vec_of_tuple() {
52        assert_eq!(DefaultCollate.collate(vec![(1, 2)]), (array![1], array![2]));
53        assert_eq!(
54            DefaultCollate.collate(vec![(1.0, 2.0), (3.0, 4.0)]),
55            (array![1.0, 3.0], array![2.0, 4.0])
56        );
57        assert_eq!(
58            DefaultCollate.collate(vec![(1, 2), (3, 4)]),
59            (array![1, 3], array![2, 4])
60        );
61        assert_eq!(
62            DefaultCollate.collate(vec![(-1, 2), (3, 4)]),
63            (array![-1, 3], array![2, 4])
64        );
65        assert_eq!(
66            DefaultCollate.collate(vec![(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]),
67            (array![1.0, 3.0, 5.0], array![2.0, 4.0, 6.0])
68        );
69    }
70    #[test]
71    fn vec_of_tuple_with_len_1() {
72        assert_eq!(DefaultCollate.collate(vec![(1,)]), (array![1],));
73    }
74
75    #[test]
76    fn vec_of_tuple_with_len_2() {
77        assert_eq!(
78            DefaultCollate.collate(vec![(1, 2.0)]),
79            (array![1], array![2.0])
80        );
81        assert_eq!(
82            DefaultCollate.collate(vec![(1, 2.0), (3, 4.0)]),
83            (array![1, 3], array![2.0, 4.0])
84        );
85        assert_eq!(
86            DefaultCollate.collate(vec![(-1, true), (-3, false)]),
87            (array![-1, -3], array![true, false])
88        );
89        assert_eq!(
90            DefaultCollate.collate(vec![(-1, true), (3, false)]),
91            (array![-1, 3], array![true, false])
92        );
93        assert_eq!(
94            DefaultCollate.collate(vec![(1, 2.0), (3, 4.0), (5, 6.0)]),
95            (array![1, 3, 5], array![2.0, 4.0, 6.0])
96        );
97    }
98    #[test]
99    fn vec_of_tuple_with_len_3() {
100        assert_eq!(
101            DefaultCollate.collate(vec![(1, 2.0, true)]),
102            (array![1], array![2.0], array![true])
103        );
104        assert_eq!(
105            DefaultCollate.collate(vec![(1, 2.0, true), (3, 4.0, true)]),
106            (array![1, 3], array![2.0, 4.0], array![true, true])
107        );
108        assert_eq!(
109            DefaultCollate.collate(vec![(1, 2.0, true), (3, 4.0, false), (5, 6.0, true)]),
110            (
111                array![1, 3, 5],
112                array![2.0, 4.0, 6.0],
113                array![true, false, true]
114            )
115        );
116    }
117}