use std::fmt::Write as FmtWrite;
use crate::arch::SmVersion;
use crate::error::PtxGenError;
use crate::ir::PtxType;
#[derive(Debug, Clone)]
pub struct TransposeTemplate {
pub precision: PtxType,
pub tile_dim: u32,
pub block_rows: u32,
}
impl TransposeTemplate {
#[must_use]
pub const fn new(precision: PtxType) -> Self {
Self {
precision,
tile_dim: 32,
block_rows: 8,
}
}
#[must_use]
pub const fn tile_dim(mut self, dim: u32) -> Self {
self.tile_dim = dim;
self
}
#[must_use]
pub const fn block_rows(mut self, rows: u32) -> Self {
self.block_rows = rows;
self
}
#[must_use]
pub fn kernel_name(&self) -> String {
let type_str = self.precision.as_ptx_str().trim_start_matches('.');
if self.tile_dim == 32 {
format!("transpose_{type_str}")
} else {
format!("transpose_{type_str}_t{}", self.tile_dim)
}
}
pub fn validate(&self) -> Result<(), PtxGenError> {
if !matches!(self.precision, PtxType::F32 | PtxType::F64) {
return Err(PtxGenError::InvalidType(format!(
"transpose requires F32 or F64, got {}",
self.precision.as_ptx_str()
)));
}
if self.tile_dim < 16 {
return Err(PtxGenError::GenerationFailed(format!(
"tile_dim must be >= 16, got {}",
self.tile_dim
)));
}
if !self.tile_dim.is_power_of_two() {
return Err(PtxGenError::GenerationFailed(format!(
"tile_dim must be a power of 2, got {}",
self.tile_dim
)));
}
if self.block_rows == 0 {
return Err(PtxGenError::GenerationFailed(
"block_rows must be > 0".to_string(),
));
}
if self.tile_dim % self.block_rows != 0 {
return Err(PtxGenError::GenerationFailed(format!(
"block_rows ({}) must divide tile_dim ({}) evenly",
self.block_rows, self.tile_dim
)));
}
Ok(())
}
#[allow(clippy::too_many_lines)]
pub fn generate(&self, sm: SmVersion) -> Result<String, PtxGenError> {
self.validate()?;
let ty = self.precision.as_ptx_str();
let byte_size = self.precision.size_bytes();
let tile_dim = self.tile_dim;
let block_rows = self.block_rows;
let iterations = tile_dim / block_rows;
let kernel_name = self.kernel_name();
let smem_pitch = tile_dim + 1;
let smem_elements = tile_dim * smem_pitch;
let smem_bytes = smem_elements as usize * byte_size;
let reg_count: u32 = 32;
let rd_count: u32 = 16;
let val_count: u32 = (iterations + 4).max(8);
let p_count: u32 = (iterations * 2 + 4).max(8);
let mut ptx = String::with_capacity(4096);
emit_header(&mut ptx, sm)?;
writeln!(ptx, ".visible .entry {kernel_name}(").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_output,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_input,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_width,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_height").map_err(PtxGenError::FormatError)?;
writeln!(ptx, ")").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "{{").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .maxntid {tile_dim}, {block_rows}, 1;")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .u32 %r<{reg_count}>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .u64 %rd<{rd_count}>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg {ty} %val<{val_count}>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .pred %p<{p_count}>;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" .shared .align {} .b8 smem_tile[{}];",
byte_size.max(4),
smem_bytes
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Load kernel parameters").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd0, [%param_output];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd1, [%param_input];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r0, [%param_width];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r1, [%param_height];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Compute tile origin and thread indices")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r2, %tid.x;").map_err(PtxGenError::FormatError)?; writeln!(ptx, " mov.u32 %r3, %tid.y;").map_err(PtxGenError::FormatError)?; writeln!(ptx, " mov.u32 %r4, %ctaid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r5, %ctaid.y;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r6, %r4, {tile_dim};").map_err(PtxGenError::FormatError)?; writeln!(ptx, " mul.lo.u32 %r7, %r5, {tile_dim};").map_err(PtxGenError::FormatError)?; writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u64 %rd2, smem_tile;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" // Phase 1: Load tile from global to shared memory (coalesced)"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r8, %r6, %r2;").map_err(PtxGenError::FormatError)?;
for i in 0..iterations {
let row_offset = i * block_rows;
let pred_col = format!("%p{}", i * 2);
let pred_row = format!("%p{}", i * 2 + 1);
let val_reg = format!("%val{i}");
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Iteration {i}: row offset = {row_offset}")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r9, %r7, %r3;").map_err(PtxGenError::FormatError)?;
if row_offset > 0 {
writeln!(ptx, " add.u32 %r9, %r9, {row_offset};")
.map_err(PtxGenError::FormatError)?;
}
writeln!(ptx, " setp.lt.u32 {pred_col}, %r8, %r0;")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.lt.u32 {pred_row}, %r9, %r1;")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " and.pred {pred_col}, {pred_col}, {pred_row};")
.map_err(PtxGenError::FormatError)?;
emit_zero_mov(&mut ptx, &val_reg, self.precision)?;
writeln!(ptx, " @!{pred_col} bra $SKIP_LOAD_{i};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r10, %r9, %r0;").map_err(PtxGenError::FormatError)?; writeln!(ptx, " add.u32 %r10, %r10, %r8;").map_err(PtxGenError::FormatError)?; writeln!(ptx, " cvt.u64.u32 %rd3, %r10;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd3, %rd3, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd4, %rd1, %rd3;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} {val_reg}, [%rd4];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$SKIP_LOAD_{i}:").map_err(PtxGenError::FormatError)?;
let smem_row = format!("ty_plus_{i}");
let _ = smem_row; writeln!(ptx, " add.u32 %r11, %r3, {row_offset};")
.map_err(PtxGenError::FormatError)?; writeln!(ptx, " mul.lo.u32 %r12, %r11, {smem_pitch};")
.map_err(PtxGenError::FormatError)?; writeln!(ptx, " add.u32 %r12, %r12, %r2;").map_err(PtxGenError::FormatError)?; writeln!(ptx, " cvt.u64.u32 %rd5, %r12;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd5, %rd5, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd6, %rd2, %rd5;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.shared{ty} [%rd6], {val_reg};")
.map_err(PtxGenError::FormatError)?;
}
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Synchronize after loading tile").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bar.sync 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" // Phase 2: Write transposed tile from shared to global (coalesced)"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r13, %r7, %r2;").map_err(PtxGenError::FormatError)?;
for i in 0..iterations {
let row_offset = i * block_rows;
let pred_base = iterations * 2; let pred_col = format!("%p{}", pred_base + i * 2);
let pred_row = format!("%p{}", pred_base + i * 2 + 1);
let val_reg = format!("%val{i}");
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Write iteration {i}: row offset = {row_offset}")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r14, %r6, %r3;").map_err(PtxGenError::FormatError)?;
if row_offset > 0 {
writeln!(ptx, " add.u32 %r14, %r14, {row_offset};")
.map_err(PtxGenError::FormatError)?;
}
writeln!(ptx, " setp.lt.u32 {pred_col}, %r13, %r1;")
.map_err(PtxGenError::FormatError)?; writeln!(ptx, " setp.lt.u32 {pred_row}, %r14, %r0;")
.map_err(PtxGenError::FormatError)?; writeln!(ptx, " and.pred {pred_col}, {pred_col}, {pred_row};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r15, %r2, {smem_pitch};")
.map_err(PtxGenError::FormatError)?; writeln!(ptx, " add.u32 %r16, %r3, {row_offset};")
.map_err(PtxGenError::FormatError)?; writeln!(ptx, " add.u32 %r15, %r15, %r16;").map_err(PtxGenError::FormatError)?; writeln!(ptx, " cvt.u64.u32 %rd7, %r15;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd7, %rd7, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd8, %rd2, %rd7;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.shared{ty} {val_reg}, [%rd8];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @!{pred_col} bra $SKIP_STORE_{i};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r17, %r14, %r1;").map_err(PtxGenError::FormatError)?; writeln!(ptx, " add.u32 %r17, %r17, %r13;").map_err(PtxGenError::FormatError)?; writeln!(ptx, " cvt.u64.u32 %rd9, %r17;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd9, %rd9, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd10, %rd0, %rd9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.global{ty} [%rd10], {val_reg};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$SKIP_STORE_{i}:").map_err(PtxGenError::FormatError)?;
}
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ret;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "}}").map_err(PtxGenError::FormatError)?;
Ok(ptx)
}
}
fn emit_header(ptx: &mut String, sm: SmVersion) -> Result<(), PtxGenError> {
writeln!(ptx, ".version {}", sm.ptx_version()).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".target {}", sm.as_ptx_str()).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".address_size 64").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
Ok(())
}
fn emit_zero_mov(ptx: &mut String, reg: &str, ty: PtxType) -> Result<(), PtxGenError> {
match ty {
PtxType::F32 => {
writeln!(ptx, " mov.f32 {reg}, 0f00000000;").map_err(PtxGenError::FormatError)?;
}
PtxType::F64 => {
writeln!(ptx, " mov.f64 {reg}, 0d0000000000000000;")
.map_err(PtxGenError::FormatError)?;
}
_ => {
return Err(PtxGenError::InvalidType(format!(
"zero literal not supported for {}",
ty.as_ptx_str()
)));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let t = TransposeTemplate::new(PtxType::F32);
assert_eq!(t.tile_dim, 32);
assert_eq!(t.block_rows, 8);
assert_eq!(t.precision, PtxType::F32);
}
#[test]
fn test_builder_methods() {
let t = TransposeTemplate::new(PtxType::F64)
.tile_dim(16)
.block_rows(4);
assert_eq!(t.tile_dim, 16);
assert_eq!(t.block_rows, 4);
assert_eq!(t.precision, PtxType::F64);
}
#[test]
fn test_kernel_name_default_tile() {
let t = TransposeTemplate::new(PtxType::F32);
assert_eq!(t.kernel_name(), "transpose_f32");
}
#[test]
fn test_kernel_name_custom_tile() {
let t = TransposeTemplate::new(PtxType::F64).tile_dim(16);
assert_eq!(t.kernel_name(), "transpose_f64_t16");
}
#[test]
fn test_kernel_name_f64() {
let t = TransposeTemplate::new(PtxType::F64);
assert_eq!(t.kernel_name(), "transpose_f64");
}
#[test]
fn test_validate_ok_f32() {
let t = TransposeTemplate::new(PtxType::F32);
assert!(t.validate().is_ok());
}
#[test]
fn test_validate_ok_f64() {
let t = TransposeTemplate::new(PtxType::F64);
assert!(t.validate().is_ok());
}
#[test]
fn test_validate_invalid_type() {
let t = TransposeTemplate::new(PtxType::U32);
let err = t.validate().unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("F32 or F64"), "unexpected error: {msg}");
}
#[test]
fn test_validate_tile_dim_too_small() {
let t = TransposeTemplate::new(PtxType::F32).tile_dim(8);
let err = t.validate().unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("tile_dim must be >= 16"),
"unexpected error: {msg}"
);
}
#[test]
fn test_validate_tile_dim_not_power_of_two() {
let t = TransposeTemplate::new(PtxType::F32).tile_dim(24);
let err = t.validate().unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("power of 2"), "unexpected error: {msg}");
}
#[test]
fn test_validate_block_rows_not_divisor() {
let t = TransposeTemplate::new(PtxType::F32)
.tile_dim(32)
.block_rows(6);
let err = t.validate().unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("divide"), "unexpected error: {msg}");
}
#[test]
fn test_validate_block_rows_zero() {
let t = TransposeTemplate::new(PtxType::F32).block_rows(0);
let err = t.validate().unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("block_rows must be > 0"),
"unexpected error: {msg}"
);
}
#[test]
fn test_generate_f32_contains_shared_memory() {
let t = TransposeTemplate::new(PtxType::F32);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
assert!(ptx.contains(".shared"), "PTX must declare shared memory");
assert!(ptx.contains("smem_tile"), "PTX must use smem_tile label");
}
#[test]
fn test_generate_f32_bank_conflict_free_padding() {
let t = TransposeTemplate::new(PtxType::F32);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
let expected_bytes = 32 * 33 * 4;
let pattern = format!("smem_tile[{expected_bytes}]");
assert!(
ptx.contains(&pattern),
"expected shared memory size {expected_bytes}, PTX:\n{ptx}"
);
}
#[test]
fn test_generate_f64_bank_conflict_free_padding() {
let t = TransposeTemplate::new(PtxType::F64);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
let expected_bytes = 32 * 33 * 8;
let pattern = format!("smem_tile[{expected_bytes}]");
assert!(
ptx.contains(&pattern),
"expected shared memory size {expected_bytes}, PTX:\n{ptx}"
);
}
#[test]
fn test_generate_f32_contains_kernel_name() {
let t = TransposeTemplate::new(PtxType::F32);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
assert!(
ptx.contains("transpose_f32"),
"kernel name not found in PTX"
);
}
#[test]
fn test_generate_f64_valid() {
let t = TransposeTemplate::new(PtxType::F64);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
assert!(ptx.contains("transpose_f64"));
assert!(ptx.contains(".f64"));
assert!(ptx.contains("ld.global.f64"));
assert!(ptx.contains("st.global.f64"));
}
#[test]
fn test_generate_contains_bar_sync() {
let t = TransposeTemplate::new(PtxType::F32);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
assert!(
ptx.contains("bar.sync"),
"PTX must contain bar.sync for shared memory coherence"
);
}
#[test]
fn test_generate_contains_kernel_params() {
let t = TransposeTemplate::new(PtxType::F32);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
assert!(ptx.contains("%param_output"), "missing output param");
assert!(ptx.contains("%param_input"), "missing input param");
assert!(ptx.contains("%param_width"), "missing width param");
assert!(ptx.contains("%param_height"), "missing height param");
}
#[test]
fn test_generate_coalesced_reads_and_writes() {
let t = TransposeTemplate::new(PtxType::F32);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
assert!(
ptx.contains("ld.global.f32"),
"missing coalesced global loads"
);
assert!(
ptx.contains("st.global.f32"),
"missing coalesced global stores"
);
assert!(ptx.contains("ld.shared.f32"), "missing shared memory loads");
assert!(
ptx.contains("st.shared.f32"),
"missing shared memory stores"
);
}
#[test]
fn test_generate_custom_tile_16() {
let t = TransposeTemplate::new(PtxType::F32)
.tile_dim(16)
.block_rows(4);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
assert!(ptx.contains("transpose_f32_t16"));
assert!(ptx.contains("smem_tile[1088]"));
}
#[test]
fn test_generate_custom_tile_64() {
let t = TransposeTemplate::new(PtxType::F32)
.tile_dim(64)
.block_rows(16);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
assert!(ptx.contains("transpose_f32_t64"));
assert!(ptx.contains("smem_tile[16640]"));
}
#[test]
fn test_generate_iterations_count() {
let t = TransposeTemplate::new(PtxType::F32);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
assert!(ptx.contains("$SKIP_LOAD_0:"));
assert!(ptx.contains("$SKIP_LOAD_3:"));
assert!(ptx.contains("$SKIP_STORE_0:"));
assert!(ptx.contains("$SKIP_STORE_3:"));
}
#[test]
fn test_generate_invalid_config_returns_error() {
let t = TransposeTemplate::new(PtxType::U32);
assert!(t.generate(SmVersion::Sm80).is_err());
}
#[test]
fn test_generate_sm75() {
let t = TransposeTemplate::new(PtxType::F32);
let ptx = t.generate(SmVersion::Sm75).expect("Sm75 should work");
assert!(ptx.contains(".target"));
}
#[test]
fn test_generate_maxntid_matches_block_config() {
let t = TransposeTemplate::new(PtxType::F32)
.tile_dim(16)
.block_rows(4);
let ptx = t
.generate(SmVersion::Sm80)
.expect("generation should succeed");
assert!(
ptx.contains(".maxntid 16, 4, 1"),
"maxntid should match tile_dim x block_rows"
);
}
}