#![allow(
missing_docs,
dead_code,
unused_imports,
unused_variables,
unreachable_patterns,
clippy::all
)]
use crate::pipeline::execution::dispatch_config;
use crate::pipeline::execution::InputCase;
use crate::spec::program::program_for_spec_input;
use crate::spec::types::ParityFailure;
use crate::OpSpec;
use super::workgroup_config;
use super::{progress_reporting, StreamingRunner};
pub(super) struct Batch {
pub(super) cases: Vec<InputCase>,
pub(super) workgroup_size: u32,
alt_label: Option<String>,
}
impl Batch {
pub(super) fn new(
_op_id: &str,
_version: u32,
workgroup_size: u32,
alt_label: Option<&str>,
) -> Self {
Self {
cases: Vec::new(),
workgroup_size,
alt_label: alt_label.map(str::to_string),
}
}
pub(super) fn message(&self, message: String) -> String {
match &self.alt_label {
Some(label) => format!("[alt:{label}] {message}"),
None => message,
}
}
}
pub(super) fn run_op<P: crate::pipeline::streaming::ProgressSink>(
runner: &mut StreamingRunner<P>,
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
alt_label: Option<&str>,
next_test_id: &mut u64,
failures: &mut Vec<ParityFailure>,
) {
let regressions = crate::pipeline::execution::regression_inputs(op.id);
let schedule = match workgroup_config::resolve_workgroup_sizes(op) {
Ok(sizes) => sizes,
Err(message) => {
failures.push(ParityFailure::invalid_workgroup_size(op.id, message));
return;
}
};
for workgroup_size in schedule {
let mut batch = Batch::new(op.id, op.version, workgroup_size, alt_label);
for case in regressions.iter().cloned() {
if accept_case(
runner,
backend,
op,
case,
next_test_id,
&mut batch,
failures,
)
.is_err()
{
return;
}
}
let mut all_gen_cases = Vec::new();
for generator in &runner.generators {
if !generator.handles(&op.signature) {
continue;
}
let generator_name = generator.name().to_string();
generator.generate_for_op_streaming(
op.id,
&op.signature,
crate::pipeline::execution::seed_from(op.id),
&mut |label, bytes| {
all_gen_cases.push(InputCase::new(&generator_name, label, bytes));
},
);
}
for case in all_gen_cases {
if accept_case(
runner,
backend,
op,
case,
next_test_id,
&mut batch,
failures,
)
.is_err()
{
return;
}
}
flush_batch(runner, backend, op, &mut batch, failures, *next_test_id);
}
}
pub(super) fn accept_case<P: crate::pipeline::streaming::ProgressSink>(
runner: &mut StreamingRunner<P>,
_backend: &dyn vyre::VyreBackend,
op: &OpSpec,
case: InputCase,
next_test_id: &mut u64,
batch: &mut Batch,
failures: &mut Vec<ParityFailure>,
) -> Result<(), String> {
let test_id = *next_test_id;
let Some(incremented) = next_test_id.checked_add(1) else {
let message =
"test_id overflow after 2^64 tests. Fix: runner state is corrupted.".to_string();
runner.record_failure(
failures,
case.failure(
op.id,
Vec::new(),
Vec::new(),
message.clone(),
op.version,
batch.workgroup_size,
),
);
return Err(message);
};
*next_test_id = incremented;
if test_id < runner.skip_count || test_id % runner.shard_count != runner.shard_id {
return Ok(());
}
batch.cases.push(case);
if batch.cases.len() >= runner.batch_size {
flush_batch(runner, _backend, op, batch, failures, *next_test_id);
}
Ok(())
}
fn max_input_bytes(signature: &crate::spec::types::OpSignature) -> Option<usize> {
let mut sum = 0usize;
for dtype in &signature.inputs {
let bytes = match dtype {
crate::spec::types::DataType::U32
| crate::spec::types::DataType::I32
| crate::spec::types::DataType::F32 => 4,
crate::spec::types::DataType::U64 | crate::spec::types::DataType::Vec2U32 => 8,
crate::spec::types::DataType::Vec4U32 => 16,
crate::spec::types::DataType::F16 | crate::spec::types::DataType::BF16 => 2,
_ => return None,
};
if let Some(next) = sum.checked_add(bytes) {
sum = next;
} else {
return None;
}
}
Some(sum)
}
fn cpu_reference(op: &OpSpec, case: &InputCase) -> Result<Vec<u8>, String> {
let min = op.signature.min_input_bytes();
if min > 0 && case.bytes.len() < min {
return Err(format!(
"undersized input: {} bytes for {} (minimum {min}). Fix: generator produced input smaller than the op's type signature requires.",
case.bytes.len(),
op.id
));
}
if let Some(max) = max_input_bytes(&op.signature) {
if case.bytes.len() > max {
return Err(format!(
"oversized input: {} bytes for {} (maximum {max}). Fix: generator produced input larger than the op's type signature allows.",
case.bytes.len(),
op.id
));
}
}
Ok((op.cpu_fn)(&case.bytes))
}
fn flush_batch<P: crate::pipeline::streaming::ProgressSink>(
runner: &mut StreamingRunner<P>,
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
batch: &mut Batch,
failures: &mut Vec<ParityFailure>,
next_test_id: u64,
) {
if batch.cases.is_empty() {
progress_reporting::report_progress(
&mut runner.progress,
&runner.summary,
op.id,
next_test_id,
runner.progress_interval,
runner.checkpoint_interval,
runner.shard_id,
runner.shard_count,
);
return;
}
let mut valid_cases = Vec::with_capacity(batch.cases.len());
let mut inputs = Vec::with_capacity(batch.cases.len());
let mut cpu_outputs = Vec::with_capacity(batch.cases.len());
let mut output_sizes = Vec::with_capacity(batch.cases.len());
let op_version = op.version;
let drained: Vec<_> = batch.cases.drain(..).collect();
for case in drained {
match cpu_reference(op, &case) {
Ok(cpu) if cpu.len() <= runner.max_output_bytes => {
output_sizes.push(cpu.len());
inputs.push(case.bytes.clone());
cpu_outputs.push(cpu);
valid_cases.push(case);
}
Ok(cpu) => {
let message = batch.message(format!(
"cpu_fn returned {} bytes, expected <= {}. Fix: cpu_fn output exceeds per-case limit.",
cpu.len(),
runner.max_output_bytes
));
runner.record_failure(
failures,
case.failure(
op.id,
Vec::new(),
cpu,
message,
op_version,
batch.workgroup_size,
),
);
}
Err(message) => {
let message = batch.message(message);
runner.record_failure(
failures,
case.failure(
op.id,
Vec::new(),
Vec::new(),
message,
op_version,
batch.workgroup_size,
),
);
}
}
}
if valid_cases.is_empty() {
progress_reporting::report_progress(
&mut runner.progress,
&runner.summary,
op.id,
next_test_id,
runner.progress_interval,
runner.checkpoint_interval,
runner.shard_id,
runner.shard_count,
);
return;
}
let _config = dispatch_config(
op,
output_sizes.iter().copied().max().unwrap_or(0),
batch.workgroup_size,
);
for ((case, cpu), output_size) in valid_cases.into_iter().zip(cpu_outputs).zip(output_sizes) {
let first_gpu = match dispatch_case(backend, op, &case.bytes, output_size) {
Ok(output) => output,
Err(message) => {
runner.record_failure(
failures,
case.failure(
op.id,
Vec::new(),
cpu,
batch.message(message),
op.version,
batch.workgroup_size,
),
);
continue;
}
};
let second_gpu = match dispatch_case(backend, op, &case.bytes, output_size) {
Ok(output) => output,
Err(message) => {
runner.record_failure(
failures,
case.failure(
op.id,
first_gpu,
cpu,
batch.message(message),
op.version,
batch.workgroup_size,
),
);
continue;
}
};
if first_gpu != second_gpu {
runner.record_failure(
failures,
case.failure(
op.id,
second_gpu,
first_gpu,
batch.message("nondeterministic GPU output between identical streaming dispatches. Fix: remove data races, uninitialized reads, or workgroup-order dependence.".to_string()),
op.version,
batch.workgroup_size,
),
);
continue;
}
if let Err(message) = op.comparator.compare(&first_gpu, &cpu) {
runner.record_failure(
failures,
case.failure(
op.id,
first_gpu,
cpu,
batch.message(message),
op.version,
batch.workgroup_size,
),
);
} else {
runner.summary.tested = runner.summary.tested.saturating_add(1);
runner.summary.passed = runner.summary.passed.saturating_add(1);
}
}
progress_reporting::report_progress(
&mut runner.progress,
&runner.summary,
op.id,
next_test_id,
runner.progress_interval,
runner.checkpoint_interval,
runner.shard_id,
runner.shard_count,
);
}
fn dispatch_case(
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
input: &[u8],
output_size: usize,
) -> Result<Vec<u8>, String> {
let program = program_for_spec_input(op, input)?;
let mut outputs = backend
.dispatch(
&program,
&[input.to_vec()],
&vyre::DispatchConfig::default(),
)
.map_err(|error| error.message)?;
if outputs.is_empty() {
return Err("backend returned zero output buffers, expected one. Fix: return the operation result as outputs[0].".to_string());
}
let output = outputs.remove(0);
if output.len() != output_size {
return Err(format!(
"backend returned {} bytes, expected {output_size}. Fix: size outputs[0] from the vyre Program output declaration.",
output.len()
));
}
Ok(output)
}