use std::collections::HashSet;
#[derive(Debug, Clone, PartialEq)]
pub struct BarrierSafetyResult {
pub is_safe: bool,
pub violations: Vec<BarrierViolation>,
pub barrier_count: usize,
pub exit_count: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub struct BarrierViolation {
pub line: usize,
pub kind: ViolationKind,
pub instruction: String,
pub context: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ViolationKind {
EarlyExitBeforeBarrier,
ConditionalExitBeforeBarrier,
MissingBarrierAfterSharedAccess,
}
impl std::fmt::Display for ViolationKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::EarlyExitBeforeBarrier => write!(f, "PARITY-114: Early exit before barrier"),
Self::ConditionalExitBeforeBarrier => {
write!(f, "PARITY-114: Conditional exit may cause divergence")
}
Self::MissingBarrierAfterSharedAccess => {
write!(f, "Missing barrier after shared memory access")
}
}
}
}
#[must_use]
pub fn analyze(ptx: &str) -> BarrierSafetyResult {
let lines: Vec<&str> = ptx.lines().collect();
let mut violations = Vec::new();
let mut barrier_count = 0;
let mut exit_count = 0;
let mut loop_labels: HashSet<String> = HashSet::new();
let mut loop_end_labels: HashSet<String> = HashSet::new();
let has_barriers = ptx.contains("bar.sync") || ptx.contains("bar.arrive");
for line in &lines {
let trimmed = line.trim();
if trimmed.ends_with(':') && !trimmed.starts_with('.') && !trimmed.contains("exit") {
let label = trimmed.trim_end_matches(':').to_string();
let label_pattern = format!("bra {};", label);
let label_pattern2 = format!("bra {}", label);
if ptx.contains(&label_pattern) || ptx.contains(&label_pattern2) {
loop_labels.insert(label.clone());
loop_end_labels.insert(format!("{}_end", label));
loop_end_labels.insert(format!("{}_done", label));
}
}
}
loop_end_labels.insert("k_tile_end".to_string());
loop_end_labels.insert("kv_loop_end".to_string());
loop_end_labels.insert("loop_end".to_string());
loop_end_labels.insert("sb_loop_done".to_string());
loop_end_labels.insert("sub_block_done".to_string());
loop_end_labels.insert("k_block_done".to_string());
let mut in_loop = false;
let mut loop_start_line = 0;
let mut barrier_seen_in_current_loop = false;
for (idx, line) in lines.iter().enumerate() {
let line_num = idx + 1;
let trimmed = line.trim();
if trimmed.contains("bar.sync") || trimmed.contains("bar.arrive") {
barrier_count += 1;
if in_loop {
barrier_seen_in_current_loop = true;
}
}
if trimmed.ends_with(':') && !trimmed.starts_with('.') {
let label = trimmed.trim_end_matches(':');
if loop_labels.contains(label) {
in_loop = true;
loop_start_line = line_num;
barrier_seen_in_current_loop = false;
}
if loop_end_labels.contains(label) {
in_loop = false;
}
}
let is_exit = trimmed.contains("bra exit");
if is_exit {
exit_count += 1;
if has_barriers && in_loop && !barrier_seen_in_current_loop {
if trimmed.starts_with('@') {
violations.push(BarrierViolation {
line: line_num,
kind: ViolationKind::ConditionalExitBeforeBarrier,
instruction: trimmed.to_string(),
context: format!("loop starting at line {}", loop_start_line),
});
} else {
violations.push(BarrierViolation {
line: line_num,
kind: ViolationKind::EarlyExitBeforeBarrier,
instruction: trimmed.to_string(),
context: format!("loop starting at line {}", loop_start_line),
});
}
}
}
if trimmed == "ret;" {
exit_count += 1;
}
}
BarrierSafetyResult {
is_safe: violations.is_empty(),
violations,
barrier_count,
exit_count,
}
}
pub fn validate(ptx: &str) -> Result<(), String> {
let result = analyze(ptx);
if result.is_safe {
Ok(())
} else {
let mut msg = String::from("Barrier safety violations found:\n");
for v in &result.violations {
msg.push_str(&format!(
" Line {}: {} - {}\n Context: {}\n",
v.line, v.kind, v.instruction, v.context
));
}
Err(msg)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_barrier_safe_ptx() {
let ptx = r#"
.entry kernel() {
mov.u32 %r0, %tid.x;
setp.lt.u32 %p0, %r0, 32;
loop_start:
ld.shared.f32 %f0, [%r0];
bar.sync 0;
st.shared.f32 [%r0], %f0;
bra loop_start;
loop_start_end:
@!%p0 bra exit;
st.global.f32 [%r1], %f0;
exit:
ret;
}
"#;
let result = analyze(ptx);
assert!(result.is_safe, "Should be safe: {:?}", result.violations);
assert_eq!(result.barrier_count, 1);
}
#[test]
fn test_barrier_unsafe_early_exit() {
let ptx = r#"
.entry kernel() {
mov.u32 %r0, %tid.x;
setp.lt.u32 %p0, %r0, 32;
loop_start:
@!%p0 bra exit;
ld.shared.f32 %f0, [%r0];
bar.sync 0;
st.shared.f32 [%r0], %f0;
bra loop_start;
loop_start_end:
done:
ret;
}
"#;
let result = analyze(ptx);
assert!(!result.is_safe, "Should detect early exit");
assert_eq!(result.violations.len(), 1);
assert_eq!(
result.violations[0].kind,
ViolationKind::ConditionalExitBeforeBarrier
);
}
#[test]
fn test_unconditional_early_exit() {
let ptx = r#"
.entry kernel() {
loop_start:
bra exit;
bar.sync 0;
bra loop_start;
loop_start_end:
done:
ret;
}
"#;
let result = analyze(ptx);
assert!(!result.is_safe);
assert_eq!(
result.violations[0].kind,
ViolationKind::EarlyExitBeforeBarrier
);
}
#[test]
fn test_validate_returns_error() {
let unsafe_ptx = r#"
.entry kernel() {
loop_start:
bra exit;
bar.sync 0;
bra loop_start;
loop_start_end:
done:
ret;
}
"#;
let result = validate(unsafe_ptx);
assert!(result.is_err());
assert!(result.unwrap_err().contains("PARITY-114"));
}
#[test]
fn test_exit_after_loop_ok() {
let ptx = r#"
.entry kernel() {
k_tile_loop:
bar.sync 0;
ld.shared.f32 %f0, [%r0];
bra k_tile_loop;
k_tile_end:
@!%p0 bra exit;
st.global.f32 [%r1], %f0;
done:
ret;
}
"#;
let result = analyze(ptx);
assert!(
result.is_safe,
"Exit after loop should be OK: {:?}",
result.violations
);
}
#[test]
fn test_kv_loop_pattern() {
let ptx = r#"
.entry attention() {
kv_loop:
bar.sync 0;
wmma.mma.sync.aligned.row.col.m16n16k16.f32.f16.f16.f32 ...;
bra kv_loop;
kv_loop_end:
@!%p_valid bra exit;
st.global.f32 [%out], %f0;
done:
ret;
}
"#;
let result = analyze(ptx);
assert!(result.is_safe, "KV loop pattern should be safe");
}
#[test]
fn test_warp_only_kernel_safe() {
let ptx = r#"
.entry rmsnorm() {
mov.u32 %r0, %tid.x;
setp.lt.u32 %p0, %r0, 32;
sum_loop:
@!%p1 bra exit;
ld.global.f32 %f0, [%addr];
shfl.sync.down.b32 %f1, %f0, 16, 0x1f, 0xffffffff;
add.f32 %f0, %f0, %f1;
add.u32 %idx, %idx, 32;
setp.lt.u32 %p1, %idx, %n;
bra sum_loop;
sum_loop_end:
st.global.f32 [%out], %f0;
exit:
ret;
}
"#;
let result = analyze(ptx);
assert!(
result.is_safe,
"Warp-only kernel should be safe: {:?}",
result.violations
);
assert_eq!(result.barrier_count, 0, "No barriers in warp-only kernel");
}
#[test]
fn test_no_barrier_conditional_exit_safe() {
let ptx = r#"
.entry kernel() {
loop:
@%p0 bra exit;
ld.global.f32 %f0, [%r0];
bra loop;
loop_end:
exit:
ret;
}
"#;
let result = analyze(ptx);
assert!(
result.is_safe,
"No-barrier kernel with conditional exit should be safe"
);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn barrier_after_exits_is_safe(loop_body_len in 1usize..10) {
let mut ptx = String::from(".entry test() {\nloop:\n");
for i in 0..loop_body_len {
ptx.push_str(&format!(" mov.u32 %r{}, 0;\n", i));
}
ptx.push_str(" bar.sync 0;\n");
ptx.push_str(" bra loop;\nloop_end:\nexit:\n ret;\n}\n");
let result = analyze(&ptx);
prop_assert!(result.is_safe, "Generated safe PTX should pass: {}", ptx);
}
#[test]
fn no_loops_always_safe(num_exits in 0usize..5) {
let mut ptx = String::from(".entry test() {\n");
for _ in 0..num_exits {
ptx.push_str(" @%p0 bra exit;\n");
}
ptx.push_str("exit:\n ret;\n}\n");
let result = analyze(&ptx);
prop_assert!(result.is_safe, "No-loop PTX should be safe");
}
#[test]
fn barrier_count_accurate(num_barriers in 0usize..5) {
let mut ptx = String::from(".entry test() {\n");
for i in 0..num_barriers {
ptx.push_str(&format!(" bar.sync {};\n", i % 16));
}
ptx.push_str(" ret;\n}\n");
let result = analyze(&ptx);
prop_assert_eq!(result.barrier_count, num_barriers);
}
}
}