1mod default_collate;
5pub use default_collate::DefaultCollate;
6
7#[cfg(feature = "tch")]
8#[cfg_attr(docsrs, doc(cfg(feature = "tch")))]
9mod torch_collate;
10#[cfg(feature = "tch")]
11#[cfg_attr(docsrs, doc(cfg(feature = "tch")))]
12pub use torch_collate::TorchCollate;
13
14pub trait Collate<T> {
21 type Output;
23 fn collate(&self, batch: Vec<T>) -> Self::Output;
25}
26
27impl<T, F, O> Collate<T> for F
29where
30 F: Fn(Vec<T>) -> O,
31{
32 type Output = O;
33 fn collate(&self, batch: Vec<T>) -> Self::Output {
34 (self)(batch)
35 }
36}
37
38#[derive(Default, Debug)]
40pub struct NoOpCollate;
41
42impl<T> Collate<T> for NoOpCollate {
43 type Output = Vec<T>;
44 fn collate(&self, batch: Vec<T>) -> Self::Output {
45 batch
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52
53 #[test]
54 fn no_op_collate() {
55 assert_eq!(NoOpCollate.collate(vec![1, 2]), vec![1, 2]);
56 }
57
58 #[test]
59 fn no_op_collate_closure() {
60 let collate = |x| x;
61 assert_eq!(collate.collate(vec![1, 2]), vec![1, 2]);
62 }
63}