use crate::pipeline::execution::InputCase;
use crate::spec::program::program_for_spec_input;
use crate::spec::OpSpec;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InfrastructureError {
CompileFailure,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum DispatchError {
Infra(InfrastructureError),
Other(String),
}
impl DispatchError {
fn into_bytes(self) -> Vec<u8> {
match self {
DispatchError::Infra(err) => format!("{err:?}").into_bytes(),
DispatchError::Other(msg) => msg.into_bytes(),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct DeterminismReport {
pub op_id: String,
pub divergences: Vec<Divergence>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Divergence {
pub input_label: String,
pub wg_a: u32,
pub wg_b: u32,
pub bytes_a: Vec<u8>,
pub bytes_b: Vec<u8>,
pub run: Option<u32>,
pub message: String,
}
#[inline]
pub fn enforce_determinism(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
inputs: &[InputCase],
repeats: u32,
) -> DeterminismReport {
let mut divergences = Vec::new();
if repeats < 10 {
divergences.push(Divergence {
input_label: "configuration/repeats".to_string(),
wg_a: 1,
wg_b: 1,
bytes_a: repeats.to_le_bytes().to_vec(),
bytes_b: 10u32.to_le_bytes().to_vec(),
run: None,
message: format!(
"Fix: determinism requires at least 10 repeats; caller requested {repeats}."
),
});
}
let repeats = repeats.max(10);
let workgroup_sizes = determinism_workgroup_sizes(op.workgroup_size);
for case in inputs {
let baseline = match dispatch_op(backend, op, &case.bytes, 1) {
Ok(bytes) => bytes,
Err(err) => {
divergences.push(Divergence {
input_label: case.report_label(),
wg_a: 1,
wg_b: 1,
bytes_a: Vec::new(),
bytes_b: err.into_bytes(),
run: None,
message: "Fix: backend must compile and run the canonical workgroup_size=1 baseline before determinism can be claimed.".to_string(),
});
continue;
}
};
for wg in workgroup_sizes.iter().copied().filter(|&wg| wg != 1) {
for run in 0..repeats {
match dispatch_op(backend, op, &case.bytes, wg) {
Ok(output) => {
if output != baseline {
divergences.push(Divergence {
input_label: case.report_label(),
wg_a: 1,
wg_b: wg,
bytes_a: baseline.clone(),
bytes_b: output,
run: Some(run),
message: String::new(),
});
break; }
}
Err(DispatchError::Infra(InfrastructureError::CompileFailure)) => {
break;
}
Err(err) => {
divergences.push(Divergence {
input_label: case.report_label(),
wg_a: 1,
wg_b: wg,
bytes_a: baseline.clone(),
bytes_b: err.into_bytes(),
run: Some(run),
message: format!(
"Fix: backend dispatch failed at workgroup_size={wg} after the baseline succeeded; unsupported sizes must be reported explicitly or constrained by the op spec."
),
});
break;
}
}
}
}
}
DeterminismReport {
op_id: op.id.to_string(),
divergences,
}
}
fn dispatch_op(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
input: &[u8],
workgroup_size: u32,
) -> Result<Vec<u8>, DispatchError> {
let min_bytes = op.signature.min_input_bytes();
if min_bytes > 0 && input.len() < min_bytes {
return Err(DispatchError::Other(format!(
"undersized input: {} bytes for {} (minimum {min_bytes}). \
Fix: generator produced input smaller than the op's type signature requires.",
input.len(),
op.id,
)));
}
let cpu = (op.cpu_fn)(input);
let mut program = program_for_spec_input(op, input).map_err(DispatchError::Other)?;
program.set_workgroup_size([workgroup_size, 1, 1]);
backend
.dispatch(
&program,
&[input.to_vec()],
&vyre::DispatchConfig::default(),
)
.map_err(|err| {
if workgroup_size > 1 {
DispatchError::Infra(InfrastructureError::CompileFailure)
} else {
DispatchError::Other(format!(
"backend dispatch failed on {} with workgroup_size={workgroup_size}: {err}. \
Fix: execute the canonical vyre IR program and return {} bytes.",
backend.id(),
cpu.len()
))
}
})
.and_then(|mut outputs| {
if outputs.is_empty() {
return Err(DispatchError::Other(
"backend returned zero output buffers. Fix: return the operation result as outputs[0]."
.to_string(),
));
}
let output = outputs.remove(0);
if output.len() != cpu.len() {
return Err(DispatchError::Other(format!(
"backend returned {} bytes, expected {}. Fix: size the first output buffer from the program output declaration.",
output.len(),
cpu.len()
)));
}
Ok(output)
})
}
fn determinism_workgroup_sizes(preferred: Option<u32>) -> Vec<u32> {
let sizes = vec![1, 8, 32, 64, 128, 256, 1024];
match preferred {
Some(max) if max > 0 => sizes.into_iter().filter(|&s| s <= max).collect(),
_ => sizes,
}
}
pub struct DeterminismEnforcer;
impl crate::enforce::EnforceGate for DeterminismEnforcer {
fn id(&self) -> &'static str {
"determinism"
}
fn name(&self) -> &'static str {
"determinism"
}
fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let Some(backend) = ctx.backend else {
return vec![crate::enforce::aggregate_finding(
self.id(),
vec![
"determinism: backend is required. Fix: provide a VyreBackend in EnforceCtx."
.to_string(),
],
)];
};
let mut messages = Vec::new();
for spec in ctx.specs {
let input_len = spec.signature.min_input_bytes().max(4);
let input = crate::pipeline::execution::InputCase::new(
"registry",
"zero".to_string(),
vec![0; input_len],
);
messages.extend(enforce_determinism(backend, spec, &[input], 10).divergences.into_iter().map(|d| if d.message.is_empty() { format!("determinism({}): input={} wg_a={} wg_b={}. Fix: make dispatch output byte-identical.", spec.id, d.input_label, d.wg_a, d.wg_b) } else { d.message }));
}
crate::enforce::finding_result(self.id(), messages)
}
}
pub const REGISTERED: DeterminismEnforcer = DeterminismEnforcer;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn determinism_workgroup_sizes_unbounded() {
let sizes = determinism_workgroup_sizes(None);
assert_eq!(sizes, vec![1, 8, 32, 64, 128, 256, 1024]);
}
#[test]
fn determinism_workgroup_sizes_clamped() {
let sizes = determinism_workgroup_sizes(Some(64));
assert_eq!(sizes, vec![1, 8, 32, 64]);
}
#[test]
fn determinism_workgroup_sizes_zero_max_is_unbounded() {
let sizes = determinism_workgroup_sizes(Some(0));
assert_eq!(sizes, vec![1, 8, 32, 64, 128, 256, 1024]);
}
}