use crate::{spec::program::program_for_spec_input, spec::types::BufferInitPolicy, spec::OpSpec};
use vyre_reference::value::Value;
pub use vyre_reference::run;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParityCheckReport {
pub backend: String,
pub passed: bool,
pub findings: Vec<String>,
}
#[inline]
pub(crate) fn parity_check(
program: &vyre::ir::Program,
backend: &dyn vyre::VyreBackend,
) -> ParityCheckReport {
parity_check_with_inputs(program, &[], backend)
}
#[inline]
pub(crate) fn parity_check_with_inputs(
program: &vyre::ir::Program,
inputs: &[Value],
backend: &dyn vyre::VyreBackend,
) -> ParityCheckReport {
parity_check_with_config(program, inputs, backend, default_config(program))
}
#[inline]
pub(crate) fn parity_check_with_config(
program: &vyre::ir::Program,
inputs: &[Value],
backend: &dyn vyre::VyreBackend,
config: vyre::DispatchConfig,
) -> ParityCheckReport {
let reference_outputs = match run(program, inputs) {
Ok(outputs) => outputs,
Err(err) => {
return ParityCheckReport {
backend: backend.id().to_string(),
passed: false,
findings: vec![format!(
"{err}. Fix: repair the IR program before backend parity."
)],
};
}
};
let expected = flatten_outputs(&reference_outputs);
let input_buffers = inputs.iter().map(Value::to_bytes).collect::<Vec<_>>();
let dispatch_program = program_with_output_size(program, expected.len());
let backend_output = backend.dispatch(&dispatch_program, &input_buffers, &config);
match backend_output {
Ok(actual) if flatten_buffers(&actual) == expected => ParityCheckReport {
backend: backend.id().to_string(),
passed: true,
findings: Vec::new(),
},
Ok(actual) => ParityCheckReport {
backend: backend.id().to_string(),
passed: false,
findings: vec![format!(
"backend output differed from L3 reference: expected {} bytes, got {} bytes. Fix: diff backend lowering against vyre_reference::run.",
expected.len(),
flatten_buffers(&actual).len()
)],
},
Err(err) => ParityCheckReport {
backend: backend.id().to_string(),
passed: false,
findings: vec![format!("{err}. Fix: implement VyreBackend::dispatch for this IR program.")],
},
}
}
#[inline]
pub fn enforce_registry(specs: &[OpSpec]) -> Vec<String> {
let backend = ReferenceParityBackend;
let mut findings = Vec::new();
for spec in specs {
let (program, inputs) = match probe_program_for_spec(spec) {
Ok(probe) => probe,
Err(finding) => {
findings.push(format!("{}: {finding}", spec.id));
continue;
}
};
for workgroup_size in [1, 64] {
for buffer_init in [BufferInitPolicy::Zero, BufferInitPolicy::Poison] {
let mut dispatch_program = program.clone();
dispatch_program.set_workgroup_size([workgroup_size, 1, 1]);
let report = parity_check_with_config(
&dispatch_program,
&inputs,
&backend,
l3_dispatch_config(workgroup_size, buffer_init),
);
if !report.passed {
findings.extend(report.findings.into_iter().map(|finding| {
format!(
"{}: wg={} init={:?}: {}",
spec.id, workgroup_size, buffer_init, finding
)
}));
}
}
}
}
findings
}
fn default_config(program: &vyre::ir::Program) -> vyre::DispatchConfig {
l3_dispatch_config(program.workgroup_size()[0].max(1), BufferInitPolicy::Zero)
}
fn l3_dispatch_config(workgroup_size: u32, buffer_init: BufferInitPolicy) -> vyre::DispatchConfig {
let mut config = vyre::DispatchConfig::default();
config.profile = Some(format!(
"conform-l3:workgroup_size={workgroup_size};buffer_init={buffer_init:?}"
));
config
}
fn probe_program_for_spec(spec: &OpSpec) -> Result<(vyre::ir::Program, Vec<Value>), String> {
let input = vec![0; spec.signature.min_input_bytes()];
let output = (spec.cpu_fn)(&input);
let program = program_for_spec_input(spec, &input)?;
let inputs = probe_inputs_for_program(&program, &input, output.len())?;
Ok((program, inputs))
}
fn probe_inputs_for_program(
program: &vyre::ir::Program,
input: &[u8],
output_size: usize,
) -> Result<Vec<Value>, String> {
let mut values = Vec::new();
let mut consumed_primary_input = false;
for buffer in program.buffers() {
if buffer.access() == vyre::ir::BufferAccess::Workgroup {
continue;
}
let bytes = if buffer.is_output() || buffer.access() == vyre::ir::BufferAccess::ReadWrite {
zeroed_buffer(buffer, output_size)?
} else if !consumed_primary_input {
consumed_primary_input = true;
input.to_vec()
} else {
zeroed_buffer(buffer, 0)?
};
values.push(Value::Bytes(bytes));
}
Ok(values)
}
fn zeroed_buffer(buffer: &vyre::ir::BufferDecl, fallback_size: usize) -> Result<Vec<u8>, String> {
let declared = usize::try_from(buffer.count()).map_err(|_| {
format!(
"buffer `{}` declares an unrepresentable element count. Fix: reduce the probe output size.",
buffer.name()
)
})?;
let declared_size = declared.saturating_mul(element_size_bytes(buffer.element()));
Ok(vec![0; declared_size.max(fallback_size)])
}
fn element_size_bytes(data_type: vyre::ir::DataType) -> usize {
match data_type {
vyre::ir::DataType::U64 | vyre::ir::DataType::Vec2U32 => 8,
vyre::ir::DataType::Vec4U32 => 16,
vyre::ir::DataType::Bytes => 1,
_ => 4,
}
}
struct ReferenceParityBackend;
impl vyre::VyreBackend for ReferenceParityBackend {
fn id(&self) -> &'static str {
"l3-reference-parity"
}
fn dispatch(
&self,
program: &vyre::Program,
inputs: &[Vec<u8>],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
let values = inputs.iter().cloned().map(Value::Bytes).collect::<Vec<_>>();
let outputs = run(program, &values).map_err(|err| {
vyre::BackendError::new(format!(
"{err}. Fix: repair the generated L3 probe before backend parity."
))
})?;
Ok(outputs.into_iter().map(|value| value.to_bytes()).collect())
}
}
fn flatten_outputs(values: &[Value]) -> Vec<u8> {
values.iter().flat_map(Value::to_bytes).collect()
}
fn flatten_buffers(values: &[Vec<u8>]) -> Vec<u8> {
values.iter().flatten().copied().collect()
}
fn program_with_output_size(program: &vyre::Program, output_size: usize) -> vyre::Program {
let mut buffers = program.buffers().to_vec();
for buffer in &mut buffers {
if buffer.access == vyre::ir::BufferAccess::ReadWrite {
buffer.is_output = true;
let element_size = element_size_bytes(buffer.element);
buffer.count = output_size
.div_ceil(element_size)
.try_into()
.unwrap_or(u32::MAX);
break;
}
}
let mut resized =
vyre::Program::new(buffers, program.workgroup_size(), program.entry().to_vec());
resized.entry_op_id = program.entry_op_id.clone();
resized
}
pub struct Layer3ReferenceInterpEnforcer;
impl crate::enforce::EnforceGate for Layer3ReferenceInterpEnforcer {
fn id(&self) -> &'static str {
"layer3_reference_interp"
}
fn name(&self) -> &'static str {
"layer3_reference_interp"
}
fn run(&self, _ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let messages = Vec::new();
crate::enforce::finding_result(self.id(), messages)
}
}
pub const REGISTERED: Layer3ReferenceInterpEnforcer = Layer3ReferenceInterpEnforcer;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn probe_program_for_spec_executes_real_ir_body() {
let spec = crate::spec::primitive::xor::spec();
let (program, inputs) = probe_program_for_spec(&spec).expect("xor probe must build");
assert!(
program
.entry()
.iter()
.any(|node| !matches!(node, vyre::ir::Node::Return)),
"probe must contain executable IR, not only Return"
);
assert!(
program
.buffers()
.iter()
.any(|buffer| buffer.access() == vyre::ir::BufferAccess::ReadOnly),
"probe must load from an input buffer"
);
assert!(
program
.buffers()
.iter()
.any(vyre::ir::BufferDecl::is_output),
"probe must write an output buffer"
);
let outputs = run(&program, &inputs).expect("probe must execute in reference interpreter");
let actual = flatten_outputs(&outputs);
let input = vec![0; spec.signature.min_input_bytes()];
assert_eq!(actual, (spec.cpu_fn)(&input));
}
}