use crate::generate::composers::{default_composers, CompositionStrategy};
use crate::generate::generators::{default_generators, InputGenerator};
use crate::pipeline::backend::{
ConformDispatchConfig, ExecutionModel, OneShotDispatch, WgslBackend,
};
use crate::pipeline::execution::{
execute_chain, execute_op, persist_failure, regression_inputs, seed_from, InputCase,
};
use crate::pipeline::reporter::{ConsoleReporter, Reporter};
use crate::spec::op_registry;
use crate::spec::types::{ChainSpec, OpSpec, ParityFailure};
#[cfg(loom)]
use loom::sync::Mutex;
#[cfg(not(loom))]
use std::sync::Mutex;
use crate::pipeline::{
chain_comparator, chain_version, chain_workgroup_sizes, notify, proof, workgroup_sizes,
};
use vyre::ir::{BufferAccess, DataTypeSizeBytes};
pub struct ConformanceSuite {
pub(super) generators: Vec<Box<dyn InputGenerator>>,
composers: Vec<Box<dyn CompositionStrategy>>,
pub(super) reporters: Vec<Mutex<Option<Box<dyn Reporter>>>>,
}
impl ConformanceSuite {
#[inline]
pub fn new() -> Self {
Self {
generators: default_generators(),
composers: default_composers(),
reporters: vec![Mutex::new(Some(Box::<ConsoleReporter>::default()))],
}
}
#[inline]
pub fn with_generator(mut self, generator: Box<dyn InputGenerator>) -> Self {
self.generators.push(generator);
self
}
#[inline]
pub fn with_composer(mut self, composer: Box<dyn CompositionStrategy>) -> Self {
self.composers.push(composer);
self
}
#[inline]
pub fn with_reporter(mut self, reporter: Box<dyn Reporter>) -> Self {
self.reporters.push(Mutex::new(Some(reporter)));
self
}
#[inline]
pub(crate) fn run(&self, backend: &dyn WgslBackend) -> Vec<ParityFailure> {
let mut failures = Vec::new();
for op in op_registry::all_specs() {
failures.extend(self.run_op(backend, &op, false));
}
failures
}
#[inline]
pub(crate) fn run_regressions_only(&self, backend: &dyn WgslBackend) -> Vec<ParityFailure> {
let mut failures = Vec::new();
for op in op_registry::all_specs() {
failures.extend(self.run_op(backend, &op, true));
}
failures
}
#[inline]
pub(crate) fn run_compositions(&self, backend: &dyn WgslBackend) -> Vec<ParityFailure> {
let specs = op_registry::all_specs();
let mut failures = Vec::new();
for composer in &self.composers {
for chain in composer.generate_chains(&specs) {
failures.extend(self.run_chain(backend, &chain));
}
}
failures
}
#[inline]
pub(crate) fn run_specs(
&self,
backend: &dyn WgslBackend,
specs: &[OpSpec],
) -> Vec<ParityFailure> {
let mut failures = Vec::new();
for op in specs {
failures.extend(self.run_op(backend, op, false));
}
failures
}
#[inline]
pub(crate) fn run_specs_with_compositions(
&self,
backend: &dyn WgslBackend,
specs: &[OpSpec],
) -> Vec<ParityFailure> {
let mut failures = Vec::new();
for op in specs {
failures.extend(self.run_op(backend, op, false));
}
for composer in &self.composers {
for chain in composer.generate_chains(specs) {
failures.extend(self.run_chain(backend, &chain));
}
}
failures
}
#[inline]
pub(crate) fn run_single(
&self,
backend: &dyn WgslBackend,
op_id: &str,
) -> Option<Vec<ParityFailure>> {
let specs = op_registry::all_specs();
let op = specs.iter().find(|spec| spec.id == op_id)?;
Some(self.run_op(backend, op, false))
}
#[inline]
pub(crate) fn run_filtered(
&self,
backend: &dyn WgslBackend,
prefix: &str,
) -> Vec<ParityFailure> {
let specs = op_registry::all_specs();
let mut failures = Vec::new();
for op in specs.iter().filter(|spec| spec.id.starts_with(prefix)) {
failures.extend(self.run_op(backend, op, false));
}
failures
}
#[inline]
pub fn run_nondeterminism_check(
&self,
backend: &dyn WgslBackend,
repeats: u32,
) -> Vec<ParityFailure> {
let specs = op_registry::all_specs();
let mut failures = Vec::new();
let execute_backend = ExecuteFirstBackend::new(backend);
for op in &specs {
let inputs = self.inputs_for(op, false);
let subset: Vec<_> = inputs.into_iter().take(100).collect();
let Ok(schedule) = workgroup_sizes(op.workgroup_size) else {
continue;
};
for case in &subset {
for wg in &schedule {
let Ok((first, _cpu)) = execute_op(&execute_backend, op, &case.bytes, *wg)
else {
continue;
};
for run in 1..repeats {
let Ok((current, _cpu)) =
execute_op(&execute_backend, op, &case.bytes, *wg)
else {
continue;
};
if current != first {
failures.push(case.failure(
op.id,
current,
first.clone(),
format!(
"nondeterminism detected on run {run}/{repeats}, workgroup_size={wg}. \
Fix: this op produces different GPU results on the same input — \
check for workgroup race conditions or uninitialized memory."
),
op.version,
*wg,
));
break;
}
}
}
}
}
failures
}
fn run_op(
&self,
backend: &dyn WgslBackend,
op: &OpSpec,
regressions_only: bool,
) -> Vec<ParityFailure> {
let inputs = self.inputs_for(op, regressions_only);
let workgroup_sizes = match workgroup_sizes(op.workgroup_size) {
Ok(s) => s,
Err(message) => {
return vec![ParityFailure {
op_id: op.id.to_string(),
generator: "workgroup_sizes".to_string(),
input_label: "config-error".to_string(),
input: Vec::new(),
gpu_output: Vec::new(),
cpu_output: Vec::new(),
message,
spec_version: op.version,
workgroup_size: 0,
}];
}
};
let wgsl_source_count = 1 + op.alt_wgsl_fns.len();
notify::start(
&self.reporters,
op.id,
inputs.len() * workgroup_sizes.len() * wgsl_source_count,
);
let mut failures = Vec::new();
let mut pass_count = 0;
self.run_op_wgsl(
backend,
op,
&inputs,
&workgroup_sizes,
&mut failures,
&mut pass_count,
);
for (label, alt_wgsl_fn) in &op.alt_wgsl_fns {
let alt_op = OpSpec {
wgsl_fn: *alt_wgsl_fn,
alt_wgsl_fns: Vec::new(),
version_history: op.version_history.clone(),
..op.clone()
};
let alt_inputs = self.inputs_for(&alt_op, regressions_only);
let mut alt_failures = Vec::new();
self.run_op_wgsl(
backend,
&alt_op,
&alt_inputs,
&workgroup_sizes,
&mut alt_failures,
&mut pass_count,
);
for mut f in alt_failures {
f.message = format!("[alt:{label}] {}", f.message);
failures.push(f);
}
}
notify::done(&self.reporters, op.id, pass_count, failures.len());
failures
}
fn run_op_wgsl(
&self,
backend: &dyn WgslBackend,
op: &OpSpec,
inputs: &[InputCase],
workgroup_sizes: &[u32],
failures: &mut Vec<ParityFailure>,
pass_count: &mut usize,
) {
let execute_backend = ExecuteFirstBackend::new(backend);
for case in inputs {
for workgroup_size in workgroup_sizes {
let failure_count = failures.len();
match execute_op(&execute_backend, op, &case.bytes, *workgroup_size) {
Ok((gpu, cpu)) => match op.comparator.compare(&gpu, &cpu) {
Ok(()) => {
*pass_count += 1;
notify::pass(&self.reporters, op.id, &case.report_label());
}
Err(message) => failures.push(case.failure(
op.id,
gpu,
cpu,
message,
op.version,
*workgroup_size,
)),
},
Err(message) => failures.push(case.failure(
op.id,
Vec::new(),
Vec::new(),
message,
op.version,
*workgroup_size,
)),
}
if failures.len() > failure_count {
if let Some(failure) = failures.last() {
notify::fail(&self.reporters, failure);
persist_failure(failure);
}
}
}
}
}
fn run_chain(&self, backend: &dyn WgslBackend, chain: &ChainSpec) -> Vec<ParityFailure> {
if let Err(failure) = proof::validate_chain(chain) {
return vec![*failure];
}
let inputs = self.inputs_for_chain(chain);
let workgroup_sizes = match chain_workgroup_sizes(chain) {
Ok(s) => s,
Err(message) => {
return vec![ParityFailure {
op_id: chain.id.clone(),
generator: "chain_workgroup_sizes".to_string(),
input_label: "config-error".to_string(),
input: Vec::new(),
gpu_output: Vec::new(),
cpu_output: Vec::new(),
message,
spec_version: chain_version(chain),
workgroup_size: 0,
}];
}
};
let execute_backend = ExecuteFirstBackend::new(backend);
notify::start(
&self.reporters,
&chain.id,
inputs.len() * workgroup_sizes.len(),
);
let mut failures = Vec::new();
let mut pass_count = 0;
let comparator = chain_comparator(chain);
for case in inputs {
for workgroup_size in &workgroup_sizes {
let failure_count = failures.len();
match execute_chain(&execute_backend, chain, &case.bytes, *workgroup_size) {
Ok((gpu, cpu)) => match comparator.compare(&gpu, &cpu) {
Ok(()) => {
pass_count += 1;
notify::pass(&self.reporters, &chain.id, &case.report_label());
}
Err(message) => failures.push(case.failure(
&chain.id,
gpu,
cpu,
message,
chain_version(chain),
*workgroup_size,
)),
},
Err(message) => failures.push(case.failure(
&chain.id,
Vec::new(),
Vec::new(),
message,
chain_version(chain),
*workgroup_size,
)),
}
if failures.len() > failure_count {
if let Some(failure) = failures.last() {
notify::fail(&self.reporters, failure);
persist_failure(failure);
}
}
}
}
notify::done(&self.reporters, &chain.id, pass_count, failures.len());
failures
}
fn inputs_for(&self, op: &OpSpec, regressions_only: bool) -> Vec<InputCase> {
let mut inputs = regression_inputs(op.id);
if regressions_only {
return inputs;
}
let seed = seed_from(op.id);
for generator in &self.generators {
if generator.handles(&op.signature) {
for (label, bytes) in generator.generate_for_op(op.id, &op.signature, seed) {
inputs.push(InputCase::new(generator.name(), label, bytes));
}
}
}
inputs
}
fn inputs_for_chain(&self, chain: &ChainSpec) -> Vec<InputCase> {
let seed = seed_from(&chain.id);
self.generators
.iter()
.filter(|gen| gen.handles(&chain.signature))
.flat_map(|gen| {
gen.generate_for_op(&chain.id, &chain.signature, seed)
.into_iter()
.map(|(label, bytes)| InputCase::new(gen.name(), label, bytes))
})
.collect()
}
}
struct ExecuteFirstBackend<'a> {
inner: &'a dyn WgslBackend,
}
impl<'a> ExecuteFirstBackend<'a> {
fn new(inner: &'a dyn WgslBackend) -> Self {
Self { inner }
}
}
impl WgslBackend for ExecuteFirstBackend<'_> {
fn name(&self) -> &str {
self.inner.name()
}
fn version(&self) -> &str {
self.inner.version()
}
fn dispatch(
&self,
wgsl: &str,
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
self.inner
.execute(&ExecutionModel::OneShot(OneShotDispatch {
wgsl: wgsl.to_string(),
input: input.to_vec(),
output_size,
config,
}))
}
fn supported_models(&self) -> &[crate::pipeline::backend::ExecutionModelKind] {
self.inner.supported_models()
}
fn max_convention(&self) -> crate::spec::types::Convention {
self.inner.max_convention()
}
}
impl vyre::VyreBackend for ExecuteFirstBackend<'_> {
fn id(&self) -> &'static str {
"wgsl-execute-first-adapter"
}
fn dispatch(
&self,
program: &vyre::Program,
inputs: &[Vec<u8>],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
let program_bytes = program.to_wire().map_err(|err| {
vyre::BackendError::new(format!(
"{err}. Fix: dispatch a wire-encodable vyre Program."
))
})?;
let output_size = program_output_size(program)?;
let input = inputs.first().cloned().unwrap_or_default();
let workgroup_size = program.workgroup_size()[0].max(1);
let output_words = output_size.div_ceil(4).max(1);
let workgroup_count = u32::try_from(output_words.div_ceil(workgroup_size as usize))
.unwrap_or(u32::MAX)
.max(1);
let config = ConformDispatchConfig {
workgroup_size,
workgroup_count,
convention: crate::spec::types::Convention::V1,
lookup_data: None,
buffer_init: crate::spec::types::BufferInitPolicy::default(),
};
self.dispatch_program(&program_bytes, &input, output_size, config)
.map(|output| vec![output])
.map_err(vyre::BackendError::new)
}
}
fn program_output_size(program: &vyre::Program) -> Result<usize, vyre::BackendError> {
let mut total = 0usize;
for buffer in program.buffers() {
if buffer.access() == BufferAccess::ReadWrite || buffer.is_output() {
let element_size = buffer.element().size_bytes().max(1);
let count = usize::try_from(buffer.count).map_err(|err| {
vyre::BackendError::new(format!(
"output buffer `{}` count {} cannot fit usize: {err}. Fix: reduce the output buffer count.",
buffer.name, buffer.count
))
})?;
total = total
.checked_add(element_size.checked_mul(count).ok_or_else(|| {
vyre::BackendError::new(format!(
"output buffer `{}` byte size overflowed. Fix: reduce count or element width.",
buffer.name
))
})?)
.ok_or_else(|| {
vyre::BackendError::new(
"total output byte size overflowed. Fix: reduce output buffer declarations.",
)
})?;
}
}
if total == 0 {
return Err(vyre::BackendError::new(
"program declares no output buffer. Fix: mark one read_write buffer as output.",
));
}
Ok(total)
}
impl Default for ConformanceSuite {
fn default() -> Self {
Self::new()
}
}