vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! K-dispatch determinism enforcer.
//!
//! Verifies that a given op produces byte-identical GPU output across a range of
//! workgroup sizes and across repeated dispatches at each size. This catches
//! parallelism bugs that may be invisible to CPU-parity tests.

use crate::pipeline::execution::InputCase;
use crate::spec::program::program_for_spec_input;
use crate::spec::OpSpec;

/// Infrastructure-level error that is not treated as a determinism divergence.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InfrastructureError {
    /// The backend failed to compile the shader for the requested workgroup size.
    CompileFailure,
}

/// Dispatch error classification for determinism testing.
#[derive(Debug, Clone, PartialEq, Eq)]
enum DispatchError {
    Infra(InfrastructureError),
    Other(String),
}

impl DispatchError {
    fn into_bytes(self) -> Vec<u8> {
        match self {
            DispatchError::Infra(err) => format!("{err:?}").into_bytes(),
            DispatchError::Other(msg) => msg.into_bytes(),
        }
    }
}

/// Report produced by the determinism enforcer for a single op.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct DeterminismReport {
    /// Operation that was tested.
    pub op_id: String,
    /// All detected divergences across inputs and workgroup sizes.
    pub divergences: Vec<Divergence>,
}

/// A single detected divergence between two dispatches.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Divergence {
    /// Human-readable label for the input that diverged.
    pub input_label: String,
    /// Baseline workgroup size (always 1).
    pub wg_a: u32,
    /// Workgroup size that produced different output.
    pub wg_b: u32,
    /// Baseline output bytes.
    pub bytes_a: Vec<u8>,
    /// Divergent output bytes.
    pub bytes_b: Vec<u8>,
    /// Run index at which divergence was first observed.
    pub run: Option<u32>,
    /// Actionable diagnostic. Empty for byte-output divergence.
    pub message: String,
}

/// Enforce K-dispatch determinism for `op` across `inputs`.
///
/// For each input:
/// 1. Dispatch at workgroup size 1 to establish a canonical single-threaded
///    baseline.
/// 2. For each workgroup size in `[8, 32, 64, 128, 256, 1024]`:
///    - Skip sizes that exceed `op.workgroup_size` (if declared).
///    - Repeat dispatch `repeats` times.
///    - If any repeat produces output that differs from the baseline, record a
///      `Divergence` and move to the next workgroup size.
///    - If dispatch fails to compile (e.g., unsupported workgroup size), skip
///      remaining repeats for that size. The skip is **not** treated as a
///      divergence.
///
/// This pass intentionally does **not** compare against CPU output. Its purpose
/// is to catch workgroup-size-dependent nondeterminism that might be invisible
/// to CPU-parity tests (e.g., a race that affects all parallel dispatches
/// uniformly).
#[inline]
pub fn enforce_determinism(
    backend: &dyn vyre::VyreBackend,
    op: &OpSpec,
    inputs: &[InputCase],
    repeats: u32,
) -> DeterminismReport {
    let mut divergences = Vec::new();
    if repeats < 10 {
        divergences.push(Divergence {
            input_label: "configuration/repeats".to_string(),
            wg_a: 1,
            wg_b: 1,
            bytes_a: repeats.to_le_bytes().to_vec(),
            bytes_b: 10u32.to_le_bytes().to_vec(),
            run: None,
            message: format!(
                "Fix: determinism requires at least 10 repeats; caller requested {repeats}."
            ),
        });
    }
    let repeats = repeats.max(10);
    let workgroup_sizes = determinism_workgroup_sizes(op.workgroup_size);

    for case in inputs {
        // Establish single-threaded baseline.
        let baseline = match dispatch_op(backend, op, &case.bytes, 1) {
            Ok(bytes) => bytes,
            Err(err) => {
                divergences.push(Divergence {
                    input_label: case.report_label(),
                    wg_a: 1,
                    wg_b: 1,
                    bytes_a: Vec::new(),
                    bytes_b: err.into_bytes(),
                    run: None,
                    message: "Fix: backend must compile and run the canonical workgroup_size=1 baseline before determinism can be claimed.".to_string(),
                });
                continue;
            }
        };

        for wg in workgroup_sizes.iter().copied().filter(|&wg| wg != 1) {
            for run in 0..repeats {
                match dispatch_op(backend, op, &case.bytes, wg) {
                    Ok(output) => {
                        if output != baseline {
                            divergences.push(Divergence {
                                input_label: case.report_label(),
                                wg_a: 1,
                                wg_b: wg,
                                bytes_a: baseline.clone(),
                                bytes_b: output,
                                run: Some(run),
                                message: String::new(),
                            });
                            break; // Move to next workgroup size.
                        }
                    }
                    Err(DispatchError::Infra(InfrastructureError::CompileFailure)) => {
                        // Skip remaining repeats for this workgroup size.
                        // Compile failures for wg > 1 are not treated as divergence.
                        break;
                    }
                    Err(err) => {
                        divergences.push(Divergence {
                            input_label: case.report_label(),
                            wg_a: 1,
                            wg_b: wg,
                            bytes_a: baseline.clone(),
                            bytes_b: err.into_bytes(),
                            run: Some(run),
                            message: format!(
                                "Fix: backend dispatch failed at workgroup_size={wg} after the baseline succeeded; unsupported sizes must be reported explicitly or constrained by the op spec."
                            ),
                        });
                        break;
                    }
                }
            }
        }
    }

    DeterminismReport {
        op_id: op.id.to_string(),
        divergences,
    }
}

fn dispatch_op(
    backend: &dyn vyre::VyreBackend,
    op: &OpSpec,
    input: &[u8],
    workgroup_size: u32,
) -> Result<Vec<u8>, DispatchError> {
    let min_bytes = op.signature.min_input_bytes();
    if min_bytes > 0 && input.len() < min_bytes {
        return Err(DispatchError::Other(format!(
            "undersized input: {} bytes for {} (minimum {min_bytes}). \
             Fix: generator produced input smaller than the op's type signature requires.",
            input.len(),
            op.id,
        )));
    }
    let cpu = (op.cpu_fn)(input);
    let mut program = program_for_spec_input(op, input).map_err(DispatchError::Other)?;
    program.set_workgroup_size([workgroup_size, 1, 1]);
    backend
        .dispatch(
            &program,
            &[input.to_vec()],
            &vyre::DispatchConfig::default(),
        )
        .map_err(|err| {
            if workgroup_size > 1 {
                DispatchError::Infra(InfrastructureError::CompileFailure)
            } else {
                DispatchError::Other(format!(
                    "backend dispatch failed on {} with workgroup_size={workgroup_size}: {err}. \
                 Fix: execute the canonical vyre IR program and return {} bytes.",
                    backend.id(),
                    cpu.len()
                ))
            }
        })
        .and_then(|mut outputs| {
            if outputs.is_empty() {
                return Err(DispatchError::Other(
                    "backend returned zero output buffers. Fix: return the operation result as outputs[0]."
                        .to_string(),
                ));
            }
            let output = outputs.remove(0);
            if output.len() != cpu.len() {
                return Err(DispatchError::Other(format!(
                    "backend returned {} bytes, expected {}. Fix: size the first output buffer from the program output declaration.",
                    output.len(),
                    cpu.len()
                )));
            }
            Ok(output)
        })
}

fn determinism_workgroup_sizes(preferred: Option<u32>) -> Vec<u32> {
    let sizes = vec![1, 8, 32, 64, 128, 256, 1024];
    match preferred {
        Some(max) if max > 0 => sizes.into_iter().filter(|&s| s <= max).collect(),
        _ => sizes,
    }
}

/// Registry entry for `determinism` enforcement.
pub struct DeterminismEnforcer;

impl crate::enforce::EnforceGate for DeterminismEnforcer {
    fn id(&self) -> &'static str {
        "determinism"
    }

    fn name(&self) -> &'static str {
        "determinism"
    }

    fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
        let Some(backend) = ctx.backend else {
            return vec![crate::enforce::aggregate_finding(
                self.id(),
                vec![
                    "determinism: backend is required. Fix: provide a VyreBackend in EnforceCtx."
                        .to_string(),
                ],
            )];
        };
        let mut messages = Vec::new();
        for spec in ctx.specs {
            let input_len = spec.signature.min_input_bytes().max(4);
            let input = crate::pipeline::execution::InputCase::new(
                "registry",
                "zero".to_string(),
                vec![0; input_len],
            );
            messages.extend(enforce_determinism(backend, spec, &[input], 10).divergences.into_iter().map(|d| if d.message.is_empty() { format!("determinism({}): input={} wg_a={} wg_b={}. Fix: make dispatch output byte-identical.", spec.id, d.input_label, d.wg_a, d.wg_b) } else { d.message }));
        }
        crate::enforce::finding_result(self.id(), messages)
    }
}

/// Auto-registered `determinism` enforcer.
pub const REGISTERED: DeterminismEnforcer = DeterminismEnforcer;

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn determinism_workgroup_sizes_unbounded() {
        let sizes = determinism_workgroup_sizes(None);
        assert_eq!(sizes, vec![1, 8, 32, 64, 128, 256, 1024]);
    }

    #[test]
    fn determinism_workgroup_sizes_clamped() {
        let sizes = determinism_workgroup_sizes(Some(64));
        assert_eq!(sizes, vec![1, 8, 32, 64]);
    }

    #[test]
    fn determinism_workgroup_sizes_zero_max_is_unbounded() {
        let sizes = determinism_workgroup_sizes(Some(0));
        assert_eq!(sizes, vec![1, 8, 32, 64, 128, 256, 1024]);
    }
}