ai_dataloader/
collate.rs

1//! Merges a list of samples to form a batch.
2//!
3
4mod default_collate;
5pub use default_collate::DefaultCollate;
6
7#[cfg(feature = "tch")]
8#[cfg_attr(docsrs, doc(cfg(feature = "tch")))]
9mod torch_collate;
10#[cfg(feature = "tch")]
11#[cfg_attr(docsrs, doc(cfg(feature = "tch")))]
12pub use torch_collate::TorchCollate;
13
14/// Any collate gather samples from one batch together.
15///
16/// A `DefaultCollate` struct is provided which will cover most of the use cases.
17///
18///
19/// This trait is used instead of `Fn` because [we cannot currently `impl Fn*` on struct on stable rust](https://github.com/rust-lang/rust/issues/29625).
20pub trait Collate<T> {
21    /// The type of the collate function's output
22    type Output;
23    /// Take a batch of samples and collate them
24    fn collate(&self, batch: Vec<T>) -> Self::Output;
25}
26
27// Allow user to specify closure as collate function.
28impl<T, F, O> Collate<T> for F
29where
30    F: Fn(Vec<T>) -> O,
31{
32    type Output = O;
33    fn collate(&self, batch: Vec<T>) -> Self::Output {
34        (self)(batch)
35    }
36}
37
38/// Simple Collate that doesn't change the batch of samples.
39#[derive(Default, Debug)]
40pub struct NoOpCollate;
41
42impl<T> Collate<T> for NoOpCollate {
43    type Output = Vec<T>;
44    fn collate(&self, batch: Vec<T>) -> Self::Output {
45        batch
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52
53    #[test]
54    fn no_op_collate() {
55        assert_eq!(NoOpCollate.collate(vec![1, 2]), vec![1, 2]);
56    }
57
58    #[test]
59    fn no_op_collate_closure() {
60        let collate = |x| x;
61        assert_eq!(collate.collate(vec![1, 2]), vec![1, 2]);
62    }
63}