vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! H5 — determinism harness.
//!
//! Runs a vyre IR program many times across multiple threads and asserts that
//! every result is byte-identical. This tests invariant I1 (Determinism).
//!
//! The module wraps the existing `crate::enforce::determinism` infrastructure
//! by re-exporting its report types and applying the same dispatch-and-compare
//! logic in a multi-threaded context.

use std::sync::Arc;
use std::thread;

use vyre::ir::Program;

pub use crate::enforce::enforcers::determinism::{DeterminismReport, Divergence};
use crate::spec::value::Value;

/// One detected divergence between two thread results.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ThreadDivergence {
    /// Iteration index where the divergence was observed.
    pub iteration: usize,
    /// Thread index where the divergence was observed.
    pub thread_id: usize,
    /// Baseline output bytes.
    pub expected: Vec<u8>,
    /// Divergent output bytes.
    pub actual: Vec<u8>,
}

#[derive(Debug)]
struct TaggedResult {
    iteration: usize,
    thread_id: usize,
    values: Vec<Value>,
}

/// Report from a multi-threaded determinism suite.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ThreadDeterminismReport {
    /// Number of iterations requested for each participating thread.
    pub iterations: usize,
    /// Number of worker threads that executed the program.
    pub threads: usize,
    /// All observed output mismatches across iterations and threads.
    pub divergences: Vec<ThreadDivergence>,
}

/// Run `program` on the reference interpreter `iterations` times across
/// `threads` threads. Collect all results and assert every byte is identical.
///
/// Returns `Ok(report)` where `report.divergences` is empty on success.
/// If the reference interpreter itself errors, that error is returned as `Err`.
#[inline]
pub fn run_determinism_suite(
    program: &Program,
    inputs: &[Value],
    iterations: usize,
    threads: usize,
) -> Result<ThreadDeterminismReport, String> {
    run_determinism_suite_with_runner(program, inputs, iterations, threads, |p, i| {
        let reference_inputs = Value::to_reference_values(i);
        vyre_reference::run(p, &reference_inputs)
            .map(Value::from_reference_values)
            .map_err(|e| e.to_string())
    })
}

/// Same as [`run_determinism_suite`] but accepts a custom runner.
///
/// The runner is invoked from multiple threads; it must be `Send`, `Sync`,
/// and `Clone`. This is exposed so tests can inject non-determinism without
/// modifying the reference interpreter.
#[inline]
pub fn run_determinism_suite_with_runner<F>(
    program: &Program,
    inputs: &[Value],
    iterations: usize,
    threads: usize,
    runner: F,
) -> Result<ThreadDeterminismReport, String>
where
    F: Fn(&Program, &[Value]) -> Result<Vec<Value>, String> + Send + Sync + Clone + 'static,
{
    if threads == 0 {
        return Err("threads must be > 0".to_string());
    }
    if iterations < 2 {
        return Err(
            "iterations must be >= 2 to detect nondeterminism. Fix: run at least two observations."
                .to_string(),
        );
    }

    let program = Arc::new(program.clone());
    let inputs: Arc<[Value]> = inputs.to_vec().into();
    let runner = Arc::new(runner);

    let per_thread = iterations / threads;
    let remainder = iterations % threads;

    let mut handles = Vec::with_capacity(threads);
    for t in 0..threads {
        let count = per_thread + if t < remainder { 1 } else { 0 };
        let prog = Arc::clone(&program);
        let ins = Arc::clone(&inputs);
        let run = Arc::clone(&runner);
        let start_iteration = (0..t)
            .map(|prior| per_thread + if prior < remainder { 1 } else { 0 })
            .sum::<usize>();
        let handle = thread::Builder::new()
            .name(format!("vyre-h5-worker-{t}"))
            .spawn(move || {
                let mut results = Vec::with_capacity(count);
                for local_index in 0..count {
                    match run(&prog, &ins) {
                        Ok(values) => results.push(TaggedResult {
                            iteration: start_iteration + local_index,
                            thread_id: t,
                            values,
                        }),
                        Err(e) => return Err(e),
                    }
                }
                Ok(results)
            })
            .map_err(|err| format!("failed to spawn determinism worker {t}: {err}"))?;
        handles.push(handle);
    }

    let mut all_results: Vec<TaggedResult> = Vec::with_capacity(iterations);
    for handle in handles {
        match handle.join() {
            Ok(Ok(results)) => all_results.extend(results),
            Ok(Err(e)) => return Err(e),
            Err(_) => return Err("thread panicked during determinism suite".to_string()),
        }
    }

    let mut divergences = Vec::new();
    if !all_results.is_empty() {
        all_results.sort_by_key(|result| result.iteration);
        let baseline = &all_results[0].values;
        for result in all_results.iter().skip(1) {
            if &result.values != baseline {
                divergences.push(ThreadDivergence {
                    iteration: result.iteration,
                    thread_id: result.thread_id,
                    expected: values_to_bytes(baseline),
                    actual: values_to_bytes(&result.values),
                });
            }
        }
    }

    Ok(ThreadDeterminismReport {
        iterations,
        threads,
        divergences,
    })
}

fn values_to_bytes(values: &[Value]) -> Vec<u8> {
    values.iter().flat_map(|v| v.to_bytes()).collect()
}

/// Run `op` on the GPU backend `iterations` times across `threads` threads.
/// Collect all GPU output bytes and assert every byte is identical.
///
/// Returns `Ok(report)` where `report.divergences` is empty on success.
/// If the GPU backend itself errors, that error is returned as `Err`.
#[inline]
pub fn run_gpu_determinism_suite(
    op: &crate::OpSpec,
    input: &[u8],
    iterations: usize,
    threads: usize,
) -> Result<ThreadDeterminismReport, String> {
    if threads == 0 {
        return Err("threads must be > 0".to_string());
    }
    if iterations < 2 {
        return Err(
            "iterations must be >= 2 to detect nondeterminism. Fix: run at least two observations."
                .to_string(),
        );
    }

    let op = std::sync::Arc::new(op.clone());
    let input: std::sync::Arc<[u8]> = input.to_vec().into();

    let per_thread = iterations / threads;
    let remainder = iterations % threads;

    let mut handles = Vec::with_capacity(threads);
    for t in 0..threads {
        let count = per_thread + if t < remainder { 1 } else { 0 };
        let op = std::sync::Arc::clone(&op);
        let input = std::sync::Arc::clone(&input);
        let start_iteration = (0..t)
            .map(|prior| per_thread + if prior < remainder { 1 } else { 0 })
            .sum::<usize>();
        let handle = std::thread::Builder::new()
            .name(format!("vyre-h7-gpu-worker-{t}"))
            .spawn(move || {
                let backend =
                    crate::pipeline::backend::wgpu::WgpuBackend::new().ok_or_else(|| {
                        "Fix: no GPU adapter available for gpu determinism suite.".to_string()
                    })?;
                let workgroup_size = op.workgroup_size.unwrap_or(64);
                let mut results = Vec::with_capacity(count);
                for local_index in 0..count {
                    match crate::pipeline::execution::execute_op(
                        &backend,
                        &op,
                        &input,
                        workgroup_size,
                    ) {
                        Ok((gpu, _cpu)) => results.push(TaggedResult {
                            iteration: start_iteration + local_index,
                            thread_id: t,
                            values: vec![Value::Bytes(gpu)],
                        }),
                        Err(e) => return Err(e),
                    }
                }
                Ok(results)
            })
            .map_err(|err| format!("failed to spawn gpu determinism worker {t}: {err}"))?;
        handles.push(handle);
    }

    let mut all_results: Vec<TaggedResult> = Vec::with_capacity(iterations);
    for handle in handles {
        match handle.join() {
            Ok(Ok(results)) => all_results.extend(results),
            Ok(Err(e)) => return Err(e),
            Err(_) => return Err("thread panicked during gpu determinism suite".to_string()),
        }
    }

    let mut divergences = Vec::new();
    if !all_results.is_empty() {
        all_results.sort_by_key(|r| r.iteration);
        let baseline = &all_results[0].values;
        for result in all_results.iter().skip(1) {
            if &result.values != baseline {
                divergences.push(ThreadDivergence {
                    iteration: result.iteration,
                    thread_id: result.thread_id,
                    expected: values_to_bytes(baseline),
                    actual: values_to_bytes(&result.values),
                });
            }
        }
    }

    Ok(ThreadDeterminismReport {
        iterations,
        threads,
        divergences,
    })
}

#[cfg(test)]
mod tests {

    use super::{run_determinism_suite, run_determinism_suite_with_runner};
    use crate::spec::value::Value;
    use std::sync::Arc;
    use vyre::ir::Program;

    fn empty_program() -> Program {
        Program::new(vec![], [1, 1, 1], vec![vyre::ir::Node::Return])
    }

    #[test]
    fn determinism_suite_passes_for_reference() {
        let program = empty_program();
        let report = run_determinism_suite(&program, &[], 20, 4)
            .expect("suite should complete without error");
        assert!(
            report.divergences.is_empty(),
            "reference interpreter must be deterministic, got divergences: {:?}",
            report.divergences
        );
    }

    #[test]
    fn determinism_suite_detects_divergence() {
        let program = empty_program();
        let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
        let cc = Arc::clone(&call_count);
        let runner = move |_p: &Program, _i: &[Value]| {
            let c = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
            if c == 3 {
                Ok(vec![Value::U32(99)])
            } else {
                Ok(vec![Value::U32(42)])
            }
        };

        let report = run_determinism_suite_with_runner(&program, &[], 10, 4, runner)
            .expect("suite should complete without runner error");
        assert!(
            !report.divergences.is_empty(),
            "expected at least one divergence for non-deterministic runner"
        );
    }

    #[test]
    fn determinism_suite_rejects_zero_threads() {
        let program = empty_program();
        let result = run_determinism_suite(&program, &[], 1, 0);
        assert!(
            result.is_err(),
            "expected error for zero threads, got: {:?}",
            result
        );
        assert!(result.unwrap_err().contains("threads must be > 0"));
    }

    #[test]
    fn determinism_suite_rejects_zero_iterations() {
        let program = empty_program();
        let result = run_determinism_suite(&program, &[], 0, 2);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("iterations must be >= 2"));
    }
}