vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Layer 3 - reference interpreter wrapper.

use crate::{spec::program::program_for_spec_input, spec::types::BufferInitPolicy, spec::OpSpec};
use vyre_reference::value::Value;

pub use vyre_reference::run;

/// Result of a reference-interpreter parity check against a backend.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParityCheckReport {
    /// Backend name reported by the implementation.
    pub backend: String,
    /// True when backend output exactly matches the L3 interpreter.
    pub passed: bool,
    /// Actionable findings. Empty means the parity check passed.
    pub findings: Vec<String>,
}

/// Run the H7 reference-diff hook for a program with no external inputs.
#[inline]
pub(crate) fn parity_check(
    program: &vyre::ir::Program,
    backend: &dyn vyre::VyreBackend,
) -> ParityCheckReport {
    parity_check_with_inputs(program, &[], backend)
}

/// Run the H7 reference-diff hook for a program and explicit interpreter inputs.
#[inline]
pub(crate) fn parity_check_with_inputs(
    program: &vyre::ir::Program,
    inputs: &[Value],
    backend: &dyn vyre::VyreBackend,
) -> ParityCheckReport {
    parity_check_with_config(program, inputs, backend, default_config(program))
}

/// Run parity with an explicit dispatch configuration.
#[inline]
pub(crate) fn parity_check_with_config(
    program: &vyre::ir::Program,
    inputs: &[Value],
    backend: &dyn vyre::VyreBackend,
    config: vyre::DispatchConfig,
) -> ParityCheckReport {
    let reference_outputs = match run(program, inputs) {
        Ok(outputs) => outputs,
        Err(err) => {
            return ParityCheckReport {
                backend: backend.id().to_string(),
                passed: false,
                findings: vec![format!(
                    "{err}. Fix: repair the IR program before backend parity."
                )],
            };
        }
    };

    let expected = flatten_outputs(&reference_outputs);
    let input_buffers = inputs.iter().map(Value::to_bytes).collect::<Vec<_>>();
    let dispatch_program = program_with_output_size(program, expected.len());
    let backend_output = backend.dispatch(&dispatch_program, &input_buffers, &config);

    match backend_output {
        Ok(actual) if flatten_buffers(&actual) == expected => ParityCheckReport {
            backend: backend.id().to_string(),
            passed: true,
            findings: Vec::new(),
        },
        Ok(actual) => ParityCheckReport {
            backend: backend.id().to_string(),
            passed: false,
            findings: vec![format!(
                "backend output differed from L3 reference: expected {} bytes, got {} bytes. Fix: diff backend lowering against vyre_reference::run.",
                expected.len(),
                flatten_buffers(&actual).len()
            )],
        },
        Err(err) => ParityCheckReport {
            backend: backend.id().to_string(),
            passed: false,
            findings: vec![format!("{err}. Fix: implement VyreBackend::dispatch for this IR program.")],
        },
    }
}

/// Run L3 parity for every registered operation using deterministic IR
/// probes derived from each spec's CPU reference.
#[inline]
pub fn enforce_registry(specs: &[OpSpec]) -> Vec<String> {
    let backend = ReferenceParityBackend;
    let mut findings = Vec::new();

    for spec in specs {
        let (program, inputs) = match probe_program_for_spec(spec) {
            Ok(probe) => probe,
            Err(finding) => {
                findings.push(format!("{}: {finding}", spec.id));
                continue;
            }
        };
        for workgroup_size in [1, 64] {
            for buffer_init in [BufferInitPolicy::Zero, BufferInitPolicy::Poison] {
                let mut dispatch_program = program.clone();
                dispatch_program.set_workgroup_size([workgroup_size, 1, 1]);
                let report = parity_check_with_config(
                    &dispatch_program,
                    &inputs,
                    &backend,
                    l3_dispatch_config(workgroup_size, buffer_init),
                );
                if !report.passed {
                    findings.extend(report.findings.into_iter().map(|finding| {
                        format!(
                            "{}: wg={} init={:?}: {}",
                            spec.id, workgroup_size, buffer_init, finding
                        )
                    }));
                }
            }
        }
    }

    findings
}

fn default_config(program: &vyre::ir::Program) -> vyre::DispatchConfig {
    l3_dispatch_config(program.workgroup_size()[0].max(1), BufferInitPolicy::Zero)
}

fn l3_dispatch_config(workgroup_size: u32, buffer_init: BufferInitPolicy) -> vyre::DispatchConfig {
    let mut config = vyre::DispatchConfig::default();
    config.profile = Some(format!(
        "conform-l3:workgroup_size={workgroup_size};buffer_init={buffer_init:?}"
    ));
    config
}

fn probe_program_for_spec(spec: &OpSpec) -> Result<(vyre::ir::Program, Vec<Value>), String> {
    let input = vec![0; spec.signature.min_input_bytes()];
    let output = (spec.cpu_fn)(&input);
    let program = program_for_spec_input(spec, &input)?;
    let inputs = probe_inputs_for_program(&program, &input, output.len())?;
    Ok((program, inputs))
}

fn probe_inputs_for_program(
    program: &vyre::ir::Program,
    input: &[u8],
    output_size: usize,
) -> Result<Vec<Value>, String> {
    let mut values = Vec::new();
    let mut consumed_primary_input = false;

    for buffer in program.buffers() {
        if buffer.access() == vyre::ir::BufferAccess::Workgroup {
            continue;
        }

        let bytes = if buffer.is_output() || buffer.access() == vyre::ir::BufferAccess::ReadWrite {
            zeroed_buffer(buffer, output_size)?
        } else if !consumed_primary_input {
            consumed_primary_input = true;
            input.to_vec()
        } else {
            zeroed_buffer(buffer, 0)?
        };
        values.push(Value::Bytes(bytes));
    }

    Ok(values)
}

fn zeroed_buffer(buffer: &vyre::ir::BufferDecl, fallback_size: usize) -> Result<Vec<u8>, String> {
    let declared = usize::try_from(buffer.count()).map_err(|_| {
        format!(
            "buffer `{}` declares an unrepresentable element count. Fix: reduce the probe output size.",
            buffer.name()
        )
    })?;
    let declared_size = declared.saturating_mul(element_size_bytes(buffer.element()));
    Ok(vec![0; declared_size.max(fallback_size)])
}

fn element_size_bytes(data_type: vyre::ir::DataType) -> usize {
    match data_type {
        vyre::ir::DataType::U64 | vyre::ir::DataType::Vec2U32 => 8,
        vyre::ir::DataType::Vec4U32 => 16,
        vyre::ir::DataType::Bytes => 1,
        _ => 4,
    }
}

struct ReferenceParityBackend;

impl vyre::VyreBackend for ReferenceParityBackend {
    fn id(&self) -> &'static str {
        "l3-reference-parity"
    }

    fn dispatch(
        &self,
        program: &vyre::Program,
        inputs: &[Vec<u8>],
        _config: &vyre::DispatchConfig,
    ) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
        let values = inputs.iter().cloned().map(Value::Bytes).collect::<Vec<_>>();
        let outputs = run(program, &values).map_err(|err| {
            vyre::BackendError::new(format!(
                "{err}. Fix: repair the generated L3 probe before backend parity."
            ))
        })?;
        Ok(outputs.into_iter().map(|value| value.to_bytes()).collect())
    }
}

fn flatten_outputs(values: &[Value]) -> Vec<u8> {
    values.iter().flat_map(Value::to_bytes).collect()
}

fn flatten_buffers(values: &[Vec<u8>]) -> Vec<u8> {
    values.iter().flatten().copied().collect()
}

fn program_with_output_size(program: &vyre::Program, output_size: usize) -> vyre::Program {
    let mut buffers = program.buffers().to_vec();
    for buffer in &mut buffers {
        if buffer.access == vyre::ir::BufferAccess::ReadWrite {
            buffer.is_output = true;
            let element_size = element_size_bytes(buffer.element);
            buffer.count = output_size
                .div_ceil(element_size)
                .try_into()
                .unwrap_or(u32::MAX);
            break;
        }
    }
    let mut resized =
        vyre::Program::new(buffers, program.workgroup_size(), program.entry().to_vec());
    resized.entry_op_id = program.entry_op_id.clone();
    resized
}

/// Registry entry for `layer3_reference_interp` enforcement.
pub struct Layer3ReferenceInterpEnforcer;

impl crate::enforce::EnforceGate for Layer3ReferenceInterpEnforcer {
    fn id(&self) -> &'static str {
        "layer3_reference_interp"
    }

    fn name(&self) -> &'static str {
        "layer3_reference_interp"
    }

    fn run(&self, _ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
        let messages = Vec::new();
        crate::enforce::finding_result(self.id(), messages)
    }
}

/// Auto-registered `layer3_reference_interp` enforcer.
pub const REGISTERED: Layer3ReferenceInterpEnforcer = Layer3ReferenceInterpEnforcer;

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn probe_program_for_spec_executes_real_ir_body() {
        let spec = crate::spec::primitive::xor::spec();
        let (program, inputs) = probe_program_for_spec(&spec).expect("xor probe must build");

        assert!(
            program
                .entry()
                .iter()
                .any(|node| !matches!(node, vyre::ir::Node::Return)),
            "probe must contain executable IR, not only Return"
        );
        assert!(
            program
                .buffers()
                .iter()
                .any(|buffer| buffer.access() == vyre::ir::BufferAccess::ReadOnly),
            "probe must load from an input buffer"
        );
        assert!(
            program
                .buffers()
                .iter()
                .any(vyre::ir::BufferDecl::is_output),
            "probe must write an output buffer"
        );

        let outputs = run(&program, &inputs).expect("probe must execute in reference interpreter");
        let actual = flatten_outputs(&outputs);
        let input = vec![0; spec.signature.min_input_bytes()];
        assert_eq!(actual, (spec.cpu_fn)(&input));
    }
}