use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
type ParallelResult<T> = RusTorchResult<T>;
use num_traits::Float;
use std::alloc::{alloc_zeroed, dealloc, Layout};
use std::ptr::NonNull;
pub const SIMD_ALIGNMENT: usize = 32;
pub struct SimdAllocator;
impl SimdAllocator {
pub fn alloc_f32(len: usize) -> RusTorchResult<NonNull<f32>> {
let layout = Layout::from_size_align(len * std::mem::size_of::<f32>(), SIMD_ALIGNMENT)
.map_err(|e| RusTorchError::memory_alloc(len * std::mem::size_of::<f32>(), "cpu"))?;
unsafe {
let ptr = alloc_zeroed(layout);
if ptr.is_null() {
Err(RusTorchError::memory_alloc(
len * std::mem::size_of::<f32>(),
"cpu",
))
} else {
Ok(NonNull::new_unchecked(ptr as *mut f32))
}
}
}
pub unsafe fn dealloc_f32(ptr: NonNull<f32>, len: usize) {
let layout =
Layout::from_size_align_unchecked(len * std::mem::size_of::<f32>(), SIMD_ALIGNMENT);
dealloc(ptr.as_ptr() as *mut u8, layout);
}
pub fn is_aligned<T>(ptr: *const T) -> bool {
(ptr as usize) % SIMD_ALIGNMENT == 0
}
}
pub struct SimdTensor<T: Float> {
data: NonNull<T>,
shape: Vec<usize>,
len: usize,
}
unsafe impl<T: Float + Send> Send for SimdTensor<T> {}
unsafe impl<T: Float + Sync> Sync for SimdTensor<T> {}
impl<T: Float + Clone + 'static> SimdTensor<T> {
pub fn zeros(shape: &[usize]) -> RusTorchResult<Self>
where
T: 'static,
{
let len: usize = shape.iter().product();
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let ptr = SimdAllocator::alloc_f32(len)?;
Ok(SimdTensor {
data: unsafe { std::mem::transmute(ptr) },
shape: shape.to_vec(),
len,
})
} else {
Err("Only f32 supported for SIMD alignment".into())
}
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn len(&self) -> usize {
self.len
}
pub fn as_ptr(&self) -> *const T {
self.data.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut T {
self.data.as_ptr()
}
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.len) }
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.data.as_ptr(), self.len) }
}
pub fn to_tensor(&self) -> Tensor<T> {
let data = self.as_slice().to_vec();
Tensor::from_vec(data, self.shape.clone())
}
pub fn is_simd_aligned(&self) -> bool {
SimdAllocator::is_aligned(self.data.as_ptr())
}
}
impl<T: Float> Drop for SimdTensor<T> {
fn drop(&mut self) {
unsafe {
SimdAllocator::dealloc_f32(std::mem::transmute(self.data), self.len);
}
}
}
impl SimdTensor<f32> {
pub fn add_simd(&self, other: &SimdTensor<f32>) -> ParallelResult<SimdTensor<f32>> {
if self.shape != other.shape {
return Err(RusTorchError::parallel("Shape mismatch"));
}
let mut result = SimdTensor::zeros(&self.shape)
.map_err(|_| RusTorchError::parallel("SIMD allocation failed"))?;
let self_slice = self.as_slice();
let other_slice = other.as_slice();
let result_slice = result.as_mut_slice();
#[cfg(not(target_arch = "wasm32"))]
{
crate::simd::ops::add_optimized(self_slice, other_slice, result_slice);
}
#[cfg(target_arch = "wasm32")]
{
for ((a_elem, b_elem), r_elem) in self_slice
.iter()
.zip(other_slice.iter())
.zip(result_slice.iter_mut())
{
*r_elem = *a_elem + *b_elem;
}
}
Ok(result)
}
pub fn mul_simd(&self, other: &SimdTensor<f32>) -> ParallelResult<SimdTensor<f32>> {
if self.shape != other.shape {
return Err(RusTorchError::parallel("Shape mismatch"));
}
let mut result = SimdTensor::zeros(&self.shape)
.map_err(|_| RusTorchError::parallel("SIMD allocation failed"))?;
let self_slice = self.as_slice();
let other_slice = other.as_slice();
let result_slice = result.as_mut_slice();
#[cfg(not(target_arch = "wasm32"))]
{
crate::simd::ops::mul_optimized(self_slice, other_slice, result_slice);
}
#[cfg(target_arch = "wasm32")]
{
for ((a_elem, b_elem), r_elem) in self_slice
.iter()
.zip(other_slice.iter())
.zip(result_slice.iter_mut())
{
*r_elem = *a_elem * *b_elem;
}
}
Ok(result)
}
pub fn mul_scalar_simd(&self, scalar: f32) -> SimdTensor<f32> {
let mut result = SimdTensor::zeros(&self.shape).expect("SIMD allocation should succeed");
let self_slice = self.as_slice();
let result_slice = result.as_mut_slice();
#[cfg(not(target_arch = "wasm32"))]
{
crate::simd::ops::mul_scalar_optimized(self_slice, scalar, result_slice);
}
#[cfg(target_arch = "wasm32")]
{
for (a_elem, r_elem) in self_slice.iter().zip(result_slice.iter_mut()) {
*r_elem = *a_elem * scalar;
}
}
result
}
pub fn matmul_simd(&self, other: &SimdTensor<f32>) -> ParallelResult<SimdTensor<f32>> {
if self.shape.len() != 2 || other.shape.len() != 2 {
return Err(RusTorchError::parallel("Insufficient dimensions"));
}
if self.shape[1] != other.shape[0] {
return Err(RusTorchError::parallel("Matrix dimension mismatch"));
}
let result_shape = vec![self.shape[0], other.shape[1]];
let mut result = SimdTensor::zeros(&result_shape)
.map_err(|_| RusTorchError::parallel("SIMD allocation failed"))?;
let self_rows = self.shape[0];
let self_cols = self.shape[1];
let other_rows = other.shape[0];
let other_cols = other.shape[1];
let self_slice = self.as_slice();
let other_slice = other.as_slice();
let result_slice = result.as_mut_slice();
#[cfg(not(target_arch = "wasm32"))]
{
crate::simd::vectorized::matmul_f32_simd(
self_slice,
self_rows,
self_cols,
other_slice,
other_rows,
other_cols,
result_slice,
);
}
#[cfg(target_arch = "wasm32")]
{
for i in 0..self_rows {
for j in 0..other_cols {
let mut sum = 0.0f32;
for k in 0..self_cols {
sum += self_slice[i * self_cols + k] * other_slice[k * other_cols + j];
}
result_slice[i * other_cols + j] = sum;
}
}
}
Ok(result)
}
pub fn add_assign_simd(&mut self, other: &SimdTensor<f32>) -> ParallelResult<()> {
if self.shape != other.shape {
return Err(RusTorchError::parallel("Shape mismatch"));
}
let self_slice = self.as_mut_slice();
let other_slice = other.as_slice();
let temp_result: Vec<f32> = self_slice
.iter()
.zip(other_slice.iter())
.map(|(&a, &b)| a + b)
.collect();
self_slice.copy_from_slice(&temp_result);
Ok(())
}
pub fn fill_simd(&mut self, value: f32) {
let slice = self.as_mut_slice();
if slice.len() >= 8 {
let temp = vec![value; slice.len()];
slice.copy_from_slice(&temp);
} else {
slice.fill(value);
}
}
}
impl<T: Float + Clone + 'static> Tensor<T> {
pub fn to_simd_aligned(&self) -> RusTorchResult<SimdTensor<T>> {
let mut simd_tensor = SimdTensor::zeros(self.data.shape())?;
if let (Some(self_slice), Some(simd_slice)) =
(self.data.as_slice(), Some(simd_tensor.as_mut_slice()))
{
simd_slice.copy_from_slice(self_slice);
}
Ok(simd_tensor)
}
pub fn zeros_simd_aligned(shape: &[usize]) -> Self
where
T: 'static,
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
if let Ok(simd_tensor) = SimdTensor::<T>::zeros(shape) {
return simd_tensor.to_tensor();
}
}
Self::zeros(shape)
}
}
pub struct SimdMemoryPool {
pools: Vec<Vec<SimdTensor<f32>>>,
max_pool_size: usize,
}
impl SimdMemoryPool {
pub fn new(max_pool_size: usize) -> Self {
Self {
pools: Vec::new(),
max_pool_size,
}
}
fn get_pool_index(&self, total_elements: usize) -> usize {
if total_elements <= 64 {
0
} else if total_elements <= 256 {
1
} else if total_elements <= 1024 {
2
} else if total_elements <= 4096 {
3
} else {
4
}
}
pub fn allocate(&mut self, shape: &[usize]) -> RusTorchResult<SimdTensor<f32>> {
let total_elements: usize = shape.iter().product();
let pool_index = self.get_pool_index(total_elements);
while self.pools.len() <= pool_index {
self.pools.push(Vec::new());
}
if let Some(mut tensor) = self.pools[pool_index].pop() {
if tensor.shape() == shape {
tensor.fill_simd(0.0);
return Ok(tensor);
} else {
self.pools[pool_index].push(tensor);
}
}
SimdTensor::zeros(shape)
}
pub fn deallocate(&mut self, tensor: SimdTensor<f32>) {
let total_elements = tensor.len();
let pool_index = self.get_pool_index(total_elements);
while self.pools.len() <= pool_index {
self.pools.push(Vec::new());
}
if self.pools[pool_index].len() < self.max_pool_size {
self.pools[pool_index].push(tensor);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_tensor_creation() {
let tensor = SimdTensor::<f32>::zeros(&[4, 4]);
assert!(tensor.is_ok());
let tensor = tensor.unwrap();
assert_eq!(tensor.shape(), &[4, 4]);
assert_eq!(tensor.len(), 16);
assert!(tensor.is_simd_aligned());
}
#[test]
fn test_simd_operations() {
let mut a = SimdTensor::<f32>::zeros(&[4, 4]).unwrap();
let mut b = SimdTensor::<f32>::zeros(&[4, 4]).unwrap();
a.fill_simd(2.0);
b.fill_simd(3.0);
let result = a.add_simd(&b);
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.as_slice()[0], 5.0);
}
#[test]
fn test_simd_matrix_multiplication() {
let mut a = SimdTensor::<f32>::zeros(&[2, 3]).unwrap();
let mut b = SimdTensor::<f32>::zeros(&[3, 2]).unwrap();
a.fill_simd(1.0);
b.fill_simd(2.0);
let result = a.matmul_simd(&b);
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.shape(), &[2, 2]);
println!("Result values: {:?}", result.as_slice());
println!("Expected: 6.0 (1*2 + 1*2 + 1*2)");
assert_eq!(result.as_slice()[0], 6.0); assert_eq!(result.as_slice()[1], 6.0);
assert_eq!(result.as_slice()[2], 6.0);
assert_eq!(result.as_slice()[3], 6.0);
}
#[test]
fn test_inplace_operations() {
let mut a = SimdTensor::<f32>::zeros(&[3, 3]).unwrap();
let mut b = SimdTensor::<f32>::zeros(&[3, 3]).unwrap();
a.fill_simd(1.0);
b.fill_simd(2.0);
let result = a.add_assign_simd(&b);
assert!(result.is_ok());
assert_eq!(a.as_slice()[0], 3.0);
}
#[test]
fn test_tensor_conversion() {
let regular_tensor = Tensor::<f32>::ones(&[2, 2]);
let simd_tensor = regular_tensor.to_simd_aligned();
assert!(simd_tensor.is_ok());
let simd_tensor = simd_tensor.unwrap();
assert!(simd_tensor.is_simd_aligned());
let back_to_regular = simd_tensor.to_tensor();
assert_eq!(back_to_regular.size(), vec![2, 2]);
}
#[test]
fn test_simd_memory_pool() {
let mut pool = SimdMemoryPool::new(5);
let tensor1 = pool.allocate(&[4, 4]);
assert!(tensor1.is_ok());
let tensor1 = tensor1.unwrap();
pool.deallocate(tensor1);
let tensor2 = pool.allocate(&[4, 4]);
assert!(tensor2.is_ok());
}
#[test]
fn test_alignment_check() {
let tensor = SimdTensor::<f32>::zeros(&[8, 8]).unwrap();
assert!(SimdAllocator::is_aligned(tensor.as_ptr()));
}
}