use std::fmt;
use super::instruction::PtxInstruction;
use super::kernel::PtxKernel;
use crate::instr::{MemoryOp, TensorCoreOp};
use crate::types::PtxType;
#[derive(Debug, Clone)]
pub struct PtxModule {
pub version: String,
pub target: String,
pub address_size: u32,
pub kernels: Vec<PtxKernel>,
}
impl PtxModule {
pub fn new(target: &str) -> Self {
Self {
version: "8.7".to_string(),
target: target.to_string(),
address_size: 64,
kernels: Vec::new(),
}
}
pub fn add_kernel(&mut self, kernel: PtxKernel) {
self.kernels.push(kernel);
}
fn parse_sm_target(&self) -> Option<u32> {
self.target.strip_prefix("sm_").and_then(|s| s.parse().ok())
}
pub fn validate(&self) -> Result<(), ValidationError> {
let target_sm = self.parse_sm_target();
for kernel in &self.kernels {
for instr in &kernel.body {
if let PtxInstruction::TensorCore(op) = instr {
validate_tensor_core_op(op)?;
}
if let Some(target_sm) = target_sm
&& let Some((required, feature)) = instruction_sm_requirement(instr)
&& target_sm < required
{
return Err(ValidationError::SmTooLow {
required,
actual: target_sm,
feature,
});
}
}
}
Ok(())
}
}
fn validate_tensor_core_op(op: &TensorCoreOp) -> Result<(), ValidationError> {
match op {
TensorCoreOp::MmaSync { a_ty, b_ty, .. } => {
if *a_ty == PtxType::BF16 {
return Err(ValidationError::MmaSyncBf16Rejected { operand: "a_ty" });
}
if *b_ty == PtxType::BF16 {
return Err(ValidationError::MmaSyncBf16Rejected { operand: "b_ty" });
}
}
TensorCoreOp::LdMatrix { dst, addr, .. } => {
for reg in dst.regs() {
if reg.ptx_type != PtxType::U32 {
return Err(ValidationError::LdMatrixBadRegType {
operand: "dst",
found: reg.ptx_type,
});
}
}
if addr.ptx_type != PtxType::U32 {
return Err(ValidationError::LdMatrixBadRegType {
operand: "addr",
found: addr.ptx_type,
});
}
}
_ => {}
}
Ok(())
}
fn instruction_sm_requirement(instr: &PtxInstruction) -> Option<(u32, String)> {
match instr {
PtxInstruction::TensorCore(op) => Some((op.min_sm(), op.feature_label())),
PtxInstruction::Memory(
MemoryOp::CpAsyncCaSharedGlobal { .. }
| MemoryOp::CpAsyncCommitGroup
| MemoryOp::CpAsyncWaitGroup { .. },
) => Some((80, "cp.async".to_string())),
_ => None,
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ValidationError {
SmTooLow {
required: u32,
actual: u32,
feature: String,
},
MmaSyncBf16Rejected {
operand: &'static str,
},
LdMatrixBadRegType {
operand: &'static str,
found: PtxType,
},
}
impl fmt::Display for ValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::SmTooLow {
required,
actual,
feature,
} => {
write!(
f,
"{feature} requires sm_{required}+, target is sm_{actual}"
)
}
Self::MmaSyncBf16Rejected { operand } => {
write!(
f,
"TensorCoreOp::MmaSync with PtxType::BF16 on {operand} is rejected; use TensorCoreOp::MmaSyncBf16 for bf16 emission"
)
}
Self::LdMatrixBadRegType { operand, found } => {
write!(
f,
"TensorCoreOp::LdMatrix {operand} register must be PtxType::U32 (.b32 packed-pair convention, see alloc_packed_half2), found {found:?}"
)
}
}
}
}
impl std::error::Error for ValidationError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::fragment::{alloc_a_f16, alloc_b_f16, alloc_c};
use crate::instr::{MemoryOp, MmaShape, TensorCoreOp};
use crate::ir::{PtxInstruction, PtxKernel, Register, RegisterAllocator};
use crate::types::{PtxType, RegKind};
fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
Register {
kind,
index,
ptx_type,
}
}
fn tc_kernel() -> PtxKernel {
let mut alloc = RegisterAllocator::new();
let mut k = PtxKernel::new("has_mma");
k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
d: alloc_c(&mut alloc),
a: alloc_a_f16(&mut alloc),
b: alloc_b_f16(&mut alloc),
c: alloc_c(&mut alloc),
shape: MmaShape::M16N8K16,
d_ty: PtxType::F32,
a_ty: PtxType::F16,
b_ty: PtxType::F16,
c_ty: PtxType::F32,
}));
k
}
#[test]
fn validate_rejects_mma_on_sm_70() {
let mut module = PtxModule::new("sm_70");
module.add_kernel(tc_kernel());
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::SmTooLow {
required: 80,
actual: 70,
feature: "mma.sync.m16n8k16".to_string(),
}
);
assert_eq!(
err.to_string(),
"mma.sync.m16n8k16 requires sm_80+, target is sm_70"
);
}
#[test]
fn validate_accepts_mma_on_sm_80() {
let mut module = PtxModule::new("sm_80");
module.add_kernel(tc_kernel());
assert!(module.validate().is_ok());
}
#[test]
fn validate_accepts_mma_on_sm_89() {
let mut module = PtxModule::new("sm_89");
module.add_kernel(tc_kernel());
assert!(module.validate().is_ok());
}
fn tc_int8_kernel() -> PtxKernel {
use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
let mut alloc = RegisterAllocator::new();
let mut k = PtxKernel::new("has_mma_int8");
k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSyncInt8 {
d: alloc_c_M16N8K32(&mut alloc),
a: alloc_a_M16N8K32(&mut alloc),
b: alloc_b_M16N8K32(&mut alloc),
c: alloc_c_M16N8K32(&mut alloc),
}));
k
}
#[test]
fn validate_rejects_mma_int8_on_sm_70() {
let mut module = PtxModule::new("sm_70");
module.add_kernel(tc_int8_kernel());
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::SmTooLow {
required: 80,
actual: 70,
feature: "mma.sync.m16n8k32.s8.s8.s32".to_string(),
}
);
assert_eq!(
err.to_string(),
"mma.sync.m16n8k32.s8.s8.s32 requires sm_80+, target is sm_70"
);
}
#[test]
fn validate_accepts_mma_int8_on_sm_80() {
let mut module = PtxModule::new("sm_80");
module.add_kernel(tc_int8_kernel());
assert!(module.validate().is_ok());
}
#[test]
fn validate_accepts_mma_int8_on_sm_89() {
let mut module = PtxModule::new("sm_89");
module.add_kernel(tc_int8_kernel());
assert!(module.validate().is_ok());
}
fn ldmatrix_kernel() -> PtxKernel {
use crate::instr::LdMatrixDst;
let mut alloc = RegisterAllocator::new();
let mut k = PtxKernel::new("has_ldmatrix");
k.push(PtxInstruction::TensorCore(TensorCoreOp::LdMatrix {
dst: LdMatrixDst::X4([
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
]),
addr: alloc.alloc(PtxType::U32),
trans: false,
}));
k
}
#[test]
fn validate_accepts_ldmatrix_on_sm_75() {
let mut module = PtxModule::new("sm_75");
module.add_kernel(ldmatrix_kernel());
assert!(module.validate().is_ok());
}
#[test]
fn validate_rejects_ldmatrix_on_sm_70() {
let mut module = PtxModule::new("sm_70");
module.add_kernel(ldmatrix_kernel());
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::SmTooLow {
required: 75,
actual: 70,
feature: "ldmatrix.m8n8.x4".to_string(),
}
);
assert_eq!(
err.to_string(),
"ldmatrix.m8n8.x4 requires sm_75+, target is sm_70"
);
}
#[test]
fn validate_rejects_mma_at_sm_75_even_with_ldmatrix_present() {
let mut module = PtxModule::new("sm_75");
module.add_kernel(ldmatrix_kernel());
module.add_kernel(tc_kernel());
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::SmTooLow {
required: 80,
actual: 75,
feature: "mma.sync.m16n8k16".to_string(),
}
);
}
#[test]
fn validate_rejects_ldmatrix_bad_dst_reg_type() {
use crate::instr::LdMatrixDst;
let mut alloc = RegisterAllocator::new();
let mut k = PtxKernel::new("bad_ldmatrix_dst");
k.push(PtxInstruction::TensorCore(TensorCoreOp::LdMatrix {
dst: LdMatrixDst::X4([
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
alloc.alloc(PtxType::F32),
alloc.alloc_packed_half2(),
]),
addr: alloc.alloc(PtxType::U32),
trans: false,
}));
let mut module = PtxModule::new("sm_80");
module.add_kernel(k);
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::LdMatrixBadRegType {
operand: "dst",
found: PtxType::F32,
}
);
}
#[test]
fn validate_rejects_ldmatrix_bad_addr_reg_type() {
use crate::instr::LdMatrixDst;
let mut alloc = RegisterAllocator::new();
let mut k = PtxKernel::new("bad_ldmatrix_addr");
k.push(PtxInstruction::TensorCore(TensorCoreOp::LdMatrix {
dst: LdMatrixDst::X2([alloc.alloc_packed_half2(), alloc.alloc_packed_half2()]),
addr: alloc.alloc(PtxType::U64),
trans: true,
}));
let mut module = PtxModule::new("sm_80");
module.add_kernel(k);
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::LdMatrixBadRegType {
operand: "addr",
found: PtxType::U64,
}
);
assert!(err.to_string().contains("alloc_packed_half2"));
}
#[test]
fn validate_rejects_cp_async_on_sm_75() {
let mut module = PtxModule::new("sm_75");
let mut k = PtxKernel::new("has_cp_async");
k.push(PtxInstruction::Memory(MemoryOp::new_cp_async_ca(
reg(RegKind::R, 0, PtxType::U32),
reg(RegKind::Rd, 0, PtxType::U64),
16,
)));
module.add_kernel(k);
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::SmTooLow {
required: 80,
actual: 75,
feature: "cp.async".to_string(),
}
);
}
#[test]
fn validate_accepts_scalar_kernel_on_sm_70() {
let mut module = PtxModule::new("sm_70");
let k = PtxKernel::new("scalar_only");
module.add_kernel(k);
assert!(module.validate().is_ok());
}
#[test]
fn validate_skips_unparseable_target() {
let mut module = PtxModule::new("compute_90a");
module.add_kernel(tc_kernel());
assert!(module.validate().is_ok());
}
#[test]
fn parse_sm_target() {
let m = PtxModule::new("sm_89");
assert_eq!(m.parse_sm_target(), Some(89));
let m2 = PtxModule::new("sm_80");
assert_eq!(m2.parse_sm_target(), Some(80));
let m3 = PtxModule::new("compute_90a");
assert_eq!(m3.parse_sm_target(), None);
}
fn mma_sync_with_bf16_tags() -> PtxKernel {
let mut alloc = RegisterAllocator::new();
let mut k = PtxKernel::new("legacy_bf16_on_mma_sync");
k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
d: alloc_c(&mut alloc),
a: alloc_a_f16(&mut alloc),
b: alloc_b_f16(&mut alloc),
c: alloc_c(&mut alloc),
shape: MmaShape::M16N8K16,
d_ty: PtxType::F32,
a_ty: PtxType::BF16,
b_ty: PtxType::BF16,
c_ty: PtxType::F32,
}));
k
}
#[test]
fn validate_rejects_mma_sync_bf16_a_ty() {
let mut module = PtxModule::new("sm_89");
module.add_kernel(mma_sync_with_bf16_tags());
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::MmaSyncBf16Rejected { operand: "a_ty" }
);
assert_eq!(
err.to_string(),
"TensorCoreOp::MmaSync with PtxType::BF16 on a_ty is rejected; \
use TensorCoreOp::MmaSyncBf16 for bf16 emission"
);
}
#[test]
fn validate_rejects_mma_sync_bf16_b_ty_only() {
let mut alloc = RegisterAllocator::new();
let mut k = PtxKernel::new("mixed_bf16_b_only");
k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
d: alloc_c(&mut alloc),
a: alloc_a_f16(&mut alloc),
b: alloc_b_f16(&mut alloc),
c: alloc_c(&mut alloc),
shape: MmaShape::M16N8K16,
d_ty: PtxType::F32,
a_ty: PtxType::F16,
b_ty: PtxType::BF16,
c_ty: PtxType::F32,
}));
let mut module = PtxModule::new("sm_89");
module.add_kernel(k);
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::MmaSyncBf16Rejected { operand: "b_ty" }
);
}
#[test]
fn validate_rejects_mma_sync_bf16_even_on_unparseable_target() {
let mut module = PtxModule::new("compute_90a");
module.add_kernel(mma_sync_with_bf16_tags());
assert_eq!(
module.validate().unwrap_err(),
ValidationError::MmaSyncBf16Rejected { operand: "a_ty" }
);
}
fn mma_sync_bf16_kernel() -> PtxKernel {
use crate::fragment::{alloc_a_bf16, alloc_b_bf16};
let mut alloc = RegisterAllocator::new();
let mut k = PtxKernel::new("native_bf16");
k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSyncBf16 {
d: alloc_c(&mut alloc),
a: alloc_a_bf16(&mut alloc),
b: alloc_b_bf16(&mut alloc),
c: alloc_c(&mut alloc),
}));
k
}
#[test]
fn validate_accepts_mma_sync_bf16_dedicated_variant() {
let mut module = PtxModule::new("sm_89");
module.add_kernel(mma_sync_bf16_kernel());
assert!(module.validate().is_ok());
}
}