use crate::pipeline::backend::ConformDispatchConfig;
use crate::spec::program::program_for_spec_input;
use crate::spec::types::{ChainSpec, OpSpec, ParityFailure};
use crate::verify::regression;
#[derive(Clone)]
pub struct InputCase {
pub generator: String,
pub label: String,
pub bytes: Vec<u8>,
}
impl InputCase {
#[inline]
pub fn new(generator: &str, label: String, bytes: Vec<u8>) -> Self {
Self {
generator: generator.to_string(),
label,
bytes,
}
}
#[inline]
pub fn failure(
&self,
op_id: &str,
gpu: Vec<u8>,
cpu: Vec<u8>,
message: String,
spec_version: u32,
workgroup_size: u32,
) -> ParityFailure {
ParityFailure {
op_id: op_id.to_string(),
generator: self.generator.clone(),
input_label: self.label.clone(),
input: self.bytes.clone(),
gpu_output: gpu,
cpu_output: cpu,
message,
spec_version,
workgroup_size,
}
}
#[inline]
pub fn report_label(&self) -> String {
format!("{}/{}", self.generator, self.label)
}
}
#[inline]
pub(crate) fn execute_op(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
input: &[u8],
workgroup_size: u32,
) -> Result<(Vec<u8>, Vec<u8>), String> {
validate_workgroup_size(workgroup_size)?;
let min_bytes = op.signature.min_input_bytes();
if min_bytes > 0 && input.len() < min_bytes {
return Err(format!(
"undersized input: {} bytes for {} (minimum {min_bytes}). \
Fix: generator produced input smaller than the op's type signature requires.",
input.len(),
op.id,
));
}
let cpu = {
let start = std::time::Instant::now();
let result = (op.cpu_fn)(input);
let elapsed = start.elapsed();
if let Err(bomb) = crate::verify::budget::exec_budget_record(elapsed) {
return Err(format!("{bomb}"));
}
result
};
let config = checked_dispatch_config(op, cpu.len(), workgroup_size)?;
let gpu =
dispatch_backend(backend, op, input, cpu.len(), config).map_err(|err| err.to_string())?;
Ok((gpu, cpu))
}
#[inline]
pub(crate) fn execute_chain(
backend: &dyn vyre::VyreBackend,
chain: &ChainSpec,
input: &[u8],
workgroup_size: u32,
) -> Result<(Vec<u8>, Vec<u8>), String> {
validate_workgroup_size(workgroup_size)?;
let mut cpu_current = input.to_vec();
let mut gpu_current = input.to_vec();
for spec in &chain.specs {
let cpu_next = (spec.cpu_fn)(&cpu_current);
let config = checked_dispatch_config(spec, cpu_next.len(), workgroup_size)?;
gpu_current = dispatch_backend(backend, spec, &gpu_current, cpu_next.len(), config)
.map_err(|err| {
format!(
"backend dispatch failed in chain {} at {}: {err}. Fix: make every chained op accept the previous output bytes.",
chain.id, spec.id
)
})?;
cpu_current = cpu_next;
}
let cpu_final = if let Some(cpu_chain_fn) = chain.cpu_chain {
cpu_chain_fn(input)
} else {
cpu_current
};
Ok((gpu_current, cpu_final))
}
#[inline]
pub(crate) fn regression_inputs(op_id: &str) -> Vec<InputCase> {
regression::load(op_id)
.into_iter()
.map(|(label, bytes)| InputCase::new("regression", label, bytes))
.collect()
}
#[inline]
pub(crate) fn persist_failure(failure: &ParityFailure) {
if let Err(err) = regression::save(failure) {
eprintln!(
"vyre-conform: could not persist regression for {}: {err}. Fix: ensure regressions/ is writable.",
failure.op_id
);
}
}
#[inline]
pub(crate) fn seed_from(text: &str) -> u64 {
let mut hash = 0xcbf2_9ce4_8422_2325_u64;
for byte in text.as_bytes() {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(0x0000_0100_0000_01B3);
}
hash
}
#[inline]
pub(crate) fn dispatch_config(
op: &OpSpec,
output_size: usize,
workgroup_size: u32,
) -> ConformDispatchConfig {
let workgroup_size_usize = usize::try_from(workgroup_size)
.ok()
.filter(|size| *size > 0)
.unwrap_or(1);
let output_words = output_size.div_ceil(4).max(1);
let workgroup_count = output_words
.div_ceil(workgroup_size_usize)
.try_into()
.unwrap_or(u32::MAX);
ConformDispatchConfig {
workgroup_size,
workgroup_count,
convention: op.convention,
lookup_data: None,
buffer_init: crate::spec::types::BufferInitPolicy::default(),
}
}
fn checked_dispatch_config(
op: &OpSpec,
output_size: usize,
workgroup_size: u32,
) -> Result<ConformDispatchConfig, String> {
validate_workgroup_size(workgroup_size)?;
let output_words = output_size.div_ceil(4).max(1);
let workgroup_count = output_words.div_ceil(workgroup_size as usize);
let workgroup_count = u32::try_from(workgroup_count).map_err(|_| {
format!(
"dispatch workgroup_count overflow: output_size={output_size}, output_words={output_words}, workgroup_size={workgroup_size}. Fix: split the output into multiple dispatches or reduce the requested output size."
)
})?;
Ok(ConformDispatchConfig {
workgroup_size,
workgroup_count,
convention: op.convention,
lookup_data: None,
buffer_init: crate::spec::types::BufferInitPolicy::default(),
})
}
fn validate_workgroup_size(workgroup_size: u32) -> Result<(), String> {
if workgroup_size == 0 {
return Err(
"invalid workgroup_size=0. Fix: configure at least one worker per workgroup."
.to_string(),
);
}
Ok(())
}
#[derive(Debug)]
enum ExecutionError {
BackendDispatch {
backend: String,
output_size: usize,
workgroup_size: u32,
source: String,
},
}
impl std::fmt::Display for ExecutionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::BackendDispatch {
backend,
output_size,
workgroup_size,
source,
} => write!(
f,
"backend dispatch failed on {backend} with workgroup_size={workgroup_size}: {source}. Fix: execute the canonical vyre IR program and return {output_size} bytes."
),
}
}
}
impl std::error::Error for ExecutionError {}
fn dispatch_backend(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
) -> Result<Vec<u8>, ExecutionError> {
let program =
program_for_spec_input(op, input).map_err(|source| ExecutionError::BackendDispatch {
backend: backend.id().to_string(),
output_size,
workgroup_size: config.workgroup_size,
source,
})?;
backend
.dispatch(&program, &[input.to_vec()], &config.to_core())
.map_err(|error| error.message)
.and_then(|mut outputs| {
if outputs.is_empty() {
return Err(format!(
"backend returned zero output buffers, expected one. Fix: return the operation result as outputs[0]."
));
}
let output = outputs.remove(0);
if output.len() != output_size {
return Err(format!(
"backend returned {} bytes, expected {output_size}. Fix: size the first output buffer from the program output declaration.",
output.len()
));
}
Ok(output)
})
.map_err(|source| ExecutionError::BackendDispatch {
backend: backend.id().to_string(),
output_size,
workgroup_size: config.workgroup_size,
source,
})
}
#[cfg(test)]
mod tests {
use super::{checked_dispatch_config, dispatch_config, seed_from, InputCase};
#[test]
fn input_case_report_label() {
let case = InputCase::new("random", "case_42".into(), vec![0xDE, 0xAD]);
assert_eq!(case.report_label(), "random/case_42");
}
#[test]
fn input_case_failure_preserves_fields() {
let case = InputCase::new("edge", "max_val".into(), vec![0xFF; 4]);
let f = case.failure("test.op", vec![0x00], vec![0xFF], "mismatch".into(), 2, 64);
assert_eq!(f.op_id, "test.op");
assert_eq!(f.generator, "edge");
assert_eq!(f.input_label, "max_val");
assert_eq!(f.input, vec![0xFF; 4]);
assert_eq!(f.gpu_output, vec![0x00]);
assert_eq!(f.cpu_output, vec![0xFF]);
assert_eq!(f.spec_version, 2);
assert_eq!(f.workgroup_size, 64);
}
#[test]
fn seed_from_is_deterministic() {
let a = seed_from("primitive.bitwise.xor");
let b = seed_from("primitive.bitwise.xor");
assert_eq!(a, b);
}
#[test]
fn seed_from_differs_for_different_ops() {
let a = seed_from("primitive.bitwise.xor");
let b = seed_from("primitive.bitwise.and");
assert_ne!(a, b);
}
#[test]
fn seed_from_empty_string() {
let s = seed_from("");
assert_ne!(s, 0);
}
#[test]
fn dispatch_config_single_word() {
let op = crate::spec::primitive::xor::spec();
let config = dispatch_config(&op, 4, 1);
assert_eq!(config.workgroup_size, 1);
assert_eq!(config.workgroup_count, 1);
}
#[test]
fn dispatch_config_multi_word() {
let op = crate::spec::primitive::xor::spec();
let config = dispatch_config(&op, 256, 64);
assert_eq!(config.workgroup_count, 1);
}
#[test]
fn dispatch_config_zero_output_clamps_to_one() {
let op = crate::spec::primitive::xor::spec();
let config = dispatch_config(&op, 0, 1);
assert_eq!(config.workgroup_count, 1);
}
#[test]
fn checked_dispatch_config_rejects_zero_workgroup_size() {
let op = crate::spec::primitive::xor::spec();
let err = checked_dispatch_config(&op, 4, 0).unwrap_err();
assert!(err.contains("workgroup_size=0"));
assert!(err.contains("Fix:"));
}
#[test]
fn checked_dispatch_config_rejects_workgroup_count_overflow() -> Result<(), String> {
let op = crate::spec::primitive::xor::spec();
let output_size = (u32::MAX as usize)
.checked_add(1)
.and_then(|words| words.checked_mul(4))
.ok_or_else(|| {
"Fix: overflow fixture must fit usize on this target before dispatch validation"
.to_string()
})?;
let err = checked_dispatch_config(&op, output_size, 1).unwrap_err();
assert!(err.contains("workgroup_count overflow"));
assert!(err.contains("Fix:"));
Ok(())
}
}