use std::fmt::Write as FmtWrite;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::ir::PtxType;
use crate::error::{BlasError, BlasResult};
use crate::types::Transpose;
pub struct SimtGemmBuilder {
target: SmVersion,
precision: PtxType,
accumulator: PtxType,
trans_a: Transpose,
trans_b: Transpose,
}
impl SimtGemmBuilder {
pub fn new(
target: SmVersion,
precision: PtxType,
accumulator: PtxType,
trans_a: Transpose,
trans_b: Transpose,
) -> Self {
Self {
target,
precision,
accumulator,
trans_a,
trans_b,
}
}
pub fn kernel_name(&self) -> String {
let prec = self.precision.as_ptx_str().trim_start_matches('.');
let acc = self.accumulator.as_ptx_str().trim_start_matches('.');
let ta = trans_label(self.trans_a);
let tb = trans_label(self.trans_b);
format!("simt_gemm_{prec}_{acc}_{ta}_{tb}")
}
pub fn generate(&self) -> BlasResult<String> {
self.validate()?;
let ty = self.precision.as_ptx_str();
let acc_ty = self.accumulator.as_ptx_str();
let byte_size = self.precision.size_bytes();
let kernel_name = self.kernel_name();
let mut ptx = String::with_capacity(4096);
write_line(&mut ptx, &format!(".version {}", self.target.ptx_version()))?;
write_line(&mut ptx, &format!(".target {}", self.target.as_ptx_str()))?;
write_line(&mut ptx, ".address_size 64")?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, &format!(".visible .entry {kernel_name}("))?;
write_line(&mut ptx, " .param .u64 %param_a,")?;
write_line(&mut ptx, " .param .u64 %param_b,")?;
write_line(&mut ptx, " .param .u64 %param_c,")?;
write_line(&mut ptx, " .param .u32 %param_m,")?;
write_line(&mut ptx, " .param .u32 %param_n,")?;
write_line(&mut ptx, " .param .u32 %param_k,")?;
write_line(&mut ptx, " .param .u32 %param_lda,")?;
write_line(&mut ptx, " .param .u32 %param_ldb,")?;
write_line(&mut ptx, " .param .u32 %param_ldc,")?;
write_line(&mut ptx, &format!(" .param {acc_ty} %param_alpha,"))?;
write_line(&mut ptx, &format!(" .param {acc_ty} %param_beta"))?;
write_line(&mut ptx, ")")?;
write_line(&mut ptx, "{")?;
write_line(&mut ptx, " .reg .b32 %r<32>;")?;
write_line(&mut ptx, " .reg .b64 %rd<16>;")?;
write_line(&mut ptx, " .reg .f32 %f<16>;")?;
write_line(&mut ptx, " .reg .pred %p<4>;")?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, " mov.u32 %r0, %tid.x;")?;
write_line(&mut ptx, " mov.u32 %r1, %tid.y;")?;
write_line(&mut ptx, " mov.u32 %r2, %ctaid.x;")?;
write_line(&mut ptx, " mov.u32 %r3, %ctaid.y;")?;
write_line(&mut ptx, " mov.u32 %r4, %ntid.x;")?;
write_line(&mut ptx, " mov.u32 %r5, %ntid.y;")?;
write_line(&mut ptx, " mad.lo.u32 %r6, %r2, %r4, %r0; // col")?;
write_line(&mut ptx, " mad.lo.u32 %r7, %r3, %r5, %r1; // row")?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, " ld.param.u64 %rd0, [%param_a];")?;
write_line(&mut ptx, " ld.param.u64 %rd1, [%param_b];")?;
write_line(&mut ptx, " ld.param.u64 %rd2, [%param_c];")?;
write_line(&mut ptx, " ld.param.u32 %r8, [%param_m];")?;
write_line(&mut ptx, " ld.param.u32 %r9, [%param_n];")?;
write_line(&mut ptx, " ld.param.u32 %r10, [%param_k];")?;
write_line(&mut ptx, " ld.param.u32 %r20, [%param_lda];")?;
write_line(&mut ptx, " ld.param.u32 %r21, [%param_ldb];")?;
write_line(&mut ptx, " ld.param.u32 %r22, [%param_ldc];")?;
write_line(
&mut ptx,
&format!(" ld.param{acc_ty} %f8, [%param_alpha];"),
)?;
write_line(
&mut ptx,
&format!(" ld.param{acc_ty} %f9, [%param_beta];"),
)?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, " setp.ge.u32 %p0, %r7, %r8;")?;
write_line(&mut ptx, " setp.ge.u32 %p1, %r6, %r9;")?;
write_line(&mut ptx, " @%p0 bra $SIMT_DONE;")?;
write_line(&mut ptx, " @%p1 bra $SIMT_DONE;")?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, &format!(" mov{acc_ty} %f0, 0f00000000;"))?;
write_line(&mut ptx, " mov.u32 %r11, 0;")?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, "$SIMT_K_LOOP:")?;
write_line(&mut ptx, " setp.ge.u32 %p2, %r11, %r10;")?;
write_line(&mut ptx, " @%p2 bra $SIMT_K_DONE;")?;
let (a_row_reg, a_col_reg) = match self.trans_a {
Transpose::NoTrans => ("%r7", "%r11"),
Transpose::Trans | Transpose::ConjTrans => ("%r11", "%r7"),
};
write_line(
&mut ptx,
&format!(" mad.lo.u32 %r12, {a_row_reg}, %r20, {a_col_reg};"),
)?;
write_line(&mut ptx, " cvt.u64.u32 %rd3, %r12;")?;
write_line(
&mut ptx,
&format!(" mul.lo.u64 %rd3, %rd3, {byte_size};"),
)?;
write_line(&mut ptx, " add.u64 %rd4, %rd0, %rd3;")?;
write_line(&mut ptx, &format!(" ld.global{ty} %f1, [%rd4];"))?;
let (b_row_reg, b_col_reg) = match self.trans_b {
Transpose::NoTrans => ("%r11", "%r6"),
Transpose::Trans | Transpose::ConjTrans => ("%r6", "%r11"),
};
write_line(
&mut ptx,
&format!(" mad.lo.u32 %r13, {b_row_reg}, %r21, {b_col_reg};"),
)?;
write_line(&mut ptx, " cvt.u64.u32 %rd5, %r13;")?;
write_line(
&mut ptx,
&format!(" mul.lo.u64 %rd5, %rd5, {byte_size};"),
)?;
write_line(&mut ptx, " add.u64 %rd6, %rd1, %rd5;")?;
write_line(&mut ptx, &format!(" ld.global{ty} %f2, [%rd6];"))?;
write_line(&mut ptx, &format!(" fma.rn{acc_ty} %f0, %f1, %f2, %f0;"))?;
write_line(&mut ptx, " add.u32 %r11, %r11, 1;")?;
write_line(&mut ptx, " bra $SIMT_K_LOOP;")?;
write_line(&mut ptx, "$SIMT_K_DONE:")?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, " mad.lo.u32 %r14, %r7, %r22, %r6;")?;
write_line(&mut ptx, " cvt.u64.u32 %rd7, %r14;")?;
write_line(
&mut ptx,
&format!(" mul.lo.u64 %rd7, %rd7, {byte_size};"),
)?;
write_line(&mut ptx, " add.u64 %rd8, %rd2, %rd7;")?;
write_line(&mut ptx, &format!(" ld.global{ty} %f3, [%rd8];"))?;
write_line(&mut ptx, &format!(" mul{acc_ty} %f0, %f0, %f8;"))?;
write_line(&mut ptx, &format!(" fma.rn{acc_ty} %f0, %f9, %f3, %f0;"))?;
write_line(&mut ptx, &format!(" st.global{ty} [%rd8], %f0;"))?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, "$SIMT_DONE:")?;
write_line(&mut ptx, " ret;")?;
write_line(&mut ptx, "}")?;
Ok(ptx)
}
fn validate(&self) -> BlasResult<()> {
if !matches!(
self.precision,
PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64
) {
return Err(BlasError::PtxGeneration(format!(
"SIMT GEMM requires F16/BF16/F32/F64, got {}",
self.precision.as_ptx_str()
)));
}
if !matches!(self.accumulator, PtxType::F32 | PtxType::F64) {
return Err(BlasError::PtxGeneration(format!(
"SIMT GEMM accumulator must be F32/F64, got {}",
self.accumulator.as_ptx_str()
)));
}
Ok(())
}
}
fn trans_label(t: Transpose) -> &'static str {
match t {
Transpose::NoTrans => "nn",
Transpose::Trans => "tt",
Transpose::ConjTrans => "ct",
}
}
fn write_line(ptx: &mut String, line: &str) -> BlasResult<()> {
writeln!(ptx, "{line}").map_err(|e| BlasError::PtxGeneration(format!("fmt error: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generate_simt_f32_nn() {
let builder = SimtGemmBuilder::new(
SmVersion::Sm80,
PtxType::F32,
PtxType::F32,
Transpose::NoTrans,
Transpose::NoTrans,
);
let ptx = builder.generate().expect("SIMT GEMM generation failed");
assert!(ptx.contains(".entry simt_gemm_"));
assert!(ptx.contains("fma.rn.f32"));
assert!(ptx.contains("$SIMT_K_LOOP"));
}
#[test]
fn generate_simt_f32_tn() {
let builder = SimtGemmBuilder::new(
SmVersion::Sm80,
PtxType::F32,
PtxType::F32,
Transpose::Trans,
Transpose::NoTrans,
);
let ptx = builder.generate().expect("SIMT GEMM TN generation failed");
assert!(ptx.contains("simt_gemm_f32_f32_tt_nn"));
}
#[test]
fn kernel_name_format() {
let builder = SimtGemmBuilder::new(
SmVersion::Sm75,
PtxType::F64,
PtxType::F64,
Transpose::NoTrans,
Transpose::Trans,
);
assert_eq!(builder.kernel_name(), "simt_gemm_f64_f64_nn_tt");
}
#[test]
fn invalid_precision() {
let builder = SimtGemmBuilder::new(
SmVersion::Sm80,
PtxType::U32,
PtxType::F32,
Transpose::NoTrans,
Transpose::NoTrans,
);
assert!(builder.generate().is_err());
}
}