use checks::{
check_atomic, check_compare_exchange, make_atomic_program, make_compare_exchange_program,
parse_output,
};
use vyre::ir::{AtomicOp, Program};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct AtomicsReport {
pub findings: Vec<AtomicFinding>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct AtomicFinding {
pub op: AtomicOp,
pub invocations: u32,
pub message: String,
pub run: u32,
}
#[inline]
pub(crate) fn enforce_atomics(backend: &dyn vyre::VyreBackend) -> AtomicsReport {
let mut findings = Vec::new();
let mut invocation_counts = vec![1u32, 63, 65, 127, 129, 255, 257, 1024];
let atomic_configs: &[(AtomicOp, u32)] = &[
(AtomicOp::Add, 1),
(AtomicOp::Or, 0x1234_5678),
(AtomicOp::And, 0xF0F0_F0F0),
(AtomicOp::Xor, 0xAAAA_AAAA),
(AtomicOp::Min, 42),
(AtomicOp::Max, 42),
(AtomicOp::Exchange, 0xDEAD_BEEF),
];
let boundary_operands: &[u32] = &[0, 1, u32::MAX, u32::MAX - 1, 0x8000_0000, 0x7FFF_FFFF];
for &n in &invocation_counts {
for (op, operand) in atomic_configs.iter().cloned() {
if let Some(finding) = run_atomic_pass(backend, op, operand, n) {
findings.push(finding);
}
}
for &operand in boundary_operands {
for op in [
AtomicOp::Add,
AtomicOp::Or,
AtomicOp::And,
AtomicOp::Xor,
AtomicOp::Min,
AtomicOp::Max,
AtomicOp::Exchange,
] {
if let Some(finding) = run_atomic_pass(backend, op, operand, n) {
findings.push(finding);
}
}
}
if let Some(finding) = run_compare_exchange_pass(backend, n) {
findings.push(finding);
}
}
AtomicsReport { findings }
}
#[inline]
pub fn enforce_atomic_operation(
backend: &dyn vyre::VyreBackend,
op: AtomicOp,
operand: u32,
invocations: u32,
) -> Option<AtomicFinding> {
run_atomic_pass(backend, op, operand, invocations)
}
#[inline]
pub fn enforce_compare_exchange_operation(
backend: &dyn vyre::VyreBackend,
invocations: u32,
) -> Option<AtomicFinding> {
run_compare_exchange_pass(backend, invocations)
}
fn run_atomic_pass(
backend: &dyn vyre::VyreBackend,
op: AtomicOp,
operand: u32,
n: u32,
) -> Option<AtomicFinding> {
if let Some(finding) = invalid_invocation_count_finding(op.clone(), n) {
return Some(finding);
}
let output_size = match atomic_output_size(n) {
Ok(size) => size,
Err(message) => {
return Some(AtomicFinding {
op: op.clone(),
invocations: n,
message,
run: 0,
})
}
};
let program = match make_atomic_program(op.clone(), operand, n) {
Ok(program) => program,
Err(message) => {
return Some(AtomicFinding {
op: op.clone(),
invocations: n,
message,
run: 0,
})
}
};
let input = Vec::new();
let mut first_run: Option<(Vec<u32>, u32)> = None;
for run in 0..5 {
let output = match dispatch_exact(backend, &program, &[input.clone()], output_size) {
Ok(bytes) => bytes,
Err(err) => {
if n > 64 && n % 64 == 0 {
let fallback_program = with_workgroup_size(&program, [64, 1, 1]);
match dispatch_exact(backend, &fallback_program, &[input.clone()], output_size)
{
Ok(bytes) => bytes,
Err(err2) => {
return Some(AtomicFinding {
op: op.clone(),
invocations: n,
message: format!(
"dispatch failed at {n} invocations (run {run}): {err2}. \
Original error: {err}. Fix: ensure backend supports the requested workgroup configuration."
),
run,
});
}
}
} else {
return Some(AtomicFinding {
op: op.clone(),
invocations: n,
message: format!(
"dispatch failed at {n} invocations (run {run}): {err}. Fix: ensure backend can compile and run atomic workloads."
),
run,
});
}
}
};
if let Some(finding) = output_len_finding(op.clone(), n, run, output.len(), output_size) {
return Some(finding);
}
let (pre_values, final_val) = match parse_output(&output, n) {
Ok(parsed) => parsed,
Err(message) => {
return Some(AtomicFinding {
op: op.clone(),
invocations: n,
message,
run,
})
}
};
if let Some(msg) = check_atomic(
op.clone(),
&pre_values,
final_val,
n,
operand,
atomic_initial(op.clone()),
) {
return Some(AtomicFinding {
op: op.clone(),
invocations: n,
message: msg,
run,
});
}
if let Some((ref first_pre, first_final)) = first_run {
if first_pre != &pre_values || first_final != final_val {
return Some(AtomicFinding {
op: op.clone(),
invocations: n,
message: format!(
"atomic {op:?} results diverged across runs: run 0 had final={first_final} pre-values={first_pre:?}, \
run {run} had final={final_val} pre-values={pre_values:?}. \
Fix: a conforming backend must produce sequentially-consistent atomic results deterministically."
),
run,
});
}
} else {
first_run = Some((pre_values, final_val));
}
}
None
}
fn run_compare_exchange_pass(backend: &dyn vyre::VyreBackend, n: u32) -> Option<AtomicFinding> {
if let Some(finding) = invalid_invocation_count_finding(AtomicOp::CompareExchange, n) {
return Some(finding);
}
let output_size = match atomic_output_size(n) {
Ok(size) => size,
Err(message) => {
return Some(AtomicFinding {
op: AtomicOp::CompareExchange,
invocations: n,
message,
run: 0,
})
}
};
let expected = 0u32;
let new_value = 1u32;
let program = make_compare_exchange_program(expected, new_value, n);
let input = expected.to_le_bytes().to_vec();
let mut first_run: Option<(Vec<u32>, u32)> = None;
for run in 0..5 {
let output = match dispatch_exact(backend, &program, &[input.clone()], output_size) {
Ok(bytes) => bytes,
Err(err) => {
if n > 64 && n % 64 == 0 {
let fallback_program = with_workgroup_size(&program, [64, 1, 1]);
match dispatch_exact(backend, &fallback_program, &[input.clone()], output_size)
{
Ok(bytes) => bytes,
Err(err2) => {
return Some(AtomicFinding {
op: AtomicOp::CompareExchange,
invocations: n,
message: format!(
"dispatch failed at {n} invocations (run {run}): {err2}. \
Original error: {err}. Fix: ensure backend supports the requested workgroup configuration."
),
run,
});
}
}
} else {
return Some(AtomicFinding {
op: AtomicOp::CompareExchange,
invocations: n,
message: format!(
"dispatch failed at {n} invocations (run {run}): {err}. Fix: ensure backend can compile and run atomic workloads."
),
run,
});
}
}
};
if let Some(finding) =
output_len_finding(AtomicOp::CompareExchange, n, run, output.len(), output_size)
{
return Some(finding);
}
let (pre_values, final_val) = match parse_output(&output, n) {
Ok(parsed) => parsed,
Err(message) => {
return Some(AtomicFinding {
op: AtomicOp::CompareExchange,
invocations: n,
message,
run,
})
}
};
if let Some(msg) = check_compare_exchange(&pre_values, final_val, n, expected, new_value) {
return Some(AtomicFinding {
op: AtomicOp::CompareExchange,
invocations: n,
message: msg,
run,
});
}
if let Some((ref first_pre, first_final)) = first_run {
if first_pre != &pre_values || first_final != final_val {
return Some(AtomicFinding {
op: AtomicOp::CompareExchange,
invocations: n,
message: format!(
"atomic CompareExchange results diverged across runs: run 0 had final={first_final} pre-values={first_pre:?}, \
run {run} had final={final_val} pre-values={pre_values:?}. \
Fix: a conforming backend must produce sequentially-consistent atomic results deterministically."
),
run,
});
}
} else {
first_run = Some((pre_values, final_val));
}
}
None
}
use probe_glue::*;
mod checks {
use vyre::ir::{AtomicOp, BufferDecl, DataType, Expr, Node, Program};
pub(super) fn make_atomic_program(
op: AtomicOp,
operand: u32,
n: u32,
) -> Result<Program, String> {
let atomic_expr = match op {
AtomicOp::Add => Expr::atomic_add("output", Expr::u32(0), Expr::u32(operand)),
AtomicOp::Or => atomic_expr(AtomicOp::Or, operand),
AtomicOp::And => atomic_expr(AtomicOp::And, operand),
AtomicOp::Xor => atomic_expr(AtomicOp::Xor, operand),
AtomicOp::Min => atomic_expr(AtomicOp::Min, operand),
AtomicOp::Max => atomic_expr(AtomicOp::Max, operand),
AtomicOp::Exchange => atomic_expr(AtomicOp::Exchange, operand),
AtomicOp::CompareExchange => {
return Err(
"unsupported atomic op CompareExchange in generic atomic harness. Fix: use the compare-exchange harness."
.to_string(),
)
}
};
Ok(Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::read_write("output", 1, DataType::U32),
],
[n, 1, 1],
vec![
Node::let_bind("gid", Expr::gid_x()),
Node::if_then(
Expr::lt(Expr::var("gid"), Expr::u32(n)),
vec![
Node::let_bind("pre", atomic_expr),
Node::store(
"output",
Expr::add(Expr::var("gid"), Expr::u32(1)),
Expr::var("pre"),
),
],
),
],
))
}
pub(super) fn make_compare_exchange_program(_expected: u32, new_value: u32, n: u32) -> Program {
Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::read_write("output", 1, DataType::U32),
],
[n, 1, 1],
vec![
Node::let_bind("gid", Expr::gid_x()),
Node::if_then(
Expr::lt(Expr::var("gid"), Expr::u32(n)),
vec![
Node::let_bind(
"pre",
Expr::atomic_compare_exchange(
"output",
Expr::u32(0),
Expr::load("input", Expr::u32(0)),
Expr::add(Expr::load("input", Expr::u32(0)), Expr::u32(new_value)),
),
),
Node::store(
"output",
Expr::add(Expr::var("gid"), Expr::u32(1)),
Expr::var("pre"),
),
],
),
],
)
}
pub(super) fn parse_output(output: &[u8], n: u32) -> Result<(Vec<u32>, u32), String> {
let words_len = n
.checked_add(1)
.ok_or_else(|| "atomic output word count overflowed u32. Fix: reject unrepresentable invocation counts before parsing.".to_string())?
as usize;
let byte_len = words_len
.checked_mul(4)
.ok_or_else(|| "atomic output byte count overflowed usize. Fix: reject unrepresentable invocation counts before parsing.".to_string())?;
if output.len() != byte_len {
return Err(format!(
"atomic parser received {} bytes for {n} invocations, expected exactly {byte_len}. Fix: validate output length before parsing.",
output.len()
));
}
let mut words = Vec::with_capacity(words_len);
for i in 0..words_len {
let start = i * 4;
words.push(u32::from_le_bytes([
output[start],
output[start + 1],
output[start + 2],
output[start + 3],
]));
}
let Some((&final_val, pre_values)) = words.split_first() else {
return Err(
"atomic parser received zero words. Fix: reject zero invocation tests before dispatch."
.to_string(),
);
};
Ok((pre_values.to_vec(), final_val))
}
pub(super) fn check_atomic(
op: AtomicOp,
pre_values: &[u32],
final_val: u32,
n: u32,
operand: u32,
initial: u32,
) -> Option<String> {
match op {
AtomicOp::Add => check_add(pre_values, final_val, n),
AtomicOp::Or => check_or(pre_values, final_val, n, operand, initial),
AtomicOp::And => check_and(pre_values, final_val, n, operand, initial),
AtomicOp::Xor => check_xor(pre_values, final_val, n, operand),
AtomicOp::Min => check_min(pre_values, final_val, n, operand),
AtomicOp::Max => check_max(pre_values, final_val, n, operand),
AtomicOp::Exchange => check_exchange(pre_values, final_val, n, operand),
AtomicOp::CompareExchange => Some(
"unsupported atomic op CompareExchange in generic atomic checker. Fix: use the compare-exchange checker."
.to_string(),
),
}
}
pub(super) fn check_compare_exchange(
pre_values: &[u32],
final_val: u32,
n: u32,
expected: u32,
new_value: u32,
) -> Option<String> {
let successes = pre_values.iter().filter(|&&v| v == expected).count();
if successes > 1 {
return Some(format!(
"AtomicCompareExchange let {successes} invocations succeed with expected={expected} (limit is 1). \
Fix: implement sequentially-consistent compare-exchange so that only one invocation \
succeeds when multiple race on the same expected value.",
));
}
if successes == 1 && final_val != new_value {
return Some(format!(
"AtomicCompareExchange had exactly one success but final counter = {final_val} \
instead of {new_value}. Fix: the successful compare-exchange must store the new value."
));
}
if successes == 0 && n > 0 && final_val == expected {
return Some(format!(
"AtomicCompareExchange reported zero successes even though final counter stayed at expected={expected}. \
Fix: when the initial value equals expected, exactly one racing invocation must succeed."
));
}
None
}
fn atomic_expr(op: AtomicOp, operand: u32) -> Expr {
Expr::Atomic {
op,
buffer: "output".into(),
index: Box::new(Expr::u32(0)),
expected: None,
value: Box::new(Expr::u32(operand)),
}
}
fn check_add(pre_values: &[u32], final_val: u32, n: u32) -> Option<String> {
let mut sorted = pre_values.to_vec();
sorted.sort_unstable();
if sorted.len() != n as usize {
return Some(format!(
"AtomicAdd returned {} pre-values for {n} invocations. Fix: ensure every invocation writes exactly one pre-value.",
sorted.len()
));
}
let x = sorted[0];
if sorted.windows(2).all(|w| w[1] == w[0].wrapping_add(1)) {
let expected_final = x.wrapping_add(n);
if final_val != expected_final {
return Some(format!(
"AtomicAdd final counter = {final_val} with {n} invocations, but pre-values imply initial={x} so final should be {expected_final}. Fix: atomicAdd lost or duplicated an increment."
));
}
return None;
}
let derived_initial = final_val.wrapping_sub(n);
Some(format!(
"AtomicAdd returned pre-values not a permutation of [{},{}) - final counter = {final_val}. Fix: implement sequentially-consistent atomicAdd with wrapping addition.",
derived_initial,
derived_initial.wrapping_add(n)
))
}
fn check_or(
pre_values: &[u32],
final_val: u32,
n: u32,
operand: u32,
initial: u32,
) -> Option<String> {
check_two_state(
pre_values,
final_val,
n,
initial,
|v| v | operand,
"AtomicOr",
operand,
)
}
fn check_and(
pre_values: &[u32],
final_val: u32,
n: u32,
operand: u32,
initial: u32,
) -> Option<String> {
check_two_state(
pre_values,
final_val,
n,
initial,
|v| v & operand,
"AtomicAnd",
operand,
)
}
fn check_two_state(
pre_values: &[u32],
final_val: u32,
n: u32,
initial: u32,
apply: impl Fn(u32) -> u32,
name: &str,
operand: u32,
) -> Option<String> {
let count_final = pre_values.iter().filter(|&&v| v == final_val).count();
if count_final == n as usize {
if final_val == initial && apply(final_val) == final_val {
return None;
}
} else if count_final == n as usize - 1 {
let minority: Vec<_> = pre_values.iter().filter(|&&v| v != final_val).collect();
if minority.len() == 1 && apply(*minority[0]) == final_val {
return None;
}
}
Some(format!(
"{name} final counter = {final_val} with {n} invocations (operand={operand:#x}), but returned pre-values do not match any sequential ordering. Fix: implement sequentially-consistent {name}."
))
}
fn check_xor(pre_values: &[u32], final_val: u32, n: u32, operand: u32) -> Option<String> {
let other = final_val ^ operand;
let count_final = pre_values.iter().filter(|&&v| v == final_val).count();
let count_other = pre_values.iter().filter(|&&v| v == other).count();
if operand == 0 && count_final == n as usize {
return None;
}
if count_final == n as usize / 2 && count_other == n as usize - (n as usize / 2) {
return None;
}
Some(format!(
"AtomicXor final counter = {final_val} with {n} invocations (operand={operand:#x}), but returned pre-values do not match any sequential ordering. Fix: implement sequentially-consistent atomicXor."
))
}
fn check_min(pre_values: &[u32], final_val: u32, n: u32, operand: u32) -> Option<String> {
check_absorbing(
pre_values,
final_val,
n,
operand,
std::cmp::min,
"AtomicMin",
)
}
fn check_max(pre_values: &[u32], final_val: u32, n: u32, operand: u32) -> Option<String> {
check_absorbing(
pre_values,
final_val,
n,
operand,
std::cmp::max,
"AtomicMax",
)
}
fn check_absorbing(
pre_values: &[u32],
final_val: u32,
n: u32,
operand: u32,
apply: fn(u32, u32) -> u32,
name: &str,
) -> Option<String> {
let unique: std::collections::HashSet<_> = pre_values.iter().copied().collect();
if unique.len() == 1 && apply(pre_values[0], operand) == final_val {
return None;
}
if unique.len() == 2 {
let singleton = unique
.iter()
.find(|&&candidate| pre_values.iter().filter(|&&v| v == candidate).count() == 1);
if let Some(&x) = singleton {
let y = *unique.iter().find(|&&k| k != x).unwrap_or(&x);
if y == final_val && y == apply(x, operand) {
return None;
}
}
}
Some(format!(
"{name} final counter = {final_val} with {n} invocations (operand={operand}), but returned pre-values do not match any sequential ordering. Fix: implement sequentially-consistent {name}."
))
}
fn check_exchange(pre_values: &[u32], final_val: u32, n: u32, operand: u32) -> Option<String> {
if final_val != operand {
return Some(format!(
"AtomicExchange final counter = {final_val} but operand was {operand}. Fix: atomicExchange must store the operand unconditionally."
));
}
let count_operand = pre_values.iter().filter(|&&v| v == operand).count();
if count_operand == n as usize || count_operand == n as usize - 1 {
None
} else {
Some(format!(
"AtomicExchange with {n} invocations returned pre-values inconsistent with any sequential ordering (operand={operand}). Fix: implement sequentially-consistent atomicExchange."
))
}
}
}
mod probe_glue {
use crate::enforce::enforcers::atomics::*;
pub(super) fn with_workgroup_size(program: &Program, workgroup_size: [u32; 3]) -> Program {
let mut fallback_program = program.clone();
fallback_program.set_workgroup_size(workgroup_size);
fallback_program
}
pub(super) fn dispatch_exact(
backend: &dyn vyre::VyreBackend,
program: &Program,
inputs: &[Vec<u8>],
output_size: usize,
) -> Result<Vec<u8>, vyre::BackendError> {
let program = program_with_output_size(program, output_size);
let mut outputs = backend.dispatch(&program, inputs, &vyre::DispatchConfig::default())?;
if outputs.is_empty() {
return Err(vyre::BackendError::new(
"backend returned zero output buffers. Fix: return the atomic probe output as outputs[0].",
));
}
let output = outputs.remove(0);
if output.len() != output_size {
return Err(vyre::BackendError::new(format!(
"backend returned {} bytes, expected {output_size}. Fix: size the first output buffer from the atomic probe output declaration.",
output.len()
)));
}
Ok(output)
}
fn program_with_output_size(program: &Program, output_size: usize) -> Program {
let mut buffers = program.buffers().to_vec();
for buffer in &mut buffers {
if buffer.access == vyre::ir::BufferAccess::ReadWrite {
buffer.is_output = true;
buffer.count = output_size.div_ceil(4).try_into().unwrap_or(u32::MAX);
break;
}
}
Program::new(buffers, program.workgroup_size(), program.entry().to_vec())
}
pub(super) fn atomic_initial(op: AtomicOp) -> u32 {
if matches!(op, AtomicOp::And) {
0xCDCD_CDCD
} else {
0
}
}
pub(super) fn output_len_finding(
op: AtomicOp,
n: u32,
run: u32,
actual: usize,
expected: usize,
) -> Option<AtomicFinding> {
(actual != expected).then(|| AtomicFinding {
op: op.clone(),
invocations: n,
message: format!(
"atomic {op:?} returned {actual} bytes for {n} invocations, expected exactly {expected}. \
Fix: return the final counter plus one pre-value per invocation without truncation or trailing bytes."
),
run,
})
}
pub(super) fn invalid_invocation_count_finding(op: AtomicOp, n: u32) -> Option<AtomicFinding> {
(n == 0).then(|| AtomicFinding {
op: op.clone(),
invocations: n,
message: format!(
"atomic {op:?} requested zero invocations. Fix: run atomic conformance with at least one invocation so the final state is checked."
),
run: 0,
})
}
pub(super) fn atomic_output_size(n: u32) -> Result<usize, String> {
let words = n.checked_add(1).ok_or_else(|| {
format!(
"atomic output word count overflowed for {n} invocations. Fix: reject invocation counts whose final counter plus per-invocation pre-values cannot be represented."
)
})?;
(words as usize).checked_mul(4).ok_or_else(|| {
format!(
"atomic output byte count overflowed for {n} invocations. Fix: bound atomic output buffers before dispatch."
)
})
}
}
pub struct AtomicsRaceEnforcer;
impl crate::enforce::EnforceGate for AtomicsRaceEnforcer {
fn id(&self) -> &'static str {
"atomics_race"
}
fn name(&self) -> &'static str {
"atomics_race"
}
fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let Some(backend) = ctx.backend else {
return vec![crate::enforce::aggregate_finding(
self.id(),
vec![
"atomics_race: backend is required. Fix: provide a VyreBackend in EnforceCtx."
.to_string(),
],
)];
};
let report = enforce_atomics(backend);
let messages = report
.findings
.into_iter()
.map(|finding| finding.message)
.collect::<Vec<_>>();
crate::enforce::finding_result(self.id(), messages)
}
}
pub const REGISTERED: AtomicsRaceEnforcer = AtomicsRaceEnforcer;