ai_dataloader/collate/default_collate/
primitive.rs1use super::super::Collate;
2use super::DefaultCollate;
3
4use ndarray::{Array, Array1};
5
6macro_rules! primitive_impl {
7 ($($t:ty)*) => {
8 $(
9 impl Collate<$t> for DefaultCollate {
10 type Output = Array1<$t>;
11 fn collate(&self, batch: Vec<$t>) -> Self::Output {
12 Array::from_vec(batch)
13 }
14 }
15 )*
16 };
17}
18primitive_impl!(usize u16 u32 u64 u128
19 isize i8 i16 i32 i64 i128
20 f32 f64
21 bool char);
22
23impl Collate<u8> for DefaultCollate {
25 type Output = Vec<u8>;
26 fn collate(&self, batch: Vec<u8>) -> Self::Output {
27 batch
28 }
29}
30
31#[cfg(test)]
32mod tests {
33 use super::*;
34 use ndarray::array;
35
36 #[test]
37 fn scalar_type() {
38 assert_eq!(
39 DefaultCollate.collate(vec![0, 1, 2, 3, 4, 5]),
40 array![0, 1, 2, 3, 4, 5]
41 );
42 assert_eq!(
43 DefaultCollate.collate(vec![0., 1., 2., 3., 4., 5.]),
44 array![0., 1., 2., 3., 4., 5.]
45 );
46 }
47}