ai_dataloader/collate/default_collate/
sequence.rs

1/// Implementation for Sequence.
2///
3/// Currently `BinaryHeap`, `BTreeSet`, `HashSet` and `LinkedList` are not supported because the current implementation
4/// require indexing for doing the transpose.
5///
6use super::super::Collate;
7use super::DefaultCollate;
8use std::collections::VecDeque;
9
10impl<T> Collate<Vec<T>> for DefaultCollate
11where
12    Self: Collate<T>,
13    T: Clone,
14{
15    type Output = Vec<<Self as Collate<T>>::Output>;
16    fn collate(&self, batch: Vec<Vec<T>>) -> Self::Output {
17        let elem_size = batch
18            .first()
19            .expect("Batch should contain at least one element")
20            .len();
21
22        assert!(
23            batch.iter().all(|vec| vec.len() == elem_size),
24            "Each Vec in the batch should have equal size"
25        );
26
27        let mut collated = Vec::with_capacity(batch.len());
28
29        for i in 0..batch[0].len() {
30            let vec: Vec<_> = batch.iter().map(|sample| sample[i].clone()).collect();
31            collated.push(self.collate(vec));
32        }
33        collated
34    }
35}
36
37impl<T> Collate<VecDeque<T>> for DefaultCollate
38where
39    Self: Collate<T>,
40    T: Clone,
41{
42    type Output = Vec<<Self as Collate<T>>::Output>;
43    fn collate(&self, batch: Vec<VecDeque<T>>) -> Self::Output {
44        let elem_size = batch
45            .first()
46            .expect("Batch should contain at least one element")
47            .len();
48
49        assert!(
50            batch.iter().all(|vec| vec.len() == elem_size),
51            "Each Vec in the batch should have equal size"
52        );
53
54        let mut collated = Vec::with_capacity(batch.len());
55
56        for i in 0..batch[0].len() {
57            let vec: Vec<_> = batch.iter().map(|sample| sample[i].clone()).collect();
58            collated.push(self.collate(vec));
59        }
60        collated
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67    use ndarray::array;
68
69    #[test]
70    fn vec_of_vec() {
71        assert_eq!(DefaultCollate.collate(vec![vec![1]]), vec![array![1]]);
72        assert_eq!(
73            DefaultCollate.collate(vec![vec![1, 2], vec![3, 4]]),
74            vec![array![1, 3], array![2, 4]]
75        );
76        // different type
77        assert_eq!(
78            DefaultCollate.collate(vec![vec![true, false], vec![true, false]]),
79            vec![array![true, true], array![false, false]]
80        );
81
82        assert_eq!(
83            DefaultCollate.collate(vec![vec![1, 2, 3], vec![4, 5, 6]]),
84            vec![array![1, 4], array![2, 5], array![3, 6]]
85        );
86        // batch_size 3
87        assert_eq!(
88            DefaultCollate.collate(vec![vec![1, 2], vec![3, 4], vec![5, 6]]),
89            vec![array![1, 3, 5], array![2, 4, 6]]
90        );
91        // batch_size 10
92        assert_eq!(
93            DefaultCollate.collate(vec![
94                vec![1, 2],
95                vec![3, 4],
96                vec![5, 6],
97                vec![7, 8],
98                vec![9, 10],
99                vec![11, 12],
100                vec![13, 14],
101                vec![15, 16],
102                vec![17, 18],
103                vec![19, 20]
104            ]),
105            vec![
106                array![1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
107                array![2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
108            ]
109        );
110    }
111
112    #[test]
113    fn specialized() {
114        assert_eq!(
115            DefaultCollate.collate(vec![
116                vec![String::from("a"), String::from("b")],
117                vec![String::from("c"), String::from("d")]
118            ]),
119            vec![
120                vec![String::from('a'), String::from('c')],
121                vec![String::from('b'), String::from('d')],
122            ]
123        );
124    }
125}