#[derive(Debug, Clone)]
pub struct ParityResult {
pub is_compatible: bool,
pub violations: Vec<ParityViolation>,
pub single_name: String,
pub batched_name: String,
}
#[derive(Debug, Clone)]
pub struct ParityViolation {
pub kind: ParityViolationKind,
pub message: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ParityViolationKind {
ParameterCountMismatch,
SharedMemoryMismatch,
MissingBatchDispatch,
SharedMemoryAddressingU64,
LoopStructureMismatch,
RegisterTypeMismatch,
}
impl std::fmt::Display for ParityViolationKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ParameterCountMismatch => write!(f, "PARAM_COUNT"),
Self::SharedMemoryMismatch => write!(f, "SHARED_MEM_SIZE"),
Self::MissingBatchDispatch => write!(f, "MISSING_CTAID_Y"),
Self::SharedMemoryAddressingU64 => write!(f, "SHARED_MEM_U64"),
Self::LoopStructureMismatch => write!(f, "LOOP_STRUCTURE"),
Self::RegisterTypeMismatch => write!(f, "REG_TYPE"),
}
}
}
fn count_params(ptx: &str) -> usize {
ptx.lines()
.filter(|line| {
let trimmed = line.trim();
trimmed.starts_with(".param")
})
.count()
}
fn extract_shared_memory_bytes(ptx: &str) -> Option<u32> {
ptx.lines()
.map(str::trim)
.filter(|line| line.contains(".shared") && line.contains("smem["))
.find_map(parse_smem_size)
}
fn parse_smem_size(line: &str) -> Option<u32> {
let after = &line[line.find("smem[")? + 5..];
let end = after.find(']')?;
after[..end].parse().ok()
}
fn extract_loop_labels(ptx: &str) -> Vec<String> {
let mut labels = Vec::new();
for line in ptx.lines() {
let trimmed = line.trim();
if trimmed.ends_with(':') && !trimmed.starts_with("//") {
let label = trimmed.trim_end_matches(':');
if label.contains("loop") {
labels.push(label.to_string());
}
}
}
labels
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BatchDispatchStrategy {
GridY,
RegisterUnroll,
}
fn has_grid_y_dispatch(ptx: &str) -> bool {
ptx.contains("%ctaid.y")
}
fn has_register_unroll_dispatch(ptx: &str) -> bool {
ptx.contains("m_dim")
}
fn has_batch_dispatch(ptx: &str) -> bool {
has_grid_y_dispatch(ptx) || has_register_unroll_dispatch(ptx)
}
fn has_u64_shared_memory_addressing(ptx: &str) -> bool {
for line in ptx.lines() {
let trimmed = line.trim();
if (trimmed.contains("st.shared") || trimmed.contains("ld.shared"))
&& trimmed.contains("[%rd")
{
return true;
}
}
false
}
pub fn validate_parity(
single_ptx: &str,
batched_ptx: &str,
single_name: &str,
batched_name: &str,
) -> ParityResult {
let mut violations = Vec::new();
let single_params = count_params(single_ptx);
let batched_params = count_params(batched_ptx);
if single_params != batched_params {
violations.push(ParityViolation {
kind: ParityViolationKind::ParameterCountMismatch,
message: format!(
"Single kernel '{}' has {} params, batched '{}' has {} params",
single_name, single_params, batched_name, batched_params
),
});
}
let single_smem = extract_shared_memory_bytes(single_ptx);
let batched_smem = extract_shared_memory_bytes(batched_ptx);
if single_smem != batched_smem {
violations.push(ParityViolation {
kind: ParityViolationKind::SharedMemoryMismatch,
message: format!(
"Shared memory mismatch: single={:?} bytes, batched={:?} bytes",
single_smem, batched_smem
),
});
}
if !has_batch_dispatch(batched_ptx) {
violations.push(ParityViolation {
kind: ParityViolationKind::MissingBatchDispatch,
message: format!(
"Batched kernel '{}' does not use %ctaid.y for row dispatch",
batched_name
),
});
}
if has_u64_shared_memory_addressing(batched_ptx) {
violations.push(ParityViolation {
kind: ParityViolationKind::SharedMemoryAddressingU64,
message: format!(
"Batched kernel '{}' uses u64 registers (%rd) for shared memory addressing; \
use u32 (%r) for portability",
batched_name
),
});
}
if has_u64_shared_memory_addressing(single_ptx) {
violations.push(ParityViolation {
kind: ParityViolationKind::SharedMemoryAddressingU64,
message: format!(
"Single kernel '{}' uses u64 registers (%rd) for shared memory addressing; \
use u32 (%r) for portability",
single_name
),
});
}
let single_loops = extract_loop_labels(single_ptx);
let batched_loops = extract_loop_labels(batched_ptx);
if single_loops != batched_loops {
violations.push(ParityViolation {
kind: ParityViolationKind::LoopStructureMismatch,
message: format!(
"Loop structure differs: single has {:?}, batched has {:?}",
single_loops, batched_loops
),
});
}
ParityResult {
is_compatible: violations.is_empty(),
violations,
single_name: single_name.to_string(),
batched_name: batched_name.to_string(),
}
}
pub fn validate_batched_kernel(ptx: &str, kernel_name: &str) -> ParityResult {
let mut violations = Vec::new();
if !has_batch_dispatch(ptx) {
violations.push(ParityViolation {
kind: ParityViolationKind::MissingBatchDispatch,
message: format!(
"Batched kernel '{}' does not use %ctaid.y for row dispatch",
kernel_name
),
});
}
if has_u64_shared_memory_addressing(ptx) {
violations.push(ParityViolation {
kind: ParityViolationKind::SharedMemoryAddressingU64,
message: format!(
"Batched kernel '{}' uses u64 registers for shared memory addressing",
kernel_name
),
});
}
ParityResult {
is_compatible: violations.is_empty(),
violations,
single_name: String::new(),
batched_name: kernel_name.to_string(),
}
}
#[cfg(test)]
mod tests;