use vyre::VyreBackend;
use crate::spec::program::program_for_spec_input;
use crate::spec::types::OpSpec;
const MAX_INFERRED_OUTPUT_SIZE: usize = 64 * 1024 * 1024;
pub const MAX_RACE_REPEATS: u32 = 10_000;
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum RaceError {
#[error("RaceZeroRepeats: repeats is zero. Fix: run at least one repeat.")]
ZeroRepeats,
#[error(
"RaceRepeatsTooLarge: repeats {actual} exceeds limit {max}. Fix: run a bounded race probe."
)]
RepeatsTooLarge {
actual: u32,
max: u32,
},
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct RaceReport {
pub has_race: bool,
pub findings: Vec<RaceFinding>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct RaceFinding {
pub op_id: String,
pub first_bytes: Vec<u8>,
pub divergent_bytes: Vec<Vec<usize>>,
pub byte_heatmap: Vec<u32>,
pub messages: Vec<String>,
}
impl RaceReport {
#[must_use]
#[inline]
pub fn empty() -> Self {
Self {
has_race: false,
findings: Vec::new(),
}
}
}
#[inline]
pub fn detect_race(
backend: &dyn VyreBackend,
program: &vyre::ir::Program,
input: &[u8],
workgroup_count: u32,
repeats: u32,
) -> Result<RaceReport, RaceError> {
let mut finding = RaceFinding {
op_id: "ir-program".to_string(),
first_bytes: Vec::new(),
divergent_bytes: Vec::new(),
byte_heatmap: Vec::new(),
messages: Vec::new(),
};
if repeats == 0 {
return Err(RaceError::ZeroRepeats);
}
if repeats > MAX_RACE_REPEATS {
return Err(RaceError::RepeatsTooLarge {
actual: repeats,
max: MAX_RACE_REPEATS,
});
}
let workgroup_size = program.workgroup_size();
let invocations = (workgroup_count as usize)
.saturating_mul(workgroup_size[0] as usize)
.saturating_mul(workgroup_size[1] as usize)
.saturating_mul(workgroup_size[2] as usize);
let output_size = match inferred_output_size(invocations) {
Ok(size) => size,
Err(message) => {
finding.messages.push(message);
return Ok(RaceReport {
has_race: true,
findings: vec![finding],
});
}
};
let baseline = match dispatch_exact(backend, program, &[input.to_vec()], output_size) {
Ok(bytes) => bytes,
Err(err) => {
finding.messages.push(format!(
"dispatch failed on {}: {err}. Fix: check that the IR program lowers to valid backend code.",
backend.id()
));
return Ok(RaceReport {
has_race: true,
findings: vec![finding],
});
}
};
finding.first_bytes = baseline.clone();
finding.byte_heatmap = vec![0u32; baseline.len()];
let mut has_race = false;
for run in 0..repeats {
match dispatch_exact(backend, program, &[input.to_vec()], output_size) {
Ok(output) => {
let divergent = divergent_offsets(&baseline, &output);
if !divergent.is_empty() {
has_race = true;
resize_heatmap_for_output(&mut finding.byte_heatmap, &baseline, &output);
for &offset in &divergent {
finding.byte_heatmap[offset] =
finding.byte_heatmap[offset].saturating_add(1);
}
finding.messages.push(format!(
"Fix: byte {} unstable across runs. Either the program has a data race (add barrier or use atomic) or the backend has a nondeterminism bug.",
divergent.iter().map(|o| o.to_string()).collect::<Vec<_>>().join(", ")
));
}
finding.divergent_bytes.push(divergent);
}
Err(err) => {
has_race = true;
finding.messages.push(format!(
"dispatch failed on {} at run {run}: {err}. Fix: check backend stability.",
backend.id()
));
finding.divergent_bytes.push(Vec::new());
}
}
}
Ok(RaceReport {
has_race,
findings: vec![finding],
})
}
#[inline]
pub(crate) fn verify_no_race(backend: &dyn VyreBackend, specs: &[OpSpec]) -> RaceReport {
let mut report = RaceReport::empty();
for spec in specs {
let input = vec![0u8; spec.signature.min_input_bytes()];
let cpu = (spec.cpu_fn)(&input);
let output_size = cpu.len();
let workgroup_size = spec.workgroup_size.unwrap_or(64);
let mut program = match program_for_spec_input(spec, &input) {
Ok(program) => program,
Err(err) => {
report.findings.push(RaceFinding {
op_id: spec.id.to_string(),
first_bytes: Vec::new(),
divergent_bytes: Vec::new(),
byte_heatmap: Vec::new(),
messages: vec![err],
});
report.has_race = true;
continue;
}
};
program.set_workgroup_size([workgroup_size, 1, 1]);
let baseline = match dispatch_exact(backend, &program, &[input.clone()], output_size) {
Ok(bytes) => bytes,
Err(err) => {
report.findings.push(RaceFinding {
op_id: spec.id.to_string(),
first_bytes: Vec::new(),
divergent_bytes: Vec::new(),
byte_heatmap: Vec::new(),
messages: vec![format!(
"backend dispatch failed on {} for {}: {err}. Fix: execute the canonical vyre IR and return {output_size} bytes.",
backend.id(),
spec.id,
)],
});
report.has_race = true;
continue;
}
};
let mut finding = RaceFinding {
op_id: spec.id.to_string(),
first_bytes: baseline.clone(),
divergent_bytes: Vec::new(),
byte_heatmap: vec![0u32; baseline.len()],
messages: Vec::new(),
};
let mut spec_has_race = false;
for run in 0..100 {
match dispatch_exact(backend, &program, &[input.clone()], output_size) {
Ok(output) => {
let divergent = divergent_offsets(&baseline, &output);
if !divergent.is_empty() {
spec_has_race = true;
resize_heatmap_for_output(&mut finding.byte_heatmap, &baseline, &output);
for &offset in &divergent {
finding.byte_heatmap[offset] =
finding.byte_heatmap[offset].saturating_add(1);
}
finding.messages.push(format!(
"Fix: byte {} unstable across runs in {} at run {run}. Either the program has a data race (add barrier or use atomic) or the backend has a nondeterminism bug.",
divergent.iter().map(|o| o.to_string()).collect::<Vec<_>>().join(", "),
spec.id,
));
}
finding.divergent_bytes.push(divergent);
}
Err(err) => {
spec_has_race = true;
finding.messages.push(format!(
"backend dispatch failed on {} for {} at run {run}: {err}. Fix: check backend stability.",
backend.id(),
spec.id,
));
finding.divergent_bytes.push(Vec::new());
}
}
}
if spec_has_race {
report.has_race = true;
report.findings.push(finding);
}
}
report
}
fn divergent_offsets(a: &[u8], b: &[u8]) -> Vec<usize> {
let len = a.len().min(b.len());
let mut diffs = Vec::new();
for i in 0..len {
if a[i] != b[i] {
diffs.push(i);
}
}
if a.len() != b.len() {
diffs.push(len);
}
diffs
}
fn resize_heatmap_for_output(heatmap: &mut Vec<u32>, baseline: &[u8], output: &[u8]) {
let required = baseline.len().max(output.len());
if heatmap.len() < required {
heatmap.resize(required, 0);
}
}
fn inferred_output_size(invocations: usize) -> Result<usize, String> {
let output_size = invocations.saturating_mul(4).max(4);
if output_size > MAX_INFERRED_OUTPUT_SIZE {
return Err(format!(
"inferred data-race output size {output_size} bytes exceeds the {MAX_INFERRED_OUTPUT_SIZE}-byte conformance cap. Fix: bound workgroup dimensions or provide a smaller race probe."
));
}
Ok(output_size)
}
fn dispatch_exact(
backend: &dyn VyreBackend,
program: &vyre::Program,
inputs: &[Vec<u8>],
output_size: usize,
) -> Result<Vec<u8>, vyre::BackendError> {
let program = program_with_output_size(program, output_size);
let mut outputs = backend.dispatch(&program, inputs, &vyre::DispatchConfig::default())?;
if outputs.is_empty() {
return Err(vyre::BackendError::new(
"backend returned zero output buffers. Fix: return the race probe output as outputs[0].",
));
}
let output = outputs.remove(0);
if output.len() != output_size {
return Err(vyre::BackendError::new(format!(
"backend returned {} bytes, expected {output_size}. Fix: size the first output buffer from the race probe output declaration.",
output.len()
)));
}
Ok(output)
}
fn program_with_output_size(program: &vyre::Program, output_size: usize) -> vyre::Program {
let mut buffers = program.buffers().to_vec();
for buffer in &mut buffers {
if buffer.access == vyre::ir::BufferAccess::ReadWrite {
buffer.is_output = true;
buffer.count = output_size.div_ceil(4).try_into().unwrap_or(u32::MAX);
break;
}
}
vyre::Program::new(buffers, program.workgroup_size(), program.entry().to_vec())
}
pub struct DataRaceEnforcer;
impl crate::enforce::EnforceGate for DataRaceEnforcer {
fn id(&self) -> &'static str {
"data_race"
}
fn name(&self) -> &'static str {
"data_race"
}
fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let Some(backend) = ctx.backend else {
return vec![crate::enforce::aggregate_finding(
self.id(),
vec![
"data_race: backend is required. Fix: provide a VyreBackend in EnforceCtx."
.to_string(),
],
)];
};
let report = verify_no_race(backend, ctx.specs);
let messages = report
.findings
.into_iter()
.flat_map(|finding| finding.messages)
.collect::<Vec<_>>();
crate::enforce::finding_result(self.id(), messages)
}
}
pub const REGISTERED: DataRaceEnforcer = DataRaceEnforcer;
#[cfg(test)]
mod tests {
use super::*;
struct DeterministicBackend {
output: Vec<u8>,
}
impl vyre::VyreBackend for DeterministicBackend {
fn id(&self) -> &'static str {
"deterministic-mock"
}
fn dispatch(
&self,
_program: &vyre::Program,
_inputs: &[Vec<u8>],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
Ok(vec![self.output.clone()])
}
}
struct FlakyBackend {
baseline: Vec<u8>,
flake_run: u32,
flake_output: Vec<u8>,
calls: std::sync::atomic::AtomicU32,
}
impl vyre::VyreBackend for FlakyBackend {
fn id(&self) -> &'static str {
"flaky-mock"
}
fn dispatch(
&self,
_program: &vyre::Program,
_inputs: &[Vec<u8>],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
let run = self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if run > self.flake_run {
Ok(vec![self.flake_output.clone()])
} else {
Ok(vec![self.baseline.clone()])
}
}
}
#[test]
fn deterministic_passes() {
let backend = DeterministicBackend {
output: vec![0xAB, 0xCD, 0xEF, 0x00],
};
let op = crate::spec::primitive::xor::spec();
let report = verify_no_race(&backend, &[op]);
assert!(
!report.has_race,
"expected no race, got: {:?}",
report.findings
);
}
#[test]
fn flaky_backend_flags_race() {
let backend = FlakyBackend {
baseline: vec![0x00, 0x00, 0x00, 0x00],
flake_run: 3,
flake_output: vec![0xFF, 0x00, 0xFF, 0x00],
calls: std::sync::atomic::AtomicU32::new(0),
};
let op = crate::spec::primitive::xor::spec();
let report = verify_no_race(&backend, &[op]);
assert!(report.has_race, "expected race detection");
let finding = &report.findings[0];
assert_eq!(finding.first_bytes, vec![0x00, 0x00, 0x00, 0x00]);
assert_eq!(finding.byte_heatmap, vec![97, 0, 97, 0]);
}
}