ai_dataloader/collate/torch_collate/
sequence.rs

1/// Implementation for Sequence.
2///
3/// Currently `BinaryHeap`, `BTreeSet`, `HashSet` and `LinkedList` are not supported because the current implementation
4/// require indexing for doing the transpose.
5///
6use super::super::Collate;
7use super::TorchCollate;
8use std::collections::VecDeque;
9
10impl<T> Collate<Vec<T>> for TorchCollate
11where
12    Self: Collate<T>,
13    T: Clone,
14{
15    type Output = Vec<<Self as Collate<T>>::Output>;
16    fn collate(&self, batch: Vec<Vec<T>>) -> Self::Output {
17        let elem_size = batch
18            .first()
19            .expect("Batch should contain at least one element")
20            .len();
21
22        assert!(
23            batch.iter().all(|vec| vec.len() == elem_size),
24            "Each Vec in the batch should have equal size"
25        );
26
27        let mut collated = Vec::with_capacity(batch.len());
28
29        for i in 0..batch[0].len() {
30            let vec: Vec<_> = batch.iter().map(|sample| sample[i].clone()).collect();
31            collated.push(self.collate(vec));
32        }
33        collated
34    }
35}
36
37impl<T> Collate<VecDeque<T>> for TorchCollate
38where
39    Self: Collate<T>,
40    T: Clone,
41{
42    type Output = Vec<<Self as Collate<T>>::Output>;
43    fn collate(&self, batch: Vec<VecDeque<T>>) -> Self::Output {
44        let elem_size = batch
45            .first()
46            .expect("Batch should contain at least one element")
47            .len();
48
49        assert!(
50            batch.iter().all(|vec| vec.len() == elem_size),
51            "Each Vec in the batch should have equal size"
52        );
53
54        let mut collated = Vec::with_capacity(batch.len());
55
56        for i in 0..batch[0].len() {
57            let vec: Vec<_> = batch.iter().map(|sample| sample[i].clone()).collect();
58            collated.push(self.collate(vec));
59        }
60        collated
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    use tch::Tensor;
69
70    #[test]
71    fn vec_of_vec() {
72        assert_eq!(
73            TorchCollate.collate(vec![vec![1]]),
74            vec![Tensor::from_slice(&[1])]
75        );
76        assert_eq!(
77            TorchCollate.collate(vec![vec![1, 2], vec![3, 4]]),
78            vec![Tensor::from_slice(&[1, 3]), Tensor::from_slice(&[2, 4])]
79        );
80        // different type
81        assert_eq!(
82            TorchCollate.collate(vec![vec![true, false], vec![true, false]]),
83            vec![
84                Tensor::from_slice(&[true, true]),
85                Tensor::from_slice(&[false, false])
86            ]
87        );
88
89        assert_eq!(
90            TorchCollate.collate(vec![vec![1, 2, 3], vec![4, 5, 6]]),
91            vec![
92                Tensor::from_slice(&[1, 4]),
93                Tensor::from_slice(&[2, 5]),
94                Tensor::from_slice(&[3, 6])
95            ]
96        );
97        // batch_size 3
98        assert_eq!(
99            TorchCollate.collate(vec![vec![1, 2], vec![3, 4], vec![5, 6]]),
100            vec![
101                Tensor::from_slice(&[1, 3, 5]),
102                Tensor::from_slice(&[2, 4, 6])
103            ]
104        );
105        // batch_size 10
106        assert_eq!(
107            TorchCollate.collate(vec![
108                vec![1, 2],
109                vec![3, 4],
110                vec![5, 6],
111                vec![7, 8],
112                vec![9, 10],
113                vec![11, 12],
114                vec![13, 14],
115                vec![15, 16],
116                vec![17, 18],
117                vec![19, 20]
118            ]),
119            vec![
120                Tensor::from_slice(&[1, 3, 5, 7, 9, 11, 13, 15, 17, 19]),
121                Tensor::from_slice(&[2, 4, 6, 8, 10, 12, 14, 16, 18, 20])
122            ]
123        );
124    }
125
126    #[test]
127    fn specialized() {
128        assert_eq!(
129            TorchCollate.collate(vec![
130                vec![String::from("a"), String::from("b")],
131                vec![String::from("c"), String::from("d")]
132            ]),
133            vec![
134                vec![String::from('a'), String::from('c')],
135                vec![String::from('b'), String::from('d')],
136            ]
137        );
138    }
139}