use crate::kernels::kernel_4x4::kernel_4x4_avx2;
use crate::matrix::transpose::transpose;
#[target_feature(enable = "avx2,fma")]
#[allow(clippy::identity_op)]
#[allow(clippy::erasing_op)]
#[allow(unsafe_op_in_unsafe_fn)]
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_blocked_4x4(
a: &[f64],
b: &[f64],
c: &mut [f64],
m: usize,
n: usize,
k: usize,
row_start: Option<usize>,
row_end: Option<usize>,
) {
let start = row_start.unwrap_or(0);
let end = row_end.unwrap_or(m);
let mut bt = vec![0.0; k * n];
transpose(b, &mut bt, k, n);
let m_start = (start / 4) * 4;
let m_end = (end / 4) * 4;
let n_main = (n / 4) * 4;
let kc = k.min(256); let mc = m.min(128);
let mut a_panel = vec![0.0; mc * kc]; let mut b_pack = vec![0.0; 4 * kc];
for kk in (0..k).step_by(kc) {
let k_block = (kk + kc).min(k) - kk;
for ii in (m_start..m_end).step_by(mc) {
let m_block = (ii + mc).min(m_end) - ii;
pack_a_panel_large(a, &mut a_panel, ii, kk, m_block, k_block, k);
for j in (0..n_main).step_by(4) {
pack_b_panel(&bt, &mut b_pack, j, kk, k_block, k);
for i in (0..m_block).step_by(4) {
let a_pack_offset = i * k_block;
kernel_4x4_avx2(
a_panel.as_ptr().add(a_pack_offset),
b_pack.as_ptr(),
c.as_mut_ptr().add((ii + i) * n + j),
k_block,
n,
);
}
}
}
}
if m_end < end {
edge_case_rows(a, b, c, m_end, end, n, k);
}
if n_main < n {
edge_case_cols(a, b, c, m_start, m_end, n_main, n, k); }
}
#[allow(clippy::identity_op)]
fn pack_a_panel_large(
a: &[f64],
a_panel: &mut [f64],
i_start: usize,
k_start: usize,
m_block: usize,
k_block: usize,
k_total: usize,
) {
for i_offset in (0..m_block).step_by(4) {
for p in 0..k_block {
let k_idx = k_start + p;
let out_base = (i_offset * k_block) + (p * 4);
a_panel[out_base + 0] = a[(i_start + i_offset + 0) * k_total + k_idx];
a_panel[out_base + 1] = a[(i_start + i_offset + 1) * k_total + k_idx];
a_panel[out_base + 2] = a[(i_start + i_offset + 2) * k_total + k_idx];
a_panel[out_base + 3] = a[(i_start + i_offset + 3) * k_total + k_idx];
}
}
}
#[allow(clippy::identity_op)]
fn pack_b_panel(
bt: &[f64],
b_pack: &mut [f64],
j_start: usize,
k_start: usize,
k_block: usize,
k_total: usize,
) {
for p in 0..k_block {
let k_idx = k_start + p;
b_pack[p * 4 + 0] = bt[(j_start + 0) * k_total + k_idx];
b_pack[p * 4 + 1] = bt[(j_start + 1) * k_total + k_idx];
b_pack[p * 4 + 2] = bt[(j_start + 2) * k_total + k_idx];
b_pack[p * 4 + 3] = bt[(j_start + 3) * k_total + k_idx];
}
}
#[allow(clippy::too_many_arguments)]
fn edge_case_rows(
a: &[f64],
b: &[f64],
c: &mut [f64],
i_start: usize,
m: usize,
n: usize,
k: usize,
) {
for i in i_start..m {
for p in 0..k {
for j in 0..n {
c[i * n + j] += a[i * k + p] * b[p * n + j];
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn edge_case_cols(
a: &[f64],
b: &[f64],
c: &mut [f64],
i_start: usize,
i_end: usize,
j_start: usize,
n: usize,
k: usize,
) {
for i in i_start..i_end {
for p in 0..k {
for j in j_start..n {
c[i * n + j] += a[i * k + p] * b[p * n + j];
}
}
}
}