ai_dataloader/collate/default_collate/
array.rs

1use super::super::Collate;
2use super::DefaultCollate;
3
4impl<T, const N: usize> Collate<[T; N]> for DefaultCollate
5where
6    Self: Collate<T>,
7    T: Clone,
8{
9    type Output = Vec<<Self as Collate<T>>::Output>;
10    fn collate(&self, batch: Vec<[T; N]>) -> Self::Output {
11        let mut collated = Vec::with_capacity(batch.len());
12        for i in 0..batch[0].len() {
13            let vec: Vec<_> = batch.iter().map(|sample| sample[i].clone()).collect();
14            collated.push(self.collate(vec));
15        }
16        collated
17    }
18}
19
20#[cfg(test)]
21mod tests {
22    use super::*;
23    use ndarray::array;
24
25    #[test]
26    fn vec_of_array() {
27        assert_eq!(
28            DefaultCollate.collate(vec![[1, 2], [3, 4], [5, 6]]),
29            vec![array![1, 3, 5], array![2, 4, 6]]
30        );
31    }
32}