use std::fmt;
use super::instruction::PtxInstruction;
use super::kernel::PtxKernel;
use crate::instr::MemoryOp;
#[derive(Debug, Clone)]
pub struct PtxModule {
pub version: String,
pub target: String,
pub address_size: u32,
pub kernels: Vec<PtxKernel>,
}
impl PtxModule {
pub fn new(target: &str) -> Self {
Self {
version: "8.7".to_string(),
target: target.to_string(),
address_size: 64,
kernels: Vec::new(),
}
}
pub fn add_kernel(&mut self, kernel: PtxKernel) {
self.kernels.push(kernel);
}
fn parse_sm_target(&self) -> Option<u32> {
self.target.strip_prefix("sm_").and_then(|s| s.parse().ok())
}
pub fn validate(&self) -> Result<(), ValidationError> {
let Some(target_sm) = self.parse_sm_target() else {
return Ok(());
};
for kernel in &self.kernels {
for instr in &kernel.body {
if let Some((required, feature)) = instruction_sm_requirement(instr)
&& target_sm < required
{
return Err(ValidationError::SmTooLow {
required,
actual: target_sm,
feature,
});
}
}
}
Ok(())
}
}
fn instruction_sm_requirement(instr: &PtxInstruction) -> Option<(u32, String)> {
match instr {
PtxInstruction::TensorCore(op) => Some((op.min_sm(), op.feature_label())),
PtxInstruction::Memory(
MemoryOp::CpAsyncCaSharedGlobal { .. }
| MemoryOp::CpAsyncCommitGroup
| MemoryOp::CpAsyncWaitGroup { .. },
) => Some((80, "cp.async".to_string())),
_ => None,
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ValidationError {
SmTooLow {
required: u32,
actual: u32,
feature: String,
},
}
impl fmt::Display for ValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::SmTooLow {
required,
actual,
feature,
} => {
write!(
f,
"{feature} requires sm_{required}+, target is sm_{actual}"
)
}
}
}
}
impl std::error::Error for ValidationError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::fragment::{alloc_a, alloc_b, alloc_c};
use crate::instr::{MemoryOp, MmaShape, TensorCoreOp};
use crate::ir::{PtxInstruction, PtxKernel, Register, RegisterAllocator};
use crate::types::{PtxType, RegKind};
fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
Register {
kind,
index,
ptx_type,
}
}
fn tc_kernel() -> PtxKernel {
let mut alloc = RegisterAllocator::new();
let mut k = PtxKernel::new("has_mma");
k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
d: alloc_c(&mut alloc),
a: alloc_a(&mut alloc),
b: alloc_b(&mut alloc),
c: alloc_c(&mut alloc),
shape: MmaShape::M16N8K16,
d_ty: PtxType::F32,
a_ty: PtxType::F16,
b_ty: PtxType::F16,
c_ty: PtxType::F32,
}));
k
}
#[test]
fn validate_rejects_mma_on_sm_70() {
let mut module = PtxModule::new("sm_70");
module.add_kernel(tc_kernel());
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::SmTooLow {
required: 80,
actual: 70,
feature: "mma.sync.m16n8k16".to_string(),
}
);
assert_eq!(
err.to_string(),
"mma.sync.m16n8k16 requires sm_80+, target is sm_70"
);
}
#[test]
fn validate_accepts_mma_on_sm_80() {
let mut module = PtxModule::new("sm_80");
module.add_kernel(tc_kernel());
assert!(module.validate().is_ok());
}
#[test]
fn validate_accepts_mma_on_sm_89() {
let mut module = PtxModule::new("sm_89");
module.add_kernel(tc_kernel());
assert!(module.validate().is_ok());
}
fn tc_int8_kernel() -> PtxKernel {
use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
let mut alloc = RegisterAllocator::new();
let mut k = PtxKernel::new("has_mma_int8");
k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSyncInt8 {
d: alloc_c_M16N8K32(&mut alloc),
a: alloc_a_M16N8K32(&mut alloc),
b: alloc_b_M16N8K32(&mut alloc),
c: alloc_c_M16N8K32(&mut alloc),
}));
k
}
#[test]
fn validate_rejects_mma_int8_on_sm_70() {
let mut module = PtxModule::new("sm_70");
module.add_kernel(tc_int8_kernel());
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::SmTooLow {
required: 80,
actual: 70,
feature: "mma.sync.m16n8k32.s8.s8.s32".to_string(),
}
);
assert_eq!(
err.to_string(),
"mma.sync.m16n8k32.s8.s8.s32 requires sm_80+, target is sm_70"
);
}
#[test]
fn validate_accepts_mma_int8_on_sm_80() {
let mut module = PtxModule::new("sm_80");
module.add_kernel(tc_int8_kernel());
assert!(module.validate().is_ok());
}
#[test]
fn validate_accepts_mma_int8_on_sm_89() {
let mut module = PtxModule::new("sm_89");
module.add_kernel(tc_int8_kernel());
assert!(module.validate().is_ok());
}
#[test]
fn validate_rejects_cp_async_on_sm_75() {
let mut module = PtxModule::new("sm_75");
let mut k = PtxKernel::new("has_cp_async");
k.push(PtxInstruction::Memory(MemoryOp::new_cp_async_ca(
reg(RegKind::R, 0, PtxType::U32),
reg(RegKind::Rd, 0, PtxType::U64),
16,
)));
module.add_kernel(k);
let err = module.validate().unwrap_err();
assert_eq!(
err,
ValidationError::SmTooLow {
required: 80,
actual: 75,
feature: "cp.async".to_string(),
}
);
}
#[test]
fn validate_accepts_scalar_kernel_on_sm_70() {
let mut module = PtxModule::new("sm_70");
let k = PtxKernel::new("scalar_only");
module.add_kernel(k);
assert!(module.validate().is_ok());
}
#[test]
fn validate_skips_unparseable_target() {
let mut module = PtxModule::new("compute_90a");
module.add_kernel(tc_kernel());
assert!(module.validate().is_ok());
}
#[test]
fn parse_sm_target() {
let m = PtxModule::new("sm_89");
assert_eq!(m.parse_sm_target(), Some(89));
let m2 = PtxModule::new("sm_80");
assert_eq!(m2.parse_sm_target(), Some(80));
let m3 = PtxModule::new("compute_90a");
assert_eq!(m3.parse_sm_target(), None);
}
}