use diskann_wide::arch::Scalar;
use diskann_wide::{SIMDMinMax, SIMDVector};
use super::super::Kernel;
use super::super::layouts;
use super::super::reduce::Reduce;
use super::F32Kernel;
diskann_wide::alias!(f32s = <Scalar>::f32x8);
unsafe impl Kernel<Scalar> for F32Kernel<8> {
type Left = layouts::BlockTransposed<f32, 8>;
type Right = layouts::RowMajor<f32>;
const A_PANEL: usize = 8;
const B_PANEL: usize = 2;
#[inline(always)]
unsafe fn full_panel(arch: Scalar, a: *const f32, b: *const f32, k: usize, r: *mut f32) {
unsafe { scalar_f32_microkernel::<{ Self::B_PANEL }>(arch, a, b, k, r) }
}
#[inline(always)]
unsafe fn partial_panel(
arch: Scalar,
remainder: usize,
a: *const f32,
b: *const f32,
k: usize,
r: *mut f32,
) {
unsafe {
match remainder {
1 => scalar_f32_microkernel::<1>(arch, a, b, k, r),
_ => unreachable!(
"unexpected remainder {remainder} for B_PANEL={}",
Self::B_PANEL
),
}
}
}
}
#[inline(always)]
unsafe fn scalar_f32_microkernel<const UNROLL: usize>(
arch: Scalar,
a_packed: *const f32,
b: *const f32,
k: usize,
r: *mut f32,
) where
[f32s; UNROLL]: Reduce<Element = f32s>,
{
let op = |x: f32s, y: f32s| x.max_simd(y);
let mut p0 = [f32s::default(arch); UNROLL];
let offsets: [usize; UNROLL] = core::array::from_fn(|i| k * i);
let a_stride = f32s::LANES;
for i in 0..k {
unsafe {
let a0 = f32s::load_simd(arch, a_packed.add(a_stride * i));
for j in 0..UNROLL {
let bj = f32s::splat(arch, b.add(i + offsets[j]).read_unaligned());
p0[j] = a0 * bj + p0[j];
}
}
}
let mut r0 = unsafe { f32s::load_simd(arch, r) };
r0 = op(r0, p0.reduce(&op));
unsafe { r0.store_simd(r) };
}