use super::buffers::emit_buffer_decl;
use super::helper_usage::HelperUsage;
use super::node::emit_nodes;
use super::{ensure_output_within_limit, Error, LowerCtx};
use crate::ir::model::program::{BufferDecl, Program};
use crate::lower::wgsl::analysis::{atomic_buffers, TypeScope};
use crate::lower::wgsl::emit::{emit_buffer_access_helpers, emit_safe_arithmetic_helpers};
use std::collections::HashMap;
#[inline]
pub fn emit_wgsl(program: &Program) -> Result<String, Error> {
let est = program.buffers().len() * 128 + program.entry().len() * 64 + 4096;
let mut out = String::with_capacity(est);
let atomic_buffers = atomic_buffers(program);
let helper_usage = HelperUsage::collect(program);
let mut buffer_map: HashMap<&str, &BufferDecl> =
HashMap::with_capacity(program.buffers().len());
buffer_map.extend(program.buffers().iter().map(|b| (b.name.as_str(), b)));
emit_safe_arithmetic_helpers(&mut out);
ensure_output_within_limit(&out)?;
for buf in program.buffers() {
emit_buffer_access_helpers(&mut out, buf, helper_usage.buffer(buf.name.as_str()))?;
ensure_output_within_limit(&out)?;
}
for buf in program.buffers() {
emit_buffer_decl(&mut out, buf, atomic_buffers.contains(buf.name.as_str()))?;
ensure_output_within_limit(&out)?;
}
let [wx, wy, wz] = program.workgroup_size;
use std::fmt::Write as _;
write!(
out,
"\n@compute @workgroup_size({wx}, {wy}, {wz})\nfn main(\n @builtin(global_invocation_id) _vyre_gid: vec3<u32>,\n @builtin(workgroup_id) _vyre_wgid: vec3<u32>,\n @builtin(local_invocation_id) _vyre_lid: vec3<u32>,\n) {{\n"
)
.expect("Fix: writing into a String must be infallible; replace String with a fallible sink only if errors are propagated");
ensure_output_within_limit(&out)?;
let mut ctx = LowerCtx {
indent: 1,
vars: TypeScope::new(),
atomic_buffers,
buffer_map,
};
emit_nodes(&mut out, &mut ctx, program.entry(), program)?;
out.push_str("}\n");
ensure_output_within_limit(&out)?;
Ok(out)
}