#[derive(Debug, Clone)]
pub struct ParityResult {
pub is_compatible: bool,
pub violations: Vec<ParityViolation>,
pub single_name: String,
pub batched_name: String,
}
#[derive(Debug, Clone)]
pub struct ParityViolation {
pub kind: ParityViolationKind,
pub message: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ParityViolationKind {
ParameterCountMismatch,
SharedMemoryMismatch,
MissingBatchDispatch,
SharedMemoryAddressingU64,
LoopStructureMismatch,
RegisterTypeMismatch,
}
impl std::fmt::Display for ParityViolationKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ParameterCountMismatch => write!(f, "PARAM_COUNT"),
Self::SharedMemoryMismatch => write!(f, "SHARED_MEM_SIZE"),
Self::MissingBatchDispatch => write!(f, "MISSING_CTAID_Y"),
Self::SharedMemoryAddressingU64 => write!(f, "SHARED_MEM_U64"),
Self::LoopStructureMismatch => write!(f, "LOOP_STRUCTURE"),
Self::RegisterTypeMismatch => write!(f, "REG_TYPE"),
}
}
}
fn count_params(ptx: &str) -> usize {
ptx.lines()
.filter(|line| {
let trimmed = line.trim();
trimmed.starts_with(".param")
})
.count()
}
fn extract_shared_memory_bytes(ptx: &str) -> Option<u32> {
for line in ptx.lines() {
let trimmed = line.trim();
if trimmed.contains(".shared") && trimmed.contains("smem[") {
if let Some(start) = trimmed.find("smem[") {
let after = &trimmed[start + 5..];
if let Some(end) = after.find(']') {
if let Ok(size) = after[..end].parse::<u32>() {
return Some(size);
}
}
}
}
}
None
}
fn extract_loop_labels(ptx: &str) -> Vec<String> {
let mut labels = Vec::new();
for line in ptx.lines() {
let trimmed = line.trim();
if trimmed.ends_with(':') && !trimmed.starts_with("//") {
let label = trimmed.trim_end_matches(':');
if label.contains("loop") {
labels.push(label.to_string());
}
}
}
labels
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BatchDispatchStrategy {
GridY,
RegisterUnroll,
}
fn has_grid_y_dispatch(ptx: &str) -> bool {
ptx.contains("%ctaid.y")
}
fn has_register_unroll_dispatch(ptx: &str) -> bool {
ptx.contains("m_dim")
}
fn has_batch_dispatch(ptx: &str) -> bool {
has_grid_y_dispatch(ptx) || has_register_unroll_dispatch(ptx)
}
fn has_u64_shared_memory_addressing(ptx: &str) -> bool {
for line in ptx.lines() {
let trimmed = line.trim();
if (trimmed.contains("st.shared") || trimmed.contains("ld.shared"))
&& trimmed.contains("[%rd")
{
return true;
}
}
false
}
pub fn validate_parity(
single_ptx: &str,
batched_ptx: &str,
single_name: &str,
batched_name: &str,
) -> ParityResult {
let mut violations = Vec::new();
let single_params = count_params(single_ptx);
let batched_params = count_params(batched_ptx);
if single_params != batched_params {
violations.push(ParityViolation {
kind: ParityViolationKind::ParameterCountMismatch,
message: format!(
"Single kernel '{}' has {} params, batched '{}' has {} params",
single_name, single_params, batched_name, batched_params
),
});
}
let single_smem = extract_shared_memory_bytes(single_ptx);
let batched_smem = extract_shared_memory_bytes(batched_ptx);
if single_smem != batched_smem {
violations.push(ParityViolation {
kind: ParityViolationKind::SharedMemoryMismatch,
message: format!(
"Shared memory mismatch: single={:?} bytes, batched={:?} bytes",
single_smem, batched_smem
),
});
}
if !has_batch_dispatch(batched_ptx) {
violations.push(ParityViolation {
kind: ParityViolationKind::MissingBatchDispatch,
message: format!(
"Batched kernel '{}' does not use %ctaid.y for row dispatch",
batched_name
),
});
}
if has_u64_shared_memory_addressing(batched_ptx) {
violations.push(ParityViolation {
kind: ParityViolationKind::SharedMemoryAddressingU64,
message: format!(
"Batched kernel '{}' uses u64 registers (%rd) for shared memory addressing; \
use u32 (%r) for portability",
batched_name
),
});
}
if has_u64_shared_memory_addressing(single_ptx) {
violations.push(ParityViolation {
kind: ParityViolationKind::SharedMemoryAddressingU64,
message: format!(
"Single kernel '{}' uses u64 registers (%rd) for shared memory addressing; \
use u32 (%r) for portability",
single_name
),
});
}
let single_loops = extract_loop_labels(single_ptx);
let batched_loops = extract_loop_labels(batched_ptx);
if single_loops != batched_loops {
violations.push(ParityViolation {
kind: ParityViolationKind::LoopStructureMismatch,
message: format!(
"Loop structure differs: single has {:?}, batched has {:?}",
single_loops, batched_loops
),
});
}
ParityResult {
is_compatible: violations.is_empty(),
violations,
single_name: single_name.to_string(),
batched_name: batched_name.to_string(),
}
}
pub fn validate_batched_kernel(ptx: &str, kernel_name: &str) -> ParityResult {
let mut violations = Vec::new();
if !has_batch_dispatch(ptx) {
violations.push(ParityViolation {
kind: ParityViolationKind::MissingBatchDispatch,
message: format!(
"Batched kernel '{}' does not use %ctaid.y for row dispatch",
kernel_name
),
});
}
if has_u64_shared_memory_addressing(ptx) {
violations.push(ParityViolation {
kind: ParityViolationKind::SharedMemoryAddressingU64,
message: format!(
"Batched kernel '{}' uses u64 registers for shared memory addressing",
kernel_name
),
});
}
ParityResult {
is_compatible: violations.is_empty(),
violations,
single_name: String::new(),
batched_name: kernel_name.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_params_basic() {
let ptx = r#"
.visible .entry test(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 c_ptr
) {
ret;
}
"#;
assert_eq!(count_params(ptx), 3);
}
#[test]
fn test_extract_shared_memory_bytes() {
let ptx = " .shared .align 16 .b8 smem[32];";
assert_eq!(extract_shared_memory_bytes(ptx), Some(32));
let ptx_none = " .reg .f32 %f<10>;";
assert_eq!(extract_shared_memory_bytes(ptx_none), None);
}
#[test]
fn test_extract_loop_labels() {
let ptx = r#"
sum_loop:
add.u32 %r6, %r6, 256;
bra sum_loop;
sum_loop_end:
norm_loop:
bra norm_loop;
exit:
ret;
"#;
let labels = extract_loop_labels(ptx);
assert_eq!(
labels,
vec!["sum_loop", "sum_loop_end", "norm_loop"]
);
}
#[test]
fn test_has_batch_dispatch() {
assert!(has_batch_dispatch(" mov.u32 %r1, %ctaid.y;"));
assert!(has_batch_dispatch(" .param .u32 m_dim"));
assert!(!has_batch_dispatch(" mov.u32 %r1, %ctaid.x;"));
}
#[test]
fn test_batch_dispatch_strategies() {
assert!(has_grid_y_dispatch(" mov.u32 %r1, %ctaid.y;"));
assert!(!has_grid_y_dispatch(" .param .u32 m_dim"));
assert!(has_register_unroll_dispatch(" .param .u32 m_dim"));
assert!(!has_register_unroll_dispatch(" mov.u32 %r1, %ctaid.y;"));
}
#[test]
fn test_has_u64_shared_memory_addressing() {
assert!(has_u64_shared_memory_addressing(
" st.shared.f32 [%rd3], %f0;"
));
assert!(!has_u64_shared_memory_addressing(
" st.shared.f32 [%r3], %f0;"
));
}
#[test]
fn test_validate_parity_matching_kernels() {
let single = r#"
.version 8.0
.target sm_89
.address_size 64
.visible .entry rmsnorm(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u64 gamma_ptr
) {
.shared .align 16 .b8 smem[32];
mov.u32 %r0, %tid.x;
sum_loop:
bra sum_loop;
sum_loop_end:
norm_loop:
bra norm_loop;
exit:
ret;
}
"#;
let batched = r#"
.version 8.0
.target sm_89
.address_size 64
.visible .entry batched_rmsnorm(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u64 gamma_ptr
) {
.shared .align 16 .b8 smem[32];
mov.u32 %r0, %tid.x;
mov.u32 %r1, %ctaid.y;
sum_loop:
bra sum_loop;
sum_loop_end:
norm_loop:
bra norm_loop;
exit:
ret;
}
"#;
let result = validate_parity(single, batched, "rmsnorm", "batched_rmsnorm");
assert!(
result.is_compatible,
"Should be compatible: {:?}",
result.violations
);
}
#[test]
fn test_validate_parity_param_mismatch() {
let single = r#"
.visible .entry test(
.param .u64 a,
.param .u64 b
) { ret; }
"#;
let batched = r#"
.visible .entry test_batched(
.param .u64 a,
.param .u64 b,
.param .u32 batch_size
) {
mov.u32 %r1, %ctaid.y;
ret;
}
"#;
let result = validate_parity(single, batched, "test", "test_batched");
assert!(!result.is_compatible);
assert!(result
.violations
.iter()
.any(|v| v.kind == ParityViolationKind::ParameterCountMismatch));
}
#[test]
fn test_validate_parity_missing_ctaid_y() {
let single = r#"
.visible .entry test(
.param .u64 a
) { ret; }
"#;
let batched = r#"
.visible .entry test_batched(
.param .u64 a
) { ret; }
"#;
let result = validate_parity(single, batched, "test", "test_batched");
assert!(!result.is_compatible);
assert!(result
.violations
.iter()
.any(|v| v.kind == ParityViolationKind::MissingBatchDispatch));
}
#[test]
fn test_validate_parity_u64_shared_memory() {
let single = r#"
.visible .entry test(
.param .u64 a
) {
.shared .align 16 .b8 smem[32];
st.shared.f32 [%r3], %f0;
ret;
}
"#;
let batched = r#"
.visible .entry test_batched(
.param .u64 a
) {
.shared .align 16 .b8 smem[32];
mov.u32 %r1, %ctaid.y;
st.shared.f32 [%rd3], %f0;
ret;
}
"#;
let result = validate_parity(single, batched, "test", "test_batched");
assert!(!result.is_compatible);
assert!(result
.violations
.iter()
.any(|v| v.kind == ParityViolationKind::SharedMemoryAddressingU64));
}
#[test]
fn test_validate_batched_kernel_standalone() {
let good_grid = r#"
.visible .entry good_batched(
.param .u64 a
) {
mov.u32 %r1, %ctaid.y;
st.shared.f32 [%r3], %f0;
ret;
}
"#;
let result = validate_batched_kernel(good_grid, "good_batched");
assert!(result.is_compatible);
let good_reg = r#"
.visible .entry good_reg_batched(
.param .u64 a,
.param .u32 m_dim
) {
ret;
}
"#;
let result = validate_batched_kernel(good_reg, "good_reg_batched");
assert!(result.is_compatible);
let bad = r#"
.visible .entry bad_batched(
.param .u64 a
) {
st.shared.f32 [%rd3], %f0;
ret;
}
"#;
let result = validate_batched_kernel(bad, "bad_batched");
assert!(!result.is_compatible);
assert_eq!(result.violations.len(), 2); }
#[test]
fn test_parity_violation_display() {
assert_eq!(
ParityViolationKind::ParameterCountMismatch.to_string(),
"PARAM_COUNT"
);
assert_eq!(
ParityViolationKind::SharedMemoryAddressingU64.to_string(),
"SHARED_MEM_U64"
);
assert_eq!(
ParityViolationKind::MissingBatchDispatch.to_string(),
"MISSING_CTAID_Y"
);
}
}