#[cfg(target_arch = "aarch64")]
use super::aarch64;
#[cfg(target_arch = "x86_64")]
use super::avx2;
#[cfg(target_arch = "x86_64")]
use super::avx512;
use super::scalar::{matmul_bias_scalar_f32, matmul_bias_scalar_f64};
use super::scalar::{matmul_scalar_f32, matmul_scalar_f64};
use super::scalar::{microkernel_edge_f32, microkernel_edge_f64};
use super::small;
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
use super::tiling::{matmul_bias_tiled_f32, matmul_bias_tiled_f64};
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
use super::tiling::{matmul_tiled_f32, matmul_tiled_f64};
use crate::runtime::cpu::kernels::simd::{SimdLevel, detect_simd};
pub const MR: usize = 6;
pub const MC: usize = 126;
pub const KC: usize = 256;
pub const NC: usize = 512;
const SMALL_MATRIX_THRESHOLD: usize = 128 * 128 * 128 + 1;
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_f32(
a: *const f32,
b: *const f32,
out: *mut f32,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
let level = detect_simd();
if m * n * k < SMALL_MATRIX_THRESHOLD {
small::small_matmul_f32(a, b, out, m, n, k, lda, ldb, ldc, level);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => matmul_tiled_f32::<32>(a, b, out, m, n, k, lda, ldb, ldc, level),
SimdLevel::Avx2Fma => matmul_tiled_f32::<16>(a, b, out, m, n, k, lda, ldb, ldc, level),
_ => matmul_scalar_f32(a, b, out, m, n, k, lda, ldb, ldc),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
matmul_tiled_f32::<8>(a, b, out, m, n, k, lda, ldb, ldc, level)
}
_ => matmul_scalar_f32(a, b, out, m, n, k, lda, ldb, ldc),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
matmul_scalar_f32(a, b, out, m, n, k, lda, ldb, ldc);
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_f64(
a: *const f64,
b: *const f64,
out: *mut f64,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
let level = detect_simd();
if m * n * k < SMALL_MATRIX_THRESHOLD {
small::small_matmul_f64(a, b, out, m, n, k, lda, ldb, ldc, level);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => matmul_tiled_f64::<16>(a, b, out, m, n, k, lda, ldb, ldc, level),
SimdLevel::Avx2Fma => matmul_tiled_f64::<8>(a, b, out, m, n, k, lda, ldb, ldc, level),
_ => matmul_scalar_f64(a, b, out, m, n, k, lda, ldb, ldc),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
matmul_tiled_f64::<4>(a, b, out, m, n, k, lda, ldb, ldc, level)
}
_ => matmul_scalar_f64(a, b, out, m, n, k, lda, ldb, ldc),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
matmul_scalar_f64(a, b, out, m, n, k, lda, ldb, ldc);
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_bias_f32(
a: *const f32,
b: *const f32,
bias: *const f32,
out: *mut f32,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
let level = detect_simd();
if m * n * k < SMALL_MATRIX_THRESHOLD {
small::small_matmul_bias_f32(a, b, bias, out, m, n, k, lda, ldb, ldc, level);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => {
matmul_bias_tiled_f32::<32>(a, b, bias, out, m, n, k, lda, ldb, ldc, level)
}
SimdLevel::Avx2Fma => {
matmul_bias_tiled_f32::<16>(a, b, bias, out, m, n, k, lda, ldb, ldc, level)
}
_ => matmul_bias_scalar_f32(a, b, bias, out, m, n, k, lda, ldb, ldc),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
matmul_bias_tiled_f32::<8>(a, b, bias, out, m, n, k, lda, ldb, ldc, level)
}
_ => matmul_bias_scalar_f32(a, b, bias, out, m, n, k, lda, ldb, ldc),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
matmul_bias_scalar_f32(a, b, bias, out, m, n, k, lda, ldb, ldc);
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_bias_f64(
a: *const f64,
b: *const f64,
bias: *const f64,
out: *mut f64,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
let level = detect_simd();
if m * n * k < SMALL_MATRIX_THRESHOLD {
small::small_matmul_bias_f64(a, b, bias, out, m, n, k, lda, ldb, ldc, level);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => {
matmul_bias_tiled_f64::<16>(a, b, bias, out, m, n, k, lda, ldb, ldc, level)
}
SimdLevel::Avx2Fma => {
matmul_bias_tiled_f64::<8>(a, b, bias, out, m, n, k, lda, ldb, ldc, level)
}
_ => matmul_bias_scalar_f64(a, b, bias, out, m, n, k, lda, ldb, ldc),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
matmul_bias_tiled_f64::<4>(a, b, bias, out, m, n, k, lda, ldb, ldc, level)
}
_ => matmul_bias_scalar_f64(a, b, bias, out, m, n, k, lda, ldb, ldc),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
matmul_bias_scalar_f64(a, b, bias, out, m, n, k, lda, ldb, ldc);
}
#[inline]
pub unsafe fn call_microkernel_f32(
a: *const f32,
b: *const f32,
c: *mut f32,
k: usize,
ldc: usize,
level: SimdLevel,
first_k: bool,
) {
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::microkernel_6x16_f32(a, b, c, k, ldc, first_k),
SimdLevel::Avx2Fma => avx2::microkernel_6x8_f32(a, b, c, k, ldc, first_k),
_ => microkernel_edge_f32(a, b, c, MR, 4, k, ldc, first_k),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::microkernel_6x4_f32(a, b, c, k, ldc, first_k)
}
_ => microkernel_edge_f32(a, b, c, MR, 4, k, ldc, first_k),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
microkernel_edge_f32(a, b, c, MR, 4, k, ldc, first_k);
}
#[inline]
pub unsafe fn call_microkernel_2x_f32(
a: *const f32,
b: *const f32,
c: *mut f32,
k: usize,
ldc: usize,
level: SimdLevel,
first_k: bool,
) {
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::microkernel_6x32_f32(a, b, c, k, ldc, first_k),
SimdLevel::Avx2Fma => avx2::microkernel_6x16_f32(a, b, c, k, ldc, first_k),
_ => {
let nr = 4usize;
microkernel_edge_f32(a, b, c, MR, nr, k, ldc, first_k);
microkernel_edge_f32(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k);
}
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::microkernel_6x4_f32(a, b, c, k, ldc, first_k);
aarch64::neon::microkernel_6x4_f32(a, b.add(4 * k), c.add(4), k, ldc, first_k);
}
_ => {
let nr = 4usize;
microkernel_edge_f32(a, b, c, MR, nr, k, ldc, first_k);
microkernel_edge_f32(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k);
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let nr = 4usize;
microkernel_edge_f32(a, b, c, MR, nr, k, ldc, first_k);
microkernel_edge_f32(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k);
}
}
#[inline]
pub unsafe fn call_microkernel_f64(
a: *const f64,
b: *const f64,
c: *mut f64,
k: usize,
ldc: usize,
level: SimdLevel,
first_k: bool,
) {
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::microkernel_6x8_f64(a, b, c, k, ldc, first_k),
SimdLevel::Avx2Fma => avx2::microkernel_6x4_f64(a, b, c, k, ldc, first_k),
_ => microkernel_edge_f64(a, b, c, MR, 4, k, ldc, first_k),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::microkernel_6x2_f64(a, b, c, k, ldc, first_k)
}
_ => microkernel_edge_f64(a, b, c, MR, 2, k, ldc, first_k),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
microkernel_edge_f64(a, b, c, MR, 4, k, ldc, first_k);
}
#[inline]
pub unsafe fn call_microkernel_2x_f64(
a: *const f64,
b: *const f64,
c: *mut f64,
k: usize,
ldc: usize,
level: SimdLevel,
first_k: bool,
) {
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::microkernel_6x16_f64(a, b, c, k, ldc, first_k),
SimdLevel::Avx2Fma => avx2::microkernel_6x8_f64(a, b, c, k, ldc, first_k),
_ => {
let nr = 4usize;
microkernel_edge_f64(a, b, c, MR, nr, k, ldc, first_k);
microkernel_edge_f64(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k);
}
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::microkernel_6x2_f64(a, b, c, k, ldc, first_k);
aarch64::neon::microkernel_6x2_f64(a, b.add(2 * k), c.add(2), k, ldc, first_k);
}
_ => {
let nr = 2usize;
microkernel_edge_f64(a, b, c, MR, nr, k, ldc, first_k);
microkernel_edge_f64(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k);
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let nr = 4usize;
microkernel_edge_f64(a, b, c, MR, nr, k, ldc, first_k);
microkernel_edge_f64(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn reference_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
let mut c = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += a[i * k + kk] * b[kk * n + j];
}
c[i * n + j] = sum;
}
}
c
}
fn reference_matmul_f64(a: &[f64], b: &[f64], m: usize, n: usize, k: usize) -> Vec<f64> {
let mut c = vec![0.0f64; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f64;
for kk in 0..k {
sum += a[i * k + kk] * b[kk * n + j];
}
c[i * n + j] = sum;
}
}
c
}
fn reference_matmul_bias_f32(
a: &[f32],
b: &[f32],
bias: &[f32],
m: usize,
n: usize,
k: usize,
) -> Vec<f32> {
let mut c = reference_matmul_f32(a, b, m, n, k);
for i in 0..m {
for j in 0..n {
c[i * n + j] += bias[j];
}
}
c
}
const F32_SMALL_TOL: f32 = 1e-4;
const F32_LARGE_TOL: f32 = 1e-3;
const F64_SMALL_TOL: f64 = 1e-10;
const F64_LARGE_TOL: f64 = 1e-9;
#[test]
fn test_matmul_f32_small() {
let (m, n, k) = (4, 4, 4);
let a: Vec<f32> = (0..m * k).map(|i| (i + 1) as f32).collect();
let b: Vec<f32> = (0..k * n).map(|i| (i + 1) as f32).collect();
let mut c = vec![0.0f32; m * n];
let expected = reference_matmul_f32(&a, &b, m, n, k);
unsafe { matmul_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) };
for i in 0..m * n {
assert!((c[i] - expected[i]).abs() < F32_SMALL_TOL);
}
}
#[test]
fn test_matmul_f32_large() {
let (m, n, k) = (128, 128, 128);
let a: Vec<f32> = (0..m * k).map(|i| ((i % 17) as f32) * 0.1).collect();
let b: Vec<f32> = (0..k * n).map(|i| ((i % 13) as f32) * 0.1).collect();
let mut c = vec![0.0f32; m * n];
let expected = reference_matmul_f32(&a, &b, m, n, k);
unsafe { matmul_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) };
let max_diff = (0..m * n)
.map(|i| (c[i] - expected[i]).abs())
.fold(0.0f32, f32::max);
assert!(max_diff < F32_LARGE_TOL);
}
#[test]
fn test_matmul_f64_small() {
let (m, n, k) = (4, 4, 4);
let a: Vec<f64> = (0..m * k).map(|i| (i + 1) as f64).collect();
let b: Vec<f64> = (0..k * n).map(|i| (i + 1) as f64).collect();
let mut c = vec![0.0f64; m * n];
let expected = reference_matmul_f64(&a, &b, m, n, k);
unsafe { matmul_f64(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) };
for i in 0..m * n {
assert!((c[i] - expected[i]).abs() < F64_SMALL_TOL);
}
}
#[test]
fn test_matmul_f64_large() {
let (m, n, k) = (128, 128, 128);
let a: Vec<f64> = (0..m * k).map(|i| ((i % 17) as f64) * 0.1).collect();
let b: Vec<f64> = (0..k * n).map(|i| ((i % 13) as f64) * 0.1).collect();
let mut c = vec![0.0f64; m * n];
let expected = reference_matmul_f64(&a, &b, m, n, k);
unsafe { matmul_f64(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) };
let max_diff = (0..m * n)
.map(|i| (c[i] - expected[i]).abs())
.fold(0.0f64, f64::max);
assert!(max_diff < F64_LARGE_TOL);
}
#[test]
fn test_matmul_non_square() {
let (m, n, k) = (37, 53, 41);
let a: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32) * 0.5).collect();
let b: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32) * 0.3).collect();
let mut c = vec![0.0f32; m * n];
let expected = reference_matmul_f32(&a, &b, m, n, k);
unsafe { matmul_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) };
let max_diff = (0..m * n)
.map(|i| (c[i] - expected[i]).abs())
.fold(0.0f32, f32::max);
assert!(max_diff < F32_LARGE_TOL);
}
#[test]
fn test_matmul_bias_f32_small() {
let (m, n, k) = (4, 4, 4);
let a: Vec<f32> = (0..m * k).map(|i| (i + 1) as f32).collect();
let b: Vec<f32> = (0..k * n).map(|i| (i + 1) as f32).collect();
let bias: Vec<f32> = (0..n).map(|i| (i * 10) as f32).collect();
let mut c = vec![0.0f32; m * n];
let expected = reference_matmul_bias_f32(&a, &b, &bias, m, n, k);
unsafe {
matmul_bias_f32(
a.as_ptr(),
b.as_ptr(),
bias.as_ptr(),
c.as_mut_ptr(),
m,
n,
k,
k,
n,
n,
)
};
for i in 0..m * n {
assert!((c[i] - expected[i]).abs() < F32_SMALL_TOL);
}
}
#[test]
fn test_matmul_bias_f32_large() {
let (m, n, k) = (128, 128, 128);
let a: Vec<f32> = (0..m * k).map(|i| ((i % 17) as f32) * 0.1).collect();
let b: Vec<f32> = (0..k * n).map(|i| ((i % 13) as f32) * 0.1).collect();
let bias: Vec<f32> = (0..n).map(|i| ((i % 7) as f32) * 0.5).collect();
let mut c = vec![0.0f32; m * n];
let expected = reference_matmul_bias_f32(&a, &b, &bias, m, n, k);
unsafe {
matmul_bias_f32(
a.as_ptr(),
b.as_ptr(),
bias.as_ptr(),
c.as_mut_ptr(),
m,
n,
k,
k,
n,
n,
)
};
let max_diff = (0..m * n)
.map(|i| (c[i] - expected[i]).abs())
.fold(0.0f32, f32::max);
assert!(max_diff < F32_LARGE_TOL);
}
#[test]
fn test_simd_level_detection() {
let level = detect_simd();
println!("Detected SIMD level: {level:?}");
}
}