use super::{
advanced::{CachedCollate, DynamicBatchCollateWrapper, PadCollate},
core::DefaultCollate,
optimized::OptimizedCollate,
};
use crate::collate::Collate;
use torsh_core::dtype::TensorElement;
use torsh_tensor::Tensor;
#[cfg(not(feature = "std"))]
use alloc::boxed::Box;
#[derive(Debug, Clone, Copy)]
pub enum CollateStrategy {
Stack,
Optimized,
Padding,
Dynamic,
Cached,
}
pub struct CollateBuilder<T> {
strategy: CollateStrategy,
padding_value: Option<T>,
max_length: Option<usize>,
use_caching: bool,
batch_size_hint: Option<usize>,
}
impl<T: TensorElement> Default for CollateBuilder<T> {
fn default() -> Self {
Self {
strategy: CollateStrategy::Stack,
padding_value: None,
max_length: None,
use_caching: false,
batch_size_hint: None,
}
}
}
impl<
T: TensorElement
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ Default,
> CollateBuilder<T>
{
pub fn new() -> Self {
Self::default()
}
pub fn strategy(mut self, strategy: CollateStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_padding(mut self, padding_value: T) -> Self {
self.padding_value = Some(padding_value);
self
}
pub fn max_length(mut self, max_length: usize) -> Self {
self.max_length = Some(max_length);
self
}
pub fn with_caching(mut self) -> Self {
self.use_caching = true;
self
}
pub fn batch_size_hint(mut self, size: usize) -> Self {
self.batch_size_hint = Some(size);
self
}
pub fn build(self) -> Box<dyn Collate<Tensor<T>, Output = Tensor<T>> + Send + Sync>
where
T: Copy + 'static,
{
match self.strategy {
CollateStrategy::Stack => Box::new(DefaultCollate),
CollateStrategy::Optimized => Box::new(OptimizedCollate),
CollateStrategy::Padding => {
let padding_value = self.padding_value.unwrap_or_default();
Box::new(PadCollate::new(padding_value))
}
CollateStrategy::Dynamic => {
let padding_value = self.padding_value.unwrap_or_default();
Box::new(DynamicBatchCollateWrapper::new(padding_value))
}
CollateStrategy::Cached => {
if cfg!(feature = "std") {
Box::new(CachedCollate::new(1000))
} else {
Box::new(OptimizedCollate)
}
}
}
}
}