#[allow(unused_imports)]
use super::stacking::TensorStacker;
use crate::collate::Collate;
use torsh_core::{
dtype::TensorElement,
error::{Result, TorshError},
};
use torsh_tensor::Tensor;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[cfg(feature = "std")]
use scirs2_core::parallel_ops::*;
pub fn stack_tensors<T: TensorElement + Copy>(
tensors: &[Tensor<T>],
dim: usize,
) -> Result<Tensor<T>> {
if tensors.is_empty() {
return Err(TorshError::InvalidArgument(
"Cannot stack empty tensor list".to_string(),
));
}
let first_shape = tensors[0].shape();
for tensor in &tensors[1..] {
if tensor.shape() != first_shape {
return Err(TorshError::ShapeMismatch {
expected: first_shape.dims().to_vec(),
got: tensor.shape().dims().to_vec(),
});
}
}
let original_dims = first_shape.dims();
let mut new_dims = Vec::with_capacity(original_dims.len() + 1);
if dim == 0 {
new_dims.push(tensors.len());
new_dims.extend_from_slice(original_dims);
} else {
new_dims.extend_from_slice(&original_dims[..dim.min(original_dims.len())]);
new_dims.push(tensors.len());
if dim < original_dims.len() {
new_dims.extend_from_slice(&original_dims[dim..]);
}
}
let tensor_size = tensors[0].numel();
let total_elements = new_dims.iter().product::<usize>();
let mut new_data = Vec::with_capacity(total_elements);
unsafe { new_data.set_len(total_elements) };
#[cfg(feature = "std")]
{
if tensors.len() > 4 && tensor_size > 1000 {
let parallel_data: std::result::Result<Vec<Vec<T>>, TorshError> =
tensors.par_iter().map(|tensor| tensor.to_vec()).collect();
let parallel_data = parallel_data?;
for (i, data) in parallel_data.into_iter().enumerate() {
let start_idx = i * tensor_size;
let end_idx = start_idx + tensor_size;
new_data[start_idx..end_idx].copy_from_slice(&data);
}
} else {
for (i, tensor) in tensors.iter().enumerate() {
let data = tensor.to_vec()?;
let start_idx = i * tensor_size;
let end_idx = start_idx + tensor_size;
new_data[start_idx..end_idx].copy_from_slice(&data);
}
}
}
#[cfg(not(feature = "std"))]
{
for (i, tensor) in tensors.iter().enumerate() {
let data = tensor.to_vec()?;
let start_idx = i * tensor_size;
let end_idx = start_idx + tensor_size;
new_data[start_idx..end_idx].copy_from_slice(&data);
}
}
let result = torsh_tensor::Tensor::from_data(new_data, new_dims, tensors[0].device())?;
Ok(result)
}
#[cfg(feature = "std")]
pub fn stack_tensors_fast<T: TensorElement + Copy>(
tensors: &[Tensor<T>],
dim: usize,
) -> Result<Tensor<T>> {
if tensors.is_empty() {
return Err(TorshError::InvalidArgument(
"Cannot stack empty tensor list".to_string(),
));
}
#[cfg(feature = "mmap-support")]
{
if tensors.len() > 100 {
return stack_tensors_mmap(tensors, dim);
}
}
stack_tensors(tensors, dim)
}
#[cfg(all(feature = "std", feature = "mmap-support"))]
pub fn stack_tensors_mmap<T: TensorElement + Copy>(
tensors: &[Tensor<T>],
dim: usize,
) -> Result<Tensor<T>> {
let first_shape = tensors[0].shape();
for tensor in &tensors[1..] {
if tensor.shape() != first_shape {
return Err(TorshError::ShapeMismatch {
expected: first_shape.dims().to_vec(),
got: tensor.shape().dims().to_vec(),
});
}
}
let original_dims = first_shape.dims();
let mut new_dims = Vec::with_capacity(original_dims.len() + 1);
if dim == 0 {
new_dims.push(tensors.len());
new_dims.extend_from_slice(original_dims);
} else {
new_dims.extend_from_slice(&original_dims[..dim.min(original_dims.len())]);
new_dims.push(tensors.len());
if dim < original_dims.len() {
new_dims.extend_from_slice(&original_dims[dim..]);
}
}
let tensor_size = tensors[0].numel();
let total_size = tensor_size * tensors.len() * std::mem::size_of::<T>();
let mut temp_file =
tempfile::NamedTempFile::new().map_err(|e| TorshError::IoError(e.to_string()))?;
temp_file
.as_file_mut()
.set_len(total_size as u64)
.map_err(|e| TorshError::IoError(e.to_string()))?;
let mmap = unsafe {
memmap2::MmapOptions::new()
.map_mut(temp_file.as_file())
.map_err(|e| TorshError::IoError(e.to_string()))?
};
let all_data: std::result::Result<Vec<Vec<T>>, TorshError> =
tensors.par_iter().map(|tensor| tensor.to_vec()).collect();
let all_data = all_data?;
let mmap_ptr = mmap.as_ptr() as *mut T;
for (i, data) in all_data.iter().enumerate() {
unsafe {
let dst = mmap_ptr.add(i * tensor_size);
std::ptr::copy_nonoverlapping(data.as_ptr(), dst, tensor_size);
}
}
unsafe {
let data_slice =
std::slice::from_raw_parts(mmap_ptr as *const T, tensor_size * tensors.len());
let data_vec = data_slice.to_vec();
let result = torsh_tensor::Tensor::from_data(data_vec, new_dims, tensors[0].device())?;
Ok(result)
}
}
#[cfg(feature = "std")]
#[derive(Debug, Clone, Copy)]
pub struct OptimizedCollate;
#[cfg(feature = "std")]
impl<T: TensorElement + Copy> Collate<Tensor<T>> for OptimizedCollate {
type Output = Tensor<T>;
fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
if batch.is_empty() {
return Err(TorshError::InvalidArgument(
"Cannot collate empty batch".to_string(),
));
}
stack_tensors_fast(&batch, 0)
}
}
#[cfg(feature = "std")]
impl<T: TensorElement + Copy> Collate<Vec<Tensor<T>>> for OptimizedCollate {
type Output = Vec<Tensor<T>>;
fn collate(&self, batch: Vec<Vec<Tensor<T>>>) -> Result<Self::Output> {
if batch.is_empty() {
return Err(TorshError::InvalidArgument(
"Cannot collate empty batch".to_string(),
));
}
let num_tensors = batch[0].len();
let mut collated = Vec::with_capacity(num_tensors);
(0..num_tensors)
.into_par_iter()
.map(|i| {
let tensors: Vec<Tensor<T>> =
batch.iter().map(|sample| sample[i].clone()).collect();
stack_tensors_fast(&tensors, 0)
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.for_each(|tensor| collated.push(tensor));
Ok(collated)
}
}
#[cfg(feature = "std")]
pub fn optimized_collate_fn<T>() -> OptimizedCollate {
OptimizedCollate
}
#[cfg(not(feature = "std"))]
#[derive(Debug, Clone, Copy)]
pub struct OptimizedCollate;
#[cfg(not(feature = "std"))]
impl<T: TensorElement + Copy> Collate<Tensor<T>> for OptimizedCollate {
type Output = Tensor<T>;
fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
TensorStacker::new().stack(&batch, 0)
}
}
#[cfg(not(feature = "std"))]
impl<T: TensorElement + Copy> Collate<Vec<Tensor<T>>> for OptimizedCollate {
type Output = Vec<Tensor<T>>;
fn collate(&self, batch: Vec<Vec<Tensor<T>>>) -> Result<Self::Output> {
if batch.is_empty() {
return Err(TorshError::InvalidArgument(
"Cannot collate empty batch".to_string(),
));
}
let num_tensors = batch[0].len();
let mut collated = Vec::with_capacity(num_tensors);
let stacker = TensorStacker::new();
for i in 0..num_tensors {
let tensors: Vec<Tensor<T>> = batch.iter().map(|sample| sample[i].clone()).collect();
collated.push(stacker.stack(&tensors, 0)?);
}
Ok(collated)
}
}
#[cfg(not(feature = "std"))]
pub fn optimized_collate_fn<T>() -> OptimizedCollate {
OptimizedCollate
}