vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use super::buffers::emit_buffer_decl;
use super::helper_usage::HelperUsage;
use super::node::emit_nodes;
use super::{ensure_output_within_limit, Error, LowerCtx};
use crate::ir::model::program::{BufferDecl, Program};
use crate::lower::wgsl::analysis::{atomic_buffers, TypeScope};
use crate::lower::wgsl::emit::{emit_buffer_access_helpers, emit_safe_arithmetic_helpers};
use std::collections::HashMap;

#[inline]
pub fn emit_wgsl(program: &Program) -> Result<String, Error> {
    // L.2.3: the prior fixed 4 KiB seed forced multiple exponential
    // reallocations on large programs. Rough empirical sizing per the
    // audit: each buffer ≈128 bytes of preamble/decl/access-helpers,
    // each entry node ≈64 bytes of WGSL after emit. Add a 4 KiB
    // floor for the arithmetic-helper prelude.
    let est = program.buffers().len() * 128 + program.entry().len() * 64 + 4096;
    let mut out = String::with_capacity(est);
    let atomic_buffers = atomic_buffers(program);
    let helper_usage = HelperUsage::collect(program);
    let mut buffer_map: HashMap<&str, &BufferDecl> =
        HashMap::with_capacity(program.buffers().len());
    buffer_map.extend(program.buffers().iter().map(|b| (b.name.as_str(), b)));
    emit_safe_arithmetic_helpers(&mut out);
    ensure_output_within_limit(&out)?;

    for buf in program.buffers() {
        emit_buffer_access_helpers(&mut out, buf, helper_usage.buffer(buf.name.as_str()))?;
        ensure_output_within_limit(&out)?;
    }
    for buf in program.buffers() {
        emit_buffer_decl(&mut out, buf, atomic_buffers.contains(buf.name.as_str()))?;
        ensure_output_within_limit(&out)?;
    }

    let [wx, wy, wz] = program.workgroup_size;
    // L.2.4: use `write!` into the existing String to avoid allocating a
    // temporary prelude block.
    use std::fmt::Write as _;
    write!(
        out,
        "\n@compute @workgroup_size({wx}, {wy}, {wz})\nfn main(\n  @builtin(global_invocation_id) _vyre_gid: vec3<u32>,\n  @builtin(workgroup_id) _vyre_wgid: vec3<u32>,\n  @builtin(local_invocation_id) _vyre_lid: vec3<u32>,\n) {{\n"
    )
    .expect("Fix: writing into a String must be infallible; replace String with a fallible sink only if errors are propagated");
    ensure_output_within_limit(&out)?;

    let mut ctx = LowerCtx {
        indent: 1,
        vars: TypeScope::new(),
        atomic_buffers,
        buffer_map,
    };
    emit_nodes(&mut out, &mut ctx, program.entry(), program)?;

    out.push_str("}\n");
    ensure_output_within_limit(&out)?;
    Ok(out)
}