vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
#![allow(
    private_interfaces,
    missing_docs,
    dead_code,
    unused_imports,
    unused_variables,
    unreachable_patterns,
    ambiguous_glob_reexports,
    ambiguous_glob_imports,
    clippy::all
)]
//! Top-level streaming conformance runner orchestration.

use crate::generate::generators::{default_generators, InputGenerator};
use crate::pipeline::execution::InputCase;
use crate::{OpSpec, ParityFailure};

use super::batch_execution;
use super::batch_execution::Batch;
use super::progress_reporting::{
    FnProgress, NoProgress, ProgressSink, StreamingProgress, StreamingSummary,
};
use super::regression_sinking;
use super::workgroup_config;

const DEFAULT_BATCH_SIZE: usize = 1024;
const MAX_OUTPUT_BYTES_PER_CASE: usize = 16 * 1024 * 1024;
const MAX_BATCH_SIZE: usize = 1_048_576;

/// Streaming conformance runner for large suites with bounded memory.
pub struct StreamingRunner<P: ProgressSink = NoProgress> {
    pub(crate) shard_id: u64,
    pub(crate) shard_count: u64,
    pub(crate) skip_count: u64,
    pub(crate) batch_size: usize,
    pub(crate) progress_interval: u64,
    pub(crate) max_output_bytes: usize,
    pub(crate) checkpoint_interval: u64,
    pub(crate) generators: Vec<Box<dyn InputGenerator>>,
    pub(crate) progress: P,
    pub(crate) summary: StreamingSummary,
    pub(crate) regression_tx: Option<std::sync::mpsc::SyncSender<ParityFailure>>,
}

impl StreamingRunner<NoProgress> {
    /// Creates a default runner configured for full coverage across all defaults.
    #[inline]
    pub fn new() -> Self {
        Self {
            shard_id: 0,
            shard_count: 1,
            skip_count: 0,
            batch_size: DEFAULT_BATCH_SIZE,
            progress_interval: DEFAULT_BATCH_SIZE as u64,
            max_output_bytes: MAX_OUTPUT_BYTES_PER_CASE,
            checkpoint_interval: 1024,
            generators: default_generators(),
            progress: NoProgress,
            summary: StreamingSummary::default(),
            regression_tx: None,
        }
    }

    /// Installs a custom progress callback and returns a runner with that sink.
    #[inline]
    pub fn on_progress<F>(self, callback: F) -> StreamingRunner<FnProgress<F>>
    where
        F: FnMut(StreamingProgress) + Send + 'static,
    {
        StreamingRunner {
            shard_id: self.shard_id,
            shard_count: self.shard_count,
            skip_count: self.skip_count,
            batch_size: self.batch_size,
            progress_interval: self.progress_interval,
            max_output_bytes: self.max_output_bytes,
            checkpoint_interval: self.checkpoint_interval,
            generators: self.generators,
            progress: FnProgress(callback),
            summary: self.summary,
            regression_tx: None,
        }
    }

    /// Loads runner state from a previously written checkpoint file.
    #[inline]
    pub fn resume_from_dir(
        dir: impl AsRef<std::path::Path>,
        shard_id: u64,
    ) -> std::io::Result<Self> {
        #[derive(serde::Deserialize)]
        struct Checkpoint {
            next_test_id: u64,
            shard_id: u64,
            shard_count: u64,
        }

        let cp: Checkpoint = serde_json::from_slice(&std::fs::read(
            dir.as_ref()
                .join(format!("streaming-progress-shard{shard_id}.json")),
        )?)?;

        Ok(Self::new()
            .shard(cp.shard_id, cp.shard_count)
            .skip(cp.next_test_id))
    }
}

impl<P: ProgressSink> StreamingRunner<P> {
    /// Limits execution to a shard partition.
    #[inline]
    pub fn shard(mut self, shard_id: u64, shard_count: u64) -> Self {
        self.shard_id = shard_id;
        self.shard_count = shard_count;
        self
    }

    /// Skips the first `count` generated case ids.
    #[inline]
    pub fn skip(mut self, count: u64) -> Self {
        self.skip_count = count;
        self
    }

    /// Configures how many cases are batched per dispatch.
    #[inline]
    pub fn batch_size(mut self, size: usize) -> Self {
        self.batch_size = size.clamp(1, MAX_BATCH_SIZE);
        self.progress_interval = self.batch_size as u64;
        self
    }

    /// Sets maximum output size retained from a CPU reference before hard-failing that case.
    #[inline]
    pub fn max_output_bytes(mut self, limit: usize) -> Self {
        self.max_output_bytes = limit;
        self
    }

    /// Configures how often checkpoint files are written.
    #[inline]
    pub fn checkpoint_interval(mut self, interval: u64) -> Self {
        self.checkpoint_interval = interval;
        self
    }

    /// Adds an additional generator to the default input set.
    #[inline]
    pub fn with_generator(mut self, generator: Box<dyn InputGenerator>) -> Self {
        self.generators.push(generator);
        self
    }

    /// Runs the full provided op list and returns all failures.
    #[inline]
    pub fn run(self, backend: &dyn vyre::VyreBackend, specs: &[OpSpec]) -> Vec<ParityFailure> {
        self.run_with_summary(backend, specs).1
    }

    #[inline]
    pub(crate) fn run_with_summary(
        mut self,
        backend: &dyn vyre::VyreBackend,
        specs: &[OpSpec],
    ) -> (StreamingSummary, Vec<ParityFailure>) {
        if let Some(failure) = self.config_failure() {
            return (self.summary.clone(), vec![failure]);
        }

        let (writer_tx, writer) = regression_sinking::start_regression_writer();
        self.regression_tx = Some(writer_tx);

        let mut failures = Vec::new();
        let mut next_test_id = 0_u64;

        for op in specs {
            self.run_op(backend, op, None, &mut next_test_id, &mut failures);
            for (label, alt_wgsl_fn) in &op.alt_wgsl_fns {
                let alt = OpSpec {
                    wgsl_fn: *alt_wgsl_fn,
                    alt_wgsl_fns: Vec::new(),
                    version_history: op.version_history.clone(),
                    ..op.clone()
                };
                self.run_op(backend, &alt, Some(label), &mut next_test_id, &mut failures);
            }
        }

        drop(self.regression_tx.take());
        let _ = writer.join();

        (self.summary, failures)
    }

    #[inline]
    pub(crate) fn run_op(
        &mut self,
        backend: &dyn vyre::VyreBackend,
        op: &OpSpec,
        alt_label: Option<&str>,
        next_test_id: &mut u64,
        failures: &mut Vec<ParityFailure>,
    ) {
        batch_execution::run_op(self, backend, op, alt_label, next_test_id, failures);
    }

    #[inline]
    pub(crate) fn accept_case(
        &mut self,
        backend: &dyn vyre::VyreBackend,
        op: &OpSpec,
        case: InputCase,
        next_test_id: &mut u64,
        batch: &mut Batch,
        failures: &mut Vec<ParityFailure>,
    ) -> Result<(), String> {
        batch_execution::accept_case(self, backend, op, case, next_test_id, batch, failures)
    }

    #[inline]
    pub(crate) fn record_batch_dispatch_failure(
        &mut self,
        failures: &mut Vec<ParityFailure>,
        batch: &Batch,
        op: &OpSpec,
        cases: Vec<InputCase>,
        cpus: Vec<Vec<u8>>,
        message: String,
    ) {
        for (case, cpu_output) in cases.into_iter().zip(cpus) {
            self.record_failure(
                failures,
                case.failure(
                    op.id,
                    Vec::new(),
                    cpu_output,
                    batch.message(message.clone()),
                    op.version,
                    batch.workgroup_size,
                ),
            );
        }
    }

    #[inline]
    pub(crate) fn record_failure(
        &mut self,
        failures: &mut Vec<ParityFailure>,
        failure: ParityFailure,
    ) {
        self.summary.tested = self.summary.tested.saturating_add(1);
        self.summary.failed = self.summary.failed.saturating_add(1);
        regression_sinking::send_or_store_failure(&self.regression_tx, &failure);
        failures.push(failure);
    }

    #[inline]
    pub(crate) fn config_failure(&self) -> Option<ParityFailure> {
        workgroup_config::validate_sharding(self.shard_id, self.shard_count)
    }
}