ai_dataloader/collate/default_collate/
sequence.rs1use super::super::Collate;
7use super::DefaultCollate;
8use std::collections::VecDeque;
9
10impl<T> Collate<Vec<T>> for DefaultCollate
11where
12 Self: Collate<T>,
13 T: Clone,
14{
15 type Output = Vec<<Self as Collate<T>>::Output>;
16 fn collate(&self, batch: Vec<Vec<T>>) -> Self::Output {
17 let elem_size = batch
18 .first()
19 .expect("Batch should contain at least one element")
20 .len();
21
22 assert!(
23 batch.iter().all(|vec| vec.len() == elem_size),
24 "Each Vec in the batch should have equal size"
25 );
26
27 let mut collated = Vec::with_capacity(batch.len());
28
29 for i in 0..batch[0].len() {
30 let vec: Vec<_> = batch.iter().map(|sample| sample[i].clone()).collect();
31 collated.push(self.collate(vec));
32 }
33 collated
34 }
35}
36
37impl<T> Collate<VecDeque<T>> for DefaultCollate
38where
39 Self: Collate<T>,
40 T: Clone,
41{
42 type Output = Vec<<Self as Collate<T>>::Output>;
43 fn collate(&self, batch: Vec<VecDeque<T>>) -> Self::Output {
44 let elem_size = batch
45 .first()
46 .expect("Batch should contain at least one element")
47 .len();
48
49 assert!(
50 batch.iter().all(|vec| vec.len() == elem_size),
51 "Each Vec in the batch should have equal size"
52 );
53
54 let mut collated = Vec::with_capacity(batch.len());
55
56 for i in 0..batch[0].len() {
57 let vec: Vec<_> = batch.iter().map(|sample| sample[i].clone()).collect();
58 collated.push(self.collate(vec));
59 }
60 collated
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67 use ndarray::array;
68
69 #[test]
70 fn vec_of_vec() {
71 assert_eq!(DefaultCollate.collate(vec![vec![1]]), vec![array![1]]);
72 assert_eq!(
73 DefaultCollate.collate(vec![vec![1, 2], vec![3, 4]]),
74 vec![array![1, 3], array![2, 4]]
75 );
76 assert_eq!(
78 DefaultCollate.collate(vec![vec![true, false], vec![true, false]]),
79 vec![array![true, true], array![false, false]]
80 );
81
82 assert_eq!(
83 DefaultCollate.collate(vec![vec![1, 2, 3], vec![4, 5, 6]]),
84 vec![array![1, 4], array![2, 5], array![3, 6]]
85 );
86 assert_eq!(
88 DefaultCollate.collate(vec![vec![1, 2], vec![3, 4], vec![5, 6]]),
89 vec![array![1, 3, 5], array![2, 4, 6]]
90 );
91 assert_eq!(
93 DefaultCollate.collate(vec![
94 vec![1, 2],
95 vec![3, 4],
96 vec![5, 6],
97 vec![7, 8],
98 vec![9, 10],
99 vec![11, 12],
100 vec![13, 14],
101 vec![15, 16],
102 vec![17, 18],
103 vec![19, 20]
104 ]),
105 vec![
106 array![1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
107 array![2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
108 ]
109 );
110 }
111
112 #[test]
113 fn specialized() {
114 assert_eq!(
115 DefaultCollate.collate(vec![
116 vec![String::from("a"), String::from("b")],
117 vec![String::from("c"), String::from("d")]
118 ]),
119 vec![
120 vec![String::from('a'), String::from('c')],
121 vec![String::from('b'), String::from('d')],
122 ]
123 );
124 }
125}