use axonml_tensor::Tensor;
pub trait Collate<T>: Send + Sync {
type Output;
fn collate(&self, batch: Vec<T>) -> Self::Output;
}
pub struct DefaultCollate;
impl DefaultCollate {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Default for DefaultCollate {
fn default() -> Self {
Self::new()
}
}
impl Collate<(Tensor<f32>, Tensor<f32>)> for DefaultCollate {
type Output = (Tensor<f32>, Tensor<f32>);
fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
if batch.is_empty() {
return (
Tensor::from_vec(vec![], &[0]).unwrap(),
Tensor::from_vec(vec![], &[0]).unwrap(),
);
}
let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
let stacked_x = stack_tensors(&inputs);
let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
let stacked_y = stack_tensors(&targets);
(stacked_x, stacked_y)
}
}
impl Collate<Tensor<f32>> for DefaultCollate {
type Output = Tensor<f32>;
fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
stack_tensors(&batch)
}
}
pub struct StackCollate {
dim: usize,
}
impl StackCollate {
#[must_use]
pub fn new() -> Self {
Self { dim: 0 }
}
#[must_use]
pub fn with_dim(dim: usize) -> Self {
Self { dim }
}
}
impl Default for StackCollate {
fn default() -> Self {
Self::new()
}
}
impl Collate<Tensor<f32>> for StackCollate {
type Output = Tensor<f32>;
fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
if batch.is_empty() {
return Tensor::from_vec(vec![], &[0]).unwrap();
}
if self.dim == 0 {
return stack_tensors(&batch);
}
let first_shape = batch[0].shape();
let ndim = first_shape.len();
let dim = self.dim.min(ndim);
let mut item_shape_expanded = Vec::with_capacity(ndim + 1);
item_shape_expanded.extend_from_slice(&first_shape[..dim]);
item_shape_expanded.push(1);
item_shape_expanded.extend_from_slice(&first_shape[dim..]);
let reshaped: Vec<Tensor<f32>> = batch
.iter()
.map(|t| Tensor::from_vec(t.to_vec(), &item_shape_expanded).unwrap())
.collect();
concat_tensors(&reshaped, dim)
}
}
impl Collate<(Tensor<f32>, Tensor<f32>)> for StackCollate {
type Output = (Tensor<f32>, Tensor<f32>);
fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
if batch.is_empty() {
return (
Tensor::from_vec(vec![], &[0]).unwrap(),
Tensor::from_vec(vec![], &[0]).unwrap(),
);
}
let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
(stack_tensors(&inputs), stack_tensors(&targets))
}
}
#[must_use]
pub fn stack_tensors(tensors: &[Tensor<f32>]) -> Tensor<f32> {
if tensors.is_empty() {
return Tensor::from_vec(vec![], &[0]).unwrap();
}
let first_shape = tensors[0].shape();
let batch_size = tensors.len();
let mut new_shape = vec![batch_size];
new_shape.extend_from_slice(first_shape);
let mut all_data = Vec::new();
for tensor in tensors {
all_data.extend(tensor.to_vec());
}
Tensor::from_vec(all_data, &new_shape).unwrap()
}
#[must_use]
pub fn concat_tensors(tensors: &[Tensor<f32>], dim: usize) -> Tensor<f32> {
if tensors.is_empty() {
return Tensor::from_vec(vec![], &[0]).unwrap();
}
if tensors.len() == 1 {
return tensors[0].clone();
}
let first_shape = tensors[0].shape();
let ndim = first_shape.len();
let mut new_shape = first_shape.to_vec();
let concat_size: usize = tensors.iter().map(|t| t.shape()[dim]).sum();
new_shape[dim] = concat_size;
if dim == 0 {
let mut all_data = Vec::new();
for tensor in tensors {
all_data.extend(tensor.to_vec());
}
return Tensor::from_vec(all_data, &new_shape).unwrap();
}
let outer_size: usize = first_shape[..dim].iter().product();
let inner_size: usize = if dim + 1 < ndim {
first_shape[dim + 1..].iter().product()
} else {
1
};
let total_elements: usize = new_shape.iter().product();
let mut result = Vec::with_capacity(total_elements);
let all_vecs: Vec<Vec<f32>> = tensors.iter().map(|t| t.to_vec()).collect();
for o in 0..outer_size {
for (t_idx, tensor_data) in all_vecs.iter().enumerate() {
let t_dim_size = tensors[t_idx].shape()[dim];
let t_inner_stride = t_dim_size * inner_size;
let src_offset = o * t_inner_stride;
result.extend_from_slice(&tensor_data[src_offset..src_offset + t_inner_stride]);
}
}
Tensor::from_vec(result, &new_shape).unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stack_tensors() {
let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
let t2 = Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap();
let t3 = Tensor::from_vec(vec![5.0, 6.0], &[2]).unwrap();
let stacked = stack_tensors(&[t1, t2, t3]);
assert_eq!(stacked.shape(), &[3, 2]);
assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_stack_tensors_2d() {
let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
let stacked = stack_tensors(&[t1, t2]);
assert_eq!(stacked.shape(), &[2, 2, 2]);
}
#[test]
fn test_default_collate() {
let collate = DefaultCollate::new();
let batch = vec![
(
Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
Tensor::from_vec(vec![0.0], &[1]).unwrap(),
),
(
Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
Tensor::from_vec(vec![1.0], &[1]).unwrap(),
),
];
let (x, y) = collate.collate(batch);
assert_eq!(x.shape(), &[2, 2]);
assert_eq!(y.shape(), &[2, 1]);
}
#[test]
fn test_stack_collate() {
let collate = StackCollate::new();
let batch = vec![
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(),
Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(),
];
let result = collate.collate(batch);
assert_eq!(result.shape(), &[2, 3]);
}
#[test]
fn test_empty_collate() {
let collate = DefaultCollate::new();
let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
let (x, y) = collate.collate(batch);
assert_eq!(x.shape(), &[0]);
assert_eq!(y.shape(), &[0]);
}
#[test]
fn test_concat_tensors_dim0() {
let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
let t2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
let concat = concat_tensors(&[t1, t2], 0);
assert_eq!(concat.shape(), &[5]);
assert_eq!(concat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_concat_tensors_dim1() {
let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
let concat = concat_tensors(&[t1, t2], 1);
assert_eq!(concat.shape(), &[2, 4]);
assert_eq!(
concat.to_vec(),
vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]
);
}
}