#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[inline]
pub fn dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let n = a.len();
let chunks = n / 32;
unsafe {
let ap = a.as_ptr();
let bp = b.as_ptr();
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
for i in 0..chunks {
let off = i * 32;
acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_loadu_ps(ap.add(off)), _mm256_loadu_ps(bp.add(off))));
acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_loadu_ps(ap.add(off + 8)), _mm256_loadu_ps(bp.add(off + 8))));
acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_loadu_ps(ap.add(off + 16)), _mm256_loadu_ps(bp.add(off + 16))));
acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_loadu_ps(ap.add(off + 24)), _mm256_loadu_ps(bp.add(off + 24))));
}
acc0 = _mm256_add_ps(acc0, acc1);
acc2 = _mm256_add_ps(acc2, acc3);
acc0 = _mm256_add_ps(acc0, acc2);
let hi128 = _mm256_extractf128_ps(acc0, 1);
let lo128 = _mm256_castps256_ps128(acc0);
let sum128 = _mm_add_ps(hi128, lo128);
let shuf = _mm_movehl_ps(sum128, sum128);
let sums = _mm_add_ps(sum128, shuf);
let shuf2 = _mm_shuffle_ps(sums, sums, 1);
let total = _mm_add_ss(sums, shuf2);
let mut sum = _mm_cvtss_f32(total);
let tail = chunks * 32;
let remaining = n - tail;
let rem_octs = remaining / 8;
let mut acc_rem = _mm256_setzero_ps();
for i in 0..rem_octs {
let off = tail + i * 8;
acc_rem = _mm256_add_ps(acc_rem, _mm256_mul_ps(_mm256_loadu_ps(ap.add(off)), _mm256_loadu_ps(bp.add(off))));
}
let rhi = _mm256_extractf128_ps(acc_rem, 1);
let rlo = _mm256_castps256_ps128(acc_rem);
let rs = _mm_add_ps(rhi, rlo);
let rs2 = _mm_movehl_ps(rs, rs);
let rs3 = _mm_add_ps(rs, rs2);
let rs4 = _mm_shuffle_ps(rs3, rs3, 1);
sum += _mm_cvtss_f32(_mm_add_ss(rs3, rs4));
let scalar_start = tail + rem_octs * 8;
for i in scalar_start..n {
sum += a[i] * b[i];
}
sum
}
}
#[inline]
pub fn matmul(a: &[f32], b: &[f32], c: &mut [f32], m: usize, n: usize, p: usize) {
debug_assert_eq!(a.len(), m * n);
debug_assert_eq!(b.len(), n * p);
debug_assert_eq!(c.len(), m * p);
const MR: usize = 16; const NR: usize = 4;
const KC: usize = 256;
let m_full = (m / MR) * MR;
let p_full = (p / NR) * NR;
let mut kb = 0;
while kb < n {
let k_end = (kb + KC).min(n);
for jb in 0..p_full / NR {
let j0 = jb * NR;
for ib in 0..m_full / MR {
let i0 = ib * MR;
unsafe { microkernel_16x4(a, b, c, m, n, i0, j0, kb, k_end); }
}
}
let mut i0 = m_full;
while i0 + 8 <= m {
for jb in 0..p_full / NR {
let j0 = jb * NR;
unsafe { microkernel_8x4(a, b, c, m, n, i0, j0, kb, k_end); }
}
i0 += 8;
}
while i0 + 4 <= m {
for jb in 0..p_full / NR {
let j0 = jb * NR;
unsafe { microkernel_4x4(a, b, c, m, n, i0, j0, kb, k_end); }
}
i0 += 4;
}
if i0 < m {
for j in 0..p_full {
for k in kb..k_end {
let b_kj = b[j * n + k];
for i in i0..m {
c[j * m + i] += a[k * m + i] * b_kj;
}
}
}
}
let i_simd = m / 8;
let i_tail = i_simd * 8;
for j in p_full..p {
for k in kb..k_end {
let b_kj = b[j * n + k];
let a_col = k * m;
let c_col = j * m;
unsafe {
let vb = _mm256_set1_ps(b_kj);
for i in 0..i_simd {
let offset = i * 8;
let vc = _mm256_loadu_ps(c.as_ptr().add(c_col + offset));
let va = _mm256_loadu_ps(a.as_ptr().add(a_col + offset));
let result = _mm256_add_ps(vc, _mm256_mul_ps(va, vb));
_mm256_storeu_ps(c.as_mut_ptr().add(c_col + offset), result);
}
}
for i in i_tail..m {
c[c_col + i] += a[a_col + i] * b_kj;
}
}
}
kb += KC;
}
}
#[inline(always)]
unsafe fn microkernel_16x4(
a: &[f32], b: &[f32], c: &mut [f32],
m: usize, n: usize, i0: usize, j0: usize,
k_start: usize, k_end: usize,
) {
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut acc00 = _mm256_setzero_ps();
let mut acc10 = _mm256_setzero_ps();
let mut acc01 = _mm256_setzero_ps();
let mut acc11 = _mm256_setzero_ps();
let mut acc02 = _mm256_setzero_ps();
let mut acc12 = _mm256_setzero_ps();
let mut acc03 = _mm256_setzero_ps();
let mut acc13 = _mm256_setzero_ps();
for k in k_start..k_end {
let a_off = k * m + i0;
let a0 = _mm256_loadu_ps(a_ptr.add(a_off));
let a1 = _mm256_loadu_ps(a_ptr.add(a_off + 8));
let b0 = _mm256_set1_ps(*b_ptr.add(j0 * n + k));
acc00 = _mm256_add_ps(acc00, _mm256_mul_ps(a0, b0));
acc10 = _mm256_add_ps(acc10, _mm256_mul_ps(a1, b0));
let b1 = _mm256_set1_ps(*b_ptr.add((j0 + 1) * n + k));
acc01 = _mm256_add_ps(acc01, _mm256_mul_ps(a0, b1));
acc11 = _mm256_add_ps(acc11, _mm256_mul_ps(a1, b1));
let b2 = _mm256_set1_ps(*b_ptr.add((j0 + 2) * n + k));
acc02 = _mm256_add_ps(acc02, _mm256_mul_ps(a0, b2));
acc12 = _mm256_add_ps(acc12, _mm256_mul_ps(a1, b2));
let b3 = _mm256_set1_ps(*b_ptr.add((j0 + 3) * n + k));
acc03 = _mm256_add_ps(acc03, _mm256_mul_ps(a0, b3));
acc13 = _mm256_add_ps(acc13, _mm256_mul_ps(a1, b3));
}
let c_ptr = c.as_mut_ptr();
let off0 = j0 * m + i0;
_mm256_storeu_ps(c_ptr.add(off0), _mm256_add_ps(_mm256_loadu_ps(c_ptr.add(off0)), acc00));
_mm256_storeu_ps(c_ptr.add(off0 + 8), _mm256_add_ps(_mm256_loadu_ps(c_ptr.add(off0 + 8)), acc10));
let off1 = (j0 + 1) * m + i0;
_mm256_storeu_ps(c_ptr.add(off1), _mm256_add_ps(_mm256_loadu_ps(c_ptr.add(off1)), acc01));
_mm256_storeu_ps(c_ptr.add(off1 + 8), _mm256_add_ps(_mm256_loadu_ps(c_ptr.add(off1 + 8)), acc11));
let off2 = (j0 + 2) * m + i0;
_mm256_storeu_ps(c_ptr.add(off2), _mm256_add_ps(_mm256_loadu_ps(c_ptr.add(off2)), acc02));
_mm256_storeu_ps(c_ptr.add(off2 + 8), _mm256_add_ps(_mm256_loadu_ps(c_ptr.add(off2 + 8)), acc12));
let off3 = (j0 + 3) * m + i0;
_mm256_storeu_ps(c_ptr.add(off3), _mm256_add_ps(_mm256_loadu_ps(c_ptr.add(off3)), acc03));
_mm256_storeu_ps(c_ptr.add(off3 + 8), _mm256_add_ps(_mm256_loadu_ps(c_ptr.add(off3 + 8)), acc13));
}
}
#[inline(always)]
unsafe fn microkernel_8x4(
a: &[f32], b: &[f32], c: &mut [f32],
m: usize, n: usize, i0: usize, j0: usize,
k_start: usize, k_end: usize,
) {
unsafe {
let (ap, bp) = (a.as_ptr(), b.as_ptr());
let mut a0 = _mm256_setzero_ps(); let mut a1 = _mm256_setzero_ps();
let mut a2 = _mm256_setzero_ps(); let mut a3 = _mm256_setzero_ps();
for k in k_start..k_end {
let av = _mm256_loadu_ps(ap.add(k * m + i0));
a0 = _mm256_add_ps(a0, _mm256_mul_ps(av, _mm256_set1_ps(*bp.add(j0 * n + k))));
a1 = _mm256_add_ps(a1, _mm256_mul_ps(av, _mm256_set1_ps(*bp.add((j0+1) * n + k))));
a2 = _mm256_add_ps(a2, _mm256_mul_ps(av, _mm256_set1_ps(*bp.add((j0+2) * n + k))));
a3 = _mm256_add_ps(a3, _mm256_mul_ps(av, _mm256_set1_ps(*bp.add((j0+3) * n + k))));
}
let cp = c.as_mut_ptr();
for (j, acc) in [(j0, a0), (j0+1, a1), (j0+2, a2), (j0+3, a3)] {
let off = j * m + i0;
_mm256_storeu_ps(cp.add(off), _mm256_add_ps(_mm256_loadu_ps(cp.add(off)), acc));
}
}
}
#[inline(always)]
unsafe fn microkernel_4x4(
a: &[f32], b: &[f32], c: &mut [f32],
m: usize, n: usize, i0: usize, j0: usize,
k_start: usize, k_end: usize,
) {
unsafe {
let (ap, bp) = (a.as_ptr(), b.as_ptr());
let mut a0 = _mm_setzero_ps(); let mut a1 = _mm_setzero_ps();
let mut a2 = _mm_setzero_ps(); let mut a3 = _mm_setzero_ps();
for k in k_start..k_end {
let av = _mm_loadu_ps(ap.add(k * m + i0));
a0 = _mm_add_ps(a0, _mm_mul_ps(av, _mm_set1_ps(*bp.add(j0 * n + k))));
a1 = _mm_add_ps(a1, _mm_mul_ps(av, _mm_set1_ps(*bp.add((j0+1) * n + k))));
a2 = _mm_add_ps(a2, _mm_mul_ps(av, _mm_set1_ps(*bp.add((j0+2) * n + k))));
a3 = _mm_add_ps(a3, _mm_mul_ps(av, _mm_set1_ps(*bp.add((j0+3) * n + k))));
}
let cp = c.as_mut_ptr();
for (j, acc) in [(j0, a0), (j0+1, a1), (j0+2, a2), (j0+3, a3)] {
let off = j * m + i0;
_mm_storeu_ps(cp.add(off), _mm_add_ps(_mm_loadu_ps(cp.add(off)), acc));
}
}
}
#[inline]
pub fn add_slices(a: &[f32], b: &[f32], out: &mut [f32]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), out.len());
let n = a.len();
let chunks = n / 8;
unsafe {
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(offset));
let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
_mm256_storeu_ps(out.as_mut_ptr().add(offset), _mm256_add_ps(va, vb));
}
}
let tail = chunks * 8;
for i in tail..n {
out[i] = a[i] + b[i];
}
}
#[inline]
pub fn sub_slices(a: &[f32], b: &[f32], out: &mut [f32]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), out.len());
let n = a.len();
let chunks = n / 8;
unsafe {
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(offset));
let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
_mm256_storeu_ps(out.as_mut_ptr().add(offset), _mm256_sub_ps(va, vb));
}
}
let tail = chunks * 8;
for i in tail..n {
out[i] = a[i] - b[i];
}
}
#[inline]
pub fn scale_slices(a: &[f32], scalar: f32, out: &mut [f32]) {
debug_assert_eq!(a.len(), out.len());
let n = a.len();
let chunks = n / 8;
unsafe {
let vs = _mm256_set1_ps(scalar);
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(offset));
_mm256_storeu_ps(out.as_mut_ptr().add(offset), _mm256_mul_ps(va, vs));
}
}
let tail = chunks * 8;
for i in tail..n {
out[i] = a[i] * scalar;
}
}
#[inline]
pub fn axpy_neg(y: &mut [f32], alpha: f32, x: &[f32]) {
debug_assert_eq!(y.len(), x.len());
let n = y.len();
let chunks = n / 8;
unsafe {
let va = _mm256_set1_ps(alpha);
for i in 0..chunks {
let offset = i * 8;
let vy = _mm256_loadu_ps(y.as_ptr().add(offset));
let vx = _mm256_loadu_ps(x.as_ptr().add(offset));
let result = _mm256_sub_ps(vy, _mm256_mul_ps(va, vx));
_mm256_storeu_ps(y.as_mut_ptr().add(offset), result);
}
}
let tail = chunks * 8;
for i in tail..n {
y[i] -= alpha * x[i];
}
}
#[inline]
pub fn axpy_pos(y: &mut [f32], alpha: f32, x: &[f32]) {
debug_assert_eq!(y.len(), x.len());
let n = y.len();
let chunks = n / 8;
unsafe {
let va = _mm256_set1_ps(alpha);
for i in 0..chunks {
let offset = i * 8;
let vy = _mm256_loadu_ps(y.as_ptr().add(offset));
let vx = _mm256_loadu_ps(x.as_ptr().add(offset));
let result = _mm256_add_ps(vy, _mm256_mul_ps(va, vx));
_mm256_storeu_ps(y.as_mut_ptr().add(offset), result);
}
}
let tail = chunks * 8;
for i in tail..n {
y[i] += alpha * x[i];
}
}