ai_dataloader/collate/torch_collate.rs
1/// Torch Collate function that mimic the [`default_collate` function](https://pytorch.org/docs/stable/data.html#automatic-batching-default) from ``PyTorch``.
2///
3/// Data is collated inside a `tch` `Tensor`.
4///
5///
6/// Basic transformation implemented for the default Collate :
7///
8/// - `Vec<Scalar>` -> `tch::Tensor<scalar>`
9/// - `Vec<tuple>` -> `tuple(ndarray)`
10/// - `Vec<HashMap<Key, Value>>` -> `HasMap<Key, TorchCollate::default().collate(Vec<Value>)`
11/// - `Vec<Array>` -> `Vec<Stack Array>`
12/// - `Vec[V1_i, V2_i, ...]` -> `Vec[TorchCollate::default().collate([V1_1, V1_2, ...]), TorchCollate::default().collate([V2_1, V2_2, ...]), ...]`
13///
14///
15/// Like for `PyTorch` version, `String` and `u8` aren't changed by the collation (No Op).
16///
17/// - `Vec<String>` -> `Vec<String>`
18/// - `Vec<&str>` -> `Vec<&str>`
19/// - `Vec<u8>` -> `Vec<u8>`
20///
21///
22#[derive(Default, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
23pub struct TorchCollate;
24
25mod array;
26mod map;
27mod ndarray;
28mod nonzero;
29mod primitive;
30mod reference;
31mod sequence;
32mod string;
33mod tuple;