use super::{
BatchedQ4KGemvKernel, BatchedQ6KGemvKernel, BatchedResidualAddKernel, BatchedRopeKernel,
BatchedSwigluKernel, BatchedVectorizedRmsNormKernel, FusedSwigluKernel, KernelParity,
Q4KGemvKernel, Q6KGemvKernel, ResidualAddKernel, RopeKernel, VectorizedRmsNormKernel,
};
impl KernelParity for BatchedVectorizedRmsNormKernel {
type SingleVector = VectorizedRmsNormKernel;
fn single_vector_reference(&self) -> Self::SingleVector {
VectorizedRmsNormKernel::new(self.hidden_size).with_epsilon(self.epsilon)
}
}
impl KernelParity for BatchedQ4KGemvKernel {
type SingleVector = Q4KGemvKernel;
fn single_vector_reference(&self) -> Self::SingleVector {
Q4KGemvKernel::new(self.k, self.n)
}
}
impl KernelParity for BatchedQ6KGemvKernel {
type SingleVector = Q6KGemvKernel;
fn single_vector_reference(&self) -> Self::SingleVector {
Q6KGemvKernel::new(self.k, self.n)
}
}
impl KernelParity for BatchedResidualAddKernel {
type SingleVector = ResidualAddKernel;
fn single_vector_reference(&self) -> Self::SingleVector {
ResidualAddKernel::new(self.n)
}
}
impl KernelParity for BatchedRopeKernel {
type SingleVector = RopeKernel;
fn single_vector_reference(&self) -> Self::SingleVector {
RopeKernel::new(self.num_heads, self.head_dim, self.theta)
}
}
impl KernelParity for BatchedSwigluKernel {
type SingleVector = FusedSwigluKernel;
fn single_vector_reference(&self) -> Self::SingleVector {
FusedSwigluKernel::new(self.n)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernels::Kernel;
const HIDDEN_1536: u32 = 1536; const HIDDEN_3584: u32 = 3584; const INTERMEDIATE_4864: u32 = 4864; const NUM_HEADS_12: u32 = 12; const HEAD_DIM_128: u32 = 128; const ROPE_THETA: f32 = 1_000_000.0;
#[test]
fn test_parity_rmsnorm_batch_dispatch() {
let kernel = BatchedVectorizedRmsNormKernel::new(HIDDEN_1536, 1);
let result = kernel.validate_batch_dispatch();
assert!(
result.is_compatible,
"BatchedVectorizedRmsNormKernel missing ctaid.y: {:?}",
result.violations
);
}
#[test]
fn test_parity_rmsnorm_no_u64_shared_mem() {
let kernel = BatchedVectorizedRmsNormKernel::new(HIDDEN_1536, 1);
let ptx = kernel.emit_ptx();
for line in ptx.lines() {
let trimmed = line.trim();
if (trimmed.contains("st.shared") || trimmed.contains("ld.shared"))
&& trimmed.contains("[%rd")
{
panic!(
"GH-219: BatchedVectorizedRmsNormKernel uses u64 (%rd) for shared memory: {}",
trimmed
);
}
}
}
#[test]
fn test_parity_rmsnorm_shared_memory_size() {
let batched = BatchedVectorizedRmsNormKernel::new(HIDDEN_1536, 1);
let single = batched.single_vector_reference();
let batched_ptx = batched.emit_ptx();
let single_ptx = single.emit_ptx();
let batched_smem = extract_smem(&batched_ptx);
let single_smem = extract_smem(&single_ptx);
assert_eq!(
batched_smem, single_smem,
"Shared memory size mismatch: single={:?}, batched={:?}",
single_smem, batched_smem
);
}
#[test]
fn test_parity_rmsnorm_7b() {
let kernel = BatchedVectorizedRmsNormKernel::new(HIDDEN_3584, 1);
let result = kernel.validate_batch_dispatch();
assert!(result.is_compatible, "7B BatchedRmsNorm parity: {:?}", result.violations);
}
#[test]
fn test_parity_q4k_gemv_batch_dispatch() {
let kernel = BatchedQ4KGemvKernel::new(HIDDEN_1536, HIDDEN_1536, 1);
let result = kernel.validate_batch_dispatch();
assert!(
result.is_compatible,
"BatchedQ4KGemvKernel missing batch dispatch: {:?}",
result.violations
);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("m_dim"),
"Q4K batched should use register-unrolled dispatch via m_dim parameter"
);
}
#[test]
fn test_parity_q6k_gemv_batch_dispatch() {
let kernel = BatchedQ6KGemvKernel::new(HIDDEN_1536, HIDDEN_1536, 1);
let result = kernel.validate_batch_dispatch();
assert!(
result.is_compatible,
"BatchedQ6KGemvKernel missing batch dispatch: {:?}",
result.violations
);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("m_dim"),
"Q6K batched should use register-unrolled dispatch via m_dim parameter"
);
}
#[test]
fn test_parity_residual_add_batch_dispatch() {
let kernel = BatchedResidualAddKernel::new(HIDDEN_1536, 1);
let result = kernel.validate_batch_dispatch();
assert!(
result.is_compatible,
"BatchedResidualAddKernel missing ctaid.y: {:?}",
result.violations
);
}
#[test]
fn test_parity_rope_batch_dispatch() {
let kernel = BatchedRopeKernel::new(NUM_HEADS_12, HEAD_DIM_128, 1, ROPE_THETA);
let result = kernel.validate_batch_dispatch();
assert!(result.is_compatible, "BatchedRopeKernel missing ctaid.y: {:?}", result.violations);
}
#[test]
fn test_parity_swiglu_batch_dispatch() {
let kernel = BatchedSwigluKernel::new(INTERMEDIATE_4864, 1);
let result = kernel.validate_batch_dispatch();
assert!(
result.is_compatible,
"BatchedSwigluKernel missing ctaid.y: {:?}",
result.violations
);
}
#[test]
fn test_all_batched_kernels_have_batch_dispatch() {
let kernels: Vec<(&str, Box<dyn Kernel>, &str)> = vec![
(
"BatchedVectorizedRmsNormKernel",
Box::new(BatchedVectorizedRmsNormKernel::new(HIDDEN_1536, 1)),
"grid_y",
),
(
"BatchedQ4KGemvKernel",
Box::new(BatchedQ4KGemvKernel::new(HIDDEN_1536, HIDDEN_1536, 1)),
"register_unroll",
),
(
"BatchedQ6KGemvKernel",
Box::new(BatchedQ6KGemvKernel::new(HIDDEN_1536, HIDDEN_1536, 1)),
"register_unroll",
),
(
"BatchedResidualAddKernel",
Box::new(BatchedResidualAddKernel::new(HIDDEN_1536, 1)),
"grid_y",
),
(
"BatchedRopeKernel",
Box::new(BatchedRopeKernel::new(NUM_HEADS_12, HEAD_DIM_128, 1, ROPE_THETA)),
"grid_y",
),
(
"BatchedSwigluKernel",
Box::new(BatchedSwigluKernel::new(INTERMEDIATE_4864, 1)),
"grid_y",
),
];
let mut failures = Vec::new();
for (name, kernel, expected_strategy) in &kernels {
let ptx = kernel.emit_ptx();
let has_grid_y = ptx.contains("%ctaid.y");
let has_m_dim = ptx.contains("m_dim");
let has_any = has_grid_y || has_m_dim;
if !has_any {
failures.push(format!("{} missing batch dispatch (no ctaid.y or m_dim)", name));
}
match *expected_strategy {
"grid_y" => {
if !has_grid_y {
failures
.push(format!("{} expected grid_y dispatch but missing ctaid.y", name));
}
}
"register_unroll" => {
if !has_m_dim {
failures
.push(format!("{} expected register_unroll but missing m_dim", name));
}
}
_ => {}
}
}
assert!(
failures.is_empty(),
"GH-219: Batched kernel dispatch validation failures:\n{}",
failures.join("\n")
);
}
#[test]
fn test_no_u64_shared_memory_in_batched_kernels() {
let kernels: Vec<(&str, Box<dyn Kernel>)> = vec![
(
"BatchedVectorizedRmsNormKernel",
Box::new(BatchedVectorizedRmsNormKernel::new(HIDDEN_1536, 1)),
),
(
"BatchedQ4KGemvKernel",
Box::new(BatchedQ4KGemvKernel::new(HIDDEN_1536, HIDDEN_1536, 1)),
),
(
"BatchedQ6KGemvKernel",
Box::new(BatchedQ6KGemvKernel::new(HIDDEN_1536, HIDDEN_1536, 1)),
),
];
let mut failures = Vec::new();
for (name, kernel) in &kernels {
let ptx = kernel.emit_ptx();
for line in ptx.lines() {
let trimmed = line.trim();
if (trimmed.contains("st.shared") || trimmed.contains("ld.shared"))
&& trimmed.contains("[%rd")
{
failures.push(format!("{}: u64 shared mem address: {}", name, trimmed.trim()));
}
}
}
assert!(
failures.is_empty(),
"GH-219: Batched kernels with u64 shared memory addressing:\n{}",
failures.join("\n")
);
}
fn extract_smem(ptx: &str) -> Option<u32> {
for line in ptx.lines() {
let trimmed = line.trim();
if trimmed.contains(".shared") && trimmed.contains("smem[") {
if let Some(start) = trimmed.find("smem[") {
let after = &trimmed[start + 5..];
if let Some(end) = after.find(']') {
if let Ok(size) = after[..end].parse::<u32>() {
return Some(size);
}
}
}
}
}
None
}
}