use super::Tensor;
use crate::error::{RusTorchError, RusTorchResult};
type ParallelResult<T> = RusTorchResult<T>;
use num_traits::Float;
pub const AVX512_ALIGNMENT: usize = 64;
pub const AVX512_F32_LANES: usize = 16;
pub const AVX512_F64_LANES: usize = 8;
pub fn is_avx512_available() -> bool {
#[cfg(target_arch = "x86_64")]
{
std::arch::is_x86_feature_detected!("avx512f")
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}
pub struct Avx512F32Ops;
impl Avx512F32Ops {
pub unsafe fn add_vectorized(a: &[f32], b: &[f32], result: &mut [f32]) -> ParallelResult<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(RusTorchError::shape_mismatch(&[a.len()], &[b.len()]));
}
for i in 0..a.len() {
result[i] = a[i] + b[i];
}
Ok(())
}
pub unsafe fn mul_vectorized(a: &[f32], b: &[f32], result: &mut [f32]) -> ParallelResult<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(RusTorchError::shape_mismatch(&[a.len()], &[b.len()]));
}
for i in 0..a.len() {
result[i] = a[i] * b[i];
}
Ok(())
}
pub unsafe fn fmadd_vectorized(
a: &[f32],
b: &[f32],
c: &[f32],
result: &mut [f32],
) -> ParallelResult<()> {
if a.len() != b.len() || a.len() != c.len() || a.len() != result.len() {
return Err(RusTorchError::shape_mismatch(&[a.len()], &[b.len()]));
}
for i in 0..a.len() {
result[i] = a[i] * b[i] + c[i];
}
Ok(())
}
pub unsafe fn dot_product_vectorized(a: &[f32], b: &[f32]) -> ParallelResult<f32> {
if a.len() != b.len() {
return Err(RusTorchError::shape_mismatch(&[a.len()], &[b.len()]));
}
let mut sum = 0.0;
for i in 0..a.len() {
sum += a[i] * b[i];
}
Ok(sum)
}
pub unsafe fn matrix_multiply_vectorized<T: Float + Send + Sync>(
a: &[T],
b: &[T],
c: &mut [T],
rows_a: usize,
cols_a: usize,
cols_b: usize,
) -> ParallelResult<()> {
if a.len() != rows_a * cols_a {
return Err(RusTorchError::shape_mismatch(
&[rows_a * cols_a],
&[a.len()],
));
}
for i in 0..rows_a {
for j in 0..cols_b {
let mut sum = T::zero();
for k in 0..cols_a {
sum = sum + a[i * cols_a + k] * b[k * cols_b + j];
}
c[i * cols_b + j] = sum;
}
}
Ok(())
}
}
pub struct Avx512F64Ops;
impl Avx512F64Ops {
pub unsafe fn add_vectorized(a: &[f64], b: &[f64], result: &mut [f64]) -> ParallelResult<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(RusTorchError::shape_mismatch(&[a.len()], &[b.len()]));
}
for i in 0..a.len() {
result[i] = a[i] + b[i];
}
Ok(())
}
pub unsafe fn dot_product_vectorized(a: &[f64], b: &[f64]) -> ParallelResult<f64> {
if a.len() != b.len() {
return Err(RusTorchError::shape_mismatch(&[a.len()], &[b.len()]));
}
let mut sum = 0.0;
for i in 0..a.len() {
sum += a[i] * b[i];
}
Ok(sum)
}
}
pub trait Avx512TensorOps<T> {
fn avx512_add(&self, other: &Self) -> ParallelResult<Self>
where
Self: Sized;
fn avx512_mul(&self, other: &Self) -> ParallelResult<Self>
where
Self: Sized;
fn avx512_dot(&self, other: &Self) -> ParallelResult<T>;
}
impl Avx512TensorOps<f32> for Tensor<f32> {
fn avx512_add(&self, other: &Self) -> ParallelResult<Self> {
if self.data.shape() != other.data.shape() {
return Err(RusTorchError::shape_mismatch(
self.data.shape(),
other.data.shape(),
));
}
let mut result = Tensor::zeros(self.data.shape());
unsafe {
let self_slice = self
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Failed to get slice from tensor data"))?;
let other_slice = other
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Failed to get slice from tensor data"))?;
let result_slice = result.data.as_slice_mut().ok_or_else(|| {
RusTorchError::tensor_op("Failed to get mutable slice from tensor data")
})?;
Avx512F32Ops::add_vectorized(self_slice, other_slice, result_slice)?;
}
Ok(result)
}
fn avx512_mul(&self, other: &Self) -> ParallelResult<Self> {
if self.data.shape() != other.data.shape() {
return Err(RusTorchError::shape_mismatch(
self.data.shape(),
other.data.shape(),
));
}
let mut result = Tensor::zeros(self.data.shape());
unsafe {
let self_slice = self
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Failed to get slice from tensor data"))?;
let other_slice = other
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Failed to get slice from tensor data"))?;
let result_slice = result.data.as_slice_mut().ok_or_else(|| {
RusTorchError::tensor_op("Failed to get mutable slice from tensor data")
})?;
Avx512F32Ops::mul_vectorized(self_slice, other_slice, result_slice)?;
}
Ok(result)
}
fn avx512_dot(&self, other: &Self) -> ParallelResult<f32> {
if self.data.len() != other.data.len() {
return Err(RusTorchError::shape_mismatch(
&[self.data.len()],
&[other.data.len()],
));
}
unsafe {
let self_slice = self
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Failed to get slice from tensor data"))?;
let other_slice = other
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Failed to get slice from tensor data"))?;
Avx512F32Ops::dot_product_vectorized(self_slice, other_slice)
}
}
}
impl Avx512TensorOps<f64> for Tensor<f64> {
fn avx512_add(&self, other: &Self) -> ParallelResult<Self> {
if self.data.shape() != other.data.shape() {
return Err(RusTorchError::shape_mismatch(
self.data.shape(),
other.data.shape(),
));
}
let mut result = Tensor::zeros(self.data.shape());
unsafe {
let self_slice = self
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Failed to get slice from tensor data"))?;
let other_slice = other
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Failed to get slice from tensor data"))?;
let result_slice = result.data.as_slice_mut().ok_or_else(|| {
RusTorchError::tensor_op("Failed to get mutable slice from tensor data")
})?;
Avx512F64Ops::add_vectorized(self_slice, other_slice, result_slice)?;
}
Ok(result)
}
fn avx512_mul(&self, other: &Self) -> ParallelResult<Self> {
self.avx512_add(other) }
fn avx512_dot(&self, other: &Self) -> ParallelResult<f64> {
if self.data.len() != other.data.len() {
return Err(RusTorchError::shape_mismatch(
&[self.data.len()],
&[other.data.len()],
));
}
unsafe {
let self_slice = self
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Failed to get slice from tensor data"))?;
let other_slice = other
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Failed to get slice from tensor data"))?;
Avx512F64Ops::dot_product_vectorized(self_slice, other_slice)
}
}
}