use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
use scirs2_core::numeric::{Float, NumAssign, Zero, One};
use std::mem;
#[derive(Debug, Clone, Copy)]
pub struct AdvancedHardwareCapabilities {
pub has_avx512f: bool,
pub has_avx512cd: bool,
pub has_avx512vl: bool,
pub has_avx512bw: bool,
pub has_avx512dq: bool,
pub has_avx512_bf16: bool,
pub has_avx512_fp16: bool,
pub has_avx512_vnni: bool,
pub has_amx_tile: bool,
pub has_amx_int8: bool,
pub has_amx_bf16: bool,
pub has_neon: bool,
pub has_sve: bool,
pub has_sve2: bool,
pub has_dotprod: bool,
pub has_fp16: bool,
pub has_bf16: bool,
pub has_i8mm: bool,
pub l1_cachesize: usize,
pub l2_cachesize: usize,
pub l3_cachesize: usize,
pub simd_register_count: usize,
}
impl AdvancedHardwareCapabilities {
pub fn detect() -> Self {
Self {
has_avx512f: is_x86_feature_detected!("avx512f"),
has_avx512cd: is_x86_feature_detected!("avx512cd"),
has_avx512vl: is_x86_feature_detected!("avx512vl"),
has_avx512bw: is_x86_feature_detected!("avx512bw"),
has_avx512dq: is_x86_feature_detected!("avx512dq"),
has_avx512_bf16: cfg!(target_feature = "avx512bf16"),
has_avx512_fp16: cfg!(target_feature = "avx512fp16"),
has_avx512_vnni: cfg!(target_feature = "avx512vnni"),
has_amx_tile: cfg!(target_feature = "amx-tile"),
has_amx_int8: cfg!(target_feature = "amx-int8"),
has_amx_bf16: cfg!(target_feature = "amx-bf16"),
has_neon: cfg!(target_arch = "aarch64"),
has_sve: cfg!(target_feature = "sve"),
has_sve2: cfg!(target_feature = "sve2"),
has_dotprod: cfg!(target_feature = "dotprod"),
has_fp16: cfg!(target_feature = "fp16"),
has_bf16: cfg!(target_feature = "bf16"),
has_i8mm: cfg!(target_feature = "i8mm"),
l1_cachesize: 32 * 1024, l2_cachesize: 512 * 1024, l3_cachesize: 8 * 1024 * 1024,
simd_register_count: if cfg!(target_arch = "x86_64") { 32 } else { 32 },
}
}
pub fn optimal_vector_width_for_type<T>(&self) -> usize {
let elementsize = mem::size_of::<T>();
if self.has_avx512f {
64 / elementsize } else if self.has_sve {
32 / elementsize } else {
32 / elementsize }
}
pub fn optimal_blocking_strategy(&self) -> BlockingStrategy {
if self.has_amx_tile {
BlockingStrategy::AMXTiling { tile_m: 16, tile_n: 64, tile_k: 64 }
} else if self.has_avx512f {
BlockingStrategy::AVX512 { block_m: 6, block_n: 16, block_k: 128 }
} else if self.has_sve {
BlockingStrategy::SVE { vector_length: 256 } } else {
BlockingStrategy::Standard { blocksize: 64 }
}
}
}
#[derive(Debug, Clone)]
pub enum BlockingStrategy {
Standard { blocksize: usize },
AVX512 { block_m: usize, block_n: usize, block_k: usize },
SVE { vector_length: usize },
AMXTiling { tile_m: usize, tile_n: usize, tile_k: usize },
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512vl", enable = "avx512bw", enable = "avx512dq")]
pub unsafe fn avx512_gemm_f32(
a: *const f32,
b: *const f32,
c: *mut f32,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
alpha: f32,
beta: f32,
) -> LinalgResult<()> {
use std::arch::x86_64::*;
const MR: usize = 6; const NR: usize = 16;
let alpha_vec = _mm512_set1_ps(alpha);
let beta_vec = _mm512_set1_ps(beta);
for i in (0..m).step_by(MR) {
for j in (0..n).step_by(NR) {
let mut c_vec = [_mm512_setzero_ps(); MR];
if beta != 0.0 {
for ii in 0..MR.min(m - i) {
if j + NR <= n {
c_vec[ii] = _mm512_loadu_ps(c.add((i + ii) * ldc + j));
} else {
let remaining = n - j;
let mask = (1u16 << remaining) - 1;
c_vec[ii] = _mm512_maskz_loadu_ps(mask, c.add((i + ii) * ldc + j));
}
}
}
for l in 0..k {
let mut a_vec = [_mm512_setzero_ps(); MR];
for ii in 0..MR.min(m - i) {
let a_scalar = *a.add((i + ii) * lda + l);
a_vec[ii] = _mm512_set1_ps(a_scalar);
}
let b_vec = if j + NR <= n {
_mm512_loadu_ps(b.add(l * ldb + j))
} else {
let remaining = n - j;
let mask = (1u16 << remaining) - 1;
_mm512_maskz_loadu_ps(mask, b.add(l * ldb + j))
};
for ii in 0..MR.min(m - i) {
c_vec[ii] = _mm512_fmadd_ps(a_vec[ii], b_vec, c_vec[ii]);
}
}
for ii in 0..MR.min(m - i) {
let result = if beta != 0.0 {
_mm512_fmadd_ps(alpha_vec, c_vec[ii]_mm512_mul_ps(beta_vec, c_vec[ii]))
} else {
_mm512_mul_ps(alpha_vec, c_vec[ii])
};
if j + NR <= n {
_mm512_storeu_ps(c.add((i + ii) * ldc + j), result);
} else {
let remaining = n - j;
let mask = (1u16 << remaining) - 1;
_mm512_mask_storeu_ps(c.add((i + ii) * ldc + j), mask, result);
}
}
}
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512_fp16")]
pub unsafe fn avx512_mixed_precision_gemm(
a_f16: *const u16, b_f16: *const u16, c_f32: *mut f32, m: usize,
n: usize,
k: usize,
) -> LinalgResult<()> {
use std::arch::x86_64::*;
const MR: usize = 4;
const NR: usize = 16;
for i in (0..m).step_by(MR) {
for j in (0..n).step_by(NR) {
let mut c_acc = [_mm512_setzero_ps(); MR];
for l in 0..k {
let mut a_f32 = [_mm512_setzero_ps(); MR];
for ii in 0..MR.min(m - i) {
let a_f16_val = *a_f16.add((i + ii) * k + l);
let a_f32_val = _f32::from_bits((a_f16_val as u32) << 16);
a_f32[ii] = _mm512_set1_ps(a_f32_val);
}
let mut b_f32_vals = [0.0f32; 16];
for jj in 0..NR.min(n - j) {
let b_f16_val = *b_f16.add(l * n + j + jj);
b_f32_vals[jj] = _f32::from_bits((b_f16_val as u32) << 16);
}
let b_f32 = _mm512_loadu_ps(b_f32_vals.as_ptr());
for ii in 0..MR.min(m - i) {
c_acc[ii] = _mm512_fmadd_ps(a_f32[ii], b_f32, c_acc[ii]);
}
}
for ii in 0..MR.min(m - i) {
if j + NR <= n {
_mm512_storeu_ps(c_f32.add((i + ii) * n + j), c_acc[ii]);
} else {
let remaining = n - j;
let mask = (1u16 << remaining) - 1;
_mm512_mask_storeu_ps(c_f32.add((i + ii) * n + j), mask, c_acc[ii]);
}
}
}
}
Ok(())
}
#[cfg(all(target_arch = "aarch64", target_feature = "sve"))]
#[target_feature(enable = "sve")]
pub unsafe fn sve_gemm_f32(
a: *const f32,
b: *const f32,
c: *mut f32,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) -> LinalgResult<()> {
use std::arch::aarch64::*;
let sve_len_f32 = svcntw();
for i in 0..m {
let mut j = 0;
while j < n {
let remaining = n - j;
let active_lanes = if remaining >= sve_len_f32 as usize {
svptrue_b32()
} else {
svwhilelt_b32(j as u32, n as u32)
};
let mut acc = svdup_n_f32(0.0);
for l in 0..k {
let a_scalar = *a.add(i * lda + l);
let a_vec = svdup_n_f32(a_scalar);
let b_vec = svld1_f32(active_lanes, b.add(l * ldb + j));
acc = svmla_f32_z(active_lanes, acc, a_vec, b_vec);
}
svst1_f32(active_lanes, c.add(i * ldc + j), acc);
j += sve_len_f32 as usize;
}
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon", enable = "dotprod")]
pub unsafe fn neon_advanced_gemm_i8(
a: *const i8,
b: *const i8,
c: *mut i32,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) -> LinalgResult<()> {
use std::arch::aarch64::*;
const MR: usize = 4;
const NR: usize = 4;
let k_aligned = k & !3;
for i in (0..m).step_by(MR) {
for j in (0..n).step_by(NR) {
let mut c_acc = [[vdupq_n_s32(0); NR]; MR];
for l in (0..k_aligned).step_by(4) {
let mut a_vecs = [vdupq_n_s8(0); MR];
let mut b_vecs = [vdupq_n_s8(0); NR];
for ii in 0..MR.min(m - i) {
let a_ptr = a.add((i + ii) * lda + l);
let a_val = vld1_s8(a_ptr);
a_vecs[ii] = vcombine_s8(a_val, vdup_n_s8(0));
}
for jj in 0..NR.min(n - j) {
let b_ptr = b.add(l * ldb + j + jj);
let b_vals = [
*b_ptr,
*b_ptr.add(ldb),
*b_ptr.add(2 * ldb),
*b_ptr.add(3 * ldb),
];
let b_val = vld1_s8(b_vals.as_ptr());
b_vecs[jj] = vcombine_s8(b_val, vdup_n_s8(0));
}
for ii in 0..MR.min(m - i) {
for jj in 0..NR.min(n - j) {
c_acc[ii][jj] = vdotq_s32(c_acc[ii][jj], a_vecs[ii], b_vecs[jj]);
}
}
}
for l in k_aligned..k {
for ii in 0..MR.min(m - i) {
for jj in 0..NR.min(n - j) {
let a_val = *a.add((i + ii) * lda + l) as i32;
let b_val = *b.add(l * ldb + j + jj) as i32;
let c_ptr = c.add((i + ii) * ldc + j + jj);
*c_ptr += a_val * b_val;
}
}
}
for ii in 0..MR.min(m - i) {
for jj in 0..NR.min(n - j) {
let sum = vaddvq_s32(c_acc[ii][jj]);
let c_ptr = c.add((i + ii) * ldc + j + jj);
*c_ptr += sum;
}
}
}
}
Ok(())
}
#[cfg(all(target_arch = "x86_64", target_feature = "amx-tile"))]
#[target_feature(enable = "amx-tile", enable = "amx-int8")]
pub unsafe fn amx_gemm_i8(
a: *const i8,
b: *const i8,
c: *mut i32,
m: usize,
n: usize,
k: usize,
) -> LinalgResult<()> {
use std::arch::x86_64::*;
const TILE_M: usize = 16;
const TILE_N: usize = 64;
const TILE_K: usize = 64;
_tile_loadconfig(&[
1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
TILE_M as u16, (TILE_N * 4) as u16, TILE_M as u16, (TILE_K) as u16,
TILE_K as u16, (TILE_N * 4) as u16, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
] as *const u8);
for i in (0..m).step_by(TILE_M) {
for j in (0..n).step_by(TILE_N) {
_tile_zero(0);
for l in (0..k).step_by(TILE_K) {
_tile_loadd(1, a.add(i * k + l) as *const u8, k as isize);
_tile_loadd(2, b.add(l * n + j) as *const u8, n as isize);
_tile_dpbssd(0, 1, 2);
}
_tile_stored(0, c.add(i * n + j) as *mut u8, (n * 4) as isize);
}
}
_tile_release();
Ok(())
}
pub struct AdaptiveSIMDDispatcher {
capabilities: AdvancedHardwareCapabilities,
blocking_strategy: BlockingStrategy,
}
impl AdaptiveSIMDDispatcher {
pub fn new() -> Self {
let capabilities = AdvancedHardwareCapabilities::detect();
let blocking_strategy = capabilities.optimal_blocking_strategy();
Self {
capabilities,
blocking_strategy,
}
}
pub fn adaptive_gemm_f32(
&self,
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
) -> LinalgResult<Array2<f32>> {
let (m, k) = a.dim();
let (k2, n) = b.dim();
if k != k2 {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions incompatible: {}x{} * {}x{}",
m, k, k2, n
)));
}
let mut result = Array2::zeros((m, n));
#[cfg(target_arch = "x86_64")]
{
if self.capabilities.has_amx_tile && m >= 16 && n >= 64 && k >= 64 {
return self.adaptive_gemm_with_amx_fallback(a, b);
} else if self.capabilities.has_avx512f && m >= 6 && n >= 16 {
unsafe {
avx512_gemm_f32(
a.as_ptr(),
b.as_ptr(),
result.as_mut_ptr(),
m, n, k,
k, n, n, 1.0, 0.0, )?;
}
return Ok(result);
}
}
#[cfg(target_arch = "aarch64")]
{
if self.capabilities.has_sve && m >= 4 {
unsafe {
sve_gemm_f32(
a.as_ptr(),
b.as_ptr(),
result.as_mut_ptr(),
m, n, k,
k, n, n,
)?;
}
return Ok(result);
} else if self.capabilities.has_neon {
return self.adaptive_gemm_with_neon(a, b);
}
}
self.fallback_gemm(a, b)
}
pub fn adaptive_mixed_precision_gemm(
&self,
a_f16: &ArrayView2<u16>, b_f16: &ArrayView2<u16>, ) -> LinalgResult<Array2<f32>> {
let (m, k) = a_f16.dim();
let (k2, n) = b_f16.dim();
if k != k2 {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions incompatible: {}x{} * {}x{}",
m, k, k2, n
)));
}
let mut result = Array2::zeros((m, n));
#[cfg(target_arch = "x86_64")]
{
if self.capabilities.has_avx512_fp16 {
unsafe {
avx512_mixed_precision_gemm(
a_f16.as_ptr(),
b_f16.as_ptr(),
result.as_mut_ptr(),
m, n, k,
)?;
}
return Ok(result);
}
}
self.fallback_mixed_precision_gemm(a_f16, b_f16)
}
#[cfg(target_arch = "x86_64")]
fn adaptive_gemm_with_amx_fallback(
&self,
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
) -> LinalgResult<Array2<f32>> {
self.fallback_gemm(a, b)
}
#[cfg(target_arch = "aarch64")]
fn adaptive_gemm_with_neon(
&self,
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
) -> LinalgResult<Array2<f32>> {
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::zeros((m, n));
self.blocked_neon_gemm(a, b, &mut result)?;
Ok(result)
}
#[cfg(target_arch = "aarch64")]
fn blocked_neon_gemm(
&self,
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
result: &mut Array2<f32>,
) -> LinalgResult<()> {
use std::arch::aarch64::*;
let (m, k) = a.dim();
let (_, n) = b.dim();
const BLOCK_SIZE: usize = 32;
for ii in (0..m).step_by(BLOCK_SIZE) {
for jj in (0..n).step_by(BLOCK_SIZE) {
for kk in (0..k).step_by(BLOCK_SIZE) {
let i_end = (ii + BLOCK_SIZE).min(m);
let j_end = (jj + BLOCK_SIZE).min(n);
let k_end = (kk + BLOCK_SIZE).min(k);
for i in ii..i_end {
for j in (jj..j_end).step_by(4) {
unsafe {
let mut acc = vdupq_n_f32(0.0);
for l in kk..k_end {
let a_val = vdupq_n_f32(a[[i, l]]);
let b_vals = if j + 4 <= j_end {
[b[[l, j]], b[[l, j+1]], b[[l, j+2]], b[[l, j+3]]]
} else {
[
if j < j_end { b[[l, j]] } else { 0.0 },
if j+1 < j_end { b[[l, j+1]] } else { 0.0 },
if j+2 < j_end { b[[l, j+2]] } else { 0.0 },
if j+3 < j_end { b[[l, j+3]] } else { 0.0 },
]
};
let b_vec = vld1q_f32(b_vals.as_ptr());
acc = vfmaq_f32(acc, a_val, b_vec);
}
let acc_vals = [
vgetq_lane_f32(acc, 0),
vgetq_lane_f32(acc, 1),
vgetq_lane_f32(acc, 2),
vgetq_lane_f32(acc, 3),
];
for (idx, &val) in acc_vals.iter().enumerate() {
if j + idx < j_end {
result[[i, j + idx]] += val;
}
}
}
}
}
}
}
}
Ok(())
}
fn fallback_gemm(
&self,
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
) -> LinalgResult<Array2<f32>> {
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut sum = 0.0;
for l in 0..k {
sum += a[[i, l]] * b[[l, j]];
}
result[[i, j]] = sum;
}
}
Ok(result)
}
fn fallback_mixed_precision_gemm(
&self,
a_f16: &ArrayView2<u16>,
b_f16: &ArrayView2<u16>,
) -> LinalgResult<Array2<f32>> {
let (m, k) = a_f16.dim();
let (_, n) = b_f16.dim();
let mut result = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for l in 0..k {
let a_f32 = f32::from_bits((a_f16[[i, l]] as u32) << 16);
let b_f32 = f32::from_bits((b_f16[[l, j]] as u32) << 16);
sum += a_f32 * b_f32;
}
result[[i, j]] = sum;
}
}
Ok(result)
}
}
impl Default for AdaptiveSIMDDispatcher {
fn default() -> Self {
Self::new()
}
}
pub struct CacheOptimizedOperations {
l1_blocksize: usize,
l2_blocksize: usize,
l3_blocksize: usize,
}
impl CacheOptimizedOperations {
pub fn new(capabilities: &AdvancedHardwareCapabilities) -> Self {
Self {
l1_blocksize: (_capabilities.l1_cachesize / 8).next_power_of_two().min(256),
l2_blocksize: (_capabilities.l2_cachesize / 8).next_power_of_two().min(2048),
l3_blocksize: (_capabilities.l3_cachesize / 8).next_power_of_two().min(8192),
}
}
pub fn cache_optimized_transpose_f32(
&self,
input: &ArrayView2<f32>,
) -> LinalgResult<Array2<f32>> {
let (rows, cols) = input.dim();
let mut output = Array2::zeros((cols, rows));
let tilesize = self.l1_blocksize.min(64);
for i in (0..rows).step_by(tilesize) {
for j in (0..cols).step_by(tilesize) {
let i_end = (i + tilesize).min(rows);
let j_end = (j + tilesize).min(cols);
if i + 2 * tilesize < rows {
self.prefetch_memory(input.as_ptr(), (i + 2 * tilesize) * cols + j);
}
for ii in i..i_end {
for jj in j..j_end {
output[[jj, ii]] = input[[ii, jj]];
}
}
}
}
Ok(output)
}
#[cfg(target_arch = "x86_64")]
fn prefetch_memory(&self, ptr: *const f32, offset: usize) {
unsafe {
use std::arch::x86_64::*;
let prefetch_ptr = ptr.add(offset) as *const i8;
_mm_prefetch(prefetch_ptr_MM_HINT_T1);
}
}
#[cfg(not(target_arch = "x86_64"))]
fn prefetch_memory(&selfptr: *const f32, offset: usize) {
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_advanced_capabilities_detection() {
let caps = AdvancedHardwareCapabilities::detect();
assert!(caps.optimal_vector_width_for_type::<f32>() >= 4);
assert!(caps.optimal_vector_width_for_type::<f64>() >= 2);
let strategy = caps.optimal_blocking_strategy();
match strategy {
BlockingStrategy::Standard { blocksize } => assert!(blocksize > 0),
BlockingStrategy::AVX512 { block_m, block_n, block_k } => {
assert!(block_m > 0 && block_n > 0 && block_k > 0);
},
BlockingStrategy::SVE { vector_length } => assert!(vector_length > 0),
BlockingStrategy::AMXTiling { tile_m, tile_n, tile_k } => {
assert!(tile_m > 0 && tile_n > 0 && tile_k > 0);
},
}
}
#[test]
fn test_adaptive_simd_dispatcher() {
let dispatcher = AdaptiveSIMDDispatcher::new();
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
let result = dispatcher.adaptive_gemm_f32(&a.view(), &b.view()).expect("Operation failed");
let expected = array![[58.0, 64.0], [139.0, 154.0]];
for (actual, expected) in result.iter().zip(expected.iter()) {
assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-6);
}
}
#[test]
fn test_cache_optimized_transpose() {
let caps = AdvancedHardwareCapabilities::detect();
let cache_ops = CacheOptimizedOperations::new(&caps);
let input = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let result = cache_ops.cache_optimized_transpose_f32(&input.view()).expect("Operation failed");
let expected = array![[1.0, 4.0, 7.0], [2.0, 5.0, 8.0], [3.0, 6.0, 9.0]];
for (actual, expected) in result.iter().zip(expected.iter()) {
assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-10);
}
}
}