vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Execution primitives for conformance dispatch.
//!
//! Handles CPU reference computation, GPU backend dispatch, and result
//! comparison for individual ops and composition chains.

use crate::pipeline::backend::ConformDispatchConfig;
use crate::spec::program::program_for_spec_input;
use crate::spec::types::{ChainSpec, OpSpec, ParityFailure};
use crate::verify::regression;

/// A single test input case with its generator lineage and human-readable label.
#[derive(Clone)]
pub struct InputCase {
    /// Name of the generator that produced this input.
    pub generator: String,
    /// Human-readable label for failure reports.
    pub label: String,
    /// Raw input bytes to feed to the op.
    pub bytes: Vec<u8>,
}

impl InputCase {
    /// Create a new input case.
    #[inline]
    pub fn new(generator: &str, label: String, bytes: Vec<u8>) -> Self {
        Self {
            generator: generator.to_string(),
            label,
            bytes,
        }
    }

    /// Build a [`ParityFailure`] from this input case.
    #[inline]
    pub fn failure(
        &self,
        op_id: &str,
        gpu: Vec<u8>,
        cpu: Vec<u8>,
        message: String,
        spec_version: u32,
        workgroup_size: u32,
    ) -> ParityFailure {
        ParityFailure {
            op_id: op_id.to_string(),
            generator: self.generator.clone(),
            input_label: self.label.clone(),
            input: self.bytes.clone(),
            gpu_output: gpu,
            cpu_output: cpu,
            message,
            spec_version,
            workgroup_size,
        }
    }

    /// Human-readable label for reporters and failure reports.
    #[inline]
    pub fn report_label(&self) -> String {
        format!("{}/{}", self.generator, self.label)
    }
}

#[inline]
pub(crate) fn execute_op(
    backend: &dyn vyre::VyreBackend,
    op: &OpSpec,
    input: &[u8],
    workgroup_size: u32,
) -> Result<(Vec<u8>, Vec<u8>), String> {
    validate_workgroup_size(workgroup_size)?;
    // Reject inputs that can't represent the op's declared input types.
    // Without this, a 0-byte input to a binary_u32 op silently passes
    // (both CPU guard clause and GPU zero-padding return [0,0,0,0]).
    let min_bytes = op.signature.min_input_bytes();
    if min_bytes > 0 && input.len() < min_bytes {
        return Err(format!(
            "undersized input: {} bytes for {} (minimum {min_bytes}). \
             Fix: generator produced input smaller than the op's type signature requires.",
            input.len(),
            op.id,
        ));
    }
    let cpu = {
        let start = std::time::Instant::now();
        let result = (op.cpu_fn)(input);
        let elapsed = start.elapsed();
        if let Err(bomb) = crate::verify::budget::exec_budget_record(elapsed) {
            return Err(format!("{bomb}"));
        }
        result
    };
    let config = checked_dispatch_config(op, cpu.len(), workgroup_size)?;
    let gpu =
        dispatch_backend(backend, op, input, cpu.len(), config).map_err(|err| err.to_string())?;
    Ok((gpu, cpu))
}

#[inline]
pub(crate) fn execute_chain(
    backend: &dyn vyre::VyreBackend,
    chain: &ChainSpec,
    input: &[u8],
    workgroup_size: u32,
) -> Result<(Vec<u8>, Vec<u8>), String> {
    validate_workgroup_size(workgroup_size)?;
    // Track CPU and GPU paths independently. Using the GPU output to
    // compute CPU reference for the next step would create a feedback
    // loop where early GPU errors corrupt later CPU references.
    let mut cpu_current = input.to_vec();
    let mut gpu_current = input.to_vec();

    for spec in &chain.specs {
        // CPU path: always uses its own prior output.
        let cpu_next = (spec.cpu_fn)(&cpu_current);

        // GPU path: dispatch against its own prior output,
        // but size the output buffer from the CPU path (the spec).
        let config = checked_dispatch_config(spec, cpu_next.len(), workgroup_size)?;
        gpu_current = dispatch_backend(backend, spec, &gpu_current, cpu_next.len(), config)
            .map_err(|err| {
                format!(
                    "backend dispatch failed in chain {} at {}: {err}. Fix: make every chained op accept the previous output bytes.",
                    chain.id, spec.id
                )
            })?;

        cpu_current = cpu_next;
    }

    // Override with chain-level CPU reference if provided.
    let cpu_final = if let Some(cpu_chain_fn) = chain.cpu_chain {
        cpu_chain_fn(input)
    } else {
        cpu_current
    };

    Ok((gpu_current, cpu_final))
}

#[inline]
pub(crate) fn regression_inputs(op_id: &str) -> Vec<InputCase> {
    regression::load(op_id)
        .into_iter()
        .map(|(label, bytes)| InputCase::new("regression", label, bytes))
        .collect()
}

#[inline]
pub(crate) fn persist_failure(failure: &ParityFailure) {
    if let Err(err) = regression::save(failure) {
        eprintln!(
            "vyre-conform: could not persist regression for {}: {err}. Fix: ensure regressions/ is writable.",
            failure.op_id
        );
    }
}

#[inline]
pub(crate) fn seed_from(text: &str) -> u64 {
    let mut hash = 0xcbf2_9ce4_8422_2325_u64;
    for byte in text.as_bytes() {
        hash ^= u64::from(*byte);
        hash = hash.wrapping_mul(0x0000_0100_0000_01B3);
    }
    hash
}

#[inline]
pub(crate) fn dispatch_config(
    op: &OpSpec,
    output_size: usize,
    workgroup_size: u32,
) -> ConformDispatchConfig {
    let workgroup_size_usize = usize::try_from(workgroup_size)
        .ok()
        .filter(|size| *size > 0)
        .unwrap_or(1);
    let output_words = output_size.div_ceil(4).max(1);
    let workgroup_count = output_words
        .div_ceil(workgroup_size_usize)
        .try_into()
        .unwrap_or(u32::MAX);
    ConformDispatchConfig {
        workgroup_size,
        workgroup_count,
        convention: op.convention,
        lookup_data: None,
        buffer_init: crate::spec::types::BufferInitPolicy::default(),
    }
}

fn checked_dispatch_config(
    op: &OpSpec,
    output_size: usize,
    workgroup_size: u32,
) -> Result<ConformDispatchConfig, String> {
    validate_workgroup_size(workgroup_size)?;
    let output_words = output_size.div_ceil(4).max(1);
    let workgroup_count = output_words.div_ceil(workgroup_size as usize);
    let workgroup_count = u32::try_from(workgroup_count).map_err(|_| {
        format!(
            "dispatch workgroup_count overflow: output_size={output_size}, output_words={output_words}, workgroup_size={workgroup_size}. Fix: split the output into multiple dispatches or reduce the requested output size."
        )
    })?;
    Ok(ConformDispatchConfig {
        workgroup_size,
        workgroup_count,
        convention: op.convention,
        lookup_data: None,
        buffer_init: crate::spec::types::BufferInitPolicy::default(),
    })
}

fn validate_workgroup_size(workgroup_size: u32) -> Result<(), String> {
    if workgroup_size == 0 {
        return Err(
            "invalid workgroup_size=0. Fix: configure at least one worker per workgroup."
                .to_string(),
        );
    }
    Ok(())
}

#[derive(Debug)]
enum ExecutionError {
    BackendDispatch {
        backend: String,
        output_size: usize,
        workgroup_size: u32,
        source: String,
    },
}

impl std::fmt::Display for ExecutionError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::BackendDispatch {
                backend,
                output_size,
                workgroup_size,
                source,
            } => write!(
                f,
                "backend dispatch failed on {backend} with workgroup_size={workgroup_size}: {source}. Fix: execute the canonical vyre IR program and return {output_size} bytes."
            ),
        }
    }
}

impl std::error::Error for ExecutionError {}

fn dispatch_backend(
    backend: &dyn vyre::VyreBackend,
    op: &OpSpec,
    input: &[u8],
    output_size: usize,
    config: ConformDispatchConfig,
) -> Result<Vec<u8>, ExecutionError> {
    let program =
        program_for_spec_input(op, input).map_err(|source| ExecutionError::BackendDispatch {
            backend: backend.id().to_string(),
            output_size,
            workgroup_size: config.workgroup_size,
            source,
        })?;
    backend
        .dispatch(&program, &[input.to_vec()], &config.to_core())
        .map_err(|error| error.message)
        .and_then(|mut outputs| {
            if outputs.is_empty() {
                return Err(format!(
                    "backend returned zero output buffers, expected one. Fix: return the operation result as outputs[0]."
                ));
            }
            let output = outputs.remove(0);
            if output.len() != output_size {
                return Err(format!(
                    "backend returned {} bytes, expected {output_size}. Fix: size the first output buffer from the program output declaration.",
                    output.len()
                ));
            }
            Ok(output)
        })
        .map_err(|source| ExecutionError::BackendDispatch {
            backend: backend.id().to_string(),
            output_size,
            workgroup_size: config.workgroup_size,
            source,
        })
}

#[cfg(test)]
mod tests {

    use super::{checked_dispatch_config, dispatch_config, seed_from, InputCase};

    #[test]
    fn input_case_report_label() {
        let case = InputCase::new("random", "case_42".into(), vec![0xDE, 0xAD]);
        assert_eq!(case.report_label(), "random/case_42");
    }

    #[test]
    fn input_case_failure_preserves_fields() {
        let case = InputCase::new("edge", "max_val".into(), vec![0xFF; 4]);
        let f = case.failure("test.op", vec![0x00], vec![0xFF], "mismatch".into(), 2, 64);
        assert_eq!(f.op_id, "test.op");
        assert_eq!(f.generator, "edge");
        assert_eq!(f.input_label, "max_val");
        assert_eq!(f.input, vec![0xFF; 4]);
        assert_eq!(f.gpu_output, vec![0x00]);
        assert_eq!(f.cpu_output, vec![0xFF]);
        assert_eq!(f.spec_version, 2);
        assert_eq!(f.workgroup_size, 64);
    }

    #[test]
    fn seed_from_is_deterministic() {
        let a = seed_from("primitive.bitwise.xor");
        let b = seed_from("primitive.bitwise.xor");
        assert_eq!(a, b);
    }

    #[test]
    fn seed_from_differs_for_different_ops() {
        let a = seed_from("primitive.bitwise.xor");
        let b = seed_from("primitive.bitwise.and");
        assert_ne!(a, b);
    }

    #[test]
    fn seed_from_empty_string() {
        let s = seed_from("");
        // Must produce a valid u64, not panic.
        assert_ne!(s, 0);
    }

    #[test]
    fn dispatch_config_single_word() {
        let op = crate::spec::primitive::xor::spec();
        let config = dispatch_config(&op, 4, 1);
        assert_eq!(config.workgroup_size, 1);
        assert_eq!(config.workgroup_count, 1);
    }

    #[test]
    fn dispatch_config_multi_word() {
        let op = crate::spec::primitive::xor::spec();
        // 256 bytes = 64 words → workgroup_count = 64/64 = 1 with wg_size=64
        let config = dispatch_config(&op, 256, 64);
        assert_eq!(config.workgroup_count, 1);
    }

    #[test]
    fn dispatch_config_zero_output_clamps_to_one() {
        let op = crate::spec::primitive::xor::spec();
        let config = dispatch_config(&op, 0, 1);
        // output_words = max(ceil(0/4), 1) = 1
        assert_eq!(config.workgroup_count, 1);
    }

    #[test]
    fn checked_dispatch_config_rejects_zero_workgroup_size() {
        let op = crate::spec::primitive::xor::spec();
        let err = checked_dispatch_config(&op, 4, 0).unwrap_err();
        assert!(err.contains("workgroup_size=0"));
        assert!(err.contains("Fix:"));
    }

    #[test]
    fn checked_dispatch_config_rejects_workgroup_count_overflow() -> Result<(), String> {
        let op = crate::spec::primitive::xor::spec();
        let output_size = (u32::MAX as usize)
            .checked_add(1)
            .and_then(|words| words.checked_mul(4))
            .ok_or_else(|| {
                "Fix: overflow fixture must fit usize on this target before dispatch validation"
                    .to_string()
            })?;
        let err = checked_dispatch_config(&op, output_size, 1).unwrap_err();
        assert!(err.contains("workgroup_count overflow"));
        assert!(err.contains("Fix:"));
        Ok(())
    }
}