use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
mlx::steel::attn::steel_attention_mma::mt_sdpa_prefill_mma,
spec::{BatchedDecodeVariant, BenchDispatch, BenchSpec},
};
inventory::submit! {
BenchSpec {
op: "sdpa",
subop: "sdpa_decode_batched_q8",
kernel_name: "mt_sdpa_prefill_mma",
kernel_ir: mt_sdpa_prefill_mma::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 2e-2,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::SdpaBatchedDecode {
head_dim: 128,
n_kv: 4096,
n_q_heads: 32,
gqa_factor: 4,
batch_q: 8,
variant: BatchedDecodeVariant::PrefillTile { bq: 32, bk: 16, wm: 4, wn: 1 },
tpg: 128,
},
kernel_mode: Some(KernelMode::SimdGroup2D),
}
}
inventory::submit! {
BenchSpec {
op: "sdpa",
subop: "sdpa_decode_batched_q16",
kernel_name: "mt_sdpa_prefill_mma",
kernel_ir: mt_sdpa_prefill_mma::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 2e-2,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::SdpaBatchedDecode {
head_dim: 128,
n_kv: 4096,
n_q_heads: 32,
gqa_factor: 4,
batch_q: 16,
variant: BatchedDecodeVariant::PrefillTile { bq: 32, bk: 16, wm: 4, wn: 1 },
tpg: 128,
},
kernel_mode: Some(KernelMode::SimdGroup2D),
}
}