ai_dataloader/collate/torch_collate/
primitive.rs1use super::super::Collate;
2use super::TorchCollate;
3use tch::Tensor;
4
5macro_rules! primitive_impl {
6 ($($t:ty)*) => {
7 $(
8 impl Collate<$t> for TorchCollate {
9 type Output = Tensor;
10 fn collate(&self, batch: Vec<$t>) -> Self::Output {
11 Tensor::from_slice(batch.as_slice())
12 }
13 }
14 )*
15 };
16}
17primitive_impl!(
18 i8 i16 i32 i64
19 f32 f64
20 bool);
21
22impl Collate<u8> for TorchCollate {
26 type Output = Vec<u8>;
27 fn collate(&self, batch: Vec<u8>) -> Self::Output {
28 batch
29 }
30}
31
32#[cfg(test)]
33mod tests {
34 use super::*;
35
36 #[test]
37 fn scalar_type() {
38 assert_eq!(
39 TorchCollate.collate(vec![0, 1, 2, 3, 4, 5]),
40 Tensor::from_slice(&[0, 1, 2, 3, 4, 5])
41 );
42 assert_eq!(
43 TorchCollate.collate(vec![0., 1., 2., 3., 4., 5.]),
44 Tensor::from_slice(&[0., 1., 2., 3., 4., 5.])
45 );
46 }
47}