use super::Tensor;
use crate::error::{RusTorchError, RusTorchResult};
use ndarray::ArrayD;
use num_traits::Float;
use rayon::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum AllocationStrategy {
Pool,
Direct,
ZeroCopy,
SimdAligned,
}
impl<T: Float + Clone + Send + Sync + 'static> Tensor<T> {
pub fn with_strategy(shape: &[usize], strategy: AllocationStrategy) -> Self {
match strategy {
AllocationStrategy::Pool => Self::with_pool_allocation(shape),
AllocationStrategy::Direct => Self::zeros(shape),
AllocationStrategy::ZeroCopy => Self::zeros(shape), AllocationStrategy::SimdAligned => Self::with_simd_alignment(shape),
}
}
fn with_pool_allocation(shape: &[usize]) -> Self {
Self::zeros(shape)
}
fn with_simd_alignment(shape: &[usize]) -> Self {
Self::zeros(shape)
}
pub fn view_mut(&mut self) -> &mut ArrayD<T> {
&mut self.data
}
pub fn elementwise_inplace<F>(&mut self, other: &Tensor<T>, op: F) -> ParallelResult<()>
where
F: Fn(T, T) -> T + Send + Sync,
{
if self.data.shape() != other.data.shape() {
return Err(RusTorchError::parallel("Shape mismatch")); }
if let (Some(self_slice), Some(other_slice)) = (
self.data.as_slice_mut(),
other.data.as_slice()
) {
if self_slice.len() > 1000 {
self_slice.par_iter_mut()
.zip(other_slice.par_iter())
.for_each(|(a, &b)| {
*a = op(*a, b);
});
} else {
self_slice.iter_mut()
.zip(other_slice.iter())
.for_each(|(a, &b)| {
*a = op(*a, b);
});
}
}
Ok(())
}
pub fn batch_op_pooled<F>(&self, tensors: &[&Tensor<T>], op: F) -> ParallelResult<Vec<Tensor<T>>>
where
F: Fn(&Tensor<T>, &Tensor<T>) -> ParallelResult<Tensor<T>> + Send + Sync,
{
let _result: Vec<Tensor<T>> = Vec::with_capacity(tensors.len());
let parallel_results: Result<Vec<_>, _> = tensors.par_iter()
.map(|tensor| op(self, tensor))
.collect();
match parallel_results {
Ok(tensors) => Ok(tensors),
Err(e) => Err(e),
}
}
pub fn matmul_optimized(&self, other: &Tensor<T>) -> ParallelResult<Tensor<T>> {
let self_shape = self.data.shape();
let other_shape = other.data.shape();
if self_shape.len() != 2 || other_shape.len() != 2 {
return Err(RusTorchError::parallel(insufficient_dimensions(
2, self_shape.len(), "matrix multiplication"
).into());
}
if self_shape[1] != other_shape[0] {
return Err(RusTorchError::parallel(matmul_dimension_mismatch(
self_shape, other_shape
).into());
}
let result_shape = vec![self_shape[0], other_shape[1]];
let mut result = Self::with_strategy(&result_shape, AllocationStrategy::Pool);
let block_size = 64;
for i_block in (0..self_shape[0]).step_by(block_size) {
for j_block in (0..other_shape[1]).step_by(block_size) {
for k_block in (0..self_shape[1]).step_by(block_size) {
let i_end = (i_block + block_size).min(self_shape[0]);
let j_end = (j_block + block_size).min(other_shape[1]);
let k_end = (k_block + block_size).min(self_shape[1]);
for i in i_block..i_end {
for j in j_block..j_end {
let mut sum = T::zero();
for k in k_block..k_end {
sum = sum + self.data[[i, k]] * other.data[[k, j]];
}
result.data[[i, j]] = result.data[[i, j]] + sum;
}
}
}
}
}
Ok(result)
}
pub fn memory_info(&self) -> MemoryInfo {
let element_size = std::mem::size_of::<T>();
let total_elements = self.data.len();
let total_bytes = total_elements * element_size;
MemoryInfo {
element_count: total_elements,
element_size_bytes: element_size,
total_bytes,
shape: self.data.shape().to_vec(),
is_contiguous: self.data.is_standard_layout(),
}
}
pub fn return_to_pool(self) {
drop(self);
}
}
#[derive(Debug)]
pub struct MemoryInfo {
pub element_count: usize,
pub element_size_bytes: usize,
pub total_bytes: usize,
pub shape: Vec<usize>,
pub is_contiguous: bool,
}
impl std::fmt::Display for MemoryInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Tensor Memory Info:")?;
writeln!(f, " Shape: {:?}", self.shape)?;
writeln!(f, " Elements: {}", self.element_count)?;
writeln!(f, " Element size: {} bytes", self.element_size_bytes)?;
writeln!(f, " Total memory: {} bytes ({:.2} KB)",
self.total_bytes, self.total_bytes as f64 / 1024.0)?;
writeln!(f, " Contiguous: {}", self.is_contiguous)?;
Ok(())
}
}
impl Tensor<f32> {
pub fn add_simd_pooled(&self, other: &Tensor<f32>) -> ParallelResult<Tensor<f32>> {
if self.data.shape() != other.data.shape() {
return Err(RusTorchError::parallel("Shape mismatch")); }
let mut result = Self::with_strategy(self.data.shape(), AllocationStrategy::SimdAligned);
if let (Some(self_slice), Some(other_slice), Some(result_slice)) = (
self.data.as_slice(),
other.data.as_slice(),
result.data.as_slice_mut()
) {
#[cfg(not(target_arch = "wasm32"))]
{
crate::simd::ops::add_optimized(self_slice, other_slice, result_slice);
}
#[cfg(target_arch = "wasm32")]
{
for ((a_elem, b_elem), r_elem) in self_slice.iter().zip(other_slice.iter()).zip(result_slice.iter_mut()) {
*r_elem = *a_elem + *b_elem;
}
}
}
Ok(result)
}
pub fn matmul_simd_aligned(&self, other: &Tensor<f32>) -> ParallelResult<Tensor<f32>> {
let result = self.matmul_optimized(other)?;
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allocation_strategies() {
let shape = vec![10, 10];
let tensor_pool = Tensor::<f32>::with_strategy(&shape, AllocationStrategy::Pool);
let tensor_direct = Tensor::<f32>::with_strategy(&shape, AllocationStrategy::Direct);
let tensor_simd = Tensor::<f32>::with_strategy(&shape, AllocationStrategy::SimdAligned);
assert_eq!(tensor_pool.size(), shape);
assert_eq!(tensor_direct.size(), shape);
assert_eq!(tensor_simd.size(), shape);
}
#[test]
fn test_inplace_operations() {
let mut a = Tensor::<f32>::ones(&[3, 3]);
let b = Tensor::<f32>::ones(&[3, 3]);
let result = a.elementwise_inplace(&b, |x, y| x + y);
assert!(result.is_ok());
for i in 0..3 {
for j in 0..3 {
assert_eq!(a.as_array()[[i, j]], 2.0);
}
}
}
#[test]
fn test_memory_info() {
let tensor = Tensor::<f32>::zeros(&[5, 4]);
let info = tensor.memory_info();
assert_eq!(info.element_count, 20);
assert_eq!(info.element_size_bytes, 4); assert_eq!(info.total_bytes, 80);
assert_eq!(info.shape, vec![5, 4]);
}
#[test]
fn test_optimized_matmul() {
let a = Tensor::<f32>::ones(&[4, 3]);
let b = Tensor::<f32>::ones(&[3, 2]);
let result = a.matmul_optimized(&b);
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.size(), vec![4, 2]);
for i in 0..4 {
for j in 0..2 {
assert_eq!(result.as_array()[[i, j]], 3.0);
}
}
}
#[test]
fn test_batch_operations_pooled() {
let base = Tensor::<f32>::ones(&[2, 2]);
let tensor1 = Tensor::<f32>::ones(&[2, 2]);
let tensor2 = Tensor::<f32>::ones(&[2, 2]);
let tensors = vec![&tensor1, &tensor2];
let results = base.batch_op_pooled(&tensors, |a, _b| {
Ok(Tensor::with_strategy(a.data.shape(), AllocationStrategy::Pool))
});
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_simd_pooled_operations() {
let a = Tensor::<f32>::ones(&[100, 100]);
let b = Tensor::<f32>::ones(&[100, 100]);
let result = a.add_simd_pooled(&b);
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.size(), vec![100, 100]);
}
#[test]
fn test_shape_mismatch_error() {
let a = Tensor::<f32>::ones(&[2, 3]);
let b = Tensor::<f32>::ones(&[3, 2]);
let result = a.add_simd_pooled(&b);
assert!(result.is_err());
match result.unwrap_err() {
RusTorchError::parallel(ShapeMismatch { .. } => {},
_ => panic!("Expected ShapeMismatch error"),
}
}
}