use diskann_wide::Architecture;
use super::layouts::{ConvertTo, Layout};
use super::{Kernel, TileBudget};
#[derive(Debug, Clone, Copy)]
struct FullReduce {
a_panels_per_tile: usize,
b_panels_per_tile: usize,
}
impl FullReduce {
fn new(
a_row_bytes: usize,
b_row_bytes: usize,
a_panel: usize,
b_panel: usize,
budget: TileBudget,
) -> Self {
let a_row_bytes = a_row_bytes.max(1);
let b_row_bytes = b_row_bytes.max(1);
let a_panels_per_tile = (budget.l2_a / (a_row_bytes * a_panel)).max(1);
let a_panel_bytes = a_panel * a_row_bytes;
let b_tile_budget = budget.l1_b.saturating_sub(a_panel_bytes);
let b_panels_per_tile = (b_tile_budget / (b_row_bytes * b_panel)).max(1);
Self {
a_panels_per_tile,
b_panels_per_tile,
}
}
}
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn tiled_reduce<A, K, LA, LB>(
arch: A,
ca: &LA,
cb: &LB,
a_ptr: *const LA::Element,
a_padded_nrows: usize,
b_ptr: *const LB::Element,
b_nrows: usize,
k: usize,
scratch: &mut [f32],
budget: TileBudget,
) where
A: Architecture,
K: Kernel<A>,
LA: ConvertTo<A, K::Left>,
LB: ConvertTo<A, K::Right>,
{
let a_row_bytes = k * std::mem::size_of::<<K::Left as Layout>::Element>();
let b_row_bytes = k * std::mem::size_of::<<K::Right as Layout>::Element>();
let plan = FullReduce::new(a_row_bytes, b_row_bytes, K::A_PANEL, K::B_PANEL, budget);
let b_src_panel_stride = K::B_PANEL * k;
let b_src_tile_stride = b_src_panel_stride * plan.b_panels_per_tile;
let a_kern_panel_stride = K::A_PANEL * k;
let b_kern_panel_stride = K::B_PANEL * k;
let b_remainder = b_nrows % K::B_PANEL;
assert_eq!(
a_padded_nrows % K::A_PANEL,
0,
"a_padded_nrows ({a_padded_nrows}) must be a multiple of A_PANEL ({})",
K::A_PANEL,
);
if k == 0 {
if b_nrows > 0 {
scratch[..a_padded_nrows].fill(0.0);
}
return;
}
let a_tile_rows = K::A_PANEL * plan.a_panels_per_tile;
let b_tile_rows = K::B_PANEL * plan.b_panels_per_tile;
let mut a_buf = ca.new_buffer(a_tile_rows, k);
let mut b_buf = cb.new_buffer(b_tile_rows, k);
let pb_end = unsafe { b_ptr.add(b_nrows * k) };
let pb_full_end = unsafe { pb_end.sub(b_remainder * k) };
unsafe {
let mut rows_done: usize = 0;
while rows_done < a_padded_nrows {
let tile_rows = a_tile_rows.min(a_padded_nrows - rows_done);
let pa_tile_src = a_ptr.add(rows_done * k);
let pr_tile = scratch.as_mut_ptr().add(rows_done);
let pa_tile = ca.convert(&mut a_buf, arch, pa_tile_src, tile_rows, k);
let pa_tile_end = pa_tile.add(tile_rows * k);
let mut pb_tile_src = b_ptr;
while pb_full_end.offset_from(pb_tile_src) >= b_src_tile_stride as isize {
let pb_tile = cb.convert(&mut b_buf, arch, pb_tile_src, b_tile_rows, k);
let pb_tile_end = pb_tile.add(b_tile_rows * k);
let mut pa_panel = pa_tile;
let mut pr_panel = pr_tile;
while pa_panel < pa_tile_end {
let mut pb_panel = pb_tile;
while pb_panel < pb_tile_end {
K::full_panel(arch, pa_panel, pb_panel, k, pr_panel);
pb_panel = pb_panel.add(b_kern_panel_stride);
}
pa_panel = pa_panel.add(a_kern_panel_stride);
pr_panel = pr_panel.add(K::A_PANEL);
}
pb_tile_src = pb_tile_src.add(b_src_tile_stride);
}
if pb_tile_src < pb_end {
let remaining_b_rows = b_nrows - ((pb_tile_src.offset_from(b_ptr) as usize) / k);
let pb_tile = cb.convert(&mut b_buf, arch, pb_tile_src, remaining_b_rows, k);
let full_panels_in_remainder = remaining_b_rows / K::B_PANEL;
let pb_full_end_local = pb_tile.add(full_panels_in_remainder * b_kern_panel_stride);
let mut pa_panel = pa_tile;
let mut pr_panel = pr_tile;
while pa_panel < pa_tile_end {
let mut pb_panel = pb_tile;
while pb_panel < pb_full_end_local {
K::full_panel(arch, pa_panel, pb_panel, k, pr_panel);
pb_panel = pb_panel.add(b_kern_panel_stride);
}
if b_remainder > 0 {
K::partial_panel(arch, b_remainder, pa_panel, pb_panel, k, pr_panel);
}
pa_panel = pa_panel.add(a_kern_panel_stride);
pr_panel = pr_panel.add(K::A_PANEL);
}
}
rows_done += tile_rows;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use diskann_wide::arch::Scalar;
use super::super::f32::{F32Kernel, max_ip_kernel};
use super::super::layouts;
use crate::multi_vector::{BlockTransposed, MatRef, Standard};
#[test]
fn basic_panel_counts() {
let plan = FullReduce::new(
256,
256,
16,
4,
TileBudget {
l2_a: 40960,
l1_b: 36000,
},
);
assert_eq!(plan.a_panels_per_tile, 10);
assert_eq!(plan.b_panels_per_tile, 31);
}
#[test]
fn tiny_budget_clamps_to_one() {
let plan = FullReduce::new(1024, 1024, 16, 4, TileBudget { l2_a: 1, l1_b: 1 });
assert_eq!(plan.a_panels_per_tile, 1);
assert_eq!(plan.b_panels_per_tile, 1);
}
#[test]
fn zero_byte_rows_clamped() {
let plan = FullReduce::new(
0,
0,
16,
4,
TileBudget {
l2_a: 100_000,
l1_b: 50_000,
},
);
assert_eq!(plan.a_panels_per_tile, 6250);
assert_eq!(plan.b_panels_per_tile, 12_496);
}
#[test]
fn exact_fit_one_panel() {
let plan = FullReduce::new(
64,
64,
16,
4,
TileBudget {
l2_a: 1024,
l1_b: 2048,
},
);
assert_eq!(plan.a_panels_per_tile, 1);
assert_eq!(plan.b_panels_per_tile, 4);
}
#[test]
fn l1_saturated_by_a_panel() {
let plan = FullReduce::new(
1024,
64,
16,
4,
TileBudget {
l2_a: 100_000,
l1_b: 100,
},
);
assert_eq!(plan.b_panels_per_tile, 1);
}
#[test]
#[should_panic(expected = "must be a multiple of A_PANEL")]
fn panics_on_unaligned_a_rows() {
let k = 4;
let a = vec![0.0f32; 9 * k];
let b = vec![0.0f32; 2 * k];
let mut scratch = vec![f32::MIN; 16];
let ca = layouts::BlockTransposed::<f32, 8>::new();
let cb = layouts::RowMajor::<f32>::new();
unsafe {
super::tiled_reduce::<Scalar, F32Kernel<8>, _, _>(
Scalar::new(),
&ca,
&cb,
a.as_ptr(),
9,
b.as_ptr(),
2,
k,
&mut scratch,
TileBudget::default(),
);
}
}
#[test]
fn zero_dim_fills_scratch_and_returns() {
let a_rows = 8;
let b_rows = 3;
let k = 0;
let a = Vec::<f32>::new();
let b = Vec::<f32>::new();
let mut scratch = vec![f32::MIN; a_rows];
let ca = layouts::BlockTransposed::<f32, 8>::new();
let cb = layouts::RowMajor::<f32>::new();
unsafe {
super::tiled_reduce::<Scalar, F32Kernel<8>, _, _>(
Scalar::new(),
&ca,
&cb,
a.as_ptr(),
a_rows,
b.as_ptr(),
b_rows,
k,
&mut scratch,
TileBudget::default(),
);
}
for &v in &scratch {
assert_eq!(v, 0.0, "zero-dim IP should be 0.0");
}
}
#[test]
fn zero_dim_zero_docs_leaves_scratch_untouched() {
let a_rows = 8;
let mut scratch = vec![f32::MIN; a_rows];
let ca = layouts::BlockTransposed::<f32, 8>::new();
let cb = layouts::RowMajor::<f32>::new();
unsafe {
super::tiled_reduce::<Scalar, F32Kernel<8>, _, _>(
Scalar::new(),
&ca,
&cb,
[].as_ptr(),
a_rows,
[].as_ptr(),
0,
0,
&mut scratch,
TileBudget::default(),
);
}
for &v in &scratch {
assert_eq!(v, f32::MIN, "zero docs should leave scratch untouched");
}
}
const NAIVE_CASES: &[(usize, usize, usize)] = &[
(1, 1, 1), (1, 1, 2), (1, 1, 4), (1, 5, 8), (5, 1, 8), (3, 2, 0), (3, 0, 4), (3, 2, 3), (3, 4, 16), (5, 3, 5), (7, 7, 32), (2, 3, 7), (2, 3, 128), (8, 3, 4), (16, 5, 8), (16, 4, 64), (17, 4, 64), (32, 5, 16), (48, 3, 16), (16, 6, 32), (16, 7, 32), (16, 8, 32), ];
fn naive_max_ip_f32(
a: &[f32],
a_nrows: usize,
b: &[f32],
b_nrows: usize,
k: usize,
) -> Vec<f32> {
(0..a_nrows)
.map(|i| {
(0..b_nrows)
.map(|j| (0..k).map(|d| a[i * k + d] * b[j * k + d]).sum::<f32>())
.fold(f32::MIN, f32::max)
})
.collect()
}
fn naive_max_ip_f16(
a: &[half::f16],
a_nrows: usize,
b: &[half::f16],
b_nrows: usize,
k: usize,
) -> Vec<f32> {
(0..a_nrows)
.map(|i| {
(0..b_nrows)
.map(|j| {
(0..k)
.map(|d| a[i * k + d].to_f32() * b[j * k + d].to_f32())
.sum::<f32>()
})
.fold(f32::MIN, f32::max)
})
.collect()
}
#[allow(clippy::too_many_arguments)]
fn check_kernel<A, T, const GROUP: usize>(
arch: A,
arch_label: &str,
tol: f32,
a_data: &[T],
a_nrows: usize,
b_data: &[T],
b_nrows: usize,
dim: usize,
expected: &[f32],
) where
A: Architecture,
T: Copy + Default,
F32Kernel<GROUP>: Kernel<A>,
layouts::BlockTransposed<T, GROUP>:
ConvertTo<A, <F32Kernel<GROUP> as Kernel<A>>::Left> + Layout<Element = T>,
layouts::RowMajor<T>:
ConvertTo<A, <F32Kernel<GROUP> as Kernel<A>>::Right> + Layout<Element = T>,
{
let a_mat = MatRef::new(Standard::new(a_nrows, dim).unwrap(), a_data).unwrap();
let a_bt = BlockTransposed::<T, GROUP>::from_matrix_view(a_mat.as_matrix_view());
let b_mat = MatRef::new(Standard::new(b_nrows, dim).unwrap(), b_data).unwrap();
let mut scratch = vec![f32::MIN; a_bt.padded_nrows()];
max_ip_kernel::<A, T, GROUP>(
arch,
a_bt.as_view(),
b_mat,
&mut scratch,
TileBudget::default(),
);
for i in 0..a_nrows {
let actual = scratch[i];
let exp = expected[i];
assert!(
(actual - exp).abs() < tol,
"[{arch_label}] row {i} mismatch for ({a_nrows},{b_nrows},{dim}): actual={actual}, expected={exp}",
);
}
}
#[test]
fn tiled_reduce_f32_matches_naive() {
for &(a_nrows, b_nrows, dim) in NAIVE_CASES {
let a_data: Vec<f32> = (0..a_nrows * dim).map(|i| (i + 1) as f32).collect();
let b_data: Vec<f32> = (0..b_nrows * dim).map(|i| ((i + 1) * 2) as f32).collect();
let expected = naive_max_ip_f32(&a_data, a_nrows, &b_data, b_nrows, dim);
check_kernel::<_, f32, 8>(
Scalar::new(),
"scalar",
1e-10,
&a_data,
a_nrows,
&b_data,
b_nrows,
dim,
&expected,
);
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
check_kernel::<_, f32, 16>(
arch,
"x86-64-v3",
1e-10,
&a_data,
a_nrows,
&b_data,
b_nrows,
dim,
&expected,
);
}
}
}
#[test]
fn tiled_reduce_all_loop_paths_match_naive() {
check_tile_plan_paths::<_, f32, 8>(Scalar::new(), "scalar", gen_f32_data, naive_max_ip_f32);
check_tile_plan_paths::<_, half::f16, 8>(
Scalar::new(),
"scalar",
gen_f16_data,
naive_max_ip_f16,
);
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
check_tile_plan_paths::<_, f32, 16>(arch, "x86-64-v3", gen_f32_data, naive_max_ip_f32);
check_tile_plan_paths::<_, half::f16, 16>(
arch,
"x86-64-v3",
gen_f16_data,
naive_max_ip_f16,
);
}
}
#[derive(Debug, Clone, Copy)]
struct LoopCoveragePlan {
a_panels_per_tile: usize,
b_panels_per_tile: usize,
section_a_b_tile_iters: usize,
section_b_full_b_panels: usize,
section_b_has_b_remainder: bool,
}
const LOOP_COVERAGE_PLANS: &[LoopCoveragePlan] = &[
LoopCoveragePlan {
a_panels_per_tile: 2,
b_panels_per_tile: 2,
section_a_b_tile_iters: 1,
section_b_full_b_panels: 0,
section_b_has_b_remainder: true,
},
LoopCoveragePlan {
a_panels_per_tile: 2,
b_panels_per_tile: 3,
section_a_b_tile_iters: 2,
section_b_full_b_panels: 0,
section_b_has_b_remainder: false,
},
LoopCoveragePlan {
a_panels_per_tile: 3,
b_panels_per_tile: 2,
section_a_b_tile_iters: 1,
section_b_full_b_panels: 2,
section_b_has_b_remainder: true,
},
];
fn gen_f32_data(len: usize, ceil: usize) -> Vec<f32> {
(0..len).map(|i| (i % ceil) as f32).collect()
}
fn gen_f16_data(len: usize, ceil: usize) -> Vec<half::f16> {
(0..len)
.map(|i| diskann_wide::cast_f32_to_f16((i % ceil) as f32))
.collect()
}
type NaiveMaxIp<T> = fn(&[T], usize, &[T], usize, usize) -> Vec<f32>;
fn check_tile_plan_paths<A, T, const GROUP: usize>(
arch: A,
arch_label: &str,
gen_data: fn(usize, usize) -> Vec<T>,
naive: NaiveMaxIp<T>,
) where
A: Architecture,
T: Copy + Default,
F32Kernel<GROUP>: Kernel<A>,
layouts::BlockTransposed<T, GROUP>:
ConvertTo<A, <F32Kernel<GROUP> as Kernel<A>>::Left> + Layout<Element = T>,
layouts::RowMajor<T>:
ConvertTo<A, <F32Kernel<GROUP> as Kernel<A>>::Right> + Layout<Element = T>,
{
let a_panel = <F32Kernel<GROUP> as Kernel<A>>::A_PANEL;
let b_panel = <F32Kernel<GROUP> as Kernel<A>>::B_PANEL;
let dim = 8usize;
let row_bytes = dim * std::mem::size_of::<T>();
for &p in LOOP_COVERAGE_PLANS {
let budget = TileBudget {
l2_a: p.a_panels_per_tile * a_panel * row_bytes,
l1_b: a_panel * row_bytes + p.b_panels_per_tile * b_panel * row_bytes,
};
let plan = FullReduce::new(row_bytes, row_bytes, a_panel, b_panel, budget);
assert_eq!(
plan.a_panels_per_tile, p.a_panels_per_tile,
"[{arch_label}] a_panels_per_tile for plan {p:?}",
);
assert_eq!(
plan.b_panels_per_tile, p.b_panels_per_tile,
"[{arch_label}] b_panels_per_tile for plan {p:?}",
);
let a_nrows = p.a_panels_per_tile * a_panel + 1;
let b_nrows = p.section_a_b_tile_iters * p.b_panels_per_tile * b_panel
+ p.section_b_full_b_panels * b_panel
+ usize::from(p.section_b_has_b_remainder);
let ceil = dim;
let a_data = gen_data(a_nrows * dim, ceil);
let b_data = gen_data(b_nrows * dim, ceil);
let expected = naive(&a_data, a_nrows, &b_data, b_nrows, dim);
let a_mat = MatRef::new(Standard::new(a_nrows, dim).unwrap(), &a_data).unwrap();
let a_bt = BlockTransposed::<T, GROUP>::from_matrix_view(a_mat.as_matrix_view());
let b_mat = MatRef::new(Standard::new(b_nrows, dim).unwrap(), &b_data).unwrap();
let mut scratch = vec![f32::MIN; a_bt.padded_nrows()];
max_ip_kernel::<A, T, GROUP>(arch, a_bt.as_view(), b_mat, &mut scratch, budget);
for i in 0..a_nrows {
assert!(
(scratch[i] - expected[i]).abs() < 1e-10,
"[{arch_label}] plan={p:?} row {i}: actual={} expected={}",
scratch[i],
expected[i],
);
}
}
}
#[test]
fn tiled_reduce_f16_matches_naive() {
for &(a_nrows, b_nrows, dim) in NAIVE_CASES {
let ceil = dim.max(1);
let a_data: Vec<half::f16> = (0..a_nrows * dim)
.map(|i| diskann_wide::cast_f32_to_f16(((i + 1) % ceil) as f32))
.collect();
let b_data: Vec<half::f16> = (0..b_nrows * dim)
.map(|i| diskann_wide::cast_f32_to_f16((((i + 1) * 2) % ceil) as f32))
.collect();
let expected = naive_max_ip_f16(&a_data, a_nrows, &b_data, b_nrows, dim);
check_kernel::<_, half::f16, 8>(
Scalar::new(),
"scalar",
1e-10,
&a_data,
a_nrows,
&b_data,
b_nrows,
dim,
&expected,
);
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
check_kernel::<_, half::f16, 16>(
arch,
"x86-64-v3",
1e-10,
&a_data,
a_nrows,
&b_data,
b_nrows,
dim,
&expected,
);
}
}
}
}