use super::{core::DefaultCollate, Collate};
use torsh_core::error::Result;
#[cfg(not(feature = "std"))]
use alloc::boxed::Box;
pub struct CollateFn<F> {
func: F,
}
impl<F> CollateFn<F> {
pub fn new(func: F) -> Self {
Self { func }
}
}
impl<T, O, F> Collate<T> for CollateFn<F>
where
F: Fn(Vec<T>) -> Result<O>,
{
type Output = O;
fn collate(&self, batch: Vec<T>) -> Result<Self::Output> {
(self.func)(batch)
}
}
pub fn collate_fn<T>() -> DefaultCollate {
DefaultCollate
}