use super::MR;
macro_rules! define_scalar_matmul {
($name:ident, $ty:ty) => {
#[allow(clippy::too_many_arguments)]
pub unsafe fn $name(
a: *const $ty,
b: *const $ty,
out: *mut $ty,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
let out_slice = std::slice::from_raw_parts_mut(out, m * ldc);
for i in 0..m {
out_slice[i * ldc..i * ldc + n].fill(0.0);
}
for i in 0..m {
let c_row = &mut std::slice::from_raw_parts_mut(out.add(i * ldc), n)[..n];
for kk in 0..k {
let a_val = *a.add(i * lda + kk);
let b_row = std::slice::from_raw_parts(b.add(kk * ldb), n);
for j in 0..n {
c_row[j] += a_val * b_row[j];
}
}
}
}
};
}
macro_rules! define_scalar_matmul_bias {
($name:ident, $ty:ty) => {
#[allow(clippy::too_many_arguments)]
pub unsafe fn $name(
a: *const $ty,
b: *const $ty,
bias: *const $ty,
out: *mut $ty,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
let bias_slice = std::slice::from_raw_parts(bias, n);
for i in 0..m {
let c_row = &mut std::slice::from_raw_parts_mut(out.add(i * ldc), n)[..n];
c_row.copy_from_slice(bias_slice);
}
for i in 0..m {
let c_row = &mut std::slice::from_raw_parts_mut(out.add(i * ldc), n)[..n];
for kk in 0..k {
let a_val = *a.add(i * lda + kk);
let b_row = std::slice::from_raw_parts(b.add(kk * ldb), n);
for j in 0..n {
c_row[j] += a_val * b_row[j];
}
}
}
}
};
}
macro_rules! define_microkernel_edge {
($name:ident, $ty:ty) => {
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn $name(
a: *const $ty,
b: *const $ty,
c: *mut $ty,
mr: usize,
nr: usize,
k: usize,
ldc: usize,
first_k: bool,
) {
if first_k {
for i in 0..mr {
for j in 0..nr {
*c.add(i * ldc + j) = 0.0;
}
}
}
for kk in 0..k {
for i in 0..mr {
let a_val = *a.add(kk * MR + i);
for j in 0..nr {
let b_val = *b.add(kk * nr + j);
let c_ptr = c.add(i * ldc + j);
*c_ptr += a_val * b_val;
}
}
}
}
};
}
define_scalar_matmul!(matmul_scalar_f32, f32);
define_scalar_matmul!(matmul_scalar_f64, f64);
define_scalar_matmul_bias!(matmul_bias_scalar_f32, f32);
define_scalar_matmul_bias!(matmul_bias_scalar_f64, f64);
define_microkernel_edge!(microkernel_edge_f32, f32);
define_microkernel_edge!(microkernel_edge_f64, f64);