vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Generated vyre IR program expressions.

use crate::spec::types::OpSpec;

pub(super) fn invalid_program_expr(op: &OpSpec, rule_id: &str) -> Option<(String, String)> {
    match rule_id {
        "V001" => duplicate_buffer_program(op),
        "V021" => wrong_argument_count_program(op),
        _ => None,
    }
}

fn duplicate_buffer_program(op: &OpSpec) -> Option<(String, String)> {
    let mut buffers = Vec::new();
    for (index, dt) in op.signature.inputs.iter().enumerate() {
        buffers.push(format!(
            "vyre::ir::BufferDecl::read(\"input{index}\", {index}, {})",
            ir_data_type_expr(dt)
        ));
    }
    if op.signature.inputs.is_empty() {
        buffers.push(
            "vyre::ir::BufferDecl::read_write(\"out0\", 999, vyre::ir::DataType::U32)".to_string(),
        );
    } else {
        buffers.push(
            "vyre::ir::BufferDecl::read(\"input0\", 999, vyre::ir::DataType::U32)".to_string(),
        );
    }
    let out_binding = op.signature.inputs.len();
    buffers.push(format!(
        "vyre::ir::BufferDecl::read_write(\"out0\", {out_binding}, {})",
        ir_data_type_expr(&op.signature.output)
    ));
    let args = call_args(op);
    let workgroup_array = workgroup_array(op);
    let program = format!(
        "vyre::ir::Program::new(vec![{}], {workgroup_array}, vec![vyre::ir::Node::if_then(vyre::ir::Expr::lt(vyre::ir::Expr::gid_x(), vyre::ir::Expr::buf_len(\"out0\")), vec![vyre::ir::Node::store(\"out0\", vyre::ir::Expr::gid_x(), vyre::ir::Expr::Call {{ op_id: op_id_str, args: vec![{args}] }})])])",
        buffers.join(", "),
    );
    Some((program, "duplicate buffer name".to_string()))
}

fn wrong_argument_count_program(op: &OpSpec) -> Option<(String, String)> {
    let mut buffers = Vec::new();
    for (index, dt) in op.signature.inputs.iter().enumerate() {
        buffers.push(format!(
            "vyre::ir::BufferDecl::read(\"input{index}\", {index}, {})",
            ir_data_type_expr(dt)
        ));
    }
    let out_binding = op.signature.inputs.len();
    buffers.push(format!(
        "vyre::ir::BufferDecl::read_write(\"out0\", {out_binding}, {})",
        ir_data_type_expr(&op.signature.output)
    ));
    let wrong_args = if op.signature.inputs.is_empty() {
        "vyre::ir::Expr::u32(0)"
    } else {
        ""
    };
    let workgroup_array = workgroup_array(op);
    let program = format!(
        "vyre::ir::Program::new(vec![{}], {workgroup_array}, vec![vyre::ir::Node::if_then(vyre::ir::Expr::lt(vyre::ir::Expr::gid_x(), vyre::ir::Expr::buf_len(\"out0\")), vec![vyre::ir::Node::store(\"out0\", vyre::ir::Expr::gid_x(), vyre::ir::Expr::Call {{ op_id: op_id_str, args: vec![{wrong_args}] }})])])",
        buffers.join(", "),
    );
    Some((program, "V021".to_string()))
}

pub(super) fn program_expr(op: &OpSpec) -> String {
    let mut buffers = Vec::new();
    for (index, dt) in op.signature.inputs.iter().enumerate() {
        buffers.push(format!(
            "vyre::ir::BufferDecl::read(\"input{index}\", {index}, {})",
            ir_data_type_expr(dt)
        ));
    }
    let out_binding = op.signature.inputs.len();
    buffers.push(format!(
        "vyre::ir::BufferDecl::read_write(\"out0\", {out_binding}, {})",
        ir_data_type_expr(&op.signature.output)
    ));
    let args = call_args(op);
    let workgroup_array = workgroup_array(op);
    format!(
        "vyre::ir::Program::new(vec![{}], {workgroup_array}, vec![vyre::ir::Node::if_then(vyre::ir::Expr::lt(vyre::ir::Expr::gid_x(), vyre::ir::Expr::buf_len(\"out0\")), vec![vyre::ir::Node::store(\"out0\", vyre::ir::Expr::gid_x(), vyre::ir::Expr::Call {{ op_id: op_id_str, args: vec![{args}] }})])])",
        buffers.join(", "),
    )
}

fn call_args(op: &OpSpec) -> String {
    (0..op.signature.inputs.len())
        .map(|index| format!("vyre::ir::Expr::load(\"input{index}\", vyre::ir::Expr::gid_x())"))
        .collect::<Vec<_>>()
        .join(", ")
}

fn workgroup_array(op: &OpSpec) -> String {
    let workgroup_size = op.workgroup_size.unwrap_or(1);
    format!("[{workgroup_size}, 1, 1]")
}

fn ir_data_type_expr(dt: &crate::spec::types::DataType) -> &'static str {
    match dt {
        crate::spec::types::DataType::U32 => "vyre::ir::DataType::U32",
        crate::spec::types::DataType::I32 => "vyre::ir::DataType::I32",
        crate::spec::types::DataType::U64 => "vyre::ir::DataType::U64",
        crate::spec::types::DataType::Vec2U32 => "vyre::ir::DataType::Vec2U32",
        crate::spec::types::DataType::Vec4U32 => "vyre::ir::DataType::Vec4U32",
        crate::spec::types::DataType::Bool => "vyre::ir::DataType::Bool",
        crate::spec::types::DataType::Bytes => "vyre::ir::DataType::Bytes",
        _ => panic!(
            "Fix: unsupported DataType {:?} for IR program generation",
            dt
        ),
    }
}