use vyre::ir::{validate, Program};
const DEFAULT_WORKGROUP_SIZE: u32 = 64;
const REPEATS: u32 = 10;
pub(super) const SENTINEL: u32 = 0xA5A5_5A5A;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct BarrierReport {
pub passes: Vec<BarrierPassResult>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct BarrierPassResult {
pub pass_name: String,
pub input_bytes: Vec<u8>,
pub observed_output: Vec<u8>,
pub expected_output: Vec<u8>,
pub verdict: BarrierVerdict,
pub message: String,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum BarrierVerdict {
Passed,
Failed(BarrierViolation),
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum BarrierViolation {
PostVisibilityMissing,
PreIsolationViolated,
CrossWorkgroupSynced,
DivergentBarrierAccepted,
}
#[inline]
pub(crate) fn enforce_barrier(backend: &dyn vyre::VyreBackend) -> BarrierReport {
let workgroup_size = DEFAULT_WORKGROUP_SIZE;
BarrierReport {
passes: vec![
run_post_visibility(backend, workgroup_size),
run_pre_isolation(backend, workgroup_size),
run_cross_workgroup(backend, workgroup_size),
run_divergent_validation(workgroup_size),
],
}
}
fn run_post_visibility(backend: &dyn vyre::VyreBackend, workgroup_size: u32) -> BarrierPassResult {
let program = post_visibility_program(workgroup_size);
let expected = post_visibility_expected(workgroup_size);
run_repeated_program_pass(
backend,
workgroup_size,
"post-barrier visibility",
&program,
expected,
BarrierViolation::PostVisibilityMissing,
|observed, expected| post_visibility_message(observed, expected, workgroup_size),
)
}
fn run_pre_isolation(backend: &dyn vyre::VyreBackend, workgroup_size: u32) -> BarrierPassResult {
let program = pre_isolation_program(workgroup_size);
let expected = pre_isolation_expected(workgroup_size);
run_repeated_program_pass(
backend,
workgroup_size,
"pre-barrier isolation",
&program,
expected,
BarrierViolation::PreIsolationViolated,
|observed, expected| pre_isolation_message(observed, expected, workgroup_size),
)
}
fn run_cross_workgroup(backend: &dyn vyre::VyreBackend, workgroup_size: u32) -> BarrierPassResult {
let program = cross_workgroup_program(workgroup_size);
let program_bytes = match program.to_wire() {
Ok(bytes) => bytes,
Err(err) => {
return failed_result(
"cross-workgroup non-synchronization",
Vec::new(),
err.to_string().into_bytes(),
word_bytes(&[0, SENTINEL]),
BarrierViolation::CrossWorkgroupSynced,
"cross-workgroup barrier program failed IR serialization. Fix: construct a wire-encodable vyre IR Program.".to_string(),
);
}
};
let expected = word_bytes(&[0, SENTINEL]);
let global_sync = word_bytes(&[SENTINEL, SENTINEL]);
let mut first_observed = Vec::new();
let mut global_sync_runs = 0;
let mut expected_runs = 0;
if let Err(message) = validate_program(&program) {
return failed_result(
"cross-workgroup non-synchronization",
program_bytes,
message.into_bytes(),
expected,
BarrierViolation::CrossWorkgroupSynced,
"cross-workgroup barrier program failed local validation. Fix: keep every barrier in uniform control flow and all buffers declared.".to_string(),
);
}
for run in 0..REPEATS {
match dispatch_exact(backend, &program, &[], expected.len()) {
Ok(observed) => {
if run == 0 {
first_observed = observed.clone();
}
if observed == global_sync {
global_sync_runs += 1;
}
if observed == expected {
expected_runs += 1;
}
}
Err(err) => {
return failed_result(
"cross-workgroup non-synchronization",
program_bytes,
err.to_string().into_bytes(),
expected,
BarrierViolation::CrossWorkgroupSynced,
format!(
"backend failed cross-workgroup barrier dispatch on run {run}: Fix: compile and execute uniform-barrier IR programs with two workgroups."
),
);
}
}
}
if global_sync_runs > 0 {
failed_result(
"cross-workgroup non-synchronization",
program_bytes,
first_observed,
expected,
BarrierViolation::CrossWorkgroupSynced,
format!(
"Fix: barrier made workgroup 0's pre-barrier storage write visible to workgroup 1 after its barrier on {global_sync_runs}/{REPEATS} runs. Barriers must be workgroup-local; remove any device-wide synchronization from Barrier lowering or dispatch scheduling."
),
)
} else if expected_runs == 0 {
failed_result(
"cross-workgroup non-synchronization",
program_bytes,
first_observed,
expected,
BarrierViolation::CrossWorkgroupSynced,
format!(
"Fix: cross-workgroup barrier probe never produced the required workgroup-local result in {REPEATS} runs. Return exactly the expected bytes; unexpected or truncated output cannot prove barrier scope."
),
)
} else {
passed_result(
"cross-workgroup non-synchronization",
program_bytes,
first_observed,
expected,
)
}
}
fn run_divergent_validation(workgroup_size: u32) -> BarrierPassResult {
let program = divergent_barrier_program(workgroup_size);
let program_bytes = match program.to_wire() {
Ok(bytes) => bytes,
Err(err) => {
return failed_result(
"uniform-control-flow validation",
Vec::new(),
err.to_string().into_bytes(),
b"validation error rejecting divergent barrier".to_vec(),
BarrierViolation::DivergentBarrierAccepted,
"divergent barrier program failed IR serialization. Fix: construct a wire-encodable vyre IR Program.".to_string(),
);
}
};
let errors = validate(&program);
let observed = validation_messages(&errors).into_bytes();
let expected = b"validation error rejecting divergent barrier".to_vec();
if errors.is_empty() {
failed_result(
"uniform-control-flow validation",
program_bytes,
observed,
expected,
BarrierViolation::DivergentBarrierAccepted,
"upstream: vyre::ir::validate fails to reject divergent barrier. Fix: add a validation rule that rejects Node::Barrier in non-uniform control flow.".to_string(),
)
} else {
passed_result(
"uniform-control-flow validation",
program_bytes,
observed,
expected,
)
}
}
fn run_repeated_program_pass(
backend: &dyn vyre::VyreBackend,
workgroup_size: u32,
name: &str,
program: &Program,
expected: Vec<u8>,
violation: BarrierViolation,
message: impl Fn(&[u8], &[u8]) -> String,
) -> BarrierPassResult {
let program_bytes = match program.to_wire() {
Ok(bytes) => bytes,
Err(err) => {
return failed_result(
name,
Vec::new(),
err.to_string().into_bytes(),
expected,
violation,
"barrier test program failed IR serialization. Fix: construct a wire-encodable vyre IR Program.".to_string(),
);
}
};
if let Err(err) = validate_program(program) {
return failed_result(
name,
program_bytes,
err.to_string().into_bytes(),
expected,
violation,
"barrier test program failed local validation. Fix: keep test IR structurally valid before dispatch.".to_string(),
);
}
for run in 0..REPEATS {
match dispatch_exact(backend, program, &[], expected.len()) {
Ok(observed) if observed == expected => {}
Ok(observed) => {
return failed_result(
name,
program_bytes,
observed.clone(),
expected.clone(),
violation,
format!("run {run}: {}", message(&observed, &expected)),
);
}
Err(err) => {
return failed_result(
name,
program_bytes,
err.to_string().into_bytes(),
expected,
violation,
format!(
"backend failed barrier dispatch on run {run}: Fix: compile and execute serialized vyre IR programs containing Node::Barrier."
),
);
}
}
}
passed_result(name, program_bytes, expected.clone(), expected)
}
fn dispatch_exact(
backend: &dyn vyre::VyreBackend,
program: &Program,
inputs: &[Vec<u8>],
output_size: usize,
) -> Result<Vec<u8>, vyre::BackendError> {
let program = program_with_output_size(program, output_size);
let mut outputs = backend.dispatch(&program, inputs, &vyre::DispatchConfig::default())?;
if outputs.is_empty() {
return Err(vyre::BackendError::new(
"backend returned zero output buffers. Fix: return the barrier probe output as outputs[0].",
));
}
let output = outputs.remove(0);
if output.len() != output_size {
return Err(vyre::BackendError::new(format!(
"backend returned {} bytes, expected {output_size}. Fix: size the first output buffer from the barrier probe output declaration.",
output.len()
)));
}
Ok(output)
}
fn program_with_output_size(program: &Program, output_size: usize) -> Program {
let mut buffers = program.buffers().to_vec();
for buffer in &mut buffers {
if buffer.access == vyre::ir::BufferAccess::ReadWrite {
buffer.is_output = true;
buffer.count = output_size.div_ceil(4).try_into().unwrap_or(u32::MAX);
break;
}
}
Program::new(buffers, program.workgroup_size(), program.entry().to_vec())
}
use programs::*;
mod programs {
use crate::enforce::enforcers::barrier::*;
use vyre::ir::{BufferDecl, DataType, Expr, Node, Program};
pub(super) fn post_visibility_program(workgroup_size: u32) -> Program {
let lid = Expr::LocalId { axis: 0 };
let reverse = Expr::sub(Expr::u32(workgroup_size - 1), lid.clone());
Program::new(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[workgroup_size, 1, 1],
vec![
Node::store("out", lid.clone(), Expr::add(lid.clone(), Expr::u32(1))),
Node::Barrier,
Node::store(
"out",
Expr::add(Expr::u32(workgroup_size), lid),
Expr::load("out", reverse),
),
],
)
}
pub(super) fn pre_isolation_program(workgroup_size: u32) -> Program {
let lid = Expr::LocalId { axis: 0 };
Program::new(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[workgroup_size, 1, 1],
vec![
Node::if_then(
Expr::eq(lid.clone(), Expr::u32(0)),
vec![Node::store("out", Expr::u32(0), Expr::u32(7))],
),
Node::Barrier,
Node::store(
"out",
Expr::add(Expr::u32(1), lid.clone()),
Expr::load("out", Expr::u32(0)),
),
Node::Barrier,
Node::if_then(
Expr::eq(lid.clone(), Expr::u32(0)),
vec![Node::store("out", Expr::u32(0), Expr::u32(SENTINEL))],
),
Node::Barrier,
Node::store(
"out",
Expr::add(Expr::u32(workgroup_size + 1), lid),
Expr::load("out", Expr::u32(0)),
),
],
)
}
pub(super) fn cross_workgroup_program(workgroup_size: u32) -> Program {
Program::new(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[workgroup_size, 1, 1],
vec![
Node::if_then(
Expr::eq(Expr::gid_x(), Expr::u32(0)),
vec![Node::store("out", Expr::u32(1), Expr::u32(SENTINEL))],
),
Node::Barrier,
Node::if_then(
Expr::eq(Expr::gid_x(), Expr::u32(workgroup_size)),
vec![Node::store(
"out",
Expr::u32(0),
Expr::load("out", Expr::u32(1)),
)],
),
],
)
}
pub(super) fn divergent_barrier_program(workgroup_size: u32) -> Program {
Program::new(
Vec::new(),
[workgroup_size, 1, 1],
vec![Node::If {
cond: Expr::eq(Expr::gid_x(), Expr::u32(0)),
then: vec![Node::Barrier],
otherwise: Vec::new(),
}],
)
}
pub(super) fn post_visibility_expected(workgroup_size: u32) -> Vec<u8> {
let mut words = Vec::with_capacity((workgroup_size * 2) as usize);
words.extend(1..=workgroup_size);
words.extend((1..=workgroup_size).rev());
word_bytes(&words)
}
pub(super) fn pre_isolation_expected(workgroup_size: u32) -> Vec<u8> {
let mut words = Vec::with_capacity((workgroup_size * 2 + 1) as usize);
words.push(SENTINEL);
words.extend(std::iter::repeat_n(7, workgroup_size as usize));
words.extend(std::iter::repeat_n(SENTINEL, workgroup_size as usize));
word_bytes(&words)
}
pub(super) fn validate_program(program: &Program) -> Result<(), String> {
let errors = validate(program);
if errors.is_empty() {
Ok(())
} else {
Err(validation_messages(&errors))
}
}
pub(super) fn validation_messages(errors: &[vyre::ir::ValidationError]) -> String {
errors
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join("\n")
}
pub(super) fn word_bytes(words: &[u32]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(words.len() * 4);
for word in words {
bytes.extend_from_slice(&word.to_le_bytes());
}
bytes
}
pub(super) fn read_word(bytes: &[u8], index: usize) -> Option<u32> {
bytes
.get(index * 4..index * 4 + 4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
}
pub(super) fn post_visibility_message(
observed: &[u8],
expected: &[u8],
workgroup_size: u32,
) -> String {
for slot in workgroup_size as usize..(workgroup_size as usize * 2) {
if read_word(observed, slot) != read_word(expected, slot) {
let got = read_word(observed, slot).unwrap_or(0);
let want = read_word(expected, slot).unwrap_or(0);
return format!(
"Fix: post-barrier read at slot {slot} returned {got} but expected {want}. Barrier did not establish visibility. Check storageBarrier()/workgroupBarrier() placement in the lowered shader."
);
}
}
"Fix: post-barrier output length or prefix differed from expected barrier-visible writes."
.to_string()
}
pub(super) fn pre_isolation_message(
observed: &[u8],
expected: &[u8],
workgroup_size: u32,
) -> String {
for slot in 1..=workgroup_size as usize {
if read_word(observed, slot) != read_word(expected, slot) {
let got = read_word(observed, slot).unwrap_or(0);
return format!(
"Fix: pre-barrier read at slot {slot} returned {got} after a post-barrier write was scheduled. Writes after a barrier must not travel backward before the barrier."
);
}
}
"Fix: pre/post barrier isolation output differed from expected phased storage writes."
.to_string()
}
pub(super) fn passed_result(
name: &str,
input_bytes: Vec<u8>,
observed_output: Vec<u8>,
expected_output: Vec<u8>,
) -> BarrierPassResult {
BarrierPassResult {
pass_name: name.to_string(),
input_bytes,
observed_output,
expected_output,
verdict: BarrierVerdict::Passed,
message: String::new(),
}
}
pub(super) fn failed_result(
name: &str,
input_bytes: Vec<u8>,
observed_output: Vec<u8>,
expected_output: Vec<u8>,
violation: BarrierViolation,
message: String,
) -> BarrierPassResult {
BarrierPassResult {
pass_name: name.to_string(),
input_bytes,
observed_output,
expected_output,
verdict: BarrierVerdict::Failed(violation),
message,
}
}
}
pub struct BarrierPlacementEnforcer;
impl crate::enforce::EnforceGate for BarrierPlacementEnforcer {
fn id(&self) -> &'static str {
"barrier_placement"
}
fn name(&self) -> &'static str {
"barrier_placement"
}
fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let Some(backend) = ctx.backend else {
return vec![crate::enforce::aggregate_finding(self.id(), vec!["barrier_placement: backend is required. Fix: provide a VyreBackend in EnforceCtx.".to_string()])];
};
let report = enforce_barrier(backend);
let messages = report
.passes
.into_iter()
.filter_map(|pass| match pass.verdict {
BarrierVerdict::Passed => None,
BarrierVerdict::Failed(_) => Some(pass.message),
})
.collect::<Vec<_>>();
crate::enforce::finding_result(self.id(), messages)
}
}
pub const REGISTERED: BarrierPlacementEnforcer = BarrierPlacementEnforcer;