1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
//! Merges a list of samples to form a batch.
//!

mod default_collate;
pub use default_collate::DefaultCollate;

#[cfg(feature = "tch")]
#[cfg_attr(docsrs, doc(cfg(feature = "tch")))]
mod torch_collate;
#[cfg(feature = "tch")]
#[cfg_attr(docsrs, doc(cfg(feature = "tch")))]
pub use torch_collate::TorchCollate;

/// Any collate gather samples from one batch together.
///
/// A `DefaultCollate` struct is provided which will cover most of the use cases.
///
///
/// This trait is used instead of `Fn` because [we cannot currently `impl Fn*` on struct on stable rust](https://github.com/rust-lang/rust/issues/29625).
pub trait Collate<T> {
    /// The type of the collate function's output
    type Output;
    /// Take a batch of samples and collate them
    fn collate(&self, batch: Vec<T>) -> Self::Output;
}

// Allow user to specify closure as collate function.
impl<T, F, O> Collate<T> for F
where
    F: Fn(Vec<T>) -> O,
{
    type Output = O;
    fn collate(&self, batch: Vec<T>) -> Self::Output {
        (self)(batch)
    }
}

/// Simple Collate that doesn't change the batch of samples.
#[derive(Default, Debug)]
pub struct NoOpCollate;

impl<T> Collate<T> for NoOpCollate {
    type Output = Vec<T>;
    fn collate(&self, batch: Vec<T>) -> Self::Output {
        batch
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn no_op_collate() {
        assert_eq!(NoOpCollate::default().collate(vec![1, 2]), vec![1, 2]);
    }

    #[test]
    fn no_op_collate_closure() {
        let collate = |x| x;
        assert_eq!(collate.collate(vec![1, 2]), vec![1, 2]);
    }
}