#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use super::{NEON_LANE_WIDTH, PREFETCH_DISTANCE};
const TILE_M: usize = 96;
const TILE_N: usize = 64;
const TILE_K: usize = 256;
const MR: usize = 12;
const NR: usize = 4;
const PARALLEL_THRESHOLD: usize = 4096;
#[inline(always)]
pub fn gemv_neon(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
debug_assert_eq!(a.len(), m * n);
debug_assert_eq!(x.len(), n);
debug_assert_eq!(y.len(), m);
#[cfg(all(target_os = "macos", feature = "accelerate"))]
{
if super::accelerate::should_use_accelerate(m, n) {
super::accelerate::gemv_accelerate(
a,
x,
y,
m,
n,
super::accelerate::MatrixLayout::RowMajor,
);
return;
}
}
#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
{
if m * n >= PARALLEL_THRESHOLD {
unsafe { gemv_parallel(a, x, y, m, n) };
return;
}
}
#[cfg(target_arch = "aarch64")]
unsafe {
gemv_neon_impl(a, x, y, m, n);
}
#[cfg(not(target_arch = "aarch64"))]
{
gemv_scalar(a, x, y, m, n);
}
}
#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
pub unsafe fn gemv_parallel(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
use rayon::prelude::*;
let chunk_size = MR.max(64);
y.par_chunks_mut(chunk_size)
.enumerate()
.for_each(|(chunk_idx, y_chunk)| {
let row_start = chunk_idx * chunk_size;
let row_end = (row_start + y_chunk.len()).min(m);
let chunk_m = row_end - row_start;
let a_chunk = &a[row_start * n..(row_start + chunk_m) * n];
gemv_neon_impl(a_chunk, x, y_chunk, chunk_m, n);
});
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn gemv_neon_impl(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
let a_ptr = a.as_ptr();
let x_ptr = x.as_ptr();
let y_ptr = y.as_mut_ptr();
let row_chunks = m / MR;
for rc in 0..row_chunks {
let row_base = rc * MR;
let mut sum0 = vdupq_n_f32(0.0);
let mut sum1 = vdupq_n_f32(0.0);
let mut sum2 = vdupq_n_f32(0.0);
let mut sum3 = vdupq_n_f32(0.0);
let mut sum4 = vdupq_n_f32(0.0);
let mut sum5 = vdupq_n_f32(0.0);
let mut sum6 = vdupq_n_f32(0.0);
let mut sum7 = vdupq_n_f32(0.0);
let mut sum8 = vdupq_n_f32(0.0);
let mut sum9 = vdupq_n_f32(0.0);
let mut sum10 = vdupq_n_f32(0.0);
let mut sum11 = vdupq_n_f32(0.0);
let col_chunks_8 = n / 8;
let mut col = 0usize;
for _ in 0..col_chunks_8 {
let x_v0 = vld1q_f32(x_ptr.add(col));
let x_v1 = vld1q_f32(x_ptr.add(col + 4));
sum0 = vfmaq_f32(sum0, vld1q_f32(a_ptr.add((row_base + 0) * n + col)), x_v0);
sum0 = vfmaq_f32(
sum0,
vld1q_f32(a_ptr.add((row_base + 0) * n + col + 4)),
x_v1,
);
sum1 = vfmaq_f32(sum1, vld1q_f32(a_ptr.add((row_base + 1) * n + col)), x_v0);
sum1 = vfmaq_f32(
sum1,
vld1q_f32(a_ptr.add((row_base + 1) * n + col + 4)),
x_v1,
);
sum2 = vfmaq_f32(sum2, vld1q_f32(a_ptr.add((row_base + 2) * n + col)), x_v0);
sum2 = vfmaq_f32(
sum2,
vld1q_f32(a_ptr.add((row_base + 2) * n + col + 4)),
x_v1,
);
sum3 = vfmaq_f32(sum3, vld1q_f32(a_ptr.add((row_base + 3) * n + col)), x_v0);
sum3 = vfmaq_f32(
sum3,
vld1q_f32(a_ptr.add((row_base + 3) * n + col + 4)),
x_v1,
);
sum4 = vfmaq_f32(sum4, vld1q_f32(a_ptr.add((row_base + 4) * n + col)), x_v0);
sum4 = vfmaq_f32(
sum4,
vld1q_f32(a_ptr.add((row_base + 4) * n + col + 4)),
x_v1,
);
sum5 = vfmaq_f32(sum5, vld1q_f32(a_ptr.add((row_base + 5) * n + col)), x_v0);
sum5 = vfmaq_f32(
sum5,
vld1q_f32(a_ptr.add((row_base + 5) * n + col + 4)),
x_v1,
);
sum6 = vfmaq_f32(sum6, vld1q_f32(a_ptr.add((row_base + 6) * n + col)), x_v0);
sum6 = vfmaq_f32(
sum6,
vld1q_f32(a_ptr.add((row_base + 6) * n + col + 4)),
x_v1,
);
sum7 = vfmaq_f32(sum7, vld1q_f32(a_ptr.add((row_base + 7) * n + col)), x_v0);
sum7 = vfmaq_f32(
sum7,
vld1q_f32(a_ptr.add((row_base + 7) * n + col + 4)),
x_v1,
);
sum8 = vfmaq_f32(sum8, vld1q_f32(a_ptr.add((row_base + 8) * n + col)), x_v0);
sum8 = vfmaq_f32(
sum8,
vld1q_f32(a_ptr.add((row_base + 8) * n + col + 4)),
x_v1,
);
sum9 = vfmaq_f32(sum9, vld1q_f32(a_ptr.add((row_base + 9) * n + col)), x_v0);
sum9 = vfmaq_f32(
sum9,
vld1q_f32(a_ptr.add((row_base + 9) * n + col + 4)),
x_v1,
);
sum10 = vfmaq_f32(sum10, vld1q_f32(a_ptr.add((row_base + 10) * n + col)), x_v0);
sum10 = vfmaq_f32(
sum10,
vld1q_f32(a_ptr.add((row_base + 10) * n + col + 4)),
x_v1,
);
sum11 = vfmaq_f32(sum11, vld1q_f32(a_ptr.add((row_base + 11) * n + col)), x_v0);
sum11 = vfmaq_f32(
sum11,
vld1q_f32(a_ptr.add((row_base + 11) * n + col + 4)),
x_v1,
);
col += 8;
}
while col + 4 <= n {
let x_v = vld1q_f32(x_ptr.add(col));
sum0 = vfmaq_f32(sum0, vld1q_f32(a_ptr.add((row_base + 0) * n + col)), x_v);
sum1 = vfmaq_f32(sum1, vld1q_f32(a_ptr.add((row_base + 1) * n + col)), x_v);
sum2 = vfmaq_f32(sum2, vld1q_f32(a_ptr.add((row_base + 2) * n + col)), x_v);
sum3 = vfmaq_f32(sum3, vld1q_f32(a_ptr.add((row_base + 3) * n + col)), x_v);
sum4 = vfmaq_f32(sum4, vld1q_f32(a_ptr.add((row_base + 4) * n + col)), x_v);
sum5 = vfmaq_f32(sum5, vld1q_f32(a_ptr.add((row_base + 5) * n + col)), x_v);
sum6 = vfmaq_f32(sum6, vld1q_f32(a_ptr.add((row_base + 6) * n + col)), x_v);
sum7 = vfmaq_f32(sum7, vld1q_f32(a_ptr.add((row_base + 7) * n + col)), x_v);
sum8 = vfmaq_f32(sum8, vld1q_f32(a_ptr.add((row_base + 8) * n + col)), x_v);
sum9 = vfmaq_f32(sum9, vld1q_f32(a_ptr.add((row_base + 9) * n + col)), x_v);
sum10 = vfmaq_f32(sum10, vld1q_f32(a_ptr.add((row_base + 10) * n + col)), x_v);
sum11 = vfmaq_f32(sum11, vld1q_f32(a_ptr.add((row_base + 11) * n + col)), x_v);
col += 4;
}
let mut y0 = vaddvq_f32(sum0);
let mut y1 = vaddvq_f32(sum1);
let mut y2 = vaddvq_f32(sum2);
let mut y3 = vaddvq_f32(sum3);
let mut y4 = vaddvq_f32(sum4);
let mut y5 = vaddvq_f32(sum5);
let mut y6 = vaddvq_f32(sum6);
let mut y7 = vaddvq_f32(sum7);
let mut y8 = vaddvq_f32(sum8);
let mut y9 = vaddvq_f32(sum9);
let mut y10 = vaddvq_f32(sum10);
let mut y11 = vaddvq_f32(sum11);
for c in col..n {
let x_val = *x_ptr.add(c);
y0 += *a_ptr.add((row_base + 0) * n + c) * x_val;
y1 += *a_ptr.add((row_base + 1) * n + c) * x_val;
y2 += *a_ptr.add((row_base + 2) * n + c) * x_val;
y3 += *a_ptr.add((row_base + 3) * n + c) * x_val;
y4 += *a_ptr.add((row_base + 4) * n + c) * x_val;
y5 += *a_ptr.add((row_base + 5) * n + c) * x_val;
y6 += *a_ptr.add((row_base + 6) * n + c) * x_val;
y7 += *a_ptr.add((row_base + 7) * n + c) * x_val;
y8 += *a_ptr.add((row_base + 8) * n + c) * x_val;
y9 += *a_ptr.add((row_base + 9) * n + c) * x_val;
y10 += *a_ptr.add((row_base + 10) * n + c) * x_val;
y11 += *a_ptr.add((row_base + 11) * n + c) * x_val;
}
*y_ptr.add(row_base + 0) = y0;
*y_ptr.add(row_base + 1) = y1;
*y_ptr.add(row_base + 2) = y2;
*y_ptr.add(row_base + 3) = y3;
*y_ptr.add(row_base + 4) = y4;
*y_ptr.add(row_base + 5) = y5;
*y_ptr.add(row_base + 6) = y6;
*y_ptr.add(row_base + 7) = y7;
*y_ptr.add(row_base + 8) = y8;
*y_ptr.add(row_base + 9) = y9;
*y_ptr.add(row_base + 10) = y10;
*y_ptr.add(row_base + 11) = y11;
}
for row in (row_chunks * MR)..m {
let mut sum0 = vdupq_n_f32(0.0);
let mut sum1 = vdupq_n_f32(0.0);
let col_chunks_8 = n / 8;
let mut col = 0usize;
for _ in 0..col_chunks_8 {
let x_v0 = vld1q_f32(x_ptr.add(col));
let x_v1 = vld1q_f32(x_ptr.add(col + 4));
sum0 = vfmaq_f32(sum0, vld1q_f32(a_ptr.add(row * n + col)), x_v0);
sum1 = vfmaq_f32(sum1, vld1q_f32(a_ptr.add(row * n + col + 4)), x_v1);
col += 8;
}
let mut y_val = vaddvq_f32(vaddq_f32(sum0, sum1));
while col + 4 <= n {
let x_v = vld1q_f32(x_ptr.add(col));
let a_v = vld1q_f32(a_ptr.add(row * n + col));
y_val += vaddvq_f32(vmulq_f32(a_v, x_v));
col += 4;
}
for c in col..n {
y_val += *a_ptr.add(row * n + c) * *x_ptr.add(c);
}
*y_ptr.add(row) = y_val;
}
}
#[allow(dead_code)]
fn gemv_scalar(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
for row in 0..m {
let mut sum = 0.0f32;
for col in 0..n {
sum += a[row * n + col] * x[col];
}
y[row] = sum;
}
}
#[inline(always)]
pub fn gemm_neon(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
debug_assert_eq!(c.len(), m * n);
c.fill(0.0);
#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
{
if m * n >= PARALLEL_THRESHOLD {
unsafe { gemm_parallel(a, b, c, m, k, n) };
return;
}
}
#[cfg(target_arch = "aarch64")]
unsafe {
gemm_neon_impl(a, b, c, m, k, n);
}
#[cfg(not(target_arch = "aarch64"))]
{
gemm_scalar(a, b, c, m, k, n);
}
}
#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
pub unsafe fn gemm_parallel(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
use rayon::prelude::*;
let row_chunk_size = TILE_M;
let rows_per_chunk = row_chunk_size;
let elements_per_chunk = rows_per_chunk * n;
c.par_chunks_mut(elements_per_chunk)
.enumerate()
.for_each(|(chunk_idx, c_chunk)| {
let i_start = chunk_idx * rows_per_chunk;
let chunk_rows = c_chunk.len() / n;
let i_end = i_start + chunk_rows;
let a_start = i_start * k;
let a_end = i_end * k;
let a_chunk = &a[a_start..a_end];
gemm_neon_impl(a_chunk, b, c_chunk, chunk_rows, k, n);
});
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn gemm_tile_12x4(
a: &[f32],
b: &[f32],
c_ptr: *mut f32,
_m: usize,
k: usize,
n: usize,
i_start: usize,
i_end: usize,
j_start: usize,
j_end: usize,
k_start: usize,
k_end: usize,
) {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut ii = i_start;
while ii + MR <= i_end {
let mut jj = j_start;
while jj + NR <= j_end {
let mut c00 = vld1q_f32(c_ptr.add(ii * n + jj));
let mut c10 = vld1q_f32(c_ptr.add((ii + 1) * n + jj));
let mut c20 = vld1q_f32(c_ptr.add((ii + 2) * n + jj));
let mut c30 = vld1q_f32(c_ptr.add((ii + 3) * n + jj));
let mut c40 = vld1q_f32(c_ptr.add((ii + 4) * n + jj));
let mut c50 = vld1q_f32(c_ptr.add((ii + 5) * n + jj));
let mut c60 = vld1q_f32(c_ptr.add((ii + 6) * n + jj));
let mut c70 = vld1q_f32(c_ptr.add((ii + 7) * n + jj));
let mut c80 = vld1q_f32(c_ptr.add((ii + 8) * n + jj));
let mut c90 = vld1q_f32(c_ptr.add((ii + 9) * n + jj));
let mut ca0 = vld1q_f32(c_ptr.add((ii + 10) * n + jj));
let mut cb0 = vld1q_f32(c_ptr.add((ii + 11) * n + jj));
let mut kkk = k_start;
while kkk + 4 <= k_end {
let b0 = vld1q_f32(b_ptr.add(kkk * n + jj));
let a0 = vdupq_n_f32(*a_ptr.add(ii * k + kkk));
let a1 = vdupq_n_f32(*a_ptr.add((ii + 1) * k + kkk));
let a2 = vdupq_n_f32(*a_ptr.add((ii + 2) * k + kkk));
let a3 = vdupq_n_f32(*a_ptr.add((ii + 3) * k + kkk));
let a4 = vdupq_n_f32(*a_ptr.add((ii + 4) * k + kkk));
let a5 = vdupq_n_f32(*a_ptr.add((ii + 5) * k + kkk));
let a6 = vdupq_n_f32(*a_ptr.add((ii + 6) * k + kkk));
let a7 = vdupq_n_f32(*a_ptr.add((ii + 7) * k + kkk));
let a8 = vdupq_n_f32(*a_ptr.add((ii + 8) * k + kkk));
let a9 = vdupq_n_f32(*a_ptr.add((ii + 9) * k + kkk));
let aa = vdupq_n_f32(*a_ptr.add((ii + 10) * k + kkk));
let ab = vdupq_n_f32(*a_ptr.add((ii + 11) * k + kkk));
c00 = vfmaq_f32(c00, a0, b0);
c10 = vfmaq_f32(c10, a1, b0);
c20 = vfmaq_f32(c20, a2, b0);
c30 = vfmaq_f32(c30, a3, b0);
c40 = vfmaq_f32(c40, a4, b0);
c50 = vfmaq_f32(c50, a5, b0);
c60 = vfmaq_f32(c60, a6, b0);
c70 = vfmaq_f32(c70, a7, b0);
c80 = vfmaq_f32(c80, a8, b0);
c90 = vfmaq_f32(c90, a9, b0);
ca0 = vfmaq_f32(ca0, aa, b0);
cb0 = vfmaq_f32(cb0, ab, b0);
let b1 = vld1q_f32(b_ptr.add((kkk + 1) * n + jj));
let a0 = vdupq_n_f32(*a_ptr.add(ii * k + kkk + 1));
let a1 = vdupq_n_f32(*a_ptr.add((ii + 1) * k + kkk + 1));
let a2 = vdupq_n_f32(*a_ptr.add((ii + 2) * k + kkk + 1));
let a3 = vdupq_n_f32(*a_ptr.add((ii + 3) * k + kkk + 1));
let a4 = vdupq_n_f32(*a_ptr.add((ii + 4) * k + kkk + 1));
let a5 = vdupq_n_f32(*a_ptr.add((ii + 5) * k + kkk + 1));
let a6 = vdupq_n_f32(*a_ptr.add((ii + 6) * k + kkk + 1));
let a7 = vdupq_n_f32(*a_ptr.add((ii + 7) * k + kkk + 1));
let a8 = vdupq_n_f32(*a_ptr.add((ii + 8) * k + kkk + 1));
let a9 = vdupq_n_f32(*a_ptr.add((ii + 9) * k + kkk + 1));
let aa = vdupq_n_f32(*a_ptr.add((ii + 10) * k + kkk + 1));
let ab = vdupq_n_f32(*a_ptr.add((ii + 11) * k + kkk + 1));
c00 = vfmaq_f32(c00, a0, b1);
c10 = vfmaq_f32(c10, a1, b1);
c20 = vfmaq_f32(c20, a2, b1);
c30 = vfmaq_f32(c30, a3, b1);
c40 = vfmaq_f32(c40, a4, b1);
c50 = vfmaq_f32(c50, a5, b1);
c60 = vfmaq_f32(c60, a6, b1);
c70 = vfmaq_f32(c70, a7, b1);
c80 = vfmaq_f32(c80, a8, b1);
c90 = vfmaq_f32(c90, a9, b1);
ca0 = vfmaq_f32(ca0, aa, b1);
cb0 = vfmaq_f32(cb0, ab, b1);
let b2 = vld1q_f32(b_ptr.add((kkk + 2) * n + jj));
let a0 = vdupq_n_f32(*a_ptr.add(ii * k + kkk + 2));
let a1 = vdupq_n_f32(*a_ptr.add((ii + 1) * k + kkk + 2));
let a2 = vdupq_n_f32(*a_ptr.add((ii + 2) * k + kkk + 2));
let a3 = vdupq_n_f32(*a_ptr.add((ii + 3) * k + kkk + 2));
let a4 = vdupq_n_f32(*a_ptr.add((ii + 4) * k + kkk + 2));
let a5 = vdupq_n_f32(*a_ptr.add((ii + 5) * k + kkk + 2));
let a6 = vdupq_n_f32(*a_ptr.add((ii + 6) * k + kkk + 2));
let a7 = vdupq_n_f32(*a_ptr.add((ii + 7) * k + kkk + 2));
let a8 = vdupq_n_f32(*a_ptr.add((ii + 8) * k + kkk + 2));
let a9 = vdupq_n_f32(*a_ptr.add((ii + 9) * k + kkk + 2));
let aa = vdupq_n_f32(*a_ptr.add((ii + 10) * k + kkk + 2));
let ab = vdupq_n_f32(*a_ptr.add((ii + 11) * k + kkk + 2));
c00 = vfmaq_f32(c00, a0, b2);
c10 = vfmaq_f32(c10, a1, b2);
c20 = vfmaq_f32(c20, a2, b2);
c30 = vfmaq_f32(c30, a3, b2);
c40 = vfmaq_f32(c40, a4, b2);
c50 = vfmaq_f32(c50, a5, b2);
c60 = vfmaq_f32(c60, a6, b2);
c70 = vfmaq_f32(c70, a7, b2);
c80 = vfmaq_f32(c80, a8, b2);
c90 = vfmaq_f32(c90, a9, b2);
ca0 = vfmaq_f32(ca0, aa, b2);
cb0 = vfmaq_f32(cb0, ab, b2);
let b3 = vld1q_f32(b_ptr.add((kkk + 3) * n + jj));
let a0 = vdupq_n_f32(*a_ptr.add(ii * k + kkk + 3));
let a1 = vdupq_n_f32(*a_ptr.add((ii + 1) * k + kkk + 3));
let a2 = vdupq_n_f32(*a_ptr.add((ii + 2) * k + kkk + 3));
let a3 = vdupq_n_f32(*a_ptr.add((ii + 3) * k + kkk + 3));
let a4 = vdupq_n_f32(*a_ptr.add((ii + 4) * k + kkk + 3));
let a5 = vdupq_n_f32(*a_ptr.add((ii + 5) * k + kkk + 3));
let a6 = vdupq_n_f32(*a_ptr.add((ii + 6) * k + kkk + 3));
let a7 = vdupq_n_f32(*a_ptr.add((ii + 7) * k + kkk + 3));
let a8 = vdupq_n_f32(*a_ptr.add((ii + 8) * k + kkk + 3));
let a9 = vdupq_n_f32(*a_ptr.add((ii + 9) * k + kkk + 3));
let aa = vdupq_n_f32(*a_ptr.add((ii + 10) * k + kkk + 3));
let ab = vdupq_n_f32(*a_ptr.add((ii + 11) * k + kkk + 3));
c00 = vfmaq_f32(c00, a0, b3);
c10 = vfmaq_f32(c10, a1, b3);
c20 = vfmaq_f32(c20, a2, b3);
c30 = vfmaq_f32(c30, a3, b3);
c40 = vfmaq_f32(c40, a4, b3);
c50 = vfmaq_f32(c50, a5, b3);
c60 = vfmaq_f32(c60, a6, b3);
c70 = vfmaq_f32(c70, a7, b3);
c80 = vfmaq_f32(c80, a8, b3);
c90 = vfmaq_f32(c90, a9, b3);
ca0 = vfmaq_f32(ca0, aa, b3);
cb0 = vfmaq_f32(cb0, ab, b3);
kkk += 4;
}
while kkk < k_end {
let b0 = vld1q_f32(b_ptr.add(kkk * n + jj));
let a0 = vdupq_n_f32(*a_ptr.add(ii * k + kkk));
let a1 = vdupq_n_f32(*a_ptr.add((ii + 1) * k + kkk));
let a2 = vdupq_n_f32(*a_ptr.add((ii + 2) * k + kkk));
let a3 = vdupq_n_f32(*a_ptr.add((ii + 3) * k + kkk));
let a4 = vdupq_n_f32(*a_ptr.add((ii + 4) * k + kkk));
let a5 = vdupq_n_f32(*a_ptr.add((ii + 5) * k + kkk));
let a6 = vdupq_n_f32(*a_ptr.add((ii + 6) * k + kkk));
let a7 = vdupq_n_f32(*a_ptr.add((ii + 7) * k + kkk));
let a8 = vdupq_n_f32(*a_ptr.add((ii + 8) * k + kkk));
let a9 = vdupq_n_f32(*a_ptr.add((ii + 9) * k + kkk));
let aa = vdupq_n_f32(*a_ptr.add((ii + 10) * k + kkk));
let ab = vdupq_n_f32(*a_ptr.add((ii + 11) * k + kkk));
c00 = vfmaq_f32(c00, a0, b0);
c10 = vfmaq_f32(c10, a1, b0);
c20 = vfmaq_f32(c20, a2, b0);
c30 = vfmaq_f32(c30, a3, b0);
c40 = vfmaq_f32(c40, a4, b0);
c50 = vfmaq_f32(c50, a5, b0);
c60 = vfmaq_f32(c60, a6, b0);
c70 = vfmaq_f32(c70, a7, b0);
c80 = vfmaq_f32(c80, a8, b0);
c90 = vfmaq_f32(c90, a9, b0);
ca0 = vfmaq_f32(ca0, aa, b0);
cb0 = vfmaq_f32(cb0, ab, b0);
kkk += 1;
}
vst1q_f32(c_ptr.add(ii * n + jj), c00);
vst1q_f32(c_ptr.add((ii + 1) * n + jj), c10);
vst1q_f32(c_ptr.add((ii + 2) * n + jj), c20);
vst1q_f32(c_ptr.add((ii + 3) * n + jj), c30);
vst1q_f32(c_ptr.add((ii + 4) * n + jj), c40);
vst1q_f32(c_ptr.add((ii + 5) * n + jj), c50);
vst1q_f32(c_ptr.add((ii + 6) * n + jj), c60);
vst1q_f32(c_ptr.add((ii + 7) * n + jj), c70);
vst1q_f32(c_ptr.add((ii + 8) * n + jj), c80);
vst1q_f32(c_ptr.add((ii + 9) * n + jj), c90);
vst1q_f32(c_ptr.add((ii + 10) * n + jj), ca0);
vst1q_f32(c_ptr.add((ii + 11) * n + jj), cb0);
jj += NR;
}
while jj < j_end {
for row in ii..ii + MR {
let mut sum = *c_ptr.add(row * n + jj);
for kkk in k_start..k_end {
sum += *a_ptr.add(row * k + kkk) * *b_ptr.add(kkk * n + jj);
}
*c_ptr.add(row * n + jj) = sum;
}
jj += 1;
}
ii += MR;
}
while ii + 4 <= i_end {
let mut jj = j_start;
while jj + NR <= j_end {
let mut c00 = vld1q_f32(c_ptr.add(ii * n + jj));
let mut c10 = vld1q_f32(c_ptr.add((ii + 1) * n + jj));
let mut c20 = vld1q_f32(c_ptr.add((ii + 2) * n + jj));
let mut c30 = vld1q_f32(c_ptr.add((ii + 3) * n + jj));
for kkk in k_start..k_end {
let b0 = vld1q_f32(b_ptr.add(kkk * n + jj));
c00 = vfmaq_f32(c00, vdupq_n_f32(*a_ptr.add(ii * k + kkk)), b0);
c10 = vfmaq_f32(c10, vdupq_n_f32(*a_ptr.add((ii + 1) * k + kkk)), b0);
c20 = vfmaq_f32(c20, vdupq_n_f32(*a_ptr.add((ii + 2) * k + kkk)), b0);
c30 = vfmaq_f32(c30, vdupq_n_f32(*a_ptr.add((ii + 3) * k + kkk)), b0);
}
vst1q_f32(c_ptr.add(ii * n + jj), c00);
vst1q_f32(c_ptr.add((ii + 1) * n + jj), c10);
vst1q_f32(c_ptr.add((ii + 2) * n + jj), c20);
vst1q_f32(c_ptr.add((ii + 3) * n + jj), c30);
jj += NR;
}
while jj < j_end {
for row in ii..ii + 4 {
let mut sum = *c_ptr.add(row * n + jj);
for kkk in k_start..k_end {
sum += *a_ptr.add(row * k + kkk) * *b_ptr.add(kkk * n + jj);
}
*c_ptr.add(row * n + jj) = sum;
}
jj += 1;
}
ii += 4;
}
for row in ii..i_end {
let mut jj = j_start;
while jj + NR <= j_end {
let mut acc = vld1q_f32(c_ptr.add(row * n + jj));
for kkk in k_start..k_end {
let a_val = vdupq_n_f32(*a_ptr.add(row * k + kkk));
let b_v = vld1q_f32(b_ptr.add(kkk * n + jj));
acc = vfmaq_f32(acc, a_val, b_v);
}
vst1q_f32(c_ptr.add(row * n + jj), acc);
jj += NR;
}
for jjj in jj..j_end {
let mut sum = *c_ptr.add(row * n + jjj);
for kkk in k_start..k_end {
sum += *a_ptr.add(row * k + kkk) * *b_ptr.add(kkk * n + jjj);
}
*c_ptr.add(row * n + jjj) = sum;
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn gemm_neon_impl(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
let c_ptr = c.as_mut_ptr();
let mut i = 0usize;
while i < m {
let i_end = (i + TILE_M).min(m);
let mut j = 0usize;
while j < n {
let j_end = (j + TILE_N).min(n);
let mut kk = 0usize;
while kk < k {
let kk_end = (kk + TILE_K).min(k);
gemm_tile_12x4(a, b, c_ptr, m, k, n, i, i_end, j, j_end, kk, kk_end);
kk = kk_end;
}
j = j_end;
}
i = i_end;
}
}
#[allow(dead_code)]
fn gemm_scalar(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += a[i * k + kk] * b[kk * n + j];
}
c[i * n + j] = sum;
}
}
}
#[inline(always)]
pub fn batched_gemm_neon(
a: &[f32],
b: &[f32],
c: &mut [f32],
batch_size: usize,
m: usize,
k: usize,
n: usize,
) {
debug_assert_eq!(a.len(), batch_size * m * k);
debug_assert_eq!(b.len(), batch_size * k * n);
debug_assert_eq!(c.len(), batch_size * m * n);
let a_batch_stride = m * k;
let b_batch_stride = k * n;
let c_batch_stride = m * n;
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
{
use rayon::prelude::*;
if batch_size > 1 && batch_size * m * n >= PARALLEL_THRESHOLD {
c.par_chunks_mut(c_batch_stride)
.enumerate()
.for_each(|(batch, c_slice)| {
let a_offset = batch * a_batch_stride;
let b_offset = batch * b_batch_stride;
c_slice.fill(0.0);
#[cfg(target_arch = "aarch64")]
unsafe {
gemm_neon_impl(
&a[a_offset..a_offset + a_batch_stride],
&b[b_offset..b_offset + b_batch_stride],
c_slice,
m,
k,
n,
);
}
#[cfg(not(target_arch = "aarch64"))]
{
gemm_scalar(
&a[a_offset..a_offset + a_batch_stride],
&b[b_offset..b_offset + b_batch_stride],
c_slice,
m,
k,
n,
);
}
});
return;
}
}
for batch in 0..batch_size {
let a_offset = batch * a_batch_stride;
let b_offset = batch * b_batch_stride;
let c_offset = batch * c_batch_stride;
gemm_neon(
&a[a_offset..a_offset + a_batch_stride],
&b[b_offset..b_offset + b_batch_stride],
&mut c[c_offset..c_offset + c_batch_stride],
m,
k,
n,
);
}
}
pub fn gemm_nt_neon(a: &[f32], b_t: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b_t.len(), n * k);
debug_assert_eq!(c.len(), m * n);
c.fill(0.0);
#[cfg(target_arch = "aarch64")]
unsafe {
gemm_nt_neon_impl(a, b_t, c, m, k, n);
}
#[cfg(not(target_arch = "aarch64"))]
{
gemm_nt_scalar(a, b_t, c, m, k, n);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn gemm_nt_neon_impl(a: &[f32], b_t: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
let a_ptr = a.as_ptr();
let b_ptr = b_t.as_ptr();
let c_ptr = c.as_mut_ptr();
let m_chunks = m / 4;
let mut i = 0usize;
for _ in 0..m_chunks {
let n_chunks = n / 4;
let mut j = 0usize;
for _ in 0..n_chunks {
let mut c00 = 0.0f32;
let mut c01 = 0.0f32;
let mut c02 = 0.0f32;
let mut c03 = 0.0f32;
let mut c10 = 0.0f32;
let mut c11 = 0.0f32;
let mut c12 = 0.0f32;
let mut c13 = 0.0f32;
let mut c20 = 0.0f32;
let mut c21 = 0.0f32;
let mut c22 = 0.0f32;
let mut c23 = 0.0f32;
let mut c30 = 0.0f32;
let mut c31 = 0.0f32;
let mut c32 = 0.0f32;
let mut c33 = 0.0f32;
let k_chunks = k / 4;
let mut kk = 0usize;
for _ in 0..k_chunks {
let a0 = vld1q_f32(a_ptr.add(i * k + kk));
let a1 = vld1q_f32(a_ptr.add((i + 1) * k + kk));
let a2 = vld1q_f32(a_ptr.add((i + 2) * k + kk));
let a3 = vld1q_f32(a_ptr.add((i + 3) * k + kk));
let b0 = vld1q_f32(b_ptr.add(j * k + kk));
let b1 = vld1q_f32(b_ptr.add((j + 1) * k + kk));
let b2 = vld1q_f32(b_ptr.add((j + 2) * k + kk));
let b3 = vld1q_f32(b_ptr.add((j + 3) * k + kk));
c00 += vaddvq_f32(vmulq_f32(a0, b0));
c01 += vaddvq_f32(vmulq_f32(a0, b1));
c02 += vaddvq_f32(vmulq_f32(a0, b2));
c03 += vaddvq_f32(vmulq_f32(a0, b3));
c10 += vaddvq_f32(vmulq_f32(a1, b0));
c11 += vaddvq_f32(vmulq_f32(a1, b1));
c12 += vaddvq_f32(vmulq_f32(a1, b2));
c13 += vaddvq_f32(vmulq_f32(a1, b3));
c20 += vaddvq_f32(vmulq_f32(a2, b0));
c21 += vaddvq_f32(vmulq_f32(a2, b1));
c22 += vaddvq_f32(vmulq_f32(a2, b2));
c23 += vaddvq_f32(vmulq_f32(a2, b3));
c30 += vaddvq_f32(vmulq_f32(a3, b0));
c31 += vaddvq_f32(vmulq_f32(a3, b1));
c32 += vaddvq_f32(vmulq_f32(a3, b2));
c33 += vaddvq_f32(vmulq_f32(a3, b3));
kk += 4;
}
for kkk in kk..k {
let a0 = *a_ptr.add(i * k + kkk);
let a1 = *a_ptr.add((i + 1) * k + kkk);
let a2 = *a_ptr.add((i + 2) * k + kkk);
let a3 = *a_ptr.add((i + 3) * k + kkk);
let b0 = *b_ptr.add(j * k + kkk);
let b1 = *b_ptr.add((j + 1) * k + kkk);
let b2 = *b_ptr.add((j + 2) * k + kkk);
let b3 = *b_ptr.add((j + 3) * k + kkk);
c00 += a0 * b0;
c01 += a0 * b1;
c02 += a0 * b2;
c03 += a0 * b3;
c10 += a1 * b0;
c11 += a1 * b1;
c12 += a1 * b2;
c13 += a1 * b3;
c20 += a2 * b0;
c21 += a2 * b1;
c22 += a2 * b2;
c23 += a2 * b3;
c30 += a3 * b0;
c31 += a3 * b1;
c32 += a3 * b2;
c33 += a3 * b3;
}
*c_ptr.add(i * n + j) = c00;
*c_ptr.add(i * n + j + 1) = c01;
*c_ptr.add(i * n + j + 2) = c02;
*c_ptr.add(i * n + j + 3) = c03;
*c_ptr.add((i + 1) * n + j) = c10;
*c_ptr.add((i + 1) * n + j + 1) = c11;
*c_ptr.add((i + 1) * n + j + 2) = c12;
*c_ptr.add((i + 1) * n + j + 3) = c13;
*c_ptr.add((i + 2) * n + j) = c20;
*c_ptr.add((i + 2) * n + j + 1) = c21;
*c_ptr.add((i + 2) * n + j + 2) = c22;
*c_ptr.add((i + 2) * n + j + 3) = c23;
*c_ptr.add((i + 3) * n + j) = c30;
*c_ptr.add((i + 3) * n + j + 1) = c31;
*c_ptr.add((i + 3) * n + j + 2) = c32;
*c_ptr.add((i + 3) * n + j + 3) = c33;
j += 4;
}
for jj in j..n {
for ii in i..i + 4 {
let mut acc = vdupq_n_f32(0.0);
let k_chunks = k / 4;
let mut kk = 0usize;
for _ in 0..k_chunks {
let a_v = vld1q_f32(a_ptr.add(ii * k + kk));
let b_v = vld1q_f32(b_ptr.add(jj * k + kk));
acc = vfmaq_f32(acc, a_v, b_v);
kk += 4;
}
let mut sum = vaddvq_f32(acc);
for kkk in kk..k {
sum += *a_ptr.add(ii * k + kkk) * *b_ptr.add(jj * k + kkk);
}
*c_ptr.add(ii * n + jj) = sum;
}
}
i += 4;
}
for ii in i..m {
for jj in 0..n {
let mut acc = vdupq_n_f32(0.0);
let k_chunks = k / 4;
let mut kk = 0usize;
for _ in 0..k_chunks {
let a_v = vld1q_f32(a_ptr.add(ii * k + kk));
let b_v = vld1q_f32(b_ptr.add(jj * k + kk));
acc = vfmaq_f32(acc, a_v, b_v);
kk += 4;
}
let mut sum = vaddvq_f32(acc);
for kkk in kk..k {
sum += *a_ptr.add(ii * k + kkk) * *b_ptr.add(jj * k + kkk);
}
*c_ptr.add(ii * n + jj) = sum;
}
}
}
#[allow(dead_code)]
fn gemm_nt_scalar(a: &[f32], b_t: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += a[i * k + kk] * b_t[j * k + kk];
}
c[i * n + j] = sum;
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut sum0 = vdupq_n_f32(0.0);
let mut sum1 = vdupq_n_f32(0.0);
let mut sum2 = vdupq_n_f32(0.0);
let mut sum3 = vdupq_n_f32(0.0);
let mut sum4 = vdupq_n_f32(0.0);
let mut sum5 = vdupq_n_f32(0.0);
let mut sum6 = vdupq_n_f32(0.0);
let mut sum7 = vdupq_n_f32(0.0);
let chunks = len / 32; let mut idx = 0usize;
for _ in 0..chunks {
let a0 = vld1q_f32(a_ptr.add(idx));
let b0 = vld1q_f32(b_ptr.add(idx));
sum0 = vfmaq_f32(sum0, a0, b0);
let a1 = vld1q_f32(a_ptr.add(idx + 4));
let b1 = vld1q_f32(b_ptr.add(idx + 4));
sum1 = vfmaq_f32(sum1, a1, b1);
let a2 = vld1q_f32(a_ptr.add(idx + 8));
let b2 = vld1q_f32(b_ptr.add(idx + 8));
sum2 = vfmaq_f32(sum2, a2, b2);
let a3 = vld1q_f32(a_ptr.add(idx + 12));
let b3 = vld1q_f32(b_ptr.add(idx + 12));
sum3 = vfmaq_f32(sum3, a3, b3);
let a4 = vld1q_f32(a_ptr.add(idx + 16));
let b4 = vld1q_f32(b_ptr.add(idx + 16));
sum4 = vfmaq_f32(sum4, a4, b4);
let a5 = vld1q_f32(a_ptr.add(idx + 20));
let b5 = vld1q_f32(b_ptr.add(idx + 20));
sum5 = vfmaq_f32(sum5, a5, b5);
let a6 = vld1q_f32(a_ptr.add(idx + 24));
let b6 = vld1q_f32(b_ptr.add(idx + 24));
sum6 = vfmaq_f32(sum6, a6, b6);
let a7 = vld1q_f32(a_ptr.add(idx + 28));
let b7 = vld1q_f32(b_ptr.add(idx + 28));
sum7 = vfmaq_f32(sum7, a7, b7);
idx += 32;
}
let sum01 = vaddq_f32(sum0, sum1);
let sum23 = vaddq_f32(sum2, sum3);
let sum45 = vaddq_f32(sum4, sum5);
let sum67 = vaddq_f32(sum6, sum7);
let sum0123 = vaddq_f32(sum01, sum23);
let sum4567 = vaddq_f32(sum45, sum67);
let mut sum = vaddq_f32(sum0123, sum4567);
while idx + 4 <= len {
let a_v = vld1q_f32(a_ptr.add(idx));
let b_v = vld1q_f32(b_ptr.add(idx));
sum = vfmaq_f32(sum, a_v, b_v);
idx += 4;
}
let mut result = vaddvq_f32(sum);
for i in idx..len {
result += *a_ptr.add(i) * *b_ptr.add(i);
}
result
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub unsafe fn scale_vector_neon(x: &mut [f32], scale: f32) {
let len = x.len();
let x_ptr = x.as_mut_ptr();
let scale_vec = vdupq_n_f32(scale);
let chunks = len / 16;
let mut idx = 0usize;
for _ in 0..chunks {
let v0 = vld1q_f32(x_ptr.add(idx));
vst1q_f32(x_ptr.add(idx), vmulq_f32(v0, scale_vec));
let v1 = vld1q_f32(x_ptr.add(idx + 4));
vst1q_f32(x_ptr.add(idx + 4), vmulq_f32(v1, scale_vec));
let v2 = vld1q_f32(x_ptr.add(idx + 8));
vst1q_f32(x_ptr.add(idx + 8), vmulq_f32(v2, scale_vec));
let v3 = vld1q_f32(x_ptr.add(idx + 12));
vst1q_f32(x_ptr.add(idx + 12), vmulq_f32(v3, scale_vec));
idx += 16;
}
while idx + 4 <= len {
let v = vld1q_f32(x_ptr.add(idx));
vst1q_f32(x_ptr.add(idx), vmulq_f32(v, scale_vec));
idx += 4;
}
for i in idx..len {
*x_ptr.add(i) *= scale;
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub unsafe fn add_vectors_neon(x: &mut [f32], y: &[f32]) {
debug_assert_eq!(x.len(), y.len());
let len = x.len();
let x_ptr = x.as_mut_ptr();
let y_ptr = y.as_ptr();
let chunks = len / 16;
let mut idx = 0usize;
for _ in 0..chunks {
let x0 = vld1q_f32(x_ptr.add(idx));
let y0 = vld1q_f32(y_ptr.add(idx));
vst1q_f32(x_ptr.add(idx), vaddq_f32(x0, y0));
let x1 = vld1q_f32(x_ptr.add(idx + 4));
let y1 = vld1q_f32(y_ptr.add(idx + 4));
vst1q_f32(x_ptr.add(idx + 4), vaddq_f32(x1, y1));
let x2 = vld1q_f32(x_ptr.add(idx + 8));
let y2 = vld1q_f32(y_ptr.add(idx + 8));
vst1q_f32(x_ptr.add(idx + 8), vaddq_f32(x2, y2));
let x3 = vld1q_f32(x_ptr.add(idx + 12));
let y3 = vld1q_f32(y_ptr.add(idx + 12));
vst1q_f32(x_ptr.add(idx + 12), vaddq_f32(x3, y3));
idx += 16;
}
while idx + 4 <= len {
let x_v = vld1q_f32(x_ptr.add(idx));
let y_v = vld1q_f32(y_ptr.add(idx));
vst1q_f32(x_ptr.add(idx), vaddq_f32(x_v, y_v));
idx += 4;
}
for i in idx..len {
*x_ptr.add(i) += *y_ptr.add(i);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub unsafe fn fused_axpby_neon(x: &mut [f32], y: &[f32], a: f32, b: f32) {
debug_assert_eq!(x.len(), y.len());
let len = x.len();
let x_ptr = x.as_mut_ptr();
let y_ptr = y.as_ptr();
let a_vec = vdupq_n_f32(a);
let b_vec = vdupq_n_f32(b);
let chunks = len / NEON_LANE_WIDTH;
let mut idx = 0usize;
for _ in 0..chunks {
let x_v = vld1q_f32(x_ptr.add(idx));
let y_v = vld1q_f32(y_ptr.add(idx));
let result = vfmaq_f32(vmulq_f32(x_v, a_vec), y_v, b_vec);
vst1q_f32(x_ptr.add(idx), result);
idx += 4;
}
for i in idx..len {
*x_ptr.add(i) = a * *x_ptr.add(i) + b * *y_ptr.add(i);
}
}
#[cfg(target_arch = "aarch64")]
pub fn gemv_f16(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
use half::f16;
debug_assert_eq!(a.len(), m * n);
debug_assert_eq!(x.len(), n);
debug_assert_eq!(y.len(), m);
let a_f16: Vec<f16> = a.iter().map(|&v| f16::from_f32(v)).collect();
let x_f16: Vec<f16> = x.iter().map(|&v| f16::from_f32(v)).collect();
for row in 0..m {
let mut sum = f16::from_f32(0.0);
for col in 0..n {
sum += a_f16[row * n + col] * x_f16[col];
}
y[row] = sum.to_f32();
}
}
#[cfg(target_arch = "aarch64")]
pub fn gemm_f16(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
use half::f16;
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
debug_assert_eq!(c.len(), m * n);
let a_f16: Vec<f16> = a.iter().map(|&v| f16::from_f32(v)).collect();
let b_f16: Vec<f16> = b.iter().map(|&v| f16::from_f32(v)).collect();
for i in 0..m {
for j in 0..n {
let mut sum = f16::from_f32(0.0);
for kk in 0..k {
sum += a_f16[i * k + kk] * b_f16[kk * n + j];
}
c[i * n + j] = sum.to_f32();
}
}
}
#[allow(dead_code)]
const _: usize = PREFETCH_DISTANCE;
const METAL_GEMV_THRESHOLD: usize = 512 * 512;
pub fn gemv_metal_if_available(a: &[f32], x: &[f32], m: usize, n: usize) -> Vec<f32> {
debug_assert_eq!(a.len(), m * n);
debug_assert_eq!(x.len(), n);
#[cfg(all(target_os = "macos", feature = "metal-compute"))]
{
if m * n >= METAL_GEMV_THRESHOLD {
if let Some(result) = try_gemv_metal(a, x, m, n) {
return result;
}
}
}
let mut y = vec![0.0f32; m];
gemv_neon(a, x, &mut y, m, n);
y
}
pub fn gemv_metal_if_available_inplace(
a: &[f32],
x: &[f32],
y: &mut [f32],
m: usize,
n: usize,
) -> bool {
debug_assert_eq!(a.len(), m * n);
debug_assert_eq!(x.len(), n);
debug_assert_eq!(y.len(), m);
#[cfg(all(target_os = "macos", feature = "metal-compute"))]
{
if m * n >= METAL_GEMV_THRESHOLD {
if let Some(result) = try_gemv_metal(a, x, m, n) {
y.copy_from_slice(&result);
return true;
}
}
}
gemv_neon(a, x, y, m, n);
false
}
#[cfg(all(target_os = "macos", feature = "metal-compute"))]
fn try_gemv_metal(a: &[f32], x: &[f32], m: usize, n: usize) -> Option<Vec<f32>> {
use crate::metal::{gemv_metal, is_metal_available, MetalConfig, MetalContext};
if !is_metal_available() {
return None;
}
let ctx = match MetalContext::new(MetalConfig::default()) {
Ok(ctx) => ctx,
Err(_) => return None,
};
match gemv_metal(&ctx, a, x, m, n) {
Ok(result) => Some(result),
Err(_) => None,
}
}
#[cfg(all(target_os = "macos", feature = "metal-compute"))]
pub fn is_metal_gemv_available() -> bool {
crate::metal::is_metal_available()
}
#[cfg(not(all(target_os = "macos", feature = "metal-compute")))]
pub fn is_metal_gemv_available() -> bool {
false
}
pub fn get_metal_gemv_threshold() -> usize {
METAL_GEMV_THRESHOLD
}
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
pub fn configure_thread_pool(num_threads: usize) -> bool {
use rayon::ThreadPoolBuilder;
let threads = if num_threads == 0 {
get_physical_cores()
} else {
num_threads
};
ThreadPoolBuilder::new()
.num_threads(threads)
.build_global()
.is_ok()
}
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
pub fn get_physical_cores() -> usize {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
}
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
pub fn batched_gemm_parallel(
a: &[f32],
b: &[f32],
c: &mut [f32],
batch_size: usize,
m: usize,
k: usize,
n: usize,
) {
use rayon::prelude::*;
debug_assert_eq!(a.len(), batch_size * m * k);
debug_assert_eq!(b.len(), batch_size * k * n);
debug_assert_eq!(c.len(), batch_size * m * n);
let a_batch_stride = m * k;
let b_batch_stride = k * n;
let c_batch_stride = m * n;
c.par_chunks_mut(c_batch_stride)
.enumerate()
.for_each(|(batch, c_slice)| {
let a_offset = batch * a_batch_stride;
let b_offset = batch * b_batch_stride;
c_slice.fill(0.0);
#[cfg(target_arch = "aarch64")]
unsafe {
gemm_neon_impl(
&a[a_offset..a_offset + a_batch_stride],
&b[b_offset..b_offset + b_batch_stride],
c_slice,
m,
k,
n,
);
}
#[cfg(not(target_arch = "aarch64"))]
{
gemm_scalar(
&a[a_offset..a_offset + a_batch_stride],
&b[b_offset..b_offset + b_batch_stride],
c_slice,
m,
k,
n,
);
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gemv_basic() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let x = vec![1.0, 2.0, 3.0];
let mut y = vec![0.0; 2];
gemv_neon(&a, &x, &mut y, 2, 3);
assert!((y[0] - 14.0).abs() < 1e-5);
assert!((y[1] - 32.0).abs() < 1e-5);
}
#[test]
fn test_gemv_large() {
let m = 64;
let n = 128;
let a: Vec<f32> = (0..m * n).map(|i| (i as f32) * 0.01).collect();
let x: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
let mut y = vec![0.0; m];
gemv_neon(&a, &x, &mut y, m, n);
let mut y_scalar = vec![0.0; m];
gemv_scalar(&a, &x, &mut y_scalar, m, n);
for i in 0..m {
let tol = (y_scalar[i].abs() * 1e-5).max(1e-3);
assert!(
(y[i] - y_scalar[i]).abs() < tol,
"Mismatch at {}: {} vs {} (tol: {})",
i,
y[i],
y_scalar[i],
tol
);
}
}
#[test]
fn test_gemm_basic() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut c = vec![0.0; 4];
gemm_neon(&a, &b, &mut c, 2, 3, 2);
assert!((c[0] - 22.0).abs() < 1e-4, "c[0,0] = {}", c[0]);
assert!((c[1] - 28.0).abs() < 1e-4, "c[0,1] = {}", c[1]);
assert!((c[2] - 49.0).abs() < 1e-4, "c[1,0] = {}", c[2]);
assert!((c[3] - 64.0).abs() < 1e-4, "c[1,1] = {}", c[3]);
}
#[test]
fn test_gemm_large() {
let m = 32;
let k = 64;
let n = 32;
let a: Vec<f32> = (0..m * k).map(|i| (i as f32) * 0.001).collect();
let b: Vec<f32> = (0..k * n).map(|i| (i as f32) * 0.001).collect();
let mut c = vec![0.0; m * n];
gemm_neon(&a, &b, &mut c, m, k, n);
let mut c_scalar = vec![0.0; m * n];
gemm_scalar(&a, &b, &mut c_scalar, m, k, n);
for i in 0..(m * n) {
assert!(
(c[i] - c_scalar[i]).abs() < 0.1,
"Mismatch at {}: {} vs {}",
i,
c[i],
c_scalar[i]
);
}
}
#[test]
fn test_batched_gemm() {
let batch = 4;
let m = 8;
let k = 16;
let n = 8;
let a: Vec<f32> = (0..batch * m * k).map(|i| (i as f32) * 0.01).collect();
let b: Vec<f32> = (0..batch * k * n).map(|i| (i as f32) * 0.01).collect();
let mut c = vec![0.0; batch * m * n];
batched_gemm_neon(&a, &b, &mut c, batch, m, k, n);
assert!(c.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_gemm_nt() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let b_t = vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]; let mut c = vec![0.0; 4];
gemm_nt_neon(&a, &b_t, &mut c, 2, 3, 2);
assert!((c[0] - 22.0).abs() < 1e-4, "c[0,0] = {}", c[0]);
assert!((c[1] - 28.0).abs() < 1e-4, "c[0,1] = {}", c[1]);
assert!((c[2] - 49.0).abs() < 1e-4, "c[1,0] = {}", c[2]);
assert!((c[3] - 64.0).abs() < 1e-4, "c[1,1] = {}", c[3]);
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let result = unsafe { dot_product_neon(&a, &b) };
assert!((result - 36.0).abs() < 1e-5);
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_scale_vector() {
let mut x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
unsafe { scale_vector_neon(&mut x, 2.0) };
for (i, &v) in x.iter().enumerate() {
assert!((v - ((i + 1) as f32 * 2.0)).abs() < 1e-5);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_add_vectors() {
let mut x = vec![1.0, 2.0, 3.0, 4.0];
let y = vec![10.0, 20.0, 30.0, 40.0];
unsafe { add_vectors_neon(&mut x, &y) };
assert!((x[0] - 11.0).abs() < 1e-5);
assert!((x[1] - 22.0).abs() < 1e-5);
assert!((x[2] - 33.0).abs() < 1e-5);
assert!((x[3] - 44.0).abs() < 1e-5);
}
#[test]
fn test_identity_gemm() {
let a = vec![1.0, 0.0, 0.0, 1.0]; let b = vec![5.0, 6.0, 7.0, 8.0]; let mut c = vec![0.0; 4];
gemm_neon(&a, &b, &mut c, 2, 2, 2);
assert!((c[0] - 5.0).abs() < 1e-5);
assert!((c[1] - 6.0).abs() < 1e-5);
assert!((c[2] - 7.0).abs() < 1e-5);
assert!((c[3] - 8.0).abs() < 1e-5);
}
#[test]
fn test_gemm_12_row_boundary() {
let m = 13; let k = 16;
let n = 8;
let a: Vec<f32> = (0..m * k).map(|i| (i as f32) * 0.01).collect();
let b: Vec<f32> = (0..k * n).map(|i| (i as f32) * 0.01).collect();
let mut c = vec![0.0; m * n];
gemm_neon(&a, &b, &mut c, m, k, n);
let mut c_scalar = vec![0.0; m * n];
gemm_scalar(&a, &b, &mut c_scalar, m, k, n);
for i in 0..(m * n) {
assert!(
(c[i] - c_scalar[i]).abs() < 0.01,
"Mismatch at {}: {} vs {}",
i,
c[i],
c_scalar[i]
);
}
}
#[test]
fn test_gemv_12_row_boundary() {
let m = 13; let n = 32;
let a: Vec<f32> = (0..m * n).map(|i| (i as f32) * 0.01).collect();
let x: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
let mut y = vec![0.0; m];
gemv_neon(&a, &x, &mut y, m, n);
let mut y_scalar = vec![0.0; m];
gemv_scalar(&a, &x, &mut y_scalar, m, n);
for i in 0..m {
let tol = (y_scalar[i].abs() * 1e-5).max(1e-3);
assert!(
(y[i] - y_scalar[i]).abs() < tol,
"Mismatch at {}: {} vs {}",
i,
y[i],
y_scalar[i]
);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_gemv_f16() {
let m = 8;
let n = 16;
let a: Vec<f32> = (0..m * n).map(|i| (i as f32) * 0.01).collect();
let x: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
let mut y = vec![0.0; m];
gemv_f16(&a, &x, &mut y, m, n);
assert!(y.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_gemv_metal_if_available_small() {
let m = 4;
let n = 8;
let a = vec![1.0f32; m * n];
let x = vec![1.0f32; n];
let y = gemv_metal_if_available(&a, &x, m, n);
assert_eq!(y.len(), m);
for i in 0..m {
assert!(
(y[i] - n as f32).abs() < 1e-5,
"y[{}] = {}, expected {}",
i,
y[i],
n
);
}
}
#[test]
fn test_gemv_metal_if_available_correctness() {
let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let x = vec![1.0f32, 2.0, 3.0];
let y = gemv_metal_if_available(&a, &x, 2, 3);
assert_eq!(y.len(), 2);
assert!((y[0] - 14.0).abs() < 1e-4, "y[0] = {}, expected 14", y[0]);
assert!((y[1] - 32.0).abs() < 1e-4, "y[1] = {}, expected 32", y[1]);
}
#[test]
fn test_gemv_metal_if_available_inplace() {
let m = 8;
let n = 16;
let a = vec![1.0f32; m * n];
let x = vec![1.0f32; n];
let mut y = vec![0.0f32; m];
let _used_metal = gemv_metal_if_available_inplace(&a, &x, &mut y, m, n);
for i in 0..m {
assert!(
(y[i] - n as f32).abs() < 1e-5,
"y[{}] = {}, expected {}",
i,
y[i],
n
);
}
}
#[test]
fn test_is_metal_gemv_available() {
let available = is_metal_gemv_available();
println!("Metal GEMV available: {}", available);
}
#[test]
fn test_get_metal_gemv_threshold() {
let threshold = get_metal_gemv_threshold();
assert_eq!(threshold, 512 * 512);
}
#[cfg(target_os = "macos")]
#[test]
fn test_gemv_metal_large_matrix() {
let m = 512;
let n = 512;
let a = vec![1.0f32; m * n];
let x = vec![1.0f32; n];
let y = gemv_metal_if_available(&a, &x, m, n);
assert_eq!(y.len(), m);
for i in 0..m {
assert!(
(y[i] - n as f32).abs() < 1e-3,
"y[{}] = {}, expected {}",
i,
y[i],
n
);
}
}
}