use std::collections::HashSet;
use crate::arch::SmVersion;
use crate::ir::{Instruction, MemorySpace, Operand, WmmaOp};
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub errors: Vec<ValidationError>,
pub warnings: Vec<String>,
}
impl ValidationResult {
#[must_use]
pub fn is_ok(&self) -> bool {
self.errors.is_empty()
}
#[must_use]
pub fn has_errors(&self) -> bool {
!self.errors.is_empty()
}
}
#[derive(Debug, Clone)]
pub enum ValidationError {
MissingVersionDirective,
MissingTargetDirective,
UndefinedRegister(String),
TypeMismatch {
expected: String,
found: String,
},
InvalidSharedMemSize {
declared: usize,
max_allowed: usize,
},
InvalidAddressSize(String),
SmIncompatibleInstruction {
instruction: String,
required_sm: String,
found_sm: String,
},
RegisterPressureExceeded {
count: usize,
max_allowed: usize,
},
Other(String),
}
impl std::fmt::Display for ValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingVersionDirective => write!(f, "missing .version directive"),
Self::MissingTargetDirective => write!(f, "missing .target directive"),
Self::UndefinedRegister(name) => write!(f, "undefined register: {name}"),
Self::TypeMismatch { expected, found } => {
write!(f, "type mismatch: expected {expected}, found {found}")
}
Self::InvalidSharedMemSize {
declared,
max_allowed,
} => {
write!(
f,
"shared memory {declared} bytes exceeds limit of {max_allowed} bytes"
)
}
Self::InvalidAddressSize(msg) => write!(f, "address size issue: {msg}"),
Self::SmIncompatibleInstruction {
instruction,
required_sm,
found_sm,
} => write!(
f,
"instruction '{instruction}' requires {required_sm} but target is {found_sm}"
),
Self::RegisterPressureExceeded { count, max_allowed } => write!(
f,
"register count {count} exceeds per-thread limit of {max_allowed}"
),
Self::Other(msg) => write!(f, "{msg}"),
}
}
}
#[must_use]
pub fn validate_ptx(ptx: &str) -> ValidationResult {
let mut errors = Vec::new();
let mut warnings = Vec::new();
if !ptx.contains(".version") {
errors.push(ValidationError::MissingVersionDirective);
}
if !ptx.contains(".target") {
errors.push(ValidationError::MissingTargetDirective);
}
let target_sm = extract_target_sm(ptx);
check_shared_memory(ptx, target_sm, &mut errors, &mut warnings);
check_register_declarations(ptx, &mut warnings);
check_register_pressure(ptx, &mut errors, &mut warnings);
if let Some(sm) = target_sm {
check_sm_compatibility(ptx, sm, &mut errors, &mut warnings);
}
check_structure(ptx, &mut warnings);
ValidationResult { errors, warnings }
}
#[must_use]
pub fn validate_ptx_for_target(ptx: &str, target: SmVersion) -> ValidationResult {
let mut errors = Vec::new();
let mut warnings = Vec::new();
if !ptx.contains(".version") {
errors.push(ValidationError::MissingVersionDirective);
}
if !ptx.contains(".target") {
errors.push(ValidationError::MissingTargetDirective);
}
check_shared_memory(ptx, Some(target), &mut errors, &mut warnings);
check_register_declarations(ptx, &mut warnings);
check_register_pressure(ptx, &mut errors, &mut warnings);
check_sm_compatibility(ptx, target, &mut errors, &mut warnings);
check_structure(ptx, &mut warnings);
ValidationResult { errors, warnings }
}
fn extract_target_sm(ptx: &str) -> Option<SmVersion> {
for line in ptx.lines() {
let trimmed = line.trim();
if trimmed.starts_with(".target") {
let parts: Vec<&str> = trimmed.split_whitespace().collect();
if parts.len() >= 2 {
return parse_sm_version(parts[1].trim_end_matches(';'));
}
}
}
None
}
fn parse_sm_version(s: &str) -> Option<SmVersion> {
match s {
"sm_75" => Some(SmVersion::Sm75),
"sm_80" => Some(SmVersion::Sm80),
"sm_86" => Some(SmVersion::Sm86),
"sm_89" => Some(SmVersion::Sm89),
"sm_90" => Some(SmVersion::Sm90),
"sm_90a" => Some(SmVersion::Sm90a),
"sm_100" => Some(SmVersion::Sm100),
"sm_120" => Some(SmVersion::Sm120),
_ => None,
}
}
fn check_shared_memory(
ptx: &str,
target: Option<SmVersion>,
errors: &mut Vec<ValidationError>,
warnings: &mut Vec<String>,
) {
let max_smem = target.map_or(usize::MAX, |sm| sm.max_shared_mem_per_block() as usize);
let mut total_smem: usize = 0;
for line in ptx.lines() {
let trimmed = line.trim();
if let Some(size) = extract_shared_mem_size(trimmed) {
total_smem = total_smem.saturating_add(size);
}
}
if total_smem > max_smem {
errors.push(ValidationError::InvalidSharedMemSize {
declared: total_smem,
max_allowed: max_smem,
});
} else if total_smem > 48 * 1024 && target.is_some() {
warnings.push(format!(
"shared memory usage ({total_smem} bytes) exceeds default limit (49152); \
may require opt-in via cuFuncSetAttribute"
));
}
}
fn extract_shared_mem_size(line: &str) -> Option<usize> {
if !line.contains(".shared") {
return None;
}
let bracket_start = line.find('[')?;
let bracket_end = line.find(']')?;
if bracket_end <= bracket_start {
return None;
}
let size_str = &line[bracket_start + 1..bracket_end];
size_str.trim().parse::<usize>().ok()
}
fn check_register_declarations(ptx: &str, warnings: &mut Vec<String>) {
let decl_count = ptx
.lines()
.filter(|line| line.trim().starts_with(".reg"))
.count();
let entry_count = ptx.lines().filter(|line| line.contains(".entry")).count();
if entry_count > 0 && decl_count == 0 {
warnings.push(
"kernel has no .reg declarations; all registers may be declared via raw PTX"
.to_string(),
);
}
}
fn check_structure(ptx: &str, warnings: &mut Vec<String>) {
let open_braces = ptx.chars().filter(|c| *c == '{').count();
let close_braces = ptx.chars().filter(|c| *c == '}').count();
if open_braces != close_braces {
warnings.push(format!(
"mismatched braces: {open_braces} opening vs {close_braces} closing"
));
}
}
struct SmRequirement {
pattern: &'static str,
min_sm: SmVersion,
name: &'static str,
}
const SM_REQUIREMENTS: &[SmRequirement] = &[
SmRequirement {
pattern: "cp.async",
min_sm: SmVersion::Sm80,
name: "cp.async",
},
SmRequirement {
pattern: "wgmma",
min_sm: SmVersion::Sm90,
name: "wgmma",
},
SmRequirement {
pattern: "mma.sync",
min_sm: SmVersion::Sm75,
name: "mma.sync (tensor core)",
},
SmRequirement {
pattern: "ldmatrix",
min_sm: SmVersion::Sm75,
name: "ldmatrix",
},
SmRequirement {
pattern: ".e4m3",
min_sm: SmVersion::Sm89,
name: "fp8 e4m3 type",
},
SmRequirement {
pattern: ".e5m2",
min_sm: SmVersion::Sm89,
name: "fp8 e5m2 type",
},
SmRequirement {
pattern: "tcgen05",
min_sm: SmVersion::Sm100,
name: "tcgen05",
},
];
fn check_sm_compatibility(
ptx: &str,
sm: SmVersion,
errors: &mut Vec<ValidationError>,
_warnings: &mut Vec<String>,
) {
let found_sm_str = sm.as_ptx_str();
for req in SM_REQUIREMENTS {
if ptx.contains(req.pattern) && sm < req.min_sm {
errors.push(ValidationError::SmIncompatibleInstruction {
instruction: req.name.to_string(),
required_sm: req.min_sm.as_ptx_str().to_string(),
found_sm: found_sm_str.to_string(),
});
}
}
}
const MAX_REGISTERS_PER_THREAD: usize = 255;
const REGISTER_PRESSURE_WARNING_THRESHOLD: usize = 200;
fn check_register_pressure(
ptx: &str,
errors: &mut Vec<ValidationError>,
warnings: &mut Vec<String>,
) {
use std::collections::HashSet;
let mut seen: HashSet<&str> = HashSet::new();
let bytes = ptx.as_bytes();
let len = bytes.len();
let mut i = 0;
while i < len {
if bytes[i] == b'%' {
let start = i;
i += 1;
while i < len && bytes[i].is_ascii_alphabetic() {
i += 1;
}
if i < len && bytes[i].is_ascii_digit() {
while i < len && bytes[i].is_ascii_digit() {
i += 1;
}
let token = &ptx[start..i];
let name_part = &token[1..]; let is_special = name_part.starts_with("tid")
|| name_part.starts_with("ntid")
|| name_part.starts_with("ctaid")
|| name_part.starts_with("nctaid")
|| name_part.starts_with("laneid")
|| name_part.starts_with("warpid")
|| name_part.starts_with("smid")
|| name_part.starts_with("pm")
|| name_part.starts_with("envreg")
|| name_part.starts_with("globaltimer")
|| name_part.starts_with("param_");
if !is_special {
seen.insert(token);
}
}
} else {
i += 1;
}
}
let count = seen.len();
if count > MAX_REGISTERS_PER_THREAD {
errors.push(ValidationError::RegisterPressureExceeded {
count,
max_allowed: MAX_REGISTERS_PER_THREAD,
});
} else if count > REGISTER_PRESSURE_WARNING_THRESHOLD {
warnings.push(format!(
"register count ({count}) is approaching the per-thread limit of \
{MAX_REGISTERS_PER_THREAD}; consider reducing register pressure"
));
}
}
#[derive(Debug, Clone)]
pub struct IrValidationResult {
pub errors: Vec<IrValidationError>,
pub warnings: Vec<IrValidationWarning>,
}
impl IrValidationResult {
#[must_use]
pub fn is_ok(&self) -> bool {
self.errors.is_empty()
}
#[must_use]
pub fn has_errors(&self) -> bool {
!self.errors.is_empty()
}
fn merge(&mut self, other: &Self) {
self.errors.extend(other.errors.iter().cloned());
self.warnings.extend(other.warnings.iter().cloned());
}
}
impl std::fmt::Display for IrValidationResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.errors.is_empty() && self.warnings.is_empty() {
return write!(f, "IR validation passed: no errors, no warnings");
}
if !self.errors.is_empty() {
writeln!(f, "Errors ({}):", self.errors.len())?;
for err in &self.errors {
writeln!(
f,
" [{:>3}] {}: {}",
err.instruction_index, err.kind, err.message
)?;
}
}
if !self.warnings.is_empty() {
writeln!(f, "Warnings ({}):", self.warnings.len())?;
for warn in &self.warnings {
writeln!(f, " [{:>3}] {}", warn.instruction_index, warn.message)?;
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct IrValidationError {
pub instruction_index: usize,
pub kind: IrErrorKind,
pub message: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IrErrorKind {
TypeMismatch,
UseBeforeDef,
InvalidMemorySpace,
InvalidOperand,
BarrierInDivergent,
RegisterLifetime,
}
impl std::fmt::Display for IrErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TypeMismatch => write!(f, "TypeMismatch"),
Self::UseBeforeDef => write!(f, "UseBeforeDef"),
Self::InvalidMemorySpace => write!(f, "InvalidMemorySpace"),
Self::InvalidOperand => write!(f, "InvalidOperand"),
Self::BarrierInDivergent => write!(f, "BarrierInDivergent"),
Self::RegisterLifetime => write!(f, "RegisterLifetime"),
}
}
}
#[derive(Debug, Clone)]
pub struct IrValidationWarning {
pub instruction_index: usize,
pub message: String,
}
fn push_operand_names(op: &Operand, names: &mut Vec<String>) {
if let Operand::Register(r) = op {
names.push(r.name.clone());
}
if let Operand::Address { base, .. } = op {
names.push(base.name.clone());
}
}
#[allow(clippy::too_many_lines)]
fn collect_src_register_names(inst: &Instruction) -> Vec<String> {
let mut names = Vec::new();
match inst {
Instruction::Add { a, b, .. }
| Instruction::Sub { a, b, .. }
| Instruction::Mul { a, b, .. }
| Instruction::Min { a, b, .. }
| Instruction::Max { a, b, .. }
| Instruction::Div { a, b, .. }
| Instruction::Rem { a, b, .. }
| Instruction::And { a, b, .. }
| Instruction::Or { a, b, .. }
| Instruction::Xor { a, b, .. }
| Instruction::SetP { a, b, .. } => {
push_operand_names(a, &mut names);
push_operand_names(b, &mut names);
}
Instruction::Mad { a, b, c, .. }
| Instruction::MadLo { a, b, c, .. }
| Instruction::MadHi { a, b, c, .. }
| Instruction::MadWide { a, b, c, .. }
| Instruction::Fma { a, b, c, .. }
| Instruction::Dp4a { a, b, c, .. }
| Instruction::Dp2a { a, b, c, .. } => {
push_operand_names(a, &mut names);
push_operand_names(b, &mut names);
push_operand_names(c, &mut names);
}
Instruction::Neg { src, .. }
| Instruction::Abs { src, .. }
| Instruction::Brev { src, .. }
| Instruction::Clz { src, .. }
| Instruction::Popc { src, .. }
| Instruction::Bfind { src, .. }
| Instruction::Rcp { src, .. }
| Instruction::Rsqrt { src, .. }
| Instruction::Sqrt { src, .. }
| Instruction::Ex2 { src, .. }
| Instruction::Lg2 { src, .. }
| Instruction::Sin { src, .. }
| Instruction::Cos { src, .. }
| Instruction::Cvt { src, .. }
| Instruction::Redux { src, .. } => {
push_operand_names(src, &mut names);
}
Instruction::Bfe {
src, start, len, ..
} => {
push_operand_names(src, &mut names);
push_operand_names(start, &mut names);
push_operand_names(len, &mut names);
}
Instruction::Bfi {
insert,
base,
start,
len,
..
} => {
push_operand_names(insert, &mut names);
push_operand_names(base, &mut names);
push_operand_names(start, &mut names);
push_operand_names(len, &mut names);
}
Instruction::Shl { src, amount, .. } | Instruction::Shr { src, amount, .. } => {
push_operand_names(src, &mut names);
push_operand_names(amount, &mut names);
}
Instruction::Load { addr, .. } | Instruction::MbarrierArrive { addr } => {
push_operand_names(addr, &mut names);
}
Instruction::Store { addr, src, .. } => {
push_operand_names(addr, &mut names);
names.push(src.name.clone());
}
Instruction::CpAsync {
dst_shared,
src_global,
..
} => {
push_operand_names(dst_shared, &mut names);
push_operand_names(src_global, &mut names);
}
Instruction::Branch { predicate, .. } => {
if let Some((r, _)) = predicate {
names.push(r.name.clone());
}
}
Instruction::Atom { addr, src, .. } | Instruction::Red { addr, src, .. } => {
push_operand_names(addr, &mut names);
push_operand_names(src, &mut names);
}
Instruction::AtomCas {
addr,
compare,
value,
..
} => {
push_operand_names(addr, &mut names);
push_operand_names(compare, &mut names);
push_operand_names(value, &mut names);
}
Instruction::Tex1d { coord, .. } | Instruction::SurfLoad { coord, .. } => {
push_operand_names(coord, &mut names);
}
Instruction::Tex2d {
coord_x, coord_y, ..
} => {
push_operand_names(coord_x, &mut names);
push_operand_names(coord_y, &mut names);
}
Instruction::Tex3d {
coord_x,
coord_y,
coord_z,
..
} => {
push_operand_names(coord_x, &mut names);
push_operand_names(coord_y, &mut names);
push_operand_names(coord_z, &mut names);
}
Instruction::SurfStore { coord, src, .. } => {
push_operand_names(coord, &mut names);
names.push(src.name.clone());
}
Instruction::Wmma {
fragments,
addr,
stride,
..
} => {
for frag in fragments {
names.push(frag.name.clone());
}
if let Some(a) = addr {
push_operand_names(a, &mut names);
}
if let Some(s) = stride {
push_operand_names(s, &mut names);
}
}
Instruction::Mma {
a_regs,
b_regs,
c_regs,
..
} => {
for r in a_regs.iter().chain(b_regs).chain(c_regs) {
names.push(r.name.clone());
}
}
Instruction::Wgmma { desc_a, desc_b, .. } => {
names.push(desc_a.name.clone());
names.push(desc_b.name.clone());
}
Instruction::TmaLoad {
desc,
coords,
barrier,
dst_shared,
..
} => {
names.push(desc.name.clone());
for c in coords {
names.push(c.name.clone());
}
names.push(barrier.name.clone());
push_operand_names(dst_shared, &mut names);
}
Instruction::Stmatrix { dst_addr, src, .. } => {
push_operand_names(dst_addr, &mut names);
names.push(src.name.clone());
}
Instruction::MbarrierInit { addr, count } => {
push_operand_names(addr, &mut names);
push_operand_names(count, &mut names);
}
Instruction::MbarrierWait { addr, phase } => {
push_operand_names(addr, &mut names);
push_operand_names(phase, &mut names);
}
Instruction::MovSpecial { .. }
| Instruction::LoadParam { .. }
| Instruction::Label(_)
| Instruction::Return
| Instruction::Comment(_)
| Instruction::Raw(_)
| Instruction::Pragma(_)
| Instruction::BarSync { .. }
| Instruction::BarArrive { .. }
| Instruction::FenceAcqRel { .. }
| Instruction::FenceProxy { .. }
| Instruction::CpAsyncCommit
| Instruction::CpAsyncWait { .. }
| Instruction::ElectSync { .. }
| Instruction::Setmaxnreg { .. }
| Instruction::Griddepcontrol { .. }
| Instruction::BarrierCluster
| Instruction::FenceCluster => {}
Instruction::Tcgen05Mma { a_desc, b_desc } => {
names.push(a_desc.name.clone());
names.push(b_desc.name.clone());
}
Instruction::CpAsyncBulk {
dst_smem,
src_gmem,
desc,
} => {
names.push(dst_smem.name.clone());
names.push(src_gmem.name.clone());
names.push(desc.name.clone());
}
Instruction::Ldmatrix { src_addr, .. } => {
push_operand_names(src_addr, &mut names);
}
}
names
}
fn dst_register_name(inst: &Instruction) -> Option<String> {
match inst {
Instruction::Add { dst, .. }
| Instruction::Sub { dst, .. }
| Instruction::Mul { dst, .. }
| Instruction::Min { dst, .. }
| Instruction::Max { dst, .. }
| Instruction::Div { dst, .. }
| Instruction::Rem { dst, .. }
| Instruction::And { dst, .. }
| Instruction::Or { dst, .. }
| Instruction::Xor { dst, .. }
| Instruction::SetP { dst, .. }
| Instruction::Mad { dst, .. }
| Instruction::MadLo { dst, .. }
| Instruction::MadHi { dst, .. }
| Instruction::MadWide { dst, .. }
| Instruction::Fma { dst, .. }
| Instruction::Neg { dst, .. }
| Instruction::Abs { dst, .. }
| Instruction::Brev { dst, .. }
| Instruction::Clz { dst, .. }
| Instruction::Popc { dst, .. }
| Instruction::Bfind { dst, .. }
| Instruction::Bfe { dst, .. }
| Instruction::Bfi { dst, .. }
| Instruction::Rcp { dst, .. }
| Instruction::Rsqrt { dst, .. }
| Instruction::Sqrt { dst, .. }
| Instruction::Ex2 { dst, .. }
| Instruction::Lg2 { dst, .. }
| Instruction::Sin { dst, .. }
| Instruction::Cos { dst, .. }
| Instruction::Shl { dst, .. }
| Instruction::Shr { dst, .. }
| Instruction::Load { dst, .. }
| Instruction::Cvt { dst, .. }
| Instruction::Atom { dst, .. }
| Instruction::AtomCas { dst, .. }
| Instruction::MovSpecial { dst, .. }
| Instruction::LoadParam { dst, .. }
| Instruction::Dp4a { dst, .. }
| Instruction::Dp2a { dst, .. }
| Instruction::Tex1d { dst, .. }
| Instruction::Tex2d { dst, .. }
| Instruction::Tex3d { dst, .. }
| Instruction::SurfLoad { dst, .. }
| Instruction::Redux { dst, .. }
| Instruction::ElectSync { dst, .. } => Some(dst.name.clone()),
Instruction::Mma { d_regs, .. } => d_regs.first().map(|r| r.name.clone()),
Instruction::Wgmma { d_regs, .. } => d_regs.first().map(|r| r.name.clone()),
_ => None,
}
}
fn operand_type_compatible(op: &Operand, expected_ty: crate::ir::PtxType) -> bool {
match op {
Operand::Register(r) => r.ty == expected_ty,
Operand::Immediate(_) | Operand::Symbol(_) | Operand::Address { .. } => true,
}
}
#[must_use]
pub fn validate_ir_instructions(instructions: &[Instruction]) -> IrValidationResult {
let mut result = IrValidationResult {
errors: Vec::new(),
warnings: Vec::new(),
};
let lifetime_result = validate_register_lifetimes(instructions);
result.merge(&lifetime_result);
let consistency_result = validate_memory_consistency(instructions);
result.merge(&consistency_result);
for (idx, inst) in instructions.iter().enumerate() {
validate_type_safety(inst, idx, &mut result);
validate_memory_spaces(inst, idx, &mut result);
validate_tensor_core_operands(inst, idx, &mut result);
}
result
}
#[must_use]
pub fn validate_register_lifetimes(instructions: &[Instruction]) -> IrValidationResult {
let mut result = IrValidationResult {
errors: Vec::new(),
warnings: Vec::new(),
};
let mut defined: HashSet<String> = HashSet::new();
for (idx, inst) in instructions.iter().enumerate() {
let src_names = collect_src_register_names(inst);
for name in &src_names {
if !defined.contains(name) {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::UseBeforeDef,
message: format!("register {name} used before definition"),
});
}
}
if let Some(dst_name) = dst_register_name(inst) {
defined.insert(dst_name);
}
match inst {
Instruction::Mma { d_regs, .. } | Instruction::Wgmma { d_regs, .. } => {
for r in d_regs {
defined.insert(r.name.clone());
}
}
Instruction::Wmma { op, fragments, .. } => {
if matches!(op, WmmaOp::LoadA | WmmaOp::LoadB) {
for frag in fragments {
defined.insert(frag.name.clone());
}
}
}
_ => {}
}
}
result
}
#[must_use]
pub fn validate_memory_consistency(instructions: &[Instruction]) -> IrValidationResult {
let mut result = IrValidationResult {
errors: Vec::new(),
warnings: Vec::new(),
};
check_barrier_divergence(instructions, &mut result);
check_shared_memory_races(instructions, &mut result);
result
}
fn validate_type_safety(inst: &Instruction, idx: usize, result: &mut IrValidationResult) {
match inst {
Instruction::Add { ty, dst, a, b }
| Instruction::Sub { ty, dst, a, b }
| Instruction::Min { ty, dst, a, b }
| Instruction::Max { ty, dst, a, b } => {
if dst.ty != *ty {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::TypeMismatch,
message: format!(
"dst register {} has type {:?} but instruction type is {:?}",
dst.name, dst.ty, ty
),
});
}
if !operand_type_compatible(a, *ty) {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::TypeMismatch,
message: format!("operand a type mismatch with instruction type {ty:?}"),
});
}
if !operand_type_compatible(b, *ty) {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::TypeMismatch,
message: format!("operand b type mismatch with instruction type {ty:?}"),
});
}
}
Instruction::Mul { ty, dst, a, b, .. } => {
if !operand_type_compatible(a, *ty) {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::TypeMismatch,
message: format!("mul operand a type mismatch with instruction type {ty:?}"),
});
}
if !operand_type_compatible(b, *ty) {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::TypeMismatch,
message: format!("mul operand b type mismatch with instruction type {ty:?}"),
});
}
if dst.ty != *ty {
result.warnings.push(IrValidationWarning {
instruction_index: idx,
message: format!(
"mul dst register {} type {:?} differs from instruction type {:?}",
dst.name, dst.ty, ty
),
});
}
}
_ => {}
}
}
fn validate_memory_spaces(inst: &Instruction, idx: usize, result: &mut IrValidationResult) {
if let Instruction::CpAsync {
dst_shared: Operand::Register(r),
..
} = inst
{
result.warnings.push(IrValidationWarning {
instruction_index: idx,
message: format!(
"cp.async dst_shared uses register {} directly; expected a shared memory address",
r.name
),
});
}
match inst {
Instruction::Load {
space,
addr: Operand::Immediate(_),
..
} if *space == MemorySpace::Shared => {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::InvalidMemorySpace,
message: "shared memory load with immediate address is invalid".to_string(),
});
}
Instruction::Store {
space,
addr: Operand::Immediate(_),
..
} if *space == MemorySpace::Shared => {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::InvalidMemorySpace,
message: "shared memory store with immediate address is invalid".to_string(),
});
}
_ => {}
}
}
fn validate_tensor_core_operands(inst: &Instruction, idx: usize, result: &mut IrValidationResult) {
match inst {
Instruction::Wmma { addr, stride, .. } => {
if let Some(Operand::Immediate(_)) = addr.as_ref() {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::InvalidOperand,
message: "wmma address operand must not be an immediate value".to_string(),
});
}
if let Some(Operand::Immediate(_)) = stride.as_ref() {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::InvalidOperand,
message: "wmma stride operand must not be an immediate value".to_string(),
});
}
}
Instruction::Mma {
a_regs,
b_regs,
c_regs,
d_regs,
..
}
if (a_regs.is_empty() || b_regs.is_empty() || c_regs.is_empty() || d_regs.is_empty()) => {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::InvalidOperand,
message: "mma instruction requires non-empty register fragments".to_string(),
});
}
Instruction::Wgmma { d_regs, .. }
if d_regs.is_empty() => {
result.errors.push(IrValidationError {
instruction_index: idx,
kind: IrErrorKind::InvalidOperand,
message: "wgmma instruction requires non-empty destination registers".to_string(),
});
}
_ => {}
}
}
fn check_barrier_divergence(instructions: &[Instruction], result: &mut IrValidationResult) {
let all_labels: HashSet<&str> = instructions
.iter()
.filter_map(|inst| {
if let Instruction::Label(name) = inst {
Some(name.as_str())
} else {
None
}
})
.collect();
let mut in_conditional_region = false;
let mut conditional_branch_idx = 0;
for (idx, inst) in instructions.iter().enumerate() {
match inst {
Instruction::Branch {
predicate: Some(_),
target,
..
}
if all_labels.contains(target.as_str()) => {
in_conditional_region = true;
conditional_branch_idx = idx;
}
Instruction::Label(_) => {
in_conditional_region = false;
}
Instruction::BarSync { .. }
if in_conditional_region => {
result.warnings.push(IrValidationWarning {
instruction_index: idx,
message: format!(
"bar.sync inside potentially divergent control flow \
(conditional branch at instruction {conditional_branch_idx}); \
this may cause deadlock if not all threads reach the barrier"
),
});
}
_ => {}
}
}
}
fn check_shared_memory_races(instructions: &[Instruction], result: &mut IrValidationResult) {
let mut pending_shared_store: Option<usize> = None;
for (idx, inst) in instructions.iter().enumerate() {
match inst {
Instruction::Store {
space: MemorySpace::Shared,
..
} => {
pending_shared_store = Some(idx);
}
Instruction::BarSync { .. } => {
pending_shared_store = None;
}
Instruction::Load {
space: MemorySpace::Shared,
..
} => {
if let Some(store_idx) = pending_shared_store {
result.warnings.push(IrValidationWarning {
instruction_index: idx,
message: format!(
"shared memory load without bar.sync after shared memory \
store at instruction {store_idx}; potential race condition"
),
});
}
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{
CacheQualifier, ImmValue, Instruction, MemorySpace, Operand, PtxType, Register, SpecialReg,
VectorWidth, WmmaLayout, WmmaOp, WmmaShape,
};
#[test]
fn valid_minimal_ptx() {
let ptx = ".version 8.5\n.target sm_90a\n.address_size 64\n";
let result = validate_ptx(ptx);
assert!(result.is_ok());
assert!(result.errors.is_empty());
}
#[test]
fn missing_version() {
let ptx = ".target sm_80\n.address_size 64\n";
let result = validate_ptx(ptx);
assert!(result.has_errors());
assert!(
result
.errors
.iter()
.any(|e| matches!(e, ValidationError::MissingVersionDirective))
);
}
#[test]
fn missing_target() {
let ptx = ".version 8.5\n.address_size 64\n";
let result = validate_ptx(ptx);
assert!(result.has_errors());
assert!(
result
.errors
.iter()
.any(|e| matches!(e, ValidationError::MissingTargetDirective))
);
}
#[test]
fn shared_memory_within_limits() {
let ptx = ".version 8.5\n.target sm_80\n.address_size 64\n\
.shared .align 4 .b8 smem[4096];\n";
let result = validate_ptx(ptx);
assert!(result.is_ok());
}
#[test]
fn shared_memory_exceeds_limits() {
let ptx = ".version 6.4\n.target sm_75\n.address_size 64\n\
.shared .align 4 .b8 smem[100000];\n";
let result = validate_ptx(ptx);
assert!(result.has_errors());
assert!(
result
.errors
.iter()
.any(|e| matches!(e, ValidationError::InvalidSharedMemSize { .. }))
);
}
#[test]
fn validate_for_specific_target() {
let ptx = ".version 8.5\n.target sm_80\n.address_size 64\n\
.shared .align 4 .b8 smem[200000];\n";
let result = validate_ptx_for_target(ptx, SmVersion::Sm80);
assert!(result.has_errors());
}
#[test]
fn extract_shared_mem_size_fn() {
assert_eq!(
extract_shared_mem_size(" .shared .align 4 .b8 smem[4096];"),
Some(4096)
);
assert_eq!(
extract_shared_mem_size(" .shared .align 16 .b8 tile[65536];"),
Some(65536)
);
assert_eq!(extract_shared_mem_size(" mov.u32 %r0, 0;"), None);
}
#[test]
fn parse_sm_version_fn() {
assert_eq!(parse_sm_version("sm_80"), Some(SmVersion::Sm80));
assert_eq!(parse_sm_version("sm_90a"), Some(SmVersion::Sm90a));
assert_eq!(parse_sm_version("sm_100"), Some(SmVersion::Sm100));
assert_eq!(parse_sm_version("sm_999"), None);
}
#[test]
fn mismatched_braces_warning() {
let ptx = ".version 8.5\n.target sm_80\n.address_size 64\n{\n";
let result = validate_ptx(ptx);
assert!(!result.warnings.is_empty());
}
#[test]
fn validation_error_display() {
let err = ValidationError::MissingVersionDirective;
assert_eq!(format!("{err}"), "missing .version directive");
let err = ValidationError::InvalidSharedMemSize {
declared: 100_000,
max_allowed: 65536,
};
assert!(format!("{err}").contains("100000"));
}
fn reg(name: &str, ty: PtxType) -> Register {
Register {
name: name.to_string(),
ty,
}
}
fn reg_op(name: &str, ty: PtxType) -> Operand {
Operand::Register(reg(name, ty))
}
#[test]
fn ir_type_compatible_arithmetic_passes() {
let instructions = vec![
Instruction::LoadParam {
ty: PtxType::F32,
dst: reg("%f0", PtxType::F32),
param_name: "a".to_string(),
},
Instruction::LoadParam {
ty: PtxType::F32,
dst: reg("%f1", PtxType::F32),
param_name: "b".to_string(),
},
Instruction::Add {
ty: PtxType::F32,
dst: reg("%f2", PtxType::F32),
a: reg_op("%f0", PtxType::F32),
b: reg_op("%f1", PtxType::F32),
},
];
let result = validate_ir_instructions(&instructions);
assert!(
result.errors.is_empty(),
"expected no errors, got: {:?}",
result.errors
);
}
#[test]
fn ir_type_mismatched_arithmetic_fails() {
let instructions = vec![
Instruction::LoadParam {
ty: PtxType::F32,
dst: reg("%f0", PtxType::F32),
param_name: "a".to_string(),
},
Instruction::LoadParam {
ty: PtxType::U32,
dst: reg("%r0", PtxType::U32),
param_name: "b".to_string(),
},
Instruction::Add {
ty: PtxType::F32,
dst: reg("%f1", PtxType::F32),
a: reg_op("%f0", PtxType::F32),
b: reg_op("%r0", PtxType::U32), },
];
let result = validate_ir_instructions(&instructions);
assert!(result.has_errors());
assert!(
result
.errors
.iter()
.any(|e| e.kind == IrErrorKind::TypeMismatch)
);
}
#[test]
fn ir_use_before_def_detection() {
let instructions = vec![Instruction::Add {
ty: PtxType::F32,
dst: reg("%f2", PtxType::F32),
a: reg_op("%f0", PtxType::F32), b: reg_op("%f1", PtxType::F32), }];
let result = validate_ir_instructions(&instructions);
assert!(result.has_errors());
let ubd_count = result
.errors
.iter()
.filter(|e| e.kind == IrErrorKind::UseBeforeDef)
.count();
assert!(ubd_count >= 2, "expected at least 2 use-before-def errors");
}
#[test]
fn ir_load_param_counted_as_definition() {
let instructions = vec![
Instruction::LoadParam {
ty: PtxType::U64,
dst: reg("%rd0", PtxType::U64),
param_name: "ptr".to_string(),
},
Instruction::Load {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::F32,
dst: reg("%f0", PtxType::F32),
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: None,
},
},
];
let result = validate_register_lifetimes(&instructions);
assert!(
result.errors.is_empty(),
"LoadParam should count as definition: {:?}",
result.errors
);
}
#[test]
fn ir_mov_special_counted_as_definition() {
let instructions = vec![
Instruction::MovSpecial {
dst: reg("%r0", PtxType::U32),
special: SpecialReg::TidX,
},
Instruction::Add {
ty: PtxType::U32,
dst: reg("%r1", PtxType::U32),
a: reg_op("%r0", PtxType::U32),
b: Operand::Immediate(ImmValue::U32(1)),
},
];
let result = validate_register_lifetimes(&instructions);
assert!(
result.errors.is_empty(),
"MovSpecial should count as definition: {:?}",
result.errors
);
}
#[test]
fn ir_shared_store_without_barrier_warns() {
let addr_reg = reg("%rd0", PtxType::U64);
let instructions = vec![
Instruction::LoadParam {
ty: PtxType::U64,
dst: addr_reg.clone(),
param_name: "addr".to_string(),
},
Instruction::LoadParam {
ty: PtxType::F32,
dst: reg("%f0", PtxType::F32),
param_name: "val".to_string(),
},
Instruction::Store {
space: MemorySpace::Shared,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::F32,
addr: Operand::Address {
base: addr_reg.clone(),
offset: None,
},
src: reg("%f0", PtxType::F32),
},
Instruction::Load {
space: MemorySpace::Shared,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::F32,
dst: reg("%f1", PtxType::F32),
addr: Operand::Address {
base: addr_reg,
offset: Some(4),
},
},
];
let result = validate_memory_consistency(&instructions);
assert!(
!result.warnings.is_empty(),
"expected race condition warning"
);
assert!(
result.warnings[0].message.contains("race condition"),
"warning should mention race condition"
);
}
#[test]
fn ir_barrier_after_shared_store_no_warning() {
let addr_reg = reg("%rd0", PtxType::U64);
let instructions = vec![
Instruction::Store {
space: MemorySpace::Shared,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::F32,
addr: Operand::Address {
base: addr_reg.clone(),
offset: None,
},
src: reg("%f0", PtxType::F32),
},
Instruction::BarSync { id: 0 },
Instruction::Load {
space: MemorySpace::Shared,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::F32,
dst: reg("%f1", PtxType::F32),
addr: Operand::Address {
base: addr_reg,
offset: Some(4),
},
},
];
let result = validate_memory_consistency(&instructions);
assert!(
result.warnings.is_empty(),
"expected no warnings when barrier separates store/load"
);
}
#[test]
fn ir_empty_instruction_list_no_errors() {
let result = validate_ir_instructions(&[]);
assert!(result.is_ok());
assert!(result.warnings.is_empty());
}
#[test]
fn ir_complex_sequence_multiple_issues() {
let instructions = vec![
Instruction::Add {
ty: PtxType::F32,
dst: reg("%f1", PtxType::F32),
a: reg_op("%f0", PtxType::F32),
b: Operand::Immediate(ImmValue::F32(1.0)),
},
Instruction::Sub {
ty: PtxType::F32,
dst: reg("%r0", PtxType::U32),
a: reg_op("%f1", PtxType::F32),
b: Operand::Immediate(ImmValue::F32(2.0)),
},
];
let result = validate_ir_instructions(&instructions);
assert!(result.has_errors());
let has_ubd = result
.errors
.iter()
.any(|e| e.kind == IrErrorKind::UseBeforeDef);
let has_type_mismatch = result
.errors
.iter()
.any(|e| e.kind == IrErrorKind::TypeMismatch);
assert!(has_ubd, "expected use-before-def error");
assert!(has_type_mismatch, "expected type mismatch error");
}
#[test]
fn ir_validate_register_lifetimes_standalone() {
let instructions = vec![
Instruction::LoadParam {
ty: PtxType::F32,
dst: reg("%f0", PtxType::F32),
param_name: "x".to_string(),
},
Instruction::Neg {
ty: PtxType::F32,
dst: reg("%f1", PtxType::F32),
src: reg_op("%f0", PtxType::F32),
},
Instruction::Add {
ty: PtxType::F32,
dst: reg("%f2", PtxType::F32),
a: reg_op("%f1", PtxType::F32),
b: reg_op("%f99", PtxType::F32),
},
];
let result = validate_register_lifetimes(&instructions);
assert!(result.has_errors());
assert_eq!(result.errors.len(), 1);
assert!(result.errors[0].message.contains("%f99"));
}
#[test]
fn ir_validate_memory_consistency_standalone() {
let instructions = vec![
Instruction::LoadParam {
ty: PtxType::U32,
dst: reg("%p0", PtxType::Pred),
param_name: "pred".to_string(),
},
Instruction::Branch {
target: "skip".to_string(),
predicate: Some((reg("%p0", PtxType::Pred), false)),
},
Instruction::BarSync { id: 0 },
Instruction::Label("skip".to_string()),
];
let result = validate_memory_consistency(&instructions);
assert!(!result.warnings.is_empty(), "expected divergence warning");
assert!(result.warnings[0].message.contains("divergent"));
}
#[test]
fn ir_validation_result_display() {
let result = IrValidationResult {
errors: vec![IrValidationError {
instruction_index: 3,
kind: IrErrorKind::TypeMismatch,
message: "dst type does not match".to_string(),
}],
warnings: vec![IrValidationWarning {
instruction_index: 7,
message: "possible race".to_string(),
}],
};
let display = format!("{result}");
assert!(display.contains("Errors (1)"));
assert!(display.contains("TypeMismatch"));
assert!(display.contains("Warnings (1)"));
assert!(display.contains("possible race"));
let ok_result = IrValidationResult {
errors: Vec::new(),
warnings: Vec::new(),
};
let ok_display = format!("{ok_result}");
assert!(ok_display.contains("passed"));
}
#[test]
fn ir_wmma_with_immediate_operand_flagged() {
let instructions = vec![Instruction::Wmma {
op: WmmaOp::LoadA,
shape: WmmaShape::M16N16K16,
layout: WmmaLayout::RowMajor,
ty: PtxType::F16,
fragments: vec![reg("%f0", PtxType::F16)],
addr: Some(Operand::Immediate(ImmValue::U32(0))), stride: Some(Operand::Immediate(ImmValue::U32(16))), }];
let result = validate_ir_instructions(&instructions);
let invalid_operand_errors: Vec<_> = result
.errors
.iter()
.filter(|e| e.kind == IrErrorKind::InvalidOperand)
.collect();
assert!(
invalid_operand_errors.len() >= 2,
"expected at least 2 InvalidOperand errors for wmma immediates, got {}",
invalid_operand_errors.len()
);
}
#[test]
fn ir_mixed_valid_and_invalid_instructions() {
let instructions = vec![
Instruction::LoadParam {
ty: PtxType::F32,
dst: reg("%f0", PtxType::F32),
param_name: "x".to_string(),
},
Instruction::MovSpecial {
dst: reg("%r0", PtxType::U32),
special: SpecialReg::TidX,
},
Instruction::Add {
ty: PtxType::F32,
dst: reg("%f1", PtxType::F32),
a: reg_op("%f0", PtxType::F32),
b: Operand::Immediate(ImmValue::F32(1.0)),
},
Instruction::Sub {
ty: PtxType::F32,
dst: reg("%bad", PtxType::U32), a: reg_op("%f1", PtxType::F32),
b: Operand::Immediate(ImmValue::F32(0.5)),
},
Instruction::Comment("test".to_string()),
Instruction::Return,
];
let result = validate_ir_instructions(&instructions);
let type_errors: Vec<_> = result
.errors
.iter()
.filter(|e| e.kind == IrErrorKind::TypeMismatch)
.collect();
assert_eq!(
type_errors.len(),
1,
"expected exactly 1 type mismatch, got {}: {:?}",
type_errors.len(),
type_errors
);
let ubd_errors: Vec<_> = result
.errors
.iter()
.filter(|e| e.kind == IrErrorKind::UseBeforeDef)
.collect();
assert!(
ubd_errors.is_empty(),
"expected no use-before-def errors: {ubd_errors:?}",
);
}
}
#[cfg(test)]
#[path = "validator_tests.rs"]
mod sm_tests;