pub fn matrix_multiply_f32(
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) {
assert_eq!(a.len(), m * k, "matrix_multiply_f32: a.len() != m * k");
assert_eq!(b.len(), k * n, "matrix_multiply_f32: b.len() != k * n");
assert_eq!(c.len(), m * n, "matrix_multiply_f32: c.len() != m * n");
for val in c.iter_mut() {
*val = 0.0;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe { avx2::matrix_multiply_f32_avx2(a, b, c, m, n, k) };
return;
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { neon::matrix_multiply_f32_neon(a, b, c, m, n, k) };
return;
}
#[allow(unreachable_code)]
scalar::matrix_multiply_f32_scalar(a, b, c, m, n, k);
}
const TILE_SIZE: usize = 32;
mod scalar {
use super::TILE_SIZE;
pub fn matrix_multiply_f32_scalar(
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) {
let ti_count = (m + TILE_SIZE - 1) / TILE_SIZE;
let tj_count = (n + TILE_SIZE - 1) / TILE_SIZE;
let tp_count = (k + TILE_SIZE - 1) / TILE_SIZE;
for ti in 0..ti_count {
let i_start = ti * TILE_SIZE;
let i_end = (i_start + TILE_SIZE).min(m);
for tj in 0..tj_count {
let j_start = tj * TILE_SIZE;
let j_end = (j_start + TILE_SIZE).min(n);
for tp in 0..tp_count {
let p_start = tp * TILE_SIZE;
let p_end = (p_start + TILE_SIZE).min(k);
for i in i_start..i_end {
for p in p_start..p_end {
let a_ip = a[i * k + p];
for j in j_start..j_end {
c[i * n + j] += a_ip * b[p * n + j];
}
}
}
}
}
}
}
}
#[cfg(target_arch = "x86_64")]
mod avx2 {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use super::TILE_SIZE;
const AVX2_F32_LANES: usize = 8;
#[target_feature(enable = "avx2")]
pub unsafe fn matrix_multiply_f32_avx2(
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) {
let ti_count = (m + TILE_SIZE - 1) / TILE_SIZE;
let tj_count = (n + TILE_SIZE - 1) / TILE_SIZE;
let tp_count = (k + TILE_SIZE - 1) / TILE_SIZE;
for ti in 0..ti_count {
let i_start = ti * TILE_SIZE;
let i_end = (i_start + TILE_SIZE).min(m);
for tj in 0..tj_count {
let j_start = tj * TILE_SIZE;
let j_end = (j_start + TILE_SIZE).min(n);
for tp in 0..tp_count {
let p_start = tp * TILE_SIZE;
let p_end = (p_start + TILE_SIZE).min(k);
for i in i_start..i_end {
for p in p_start..p_end {
let a_val = _mm256_set1_ps(a[i * k + p]);
let mut j = j_start;
while j + AVX2_F32_LANES <= j_end {
let b_vec = _mm256_loadu_ps(b.as_ptr().add(p * n + j));
let c_vec = _mm256_loadu_ps(c.as_ptr().add(i * n + j));
let result = _mm256_add_ps(c_vec, _mm256_mul_ps(a_val, b_vec));
_mm256_storeu_ps(c.as_mut_ptr().add(i * n + j), result);
j += AVX2_F32_LANES;
}
let a_ip = a[i * k + p];
while j < j_end {
c[i * n + j] += a_ip * b[p * n + j];
j += 1;
}
}
}
}
}
}
}
}
#[cfg(target_arch = "aarch64")]
mod neon {
use std::arch::aarch64::*;
use super::TILE_SIZE;
const NEON_F32_LANES: usize = 4;
pub unsafe fn matrix_multiply_f32_neon(
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) {
let ti_count = (m + TILE_SIZE - 1) / TILE_SIZE;
let tj_count = (n + TILE_SIZE - 1) / TILE_SIZE;
let tp_count = (k + TILE_SIZE - 1) / TILE_SIZE;
for ti in 0..ti_count {
let i_start = ti * TILE_SIZE;
let i_end = (i_start + TILE_SIZE).min(m);
for tj in 0..tj_count {
let j_start = tj * TILE_SIZE;
let j_end = (j_start + TILE_SIZE).min(n);
for tp in 0..tp_count {
let p_start = tp * TILE_SIZE;
let p_end = (p_start + TILE_SIZE).min(k);
for i in i_start..i_end {
for p in p_start..p_end {
let a_val = vdupq_n_f32(a[i * k + p]);
let mut j = j_start;
while j + NEON_F32_LANES <= j_end {
let b_vec = vld1q_f32(b.as_ptr().add(p * n + j));
let c_vec = vld1q_f32(c.as_ptr().add(i * n + j));
let result = vfmaq_f32(c_vec, a_val, b_vec);
vst1q_f32(c.as_mut_ptr().add(i * n + j), result);
j += NEON_F32_LANES;
}
let a_ip = a[i * k + p];
while j < j_end {
c[i * n + j] += a_ip * b[p * n + j];
j += 1;
}
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-3;
fn approx_eq(a: f32, b: f32) -> bool {
(a - b).abs() < EPSILON
}
#[test]
fn test_identity_multiply() {
let a: Vec<f32> = vec![
1.0, 0.0, 0.0,
0.0, 1.0, 0.0,
0.0, 0.0, 1.0,
];
let b: Vec<f32> = vec![
1.0, 2.0,
3.0, 4.0,
5.0, 6.0,
];
let mut c = vec![0.0; 6];
matrix_multiply_f32(&a, &b, &mut c, 3, 2, 3);
assert!(approx_eq(c[0], 1.0));
assert!(approx_eq(c[1], 2.0));
assert!(approx_eq(c[2], 3.0));
assert!(approx_eq(c[3], 4.0));
assert!(approx_eq(c[4], 5.0));
assert!(approx_eq(c[5], 6.0));
}
#[test]
fn test_2x2_multiply() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let mut c = vec![0.0; 4];
matrix_multiply_f32(&a, &b, &mut c, 2, 2, 2);
assert!(approx_eq(c[0], 19.0));
assert!(approx_eq(c[1], 22.0));
assert!(approx_eq(c[2], 43.0));
assert!(approx_eq(c[3], 50.0));
}
#[test]
fn test_large_matrix() {
let m = 64;
let n = 64;
let k = 64;
let a: Vec<f32> = (0..m * k).map(|i| (i % 7) as f32 * 0.1).collect();
let b: Vec<f32> = (0..k * n).map(|i| (i % 5) as f32 * 0.1).collect();
let mut c_simd = vec![0.0; m * n];
let mut c_ref = vec![0.0; m * n];
matrix_multiply_f32(&a, &b, &mut c_simd, m, n, k);
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for p in 0..k {
sum += a[i * k + p] * b[p * n + j];
}
c_ref[i * n + j] = sum;
}
}
for idx in 0..m * n {
assert!(
approx_eq(c_simd[idx], c_ref[idx]),
"Mismatch at index {idx}: got {}, expected {}",
c_simd[idx],
c_ref[idx]
);
}
}
#[test]
fn test_non_square_matrix() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
];
let mut c = vec![0.0; 8];
matrix_multiply_f32(&a, &b, &mut c, 2, 4, 3);
assert!(approx_eq(c[0], 38.0));
assert!(approx_eq(c[1], 44.0));
assert!(approx_eq(c[2], 50.0));
assert!(approx_eq(c[3], 56.0));
assert!(approx_eq(c[4], 83.0));
assert!(approx_eq(c[5], 98.0));
assert!(approx_eq(c[6], 113.0));
assert!(approx_eq(c[7], 128.0));
}
#[test]
#[should_panic(expected = "a.len() != m * k")]
fn test_dimension_mismatch() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0, 4.0];
let mut c = vec![0.0; 4];
matrix_multiply_f32(&a, &b, &mut c, 2, 2, 2);
}
}