use std::fmt::Write as FmtWrite;
use crate::arch::SmVersion;
use crate::error::PtxGenError;
use crate::ir::PtxType;
#[derive(Debug, Clone)]
pub enum EpilogueKind {
LinearCombination,
LinearCombinationRelu,
LinearCombinationGelu,
LinearCombinationBias,
LinearCombinationBiasRelu,
}
impl EpilogueKind {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::LinearCombination => "lincomb",
Self::LinearCombinationRelu => "lincomb_relu",
Self::LinearCombinationGelu => "lincomb_gelu",
Self::LinearCombinationBias => "lincomb_bias",
Self::LinearCombinationBiasRelu => "lincomb_bias_relu",
}
}
#[must_use]
pub const fn needs_bias(&self) -> bool {
matches!(
self,
Self::LinearCombinationBias | Self::LinearCombinationBiasRelu
)
}
}
pub struct GemmTemplate {
pub tile_m: u32,
pub tile_n: u32,
pub tile_k: u32,
pub warp_m: u32,
pub warp_n: u32,
pub precision: PtxType,
pub accumulator: PtxType,
pub use_tensor_core: bool,
pub stages: u32,
pub target: SmVersion,
pub epilogue: EpilogueKind,
}
impl GemmTemplate {
#[must_use]
pub fn kernel_name(&self) -> String {
let prec = self.precision.as_ptx_str().trim_start_matches('.');
let acc = self.accumulator.as_ptx_str().trim_start_matches('.');
let tc = if self.use_tensor_core { "tc" } else { "naive" };
format!(
"gemm_{}x{}x{}_{}_{}_{}",
self.tile_m, self.tile_n, self.tile_k, prec, acc, tc
)
}
pub fn generate(&self) -> Result<String, PtxGenError> {
self.validate()?;
let ty = self.precision.as_ptx_str();
let acc_ty = self.accumulator.as_ptx_str();
let byte_size = self.precision.size_bytes();
let kernel_name = self.kernel_name();
let mut ptx = String::with_capacity(8192);
writeln!(ptx, ".version {}", self.target.ptx_version())
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".target {}", self.target.as_ptx_str()).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".address_size 64").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".visible .entry {kernel_name}(").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_a,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_b,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_c,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_m,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_n,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_k,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param {acc_ty} %param_alpha,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param {acc_ty} %param_beta").map_err(PtxGenError::FormatError)?;
writeln!(ptx, ")").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "{{").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b32 %r<32>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b64 %rd<16>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .f32 %f<16>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .pred %p<4>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Compute row and column indices").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r0, %tid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r1, %tid.y;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r2, %ctaid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r3, %ctaid.y;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r4, %ntid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r5, %ntid.y;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.u32 %r6, %r2, %r4, %r0;").map_err(PtxGenError::FormatError)?; writeln!(ptx, " mad.lo.u32 %r7, %r3, %r5, %r1;").map_err(PtxGenError::FormatError)?; writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Load parameters").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd0, [%param_a];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd1, [%param_b];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd2, [%param_c];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r8, [%param_m];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r9, [%param_n];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r10, [%param_k];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param{acc_ty} %f8, [%param_alpha];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param{acc_ty} %f9, [%param_beta];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Bounds check: row < M && col < N")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p0, %r7, %r8;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p1, %r6, %r9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p0 bra $GEMM_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p1 bra $GEMM_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Initialize accumulator to 0").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov{acc_ty} %f0, 0f00000000;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r11, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // K-loop: accumulate dot product").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$K_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p2, %r11, %r10;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p2 bra $K_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // A[row, k]").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.u32 %r12, %r7, %r10, %r11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd3, %r12;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd3, %rd3, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd4, %rd0, %rd3;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f1, [%rd4];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // B[k, col]").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.u32 %r13, %r11, %r9, %r6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd5, %r13;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd5, %rd5, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd6, %rd1, %rd5;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f2, [%rd6];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " fma.rn{acc_ty} %f0, %f1, %f2, %f0;")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r11, %r11, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $K_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$K_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Epilogue: C = alpha * acc + beta * C_old")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.u32 %r14, %r7, %r9, %r6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd7, %r14;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd7, %rd7, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd8, %rd2, %rd7;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f3, [%rd8];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul{acc_ty} %f0, %f0, %f8;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " fma.rn{acc_ty} %f0, %f9, %f3, %f0;")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.global{ty} [%rd8], %f0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$GEMM_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ret;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "}}").map_err(PtxGenError::FormatError)?;
Ok(ptx)
}
#[allow(clippy::too_many_lines)]
pub fn generate_pipelined(&self) -> Result<String, PtxGenError> {
self.validate()?;
if self.stages < 2 {
return Err(PtxGenError::GenerationFailed(
"generate_pipelined requires stages >= 2; use generate() for single-stage GEMM"
.to_string(),
));
}
let ty = self.precision.as_ptx_str();
let acc_ty = self.accumulator.as_ptx_str();
let kernel_name = format!("gemm_pipelined_{}_{}stage", self.kernel_name(), self.stages);
let stages = self.stages;
let mut ptx = String::with_capacity(16_384);
writeln!(ptx, ".version {}", self.target.ptx_version())
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".target {}", self.target.as_ptx_str()).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".address_size 64").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
#[allow(clippy::cast_possible_truncation)]
let elem_bytes = self.precision.size_bytes() as u32;
let a_tile_bytes = self.tile_m * self.tile_k * elem_bytes;
let b_tile_bytes = self.tile_k * self.tile_n * elem_bytes;
let smem_per_stage = a_tile_bytes + b_tile_bytes;
let smem_total = smem_per_stage * stages;
writeln!(ptx, ".shared .align 128 .b8 smem_a_b[{smem_total}];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".visible .entry {kernel_name}(").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_a,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_b,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_c,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_m,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_n,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_k,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param {acc_ty} %param_alpha,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param {acc_ty} %param_beta").map_err(PtxGenError::FormatError)?;
writeln!(ptx, ")").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "{{").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b32 %r<64>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b64 %rd<32>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .f32 %acc<32>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .pred %p<8>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Load parameters").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd0, [%param_a];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd1, [%param_b];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd2, [%param_c];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r0, [%param_k];").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Zero accumulators").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.f32 %acc0, 0f00000000;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" // ── Pipeline prologue: prefetch stages 0..{} ──────",
stages - 1
)
.map_err(PtxGenError::FormatError)?;
for s in 0..(stages - 1) {
writeln!(ptx, " // Stage {s}: prefetch A tile").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" cp.async.ca.shared.global [smem_a_b+{offset}], [%rd0], 16; // stage {s} A",
offset = s * smem_per_stage
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Stage {s}: prefetch B tile").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" cp.async.ca.shared.global [smem_a_b+{offset}], [%rd1], 16; // stage {s} B",
offset = s * smem_per_stage + a_tile_bytes
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cp.async.commit_group;").map_err(PtxGenError::FormatError)?;
}
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" // ── Main pipeline loop (steady state) ─────────────"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r1, 0; // stage_idx").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r2, 0; // k_tile").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$PIPE_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p0, %r2, %r0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p0 bra $PIPE_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Prefetch next k-tile").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" cp.async.ca.shared.global [smem_a_b], [%rd0], 16; // prefetch A"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" cp.async.ca.shared.global [smem_a_b+{a_tile_bytes}], [%rd1], 16; // prefetch B"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cp.async.commit_group;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
let wait_groups = stages.saturating_sub(1);
writeln!(ptx, " // Drain oldest pipeline stage").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cp.async.wait_group {wait_groups};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bar.sync 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Tensor core compute").map_err(PtxGenError::FormatError)?;
if self.use_tensor_core {
writeln!(
ptx,
" mma.sync.aligned.m16n8k16.row.col.f32{ty}{ty}.f32 {{%acc0,%acc1,%acc2,%acc3}}, {{%r4,%r5,%r6,%r7}}, {{%r8,%r9}}, {{%acc0,%acc1,%acc2,%acc3}};"
)
.map_err(PtxGenError::FormatError)?;
} else {
writeln!(
ptx,
" fma.rn{acc_ty} %acc0, %acc0, %acc0, %acc0; // naive FMA placeholder"
)
.map_err(PtxGenError::FormatError)?;
}
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r2, %r2, 1; // k_tile++").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r1, %r1, 1; // stage_idx++")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $PIPE_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$PIPE_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" // ── Pipeline epilogue: drain remaining stages ──────"
)
.map_err(PtxGenError::FormatError)?;
for flush in 0..(stages - 1) {
let remaining = (stages - 2).saturating_sub(flush);
writeln!(
ptx,
" cp.async.wait_group {remaining}; // flush stage {flush}"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bar.sync 0;").map_err(PtxGenError::FormatError)?;
if self.use_tensor_core {
writeln!(
ptx,
" mma.sync.aligned.m16n8k16.row.col.f32{ty}{ty}.f32 {{%acc0,%acc1,%acc2,%acc3}}, {{%r4,%r5,%r6,%r7}}, {{%r8,%r9}}, {{%acc0,%acc1,%acc2,%acc3}}; // epilogue mma {flush}"
)
.map_err(PtxGenError::FormatError)?;
} else {
writeln!(
ptx,
" fma.rn{acc_ty} %acc0, %acc0, %acc0, %acc0; // epilogue FMA {flush}"
)
.map_err(PtxGenError::FormatError)?;
}
}
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Write accumulator to C").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.global{acc_ty} [%rd2], %acc0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ret;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "}}").map_err(PtxGenError::FormatError)?;
Ok(ptx)
}
fn validate(&self) -> Result<(), PtxGenError> {
if !matches!(
self.precision,
PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64
) {
return Err(PtxGenError::InvalidType(format!(
"GEMM requires F16, BF16, F32, or F64 precision, got {}",
self.precision.as_ptx_str()
)));
}
if !matches!(self.accumulator, PtxType::F32 | PtxType::F64) {
return Err(PtxGenError::InvalidType(format!(
"GEMM accumulator must be F32 or F64, got {}",
self.accumulator.as_ptx_str()
)));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arch::SmVersion;
#[test]
fn kernel_name_format() {
let t = GemmTemplate {
tile_m: 128,
tile_n: 128,
tile_k: 32,
warp_m: 64,
warp_n: 64,
precision: PtxType::F32,
accumulator: PtxType::F32,
use_tensor_core: false,
stages: 2,
target: SmVersion::Sm80,
epilogue: EpilogueKind::LinearCombination,
};
assert_eq!(t.kernel_name(), "gemm_128x128x32_f32_f32_naive");
}
#[test]
fn kernel_name_tensor_core() {
let t = GemmTemplate {
tile_m: 128,
tile_n: 128,
tile_k: 32,
warp_m: 64,
warp_n: 64,
precision: PtxType::F16,
accumulator: PtxType::F32,
use_tensor_core: true,
stages: 3,
target: SmVersion::Sm80,
epilogue: EpilogueKind::LinearCombinationRelu,
};
assert_eq!(t.kernel_name(), "gemm_128x128x32_f16_f32_tc");
}
#[test]
fn epilogue_kind_names() {
assert_eq!(EpilogueKind::LinearCombination.as_str(), "lincomb");
assert_eq!(
EpilogueKind::LinearCombinationBiasRelu.as_str(),
"lincomb_bias_relu"
);
assert!(EpilogueKind::LinearCombinationBias.needs_bias());
assert!(!EpilogueKind::LinearCombination.needs_bias());
}
#[test]
fn generate_naive_gemm_f32() {
let t = GemmTemplate {
tile_m: 16,
tile_n: 16,
tile_k: 16,
warp_m: 16,
warp_n: 16,
precision: PtxType::F32,
accumulator: PtxType::F32,
use_tensor_core: false,
stages: 1,
target: SmVersion::Sm80,
epilogue: EpilogueKind::LinearCombination,
};
let ptx = t.generate().expect("should generate naive GEMM");
assert!(ptx.contains(".entry gemm_"));
assert!(ptx.contains("fma.rn.f32"));
assert!(ptx.contains("$K_LOOP"));
}
#[test]
fn invalid_accumulator() {
let t = GemmTemplate {
tile_m: 16,
tile_n: 16,
tile_k: 16,
warp_m: 16,
warp_n: 16,
precision: PtxType::F32,
accumulator: PtxType::F16,
use_tensor_core: false,
stages: 1,
target: SmVersion::Sm80,
epilogue: EpilogueKind::LinearCombination,
};
assert!(t.generate().is_err());
}
fn make_pipelined_template(stages: u32, use_tensor_core: bool) -> GemmTemplate {
GemmTemplate {
tile_m: 128,
tile_n: 128,
tile_k: 32,
warp_m: 64,
warp_n: 64,
precision: PtxType::F16,
accumulator: PtxType::F32,
use_tensor_core,
stages,
target: SmVersion::Sm80,
epilogue: EpilogueKind::LinearCombination,
}
}
#[test]
fn test_3stage_pipeline_gemm_ptx_structure() {
let t = make_pipelined_template(3, false);
let ptx = t
.generate_pipelined()
.expect("3-stage pipelined GEMM should generate");
assert!(
ptx.contains(".entry gemm_pipelined_"),
"expected pipelined entry point in PTX:\n{ptx}"
);
let cp_async_count = ptx.matches("cp.async.ca.shared.global").count();
assert!(
cp_async_count >= 3,
"expected at least 3 cp.async instructions for 3-stage pipeline, got {cp_async_count}:\n{ptx}"
);
let commit_count = ptx.matches("cp.async.commit_group;").count();
assert!(
commit_count >= 2,
"expected at least 2 cp.async.commit_group fences for 3-stage pipeline, got {commit_count}:\n{ptx}"
);
assert!(
ptx.contains("cp.async.wait_group"),
"expected cp.async.wait_group in 3-stage pipelined PTX:\n{ptx}"
);
assert!(
ptx.contains("bar.sync 0;"),
"expected bar.sync 0 between pipeline stages:\n{ptx}"
);
assert!(
ptx.contains(".shared"),
"expected shared memory declaration:\n{ptx}"
);
}
#[test]
fn test_4stage_pipeline_gemm_ptx_structure() {
let t = make_pipelined_template(4, false);
let ptx = t
.generate_pipelined()
.expect("4-stage pipelined GEMM should generate");
assert!(
ptx.contains(".entry gemm_pipelined_"),
"expected pipelined entry point:\n{ptx}"
);
let cp_async_count = ptx.matches("cp.async.ca.shared.global").count();
assert!(
cp_async_count >= 4,
"expected at least 4 cp.async instructions for 4-stage pipeline, got {cp_async_count}:\n{ptx}"
);
let commit_count = ptx.matches("cp.async.commit_group;").count();
assert!(
commit_count >= 3,
"expected at least 3 cp.async.commit_group fences for 4-stage pipeline, got {commit_count}:\n{ptx}"
);
assert!(
ptx.contains("cp.async.wait_group"),
"expected cp.async.wait_group in 4-stage pipelined PTX:\n{ptx}"
);
assert!(
ptx.contains("bar.sync 0;"),
"expected bar.sync 0 between pipeline stages:\n{ptx}"
);
}
#[test]
fn test_3stage_pipeline_tensor_core_contains_mma() {
let t = make_pipelined_template(3, true);
let ptx = t
.generate_pipelined()
.expect("3-stage TC pipelined GEMM should generate");
assert!(
ptx.contains("mma.sync.aligned"),
"expected mma.sync.aligned in tensor-core pipelined PTX:\n{ptx}"
);
assert!(
ptx.contains("cp.async.ca.shared.global"),
"expected cp.async for shared memory prefetch:\n{ptx}"
);
}
#[test]
fn test_pipeline_requires_stages_ge_2() {
let t = make_pipelined_template(1, false);
let result = t.generate_pipelined();
assert!(
result.is_err(),
"generate_pipelined should reject stages < 2"
);
}
#[test]
fn test_pipeline_smem_declaration_scales_with_stages() {
let t3 = make_pipelined_template(3, false);
let t4 = make_pipelined_template(4, false);
let ptx3 = t3.generate_pipelined().expect("3-stage should generate");
let ptx4 = t4.generate_pipelined().expect("4-stage should generate");
assert!(
ptx3.contains(".shared"),
"3-stage PTX must have .shared declaration"
);
assert!(
ptx4.contains(".shared"),
"4-stage PTX must have .shared declaration"
);
assert!(
ptx4.len() > ptx3.len(),
"4-stage PTX ({} bytes) should be longer than 3-stage PTX ({} bytes)",
ptx4.len(),
ptx3.len()
);
}
}