mod default_collate;
pub use default_collate::DefaultCollate;
#[cfg(feature = "tch")]
#[cfg_attr(docsrs, doc(cfg(feature = "tch")))]
mod torch_collate;
#[cfg(feature = "tch")]
#[cfg_attr(docsrs, doc(cfg(feature = "tch")))]
pub use torch_collate::TorchCollate;
pub trait Collate<T> {
type Output;
fn collate(&self, batch: Vec<T>) -> Self::Output;
}
impl<T, F, O> Collate<T> for F
where
F: Fn(Vec<T>) -> O,
{
type Output = O;
fn collate(&self, batch: Vec<T>) -> Self::Output {
(self)(batch)
}
}
#[derive(Default, Debug)]
pub struct NoOpCollate;
impl<T> Collate<T> for NoOpCollate {
type Output = Vec<T>;
fn collate(&self, batch: Vec<T>) -> Self::Output {
batch
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_op_collate() {
assert_eq!(NoOpCollate::default().collate(vec![1, 2]), vec![1, 2]);
}
#[test]
fn no_op_collate_closure() {
let collate = |x| x;
assert_eq!(collate.collate(vec![1, 2]), vec![1, 2]);
}
}