ai_dataloader/collate/torch_collate/
sequence.rs1use super::super::Collate;
7use super::TorchCollate;
8use std::collections::VecDeque;
9
10impl<T> Collate<Vec<T>> for TorchCollate
11where
12 Self: Collate<T>,
13 T: Clone,
14{
15 type Output = Vec<<Self as Collate<T>>::Output>;
16 fn collate(&self, batch: Vec<Vec<T>>) -> Self::Output {
17 let elem_size = batch
18 .first()
19 .expect("Batch should contain at least one element")
20 .len();
21
22 assert!(
23 batch.iter().all(|vec| vec.len() == elem_size),
24 "Each Vec in the batch should have equal size"
25 );
26
27 let mut collated = Vec::with_capacity(batch.len());
28
29 for i in 0..batch[0].len() {
30 let vec: Vec<_> = batch.iter().map(|sample| sample[i].clone()).collect();
31 collated.push(self.collate(vec));
32 }
33 collated
34 }
35}
36
37impl<T> Collate<VecDeque<T>> for TorchCollate
38where
39 Self: Collate<T>,
40 T: Clone,
41{
42 type Output = Vec<<Self as Collate<T>>::Output>;
43 fn collate(&self, batch: Vec<VecDeque<T>>) -> Self::Output {
44 let elem_size = batch
45 .first()
46 .expect("Batch should contain at least one element")
47 .len();
48
49 assert!(
50 batch.iter().all(|vec| vec.len() == elem_size),
51 "Each Vec in the batch should have equal size"
52 );
53
54 let mut collated = Vec::with_capacity(batch.len());
55
56 for i in 0..batch[0].len() {
57 let vec: Vec<_> = batch.iter().map(|sample| sample[i].clone()).collect();
58 collated.push(self.collate(vec));
59 }
60 collated
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67
68 use tch::Tensor;
69
70 #[test]
71 fn vec_of_vec() {
72 assert_eq!(
73 TorchCollate.collate(vec![vec![1]]),
74 vec![Tensor::from_slice(&[1])]
75 );
76 assert_eq!(
77 TorchCollate.collate(vec![vec![1, 2], vec![3, 4]]),
78 vec![Tensor::from_slice(&[1, 3]), Tensor::from_slice(&[2, 4])]
79 );
80 assert_eq!(
82 TorchCollate.collate(vec![vec![true, false], vec![true, false]]),
83 vec![
84 Tensor::from_slice(&[true, true]),
85 Tensor::from_slice(&[false, false])
86 ]
87 );
88
89 assert_eq!(
90 TorchCollate.collate(vec![vec![1, 2, 3], vec![4, 5, 6]]),
91 vec![
92 Tensor::from_slice(&[1, 4]),
93 Tensor::from_slice(&[2, 5]),
94 Tensor::from_slice(&[3, 6])
95 ]
96 );
97 assert_eq!(
99 TorchCollate.collate(vec![vec![1, 2], vec![3, 4], vec![5, 6]]),
100 vec![
101 Tensor::from_slice(&[1, 3, 5]),
102 Tensor::from_slice(&[2, 4, 6])
103 ]
104 );
105 assert_eq!(
107 TorchCollate.collate(vec![
108 vec![1, 2],
109 vec![3, 4],
110 vec![5, 6],
111 vec![7, 8],
112 vec![9, 10],
113 vec![11, 12],
114 vec![13, 14],
115 vec![15, 16],
116 vec![17, 18],
117 vec![19, 20]
118 ]),
119 vec![
120 Tensor::from_slice(&[1, 3, 5, 7, 9, 11, 13, 15, 17, 19]),
121 Tensor::from_slice(&[2, 4, 6, 8, 10, 12, 14, 16, 18, 20])
122 ]
123 );
124 }
125
126 #[test]
127 fn specialized() {
128 assert_eq!(
129 TorchCollate.collate(vec![
130 vec![String::from("a"), String::from("b")],
131 vec![String::from("c"), String::from("d")]
132 ]),
133 vec![
134 vec![String::from('a'), String::from('c')],
135 vec![String::from('b'), String::from('d')],
136 ]
137 );
138 }
139}