use super::cpu_features::CpuFeatures;
use crate::{Result, Tensor, TrustformersError};
use scirs2_core::ndarray::Array2;
#[cfg(not(target_os = "macos"))]
use scirs2_core::simd_ops::SimdUnifiedOps;
#[cfg(target_os = "macos")]
#[inline]
fn blas_sgemm(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
use oxiblas_blas::level3::gemm;
use oxiblas_matrix::{MatMut, MatRef};
let a_t = MatRef::new(a.as_ptr(), k, m, k);
let b_t = MatRef::new(b.as_ptr(), n, k, n);
let c_t = MatMut::new(c.as_mut_ptr(), n, m, n);
gemm(1.0, b_t, a_t, 0.0, c_t);
}
#[cfg(not(target_os = "macos"))]
#[inline]
fn blas_sgemm(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
let a_arr = Array2::from_shape_vec((m, k), a.to_vec())
.expect("BLAS input A shape must match dimensions");
let b_arr = Array2::from_shape_vec((k, n), b.to_vec())
.expect("BLAS input B shape must match dimensions");
let mut c_arr = Array2::from_shape_vec((m, n), c.to_vec())
.expect("BLAS output C shape must match dimensions");
f32::simd_gemm(1.0, &a_arr.view(), &b_arr.view(), 0.0, &mut c_arr);
c.copy_from_slice(c_arr.as_slice().expect("output array must have contiguous layout"));
}
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::{vdupq_n_f32, vfmaq_f32, vld1q_f32, vst1q_f32};
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::{
__m256, _mm256_add_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_mul_ps, _mm256_set1_ps,
_mm256_setzero_ps, _mm256_storeu_ps, _mm512_fmadd_ps, _mm512_loadu_ps, _mm512_set1_ps,
_mm512_setzero_ps, _mm512_storeu_ps,
};
pub struct SIMDMatrixOps {
cpu_features: CpuFeatures,
}
impl Default for SIMDMatrixOps {
fn default() -> Self {
Self::new()
}
}
impl SIMDMatrixOps {
pub fn new() -> Self {
Self {
cpu_features: CpuFeatures::detect(),
}
}
pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(TrustformersError::tensor_op_error(
"Only 2D matrix multiplication supported",
"matmul",
));
}
let m = a_shape[0];
let k = a_shape[1];
let n = b_shape[1];
if b_shape[0] != k {
return Err(TrustformersError::tensor_op_error(
"Matrix dimensions don't match for multiplication",
"matmul",
));
}
self.matmul_blas(a, b, m, k, n)
}
fn matmul_blas(&self, a: &Tensor, b: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
let a_data = a.data()?;
let b_data = b.data()?;
const MIN_SIZE_FOR_BLAS: usize = 32;
let c_data = if m < MIN_SIZE_FOR_BLAS || n < MIN_SIZE_FOR_BLAS || k < MIN_SIZE_FOR_BLAS {
let a_arr = Array2::from_shape_vec((m, k), a_data.to_vec())
.map_err(|e| TrustformersError::shape_error(e.to_string()))?;
let b_arr = Array2::from_shape_vec((k, n), b_data.to_vec())
.map_err(|e| TrustformersError::shape_error(e.to_string()))?;
a_arr.dot(&b_arr).into_raw_vec_and_offset().0
} else {
let mut result_vec = vec![0.0f32; m * n];
blas_sgemm(&a_data, &b_data, &mut result_vec, m, k, n);
result_vec
};
Tensor::from_vec(c_data, &[m, n])
}
#[allow(dead_code)]
pub fn matmul_legacy(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(TrustformersError::tensor_op_error(
"Only 2D matrix multiplication supported",
"matmul",
));
}
let m = a_shape[0];
let k = a_shape[1];
let n = b_shape[1];
if b_shape[0] != k {
return Err(TrustformersError::tensor_op_error(
"Matrix dimensions don't match for multiplication",
"matmul",
));
}
let simd_width = self.cpu_features.best_simd_width();
let can_use_simd = simd_width > 1 && n.is_multiple_of(simd_width) && n >= 64;
if can_use_simd {
match self.cpu_features.best_instruction_set() {
"avx512" => self.matmul_avx512(a, b, m, k, n),
"avx2_fma" | "avx2" => self.matmul_avx2(a, b, m, k, n),
"neon" => self.matmul_neon(a, b, m, k, n),
"rvv" => self.matmul_rvv(a, b, m, k, n),
_ => self.matmul_standard(a, b, m, k, n),
}
} else {
self.matmul_standard(a, b, m, k, n)
}
}
fn matmul_standard(
&self,
a: &Tensor,
b: &Tensor,
m: usize,
k: usize,
n: usize,
) -> Result<Tensor> {
let a_data = a.data()?;
let b_data = b.data()?;
let mut c_data = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for l in 0..k {
sum += a_data[i * k + l] * b_data[l * n + j];
}
c_data[i * n + j] = sum;
}
}
Tensor::from_vec(c_data, &[m, n])
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn matmul_avx2(&self, a: &Tensor, b: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
let a_data = a.data()?;
let b_data = b.data()?;
let mut c_data = vec![0.0f32; m * n];
unsafe {
self.matmul_avx2_inner(&a_data, &b_data, &mut c_data, m, k, n);
}
Tensor::from_vec(c_data, &[m, n])
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2,fma")]
unsafe fn matmul_avx2_inner(
&self,
a_data: &[f32],
b_data: &[f32],
c_data: &mut [f32],
m: usize,
k: usize,
n: usize,
) {
for i in 0..m {
for j in (0..n).step_by(8) {
let mut sum = _mm256_setzero_ps();
for l in 0..k {
let a_val = _mm256_set1_ps(a_data[i * k + l]);
let b_vec = _mm256_loadu_ps(&b_data[l * n + j]);
sum = _mm256_fmadd_ps(a_val, b_vec, sum);
}
_mm256_storeu_ps(&mut c_data[i * n + j], sum);
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn matmul_avx512(
&self,
a: &Tensor,
b: &Tensor,
m: usize,
k: usize,
n: usize,
) -> Result<Tensor> {
let a_data = a.data()?;
let b_data = b.data()?;
let mut c_data = vec![0.0f32; m * n];
unsafe {
self.matmul_avx512_inner(&a_data, &b_data, &mut c_data, m, k, n);
}
Tensor::from_vec(c_data, &[m, n])
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx512f,avx512vl")]
unsafe fn matmul_avx512_inner(
&self,
a_data: &[f32],
b_data: &[f32],
c_data: &mut [f32],
m: usize,
k: usize,
n: usize,
) {
for i in 0..m {
for j in (0..n).step_by(16) {
let mut sum = _mm512_setzero_ps();
for l in 0..k {
let a_val = _mm512_set1_ps(a_data[i * k + l]);
let b_vec = _mm512_loadu_ps(&b_data[l * n + j]);
sum = _mm512_fmadd_ps(a_val, b_vec, sum);
}
_mm512_storeu_ps(&mut c_data[i * n + j], sum);
}
}
}
#[cfg(target_arch = "aarch64")]
fn matmul_neon(&self, a: &Tensor, b: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
let a_data = a.data()?;
let b_data = b.data()?;
let mut c_data = vec![0.0f32; m * n];
unsafe {
self.matmul_neon_inner(&a_data, &b_data, &mut c_data, m, k, n);
}
Tensor::from_vec(c_data, &[m, n])
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn matmul_neon_inner(
&self,
a_data: &[f32],
b_data: &[f32],
c_data: &mut [f32],
m: usize,
k: usize,
n: usize,
) {
for i in 0..m {
for j in (0..n).step_by(4) {
let mut sum = vdupq_n_f32(0.0);
for l in 0..k {
let a_val = vdupq_n_f32(a_data[i * k + l]);
let b_vec = vld1q_f32(&b_data[l * n + j]);
sum = vfmaq_f32(sum, a_val, b_vec);
}
vst1q_f32(&mut c_data[i * n + j], sum);
}
}
}
#[cfg(not(target_arch = "aarch64"))]
fn matmul_neon(&self, a: &Tensor, b: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
self.matmul_standard(a, b, m, k, n)
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
fn matmul_avx2(&self, a: &Tensor, b: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
self.matmul_standard(a, b, m, k, n)
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
fn matmul_avx512(
&self,
a: &Tensor,
b: &Tensor,
m: usize,
k: usize,
n: usize,
) -> Result<Tensor> {
self.matmul_standard(a, b, m, k, n)
}
#[cfg(target_arch = "riscv64")]
fn matmul_rvv(&self, a: &Tensor, b: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
let a_data = a.data()?;
let b_data = b.data()?;
let mut c_data = vec![0.0f32; m * n];
unsafe {
self.matmul_rvv_inner(&a_data, &b_data, &mut c_data, m, k, n);
}
Ok(Tensor::from_vec(c_data, &[m, n])?)
}
#[cfg(target_arch = "riscv64")]
unsafe fn matmul_rvv_inner(
&self,
a_data: &[f32],
b_data: &[f32],
c_data: &mut [f32],
m: usize,
k: usize,
n: usize,
) {
let vlen_elements = self.cpu_features.rvv_vlen / 32;
for i in 0..m {
let mut j = 0;
while j + vlen_elements <= n {
let mut sum = vec![0.0f32; vlen_elements];
for l in 0..k {
let a_val = a_data[i * k + l];
for v in 0..vlen_elements {
let b_val = b_data[l * n + j + v];
sum[v] += a_val * b_val;
}
}
for v in 0..vlen_elements {
c_data[i * n + j + v] = sum[v];
}
j += vlen_elements;
}
while j < n {
let mut sum = 0.0f32;
for l in 0..k {
sum += a_data[i * k + l] * b_data[l * n + j];
}
c_data[i * n + j] = sum;
j += 1;
}
}
}
#[cfg(not(target_arch = "riscv64"))]
fn matmul_rvv(&self, a: &Tensor, b: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
self.matmul_standard(a, b, m, k, n)
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe fn simd_exp_approx(x: __m256) -> __m256 {
let one = _mm256_set1_ps(1.0);
let half = _mm256_set1_ps(0.5);
let sixth = _mm256_set1_ps(1.0 / 6.0);
let x2 = _mm256_mul_ps(x, x);
let x3 = _mm256_mul_ps(x2, x);
let term1 = x;
let term2 = _mm256_mul_ps(x2, half);
let term3 = _mm256_mul_ps(x3, sixth);
let result = _mm256_add_ps(one, term1);
let result = _mm256_add_ps(result, term2);
_mm256_add_ps(result, term3)
}