use super::ops::is_unsupported_by_backend_message;
use super::ParitySummary;
use crate::generate::generators::default_generators;
use crate::pipeline::execution::{execute_op, regression_inputs, seed_from};
use crate::pipeline::streaming::StreamingRunner;
use crate::proof::comparator::{Comparator, ExactMatch};
use crate::spec::types::{OpSpec, ParityFailure};
use crate::verify::budget::{budget_for_op, with_exec_budget, Archetype, BudgetTracker};
pub(super) fn run_parity(backend: &dyn vyre::VyreBackend, spec: &OpSpec) -> ParitySummary {
let inputs = parity_inputs(spec);
let workgroup_sizes = workgroup_sizes(spec.workgroup_size);
let mut cases_tested = 0_u64;
let mut failures = Vec::new();
let mut unsupported = false;
run_parity_for_wgsl(
backend,
spec,
&inputs,
&workgroup_sizes,
&mut cases_tested,
&mut failures,
&mut unsupported,
);
if unsupported {
return ParitySummary {
cases_tested,
failures,
unsupported,
};
}
for (_label, alt_wgsl_fn) in &spec.alt_wgsl_fns {
let alt_spec = OpSpec {
wgsl_fn: *alt_wgsl_fn,
alt_wgsl_fns: Vec::new(),
version_history: spec.version_history.clone(),
..spec.clone()
};
run_parity_for_wgsl(
backend,
&alt_spec,
&inputs,
&workgroup_sizes,
&mut cases_tested,
&mut failures,
&mut unsupported,
);
if unsupported {
break;
}
}
ParitySummary {
cases_tested,
failures,
unsupported,
}
}
pub(super) fn run_streaming_parity(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
) -> ParitySummary {
let (summary, failures) = StreamingRunner::new()
.batch_size(1024)
.run_with_summary(backend, std::slice::from_ref(spec));
let unsupported = failures
.iter()
.any(|failure| is_unsupported_by_backend_message(&failure.message));
ParitySummary {
cases_tested: summary.tested,
failures,
unsupported,
}
}
fn run_parity_for_wgsl(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
inputs: &[(String, String, Vec<u8>)],
workgroup_sizes: &[u32],
cases_tested: &mut u64,
failures: &mut Vec<ParityFailure>,
unsupported: &mut bool,
) {
let archetype = infer_archetype(spec.id);
let budget = budget_for_op(spec.id, &archetype);
let tracker = BudgetTracker::new(budget, spec.id);
let result = with_exec_budget(tracker, || {
for (generator, label, bytes) in inputs {
for workgroup_size in workgroup_sizes {
*cases_tested = cases_tested.saturating_add(2);
match execute_deterministic_case(backend, spec, bytes, *workgroup_size) {
Ok(()) => {}
Err(error) => {
if is_unsupported_by_backend_message(&error.message) {
*unsupported = true;
return;
}
failures.push(parity_failure(
spec,
generator,
label,
bytes,
error.gpu_output,
error.cpu_output,
error.message,
*workgroup_size,
));
}
}
}
}
});
if let Err(bomb) = result {
failures.push(parity_failure(
spec,
"budget",
"reference_bomb",
&[],
Vec::new(),
Vec::new(),
bomb.to_string(),
0,
));
}
}
struct ParityCaseError {
message: String,
gpu_output: Vec<u8>,
cpu_output: Vec<u8>,
}
fn execute_deterministic_case(
backend: &dyn vyre::VyreBackend,
spec: &OpSpec,
bytes: &[u8],
workgroup_size: u32,
) -> Result<(), ParityCaseError> {
let (first_gpu, first_cpu) =
execute_op(backend, spec, bytes, workgroup_size).map_err(|message| ParityCaseError {
message,
gpu_output: Vec::new(),
cpu_output: Vec::new(),
})?;
let (second_gpu, second_cpu) =
execute_op(backend, spec, bytes, workgroup_size).map_err(|message| ParityCaseError {
message,
gpu_output: first_gpu.clone(),
cpu_output: first_cpu.clone(),
})?;
if let Err(message) = ExactMatch.compare(&first_gpu, &second_gpu) {
return Err(ParityCaseError {
message: format!(
"nondeterministic GPU output between identical dispatches: {message} \
Fix: remove data races, uninitialized reads, or workgroup-order dependence."
),
gpu_output: second_gpu,
cpu_output: first_gpu,
});
}
if let Err(message) = spec.comparator.compare(&first_gpu, &first_cpu) {
return Err(ParityCaseError {
message,
gpu_output: first_gpu,
cpu_output: first_cpu,
});
}
if let Err(message) = spec.comparator.compare(&second_gpu, &second_cpu) {
return Err(ParityCaseError {
message,
gpu_output: second_gpu,
cpu_output: second_cpu,
});
}
Ok(())
}
fn parity_inputs(spec: &OpSpec) -> Vec<(String, String, Vec<u8>)> {
let mut inputs: Vec<_> = regression_inputs(spec.id)
.into_iter()
.map(|case| (case.generator, case.label, case.bytes))
.collect();
let seed = seed_from(spec.id);
for generator in default_generators() {
if generator.handles(&spec.signature) {
for (label, bytes) in generator.generate_for_op(spec.id, &spec.signature, seed) {
inputs.push((generator.name().to_string(), label, bytes));
}
}
}
inputs
}
fn parity_failure(
spec: &OpSpec,
generator: &str,
label: &str,
input: &[u8],
gpu_output: Vec<u8>,
cpu_output: Vec<u8>,
message: String,
workgroup_size: u32,
) -> ParityFailure {
ParityFailure {
op_id: spec.id.to_string(),
generator: generator.to_string(),
input_label: label.to_string(),
input: input.to_vec(),
gpu_output,
cpu_output,
message,
spec_version: spec.version,
workgroup_size,
}
}
fn workgroup_sizes(preferred: Option<u32>) -> Vec<u32> {
let mut sizes = vec![1, 64];
if let Some(size) = preferred {
if size != 0 {
sizes.push(size);
}
}
sizes.sort_unstable();
sizes.dedup();
sizes
}
fn infer_archetype(op_id: &str) -> Archetype {
if op_id.contains("hash") && op_id.contains("u32") {
Archetype("hash-bytes-to-u32")
} else if op_id.contains("hash") && op_id.contains("u64") {
Archetype("hash-bytes-to-u64")
} else if op_id.contains("decode") {
Archetype("decode-bytes-to-bytes")
} else if op_id.contains("compress") {
Archetype("compression-bytes-to-bytes")
} else {
Archetype("unknown")
}
}