use super::collate::CollateFn;
use tenflowers_core::{Result, Tensor};
#[derive(Debug, Clone)]
#[allow(clippy::large_enum_variant)]
pub enum BatchResult<T> {
Samples(Vec<(Tensor<T>, Tensor<T>)>),
Collated(Tensor<T>, Tensor<T>),
}
impl<T> BatchResult<T>
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
pub fn len(&self) -> usize {
match self {
BatchResult::Samples(samples) => samples.len(),
BatchResult::Collated(features, _) => features.shape().dims()[0],
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn into_samples(self) -> Result<Vec<(Tensor<T>, Tensor<T>)>>
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
match self {
BatchResult::Samples(samples) => Ok(samples),
BatchResult::Collated(features, labels) => {
let batch_size = features.shape().dims()[0];
let mut samples = Vec::with_capacity(batch_size);
for i in 0..batch_size {
#[allow(clippy::single_range_in_vec_init)]
let feature_slice = tenflowers_core::ops::slice(&features, &[i..i + 1])?;
#[allow(clippy::single_range_in_vec_init)]
let label_slice = tenflowers_core::ops::slice(&labels, &[i..i + 1])?;
let squeezed_feature = if feature_slice.shape().rank() > 1 {
let new_shape: Vec<usize> = feature_slice.shape().dims()[1..].to_vec();
tenflowers_core::ops::reshape(&feature_slice, &new_shape)?
} else {
feature_slice
};
let squeezed_label = if label_slice.shape().rank() > 1 {
let new_shape: Vec<usize> = label_slice.shape().dims()[1..].to_vec();
tenflowers_core::ops::reshape(&label_slice, &new_shape)?
} else {
label_slice
};
samples.push((squeezed_feature, squeezed_label));
}
Ok(samples)
}
}
}
pub fn into_collated(self) -> Result<(Tensor<T>, Tensor<T>)>
where
T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
{
match self {
BatchResult::Samples(samples) => {
let collate_fn = super::collate::DefaultCollate;
collate_fn.collate(samples)
}
BatchResult::Collated(features, labels) => Ok((features, labels)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tenflowers_core::{Device, Tensor};
#[test]
fn test_batch_result_len() {
let samples = vec![
(Tensor::<f32>::ones(&[2, 3]), Tensor::<f32>::zeros(&[1])),
(Tensor::<f32>::ones(&[2, 3]), Tensor::<f32>::zeros(&[1])),
];
let batch_samples = BatchResult::Samples(samples);
assert_eq!(batch_samples.len(), 2);
let features = Tensor::<f32>::ones(&[3, 2, 3]);
let labels = Tensor::<f32>::zeros(&[3, 1]);
let batch_collated = BatchResult::Collated(features, labels);
assert_eq!(batch_collated.len(), 3);
}
#[test]
fn test_batch_result_is_empty() {
let empty_samples = BatchResult::<f32>::Samples(vec![]);
assert!(empty_samples.is_empty());
let non_empty_samples = vec![(Tensor::<f32>::ones(&[2]), Tensor::<f32>::zeros(&[1]))];
let batch = BatchResult::Samples(non_empty_samples);
assert!(!batch.is_empty());
}
#[test]
fn test_batch_result_samples_variant() {
let samples = vec![
(Tensor::<f32>::ones(&[2]), Tensor::<f32>::zeros(&[1])),
(Tensor::<f32>::ones(&[2]), Tensor::<f32>::ones(&[1])),
];
let batch = BatchResult::Samples(samples.clone());
match batch {
BatchResult::Samples(s) => assert_eq!(s.len(), 2),
_ => panic!("Expected Samples variant"),
}
}
#[test]
fn test_batch_result_collated_variant() {
let features: Tensor<f32> = Tensor::ones(&[2, 3]);
let labels = Tensor::zeros(&[2, 1]);
let batch = BatchResult::Collated(features.clone(), labels.clone());
match batch {
BatchResult::Collated(f, l) => {
assert_eq!(f.shape().dims(), features.shape().dims());
assert_eq!(l.shape().dims(), labels.shape().dims());
}
_ => panic!("Expected Collated variant"),
}
}
}