use diskann_wide::arch::x86_64::V3;
use diskann_wide::{SIMDMinMax, SIMDMulAdd, SIMDVector};
use super::super::Kernel;
use super::super::layouts;
use super::super::reduce::Reduce;
use super::F32Kernel;
diskann_wide::alias!(f32s = <V3>::f32x8);
unsafe impl Kernel<V3> for F32Kernel<16> {
type Left = layouts::BlockTransposed<f32, 16>;
type Right = layouts::RowMajor<f32>;
const A_PANEL: usize = 16;
const B_PANEL: usize = 4;
#[inline(always)]
unsafe fn full_panel(arch: V3, a: *const f32, b: *const f32, k: usize, r: *mut f32) {
unsafe { f32_microkernel::<{ Self::B_PANEL }>(arch, a, b, k, r) }
}
#[inline(always)]
unsafe fn partial_panel(
arch: V3,
remainder: usize,
a: *const f32,
b: *const f32,
k: usize,
r: *mut f32,
) {
unsafe {
match remainder {
1 => f32_microkernel::<1>(arch, a, b, k, r),
2 => f32_microkernel::<2>(arch, a, b, k, r),
3 => f32_microkernel::<3>(arch, a, b, k, r),
_ => unreachable!(
"unexpected remainder {remainder} for B_PANEL={}",
Self::B_PANEL
),
}
}
}
}
#[inline(always)]
unsafe fn f32_microkernel<const UNROLL: usize>(
arch: V3,
a_packed: *const f32,
b: *const f32,
k: usize,
r: *mut f32,
) where
[f32s; UNROLL]: Reduce<Element = f32s>,
{
let op = |x: f32s, y: f32s| x.max_simd(y);
let mut p0 = [f32s::default(arch); UNROLL];
let mut p1 = [f32s::default(arch); UNROLL];
let offsets: [usize; UNROLL] = core::array::from_fn(|i| k * i);
let a_stride = 2 * f32s::LANES;
let a_stride_half = f32s::LANES;
for i in 0..k {
unsafe {
let a0 = f32s::load_simd(arch, a_packed.add(a_stride * i));
let a1 = f32s::load_simd(arch, a_packed.add(a_stride * i + a_stride_half));
for j in 0..UNROLL {
let bj = f32s::splat(arch, b.add(i + offsets[j]).read_unaligned());
p0[j] = a0.mul_add_simd(bj, p0[j]);
p1[j] = a1.mul_add_simd(bj, p1[j]);
}
}
}
let mut r0 = unsafe { f32s::load_simd(arch, r) };
let mut r1 = unsafe { f32s::load_simd(arch, r.add(f32s::LANES)) };
r0 = op(r0, p0.reduce(&op));
r1 = op(r1, p1.reduce(&op));
unsafe { r0.store_simd(r) };
unsafe { r1.store_simd(r.add(f32s::LANES)) };
}