use diskann_wide::Architecture;
use super::Kernel;
use super::TileBudget;
use super::layouts::{self, DescribeLayout};
use super::tiled_reduce::tiled_reduce;
use crate::multi_vector::{BlockTransposedRef, MatRef, Standard};
mod scalar;
#[cfg(target_arch = "x86_64")]
mod v3;
pub(crate) struct F32Kernel<const GROUP: usize>;
#[inline(never)]
#[cold]
#[allow(clippy::panic)]
fn max_ip_kernel_panic(scratch_len: usize, padded_nrows: usize, a_ncols: usize, b_dim: usize) {
panic!(
"max_ip_kernel: precondition failed: \
scratch.len()={scratch_len} (expected {padded_nrows}), \
a.ncols()={a_ncols}, b.vector_dim()={b_dim}"
);
}
pub(super) fn max_ip_kernel<A: Architecture, T: Copy, const GROUP: usize>(
arch: A,
a: BlockTransposedRef<'_, T, GROUP>,
b: MatRef<'_, Standard<T>>,
scratch: &mut [f32],
budget: TileBudget,
) where
F32Kernel<GROUP>: Kernel<A>,
layouts::BlockTransposed<T, GROUP>:
layouts::ConvertTo<A, <F32Kernel<GROUP> as Kernel<A>>::Left> + layouts::Layout<Element = T>,
layouts::RowMajor<T>: layouts::ConvertTo<A, <F32Kernel<GROUP> as Kernel<A>>::Right>
+ layouts::Layout<Element = T>,
{
if scratch.len() != a.padded_nrows() || a.ncols() != b.vector_dim() {
max_ip_kernel_panic(scratch.len(), a.padded_nrows(), a.ncols(), b.vector_dim());
}
let k = a.ncols();
let b_nrows = b.num_vectors();
const { assert!(<F32Kernel<GROUP> as Kernel<A>>::A_PANEL == GROUP) }
let ca = a.layout();
let cb = b.layout();
unsafe {
tiled_reduce::<A, F32Kernel<GROUP>, _, _>(
arch,
&ca,
&cb,
a.as_ptr(),
a.padded_nrows(),
b.as_slice().as_ptr(),
b_nrows,
k,
scratch,
budget,
);
}
}
impl<A, const GROUP: usize>
diskann_wide::arch::Target3<
A,
(),
BlockTransposedRef<'_, f32, GROUP>,
MatRef<'_, Standard<f32>>,
&mut [f32],
> for F32Kernel<GROUP>
where
A: Architecture,
Self: Kernel<A>,
layouts::BlockTransposed<f32, GROUP>:
layouts::ConvertTo<A, <Self as Kernel<A>>::Left> + layouts::Layout<Element = f32>,
layouts::RowMajor<f32>:
layouts::ConvertTo<A, <Self as Kernel<A>>::Right> + layouts::Layout<Element = f32>,
{
#[inline(always)]
fn run(
self,
arch: A,
lhs: BlockTransposedRef<'_, f32, GROUP>,
rhs: MatRef<'_, Standard<f32>>,
scratch: &mut [f32],
) {
max_ip_kernel(arch, lhs, rhs, scratch, TileBudget::default());
}
}