#![allow(
private_interfaces,
missing_docs,
dead_code,
unused_imports,
unused_variables,
unreachable_patterns,
ambiguous_glob_reexports,
ambiguous_glob_imports,
clippy::all
)]
use crate::generate::generators::{default_generators, InputGenerator};
use crate::pipeline::execution::InputCase;
use crate::{OpSpec, ParityFailure};
use super::batch_execution;
use super::batch_execution::Batch;
use super::progress_reporting::{
FnProgress, NoProgress, ProgressSink, StreamingProgress, StreamingSummary,
};
use super::regression_sinking;
use super::workgroup_config;
const DEFAULT_BATCH_SIZE: usize = 1024;
const MAX_OUTPUT_BYTES_PER_CASE: usize = 16 * 1024 * 1024;
const MAX_BATCH_SIZE: usize = 1_048_576;
pub struct StreamingRunner<P: ProgressSink = NoProgress> {
pub(crate) shard_id: u64,
pub(crate) shard_count: u64,
pub(crate) skip_count: u64,
pub(crate) batch_size: usize,
pub(crate) progress_interval: u64,
pub(crate) max_output_bytes: usize,
pub(crate) checkpoint_interval: u64,
pub(crate) generators: Vec<Box<dyn InputGenerator>>,
pub(crate) progress: P,
pub(crate) summary: StreamingSummary,
pub(crate) regression_tx: Option<std::sync::mpsc::SyncSender<ParityFailure>>,
}
impl StreamingRunner<NoProgress> {
#[inline]
pub fn new() -> Self {
Self {
shard_id: 0,
shard_count: 1,
skip_count: 0,
batch_size: DEFAULT_BATCH_SIZE,
progress_interval: DEFAULT_BATCH_SIZE as u64,
max_output_bytes: MAX_OUTPUT_BYTES_PER_CASE,
checkpoint_interval: 1024,
generators: default_generators(),
progress: NoProgress,
summary: StreamingSummary::default(),
regression_tx: None,
}
}
#[inline]
pub fn on_progress<F>(self, callback: F) -> StreamingRunner<FnProgress<F>>
where
F: FnMut(StreamingProgress) + Send + 'static,
{
StreamingRunner {
shard_id: self.shard_id,
shard_count: self.shard_count,
skip_count: self.skip_count,
batch_size: self.batch_size,
progress_interval: self.progress_interval,
max_output_bytes: self.max_output_bytes,
checkpoint_interval: self.checkpoint_interval,
generators: self.generators,
progress: FnProgress(callback),
summary: self.summary,
regression_tx: None,
}
}
#[inline]
pub fn resume_from_dir(
dir: impl AsRef<std::path::Path>,
shard_id: u64,
) -> std::io::Result<Self> {
#[derive(serde::Deserialize)]
struct Checkpoint {
next_test_id: u64,
shard_id: u64,
shard_count: u64,
}
let cp: Checkpoint = serde_json::from_slice(&std::fs::read(
dir.as_ref()
.join(format!("streaming-progress-shard{shard_id}.json")),
)?)?;
Ok(Self::new()
.shard(cp.shard_id, cp.shard_count)
.skip(cp.next_test_id))
}
}
impl<P: ProgressSink> StreamingRunner<P> {
#[inline]
pub fn shard(mut self, shard_id: u64, shard_count: u64) -> Self {
self.shard_id = shard_id;
self.shard_count = shard_count;
self
}
#[inline]
pub fn skip(mut self, count: u64) -> Self {
self.skip_count = count;
self
}
#[inline]
pub fn batch_size(mut self, size: usize) -> Self {
self.batch_size = size.clamp(1, MAX_BATCH_SIZE);
self.progress_interval = self.batch_size as u64;
self
}
#[inline]
pub fn max_output_bytes(mut self, limit: usize) -> Self {
self.max_output_bytes = limit;
self
}
#[inline]
pub fn checkpoint_interval(mut self, interval: u64) -> Self {
self.checkpoint_interval = interval;
self
}
#[inline]
pub fn with_generator(mut self, generator: Box<dyn InputGenerator>) -> Self {
self.generators.push(generator);
self
}
#[inline]
pub fn run(self, backend: &dyn vyre::VyreBackend, specs: &[OpSpec]) -> Vec<ParityFailure> {
self.run_with_summary(backend, specs).1
}
#[inline]
pub(crate) fn run_with_summary(
mut self,
backend: &dyn vyre::VyreBackend,
specs: &[OpSpec],
) -> (StreamingSummary, Vec<ParityFailure>) {
if let Some(failure) = self.config_failure() {
return (self.summary.clone(), vec![failure]);
}
let (writer_tx, writer) = regression_sinking::start_regression_writer();
self.regression_tx = Some(writer_tx);
let mut failures = Vec::new();
let mut next_test_id = 0_u64;
for op in specs {
self.run_op(backend, op, None, &mut next_test_id, &mut failures);
for (label, alt_wgsl_fn) in &op.alt_wgsl_fns {
let alt = OpSpec {
wgsl_fn: *alt_wgsl_fn,
alt_wgsl_fns: Vec::new(),
version_history: op.version_history.clone(),
..op.clone()
};
self.run_op(backend, &alt, Some(label), &mut next_test_id, &mut failures);
}
}
drop(self.regression_tx.take());
let _ = writer.join();
(self.summary, failures)
}
#[inline]
pub(crate) fn run_op(
&mut self,
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
alt_label: Option<&str>,
next_test_id: &mut u64,
failures: &mut Vec<ParityFailure>,
) {
batch_execution::run_op(self, backend, op, alt_label, next_test_id, failures);
}
#[inline]
pub(crate) fn accept_case(
&mut self,
backend: &dyn vyre::VyreBackend,
op: &OpSpec,
case: InputCase,
next_test_id: &mut u64,
batch: &mut Batch,
failures: &mut Vec<ParityFailure>,
) -> Result<(), String> {
batch_execution::accept_case(self, backend, op, case, next_test_id, batch, failures)
}
#[inline]
pub(crate) fn record_batch_dispatch_failure(
&mut self,
failures: &mut Vec<ParityFailure>,
batch: &Batch,
op: &OpSpec,
cases: Vec<InputCase>,
cpus: Vec<Vec<u8>>,
message: String,
) {
for (case, cpu_output) in cases.into_iter().zip(cpus) {
self.record_failure(
failures,
case.failure(
op.id,
Vec::new(),
cpu_output,
batch.message(message.clone()),
op.version,
batch.workgroup_size,
),
);
}
}
#[inline]
pub(crate) fn record_failure(
&mut self,
failures: &mut Vec<ParityFailure>,
failure: ParityFailure,
) {
self.summary.tested = self.summary.tested.saturating_add(1);
self.summary.failed = self.summary.failed.saturating_add(1);
regression_sinking::send_or_store_failure(&self.regression_tx, &failure);
failures.push(failure);
}
#[inline]
pub(crate) fn config_failure(&self) -> Option<ParityFailure> {
workgroup_config::validate_sharding(self.shard_id, self.shard_count)
}
}