vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Operation-spec to vyre IR program construction.

use crate::OpSpec;

/// Build the IR program used to certify one operation on one concrete input.
///
/// The returned program is tagged with the operation id and has its output
/// buffer count sized from the CPU reference for `input`. This gives
/// `VyreBackend::dispatch` enough IR metadata to allocate an exact output
/// buffer without conform reaching through a WGSL-specific dispatch API.
///
/// # Errors
///
/// Returns an actionable error when neither the `OpSpec` nor the core vyre
/// registry can provide an IR program for the operation.
pub fn program_for_spec_input(spec: &OpSpec, input: &[u8]) -> Result<vyre::Program, String> {
    let cpu_output = (spec.cpu_fn)(input);
    let program = spec
        .program()
        .or_else(|| vyre::ops::registry::lookup_program(spec.id))
        .ok_or_else(|| {
            format!(
                "Fix: OpSpec `{}` has no vyre IR program; register `ir_program` in the spec or add the op to vyre::ops::registry before certification.",
                spec.id
            )
        })?;
    Ok(size_output_buffers(
        program.with_entry_op_id(spec.id),
        cpu_output.len(),
    ))
}

fn size_output_buffers(mut program: vyre::Program, output_size: usize) -> vyre::Program {
    let mut buffers = program.buffers().to_vec();
    for buffer in &mut buffers {
        if buffer.is_output() {
            let element_size = element_size_bytes(buffer.element()).max(1);
            buffer.count = output_size
                .div_ceil(element_size)
                .try_into()
                .unwrap_or(u32::MAX);
        }
    }
    let mut resized = vyre::Program::new(buffers, program.workgroup_size, (*program.entry).clone());
    resized.entry_op_id = program.entry_op_id.take();
    resized
}

fn element_size_bytes(data_type: vyre::ir::DataType) -> usize {
    match data_type {
        vyre::ir::DataType::Bool => 4,
        vyre::ir::DataType::U32 | vyre::ir::DataType::I32 | vyre::ir::DataType::F32 => 4,
        vyre::ir::DataType::U64 | vyre::ir::DataType::Vec2U32 => 8,
        vyre::ir::DataType::Vec4U32 => 16,
        vyre::ir::DataType::Bytes => 1,
        _ => 4,
    }
}