ai_dataloader/collate/default_collate/
array.rs1use super::super::Collate;
2use super::DefaultCollate;
3
4impl<T, const N: usize> Collate<[T; N]> for DefaultCollate
5where
6 Self: Collate<T>,
7 T: Clone,
8{
9 type Output = Vec<<Self as Collate<T>>::Output>;
10 fn collate(&self, batch: Vec<[T; N]>) -> Self::Output {
11 let mut collated = Vec::with_capacity(batch.len());
12 for i in 0..batch[0].len() {
13 let vec: Vec<_> = batch.iter().map(|sample| sample[i].clone()).collect();
14 collated.push(self.collate(vec));
15 }
16 collated
17 }
18}
19
20#[cfg(test)]
21mod tests {
22 use super::*;
23 use ndarray::array;
24
25 #[test]
26 fn vec_of_array() {
27 assert_eq!(
28 DefaultCollate.collate(vec![[1, 2], [3, 4], [5, 6]]),
29 vec![array![1, 3, 5], array![2, 4, 6]]
30 );
31 }
32}