use tenflowers_core::{Result, Tensor, TensorError};
pub trait CollateFn<T> {
fn collate(&self, batch: Vec<(Tensor<T>, Tensor<T>)>) -> Result<(Tensor<T>, Tensor<T>)>;
}
pub struct DefaultCollate;
impl<T> CollateFn<T> for DefaultCollate
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
fn collate(&self, batch: Vec<(Tensor<T>, Tensor<T>)>) -> Result<(Tensor<T>, Tensor<T>)> {
if batch.is_empty() {
return Err(TensorError::invalid_argument(
"Cannot collate empty batch".to_string(),
));
}
let (features, labels): (Vec<_>, Vec<_>) = batch.into_iter().unzip();
let feature_refs: Vec<&Tensor<T>> = features.iter().collect();
let label_refs: Vec<&Tensor<T>> = labels.iter().collect();
let stacked_features = tenflowers_core::ops::stack(&feature_refs, 0)?;
let stacked_labels = tenflowers_core::ops::stack(&label_refs, 0)?;
Ok((stacked_features, stacked_labels))
}
}
#[derive(Debug, Clone)]
pub enum PaddingStrategy {
MaxLength,
FixedLength(usize),
Bucket(usize),
}
pub struct PaddingCollate<T> {
padding_value: T,
padding_strategy: PaddingStrategy,
truncate: bool,
}
impl<T> PaddingCollate<T>
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
pub fn new(padding_value: T, padding_strategy: PaddingStrategy) -> Self {
Self {
padding_value,
padding_strategy,
truncate: true,
}
}
pub fn with_zero_padding(padding_strategy: PaddingStrategy) -> Self {
Self::new(T::zero(), padding_strategy)
}
pub fn with_truncate(mut self, truncate: bool) -> Self {
self.truncate = truncate;
self
}
fn determine_target_length(&self, sequences: &[&Tensor<T>]) -> usize {
match &self.padding_strategy {
PaddingStrategy::MaxLength => sequences
.iter()
.map(|tensor| tensor.shape().dims()[tensor.shape().rank() - 1])
.max()
.unwrap_or(0),
PaddingStrategy::FixedLength(length) => *length,
PaddingStrategy::Bucket(bucket_size) => {
let max_len = sequences
.iter()
.map(|tensor| tensor.shape().dims()[tensor.shape().rank() - 1])
.max()
.unwrap_or(0);
((max_len + bucket_size - 1) / bucket_size) * bucket_size
}
}
}
fn pad_sequence(&self, tensor: &Tensor<T>, target_length: usize) -> Result<Tensor<T>> {
let shape = tensor.shape().dims();
let seq_length = shape[shape.len() - 1];
if seq_length == target_length {
return Ok(tensor.clone());
}
if seq_length > target_length && self.truncate {
let mut ranges = Vec::new();
#[allow(clippy::needless_range_loop)]
for i in 0..shape.len() - 1 {
ranges.push(0..shape[i]);
}
ranges.push(0..target_length);
tenflowers_core::ops::slice(tensor, &ranges)
} else if seq_length < target_length {
let mut new_shape = shape.to_vec();
let last_dim_idx = new_shape.len() - 1;
new_shape[last_dim_idx] = target_length;
let padding_length = target_length - seq_length;
let mut padding_shape = shape.to_vec();
let padding_last_dim_idx = padding_shape.len() - 1;
padding_shape[padding_last_dim_idx] = padding_length;
let total_padding_size = padding_shape.iter().product();
let padding_data = vec![self.padding_value.clone(); total_padding_size];
let padding_tensor = Tensor::from_vec(padding_data, &padding_shape)?;
tenflowers_core::ops::concat(&[tensor, &padding_tensor], shape.len() - 1)
} else {
Ok(tensor.clone())
}
}
}
impl<T> CollateFn<T> for PaddingCollate<T>
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
fn collate(&self, batch: Vec<(Tensor<T>, Tensor<T>)>) -> Result<(Tensor<T>, Tensor<T>)> {
if batch.is_empty() {
return Err(TensorError::invalid_argument(
"Cannot collate empty batch".to_string(),
));
}
let (features, labels): (Vec<_>, Vec<_>) = batch.into_iter().unzip();
let feature_refs: Vec<&Tensor<T>> = features.iter().collect();
let label_refs: Vec<&Tensor<T>> = labels.iter().collect();
let feature_target_length = self.determine_target_length(&feature_refs);
let label_target_length = self.determine_target_length(&label_refs);
let mut padded_features = Vec::new();
for feature in &features {
let padded = self.pad_sequence(feature, feature_target_length)?;
padded_features.push(padded);
}
let mut padded_labels = Vec::new();
for label in &labels {
let padded = self.pad_sequence(label, label_target_length)?;
padded_labels.push(padded);
}
let padded_feature_refs: Vec<&Tensor<T>> = padded_features.iter().collect();
let padded_label_refs: Vec<&Tensor<T>> = padded_labels.iter().collect();
let stacked_features = tenflowers_core::ops::stack(&padded_feature_refs, 0)?;
let stacked_labels = tenflowers_core::ops::stack(&padded_label_refs, 0)?;
Ok((stacked_features, stacked_labels))
}
}
pub struct BucketCollate<T> {
bucket_sizes: Vec<usize>,
padding_value: T,
}
impl<T> BucketCollate<T>
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
pub fn new(bucket_sizes: Vec<usize>, padding_value: T) -> Self {
Self {
bucket_sizes,
padding_value,
}
}
pub fn with_zero_padding(bucket_sizes: Vec<usize>) -> Self {
Self::new(bucket_sizes, T::zero())
}
fn find_bucket_size(&self, length: usize) -> usize {
self.bucket_sizes
.iter()
.find(|&&bucket_size| bucket_size >= length)
.copied()
.unwrap_or_else(|| self.bucket_sizes.last().copied().unwrap_or(length))
}
}
impl<T> CollateFn<T> for BucketCollate<T>
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
fn collate(&self, batch: Vec<(Tensor<T>, Tensor<T>)>) -> Result<(Tensor<T>, Tensor<T>)> {
if batch.is_empty() {
return Err(TensorError::invalid_argument(
"Cannot collate empty batch".to_string(),
));
}
let max_feature_length = batch
.iter()
.map(|(f, _)| f.shape().dims()[f.shape().rank() - 1])
.max()
.unwrap_or(0);
let max_label_length = batch
.iter()
.map(|(_, l)| l.shape().dims()[l.shape().rank() - 1])
.max()
.unwrap_or(0);
let feature_bucket_size = self.find_bucket_size(max_feature_length);
let label_bucket_size = self.find_bucket_size(max_label_length);
let feature_padding_collate = PaddingCollate::new(
self.padding_value.clone(),
PaddingStrategy::FixedLength(feature_bucket_size),
);
let _label_padding_collate = PaddingCollate::new(
self.padding_value.clone(),
PaddingStrategy::FixedLength(label_bucket_size),
);
feature_padding_collate.collate(batch)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tenflowers_core::{Device, Tensor};
#[test]
fn test_default_collate() {
let collate_fn = DefaultCollate;
let batch = vec![
(Tensor::<f32>::ones(&[2, 3]), Tensor::<f32>::zeros(&[1])),
(Tensor::<f32>::ones(&[2, 3]), Tensor::<f32>::ones(&[1])),
];
let result = collate_fn.collate(batch);
assert!(result.is_ok());
let (features, labels) = result.expect("test: operation should succeed");
assert_eq!(features.shape().dims(), &[2, 2, 3]); assert_eq!(labels.shape().dims(), &[2, 1]); }
#[test]
fn test_default_collate_empty_batch() {
let collate_fn = DefaultCollate;
let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
let result = collate_fn.collate(batch);
assert!(result.is_err());
}
#[test]
fn test_padding_strategy_max_length() {
let strategy = PaddingStrategy::MaxLength;
match strategy {
PaddingStrategy::MaxLength => {
}
_ => panic!("Expected MaxLength strategy"),
}
}
#[test]
fn test_padding_strategy_fixed_length() {
let strategy = PaddingStrategy::FixedLength(10);
match strategy {
PaddingStrategy::FixedLength(len) => assert_eq!(len, 10),
_ => panic!("Expected FixedLength strategy"),
}
}
#[test]
fn test_padding_strategy_bucket() {
let strategy = PaddingStrategy::Bucket(8);
match strategy {
PaddingStrategy::Bucket(size) => assert_eq!(size, 8),
_ => panic!("Expected Bucket strategy"),
}
}
#[test]
fn test_padding_collate_creation() {
let collate_fn = PaddingCollate::new(0.0_f32, PaddingStrategy::MaxLength);
assert_eq!(collate_fn.padding_value, 0.0);
assert!(collate_fn.truncate);
}
#[test]
fn test_padding_collate_with_zero_padding() {
let collate_fn = PaddingCollate::<f32>::with_zero_padding(PaddingStrategy::FixedLength(5));
assert_eq!(collate_fn.padding_value, 0.0);
}
#[test]
fn test_padding_collate_with_truncate() {
let collate_fn =
PaddingCollate::new(0.0_f32, PaddingStrategy::MaxLength).with_truncate(false);
assert!(!collate_fn.truncate);
}
#[test]
fn test_bucket_collate_creation() {
let bucket_sizes = vec![16, 32, 64, 128];
let collate_fn = BucketCollate::new(bucket_sizes.clone(), 0.0_f32);
assert_eq!(collate_fn.bucket_sizes, bucket_sizes);
assert_eq!(collate_fn.padding_value, 0.0);
}
#[test]
fn test_bucket_collate_with_zero_padding() {
let bucket_sizes = vec![8, 16, 32];
let collate_fn = BucketCollate::<f32>::with_zero_padding(bucket_sizes.clone());
assert_eq!(collate_fn.bucket_sizes, bucket_sizes);
assert_eq!(collate_fn.padding_value, 0.0);
}
#[test]
fn test_bucket_collate_find_bucket_size() {
let bucket_sizes = vec![16, 32, 64];
let collate_fn = BucketCollate::new(bucket_sizes, 0.0_f32);
assert_eq!(collate_fn.find_bucket_size(10), 16);
assert_eq!(collate_fn.find_bucket_size(20), 32);
assert_eq!(collate_fn.find_bucket_size(50), 64);
assert_eq!(collate_fn.find_bucket_size(100), 64); }
}