use crate::ir::model::node::Node;
use crate::ir::model::program::Program;
use crate::ir::model::types::{BufferAccess, DataType};
use crate::lower::wgsl::analysis::expr_type;
use crate::lower::wgsl::expr::emit_expr_string;
use crate::lower::wgsl::impl_lowerctx::write_indent;
use crate::lower::wgsl::{ensure_output_within_limit, Error, LowerCtx};
use std::collections::HashSet;
use std::fmt::Write as _;
#[inline]
pub(crate) fn emit_nodes<'a>(
out: &mut String,
ctx: &mut LowerCtx<'a>,
nodes: &'a [Node],
program: &'a Program,
) -> Result<(), Error> {
let mut barriers = BarrierState::default();
emit_nodes_with_barriers(out, ctx, nodes, program, &mut barriers)
}
#[inline]
pub fn emit_nodes_with_barriers<'a>(
out: &mut String,
ctx: &mut LowerCtx<'a>,
nodes: &'a [Node],
program: &'a Program,
barriers: &mut BarrierState,
) -> Result<(), Error> {
let mut returned = false;
for node in nodes {
if returned {
continue;
}
emit_node(out, ctx, node, program, barriers)?;
if matches!(node, Node::Return) {
returned = true;
}
}
Ok(())
}
#[inline]
pub fn emit_node<'a>(
out: &mut String,
ctx: &mut LowerCtx<'a>,
node: &'a Node,
program: &'a Program,
barriers: &mut BarrierState,
) -> Result<(), Error> {
let indent = ctx.indent;
match node {
Node::Let { name, value } => {
if let Some(ty) = expr_type(value, &ctx.buffer_map, &ctx.vars) {
ctx.vars.insert(name.as_str(), ty);
}
let safe_name = LowerCtx::safe_var(name);
let _ = write_indent(out, indent);
let _ = write!(out, "var {safe_name} = ");
emit_expr_string(out, value, program, ctx)?;
let _ = writeln!(out, ";");
barriers.record_expr(value);
}
Node::Assign { name, value } => {
let safe_name = LowerCtx::safe_var(name);
let _ = write_indent(out, indent);
let _ = write!(out, "{safe_name} = ");
emit_expr_string(out, value, program, ctx)?;
let _ = writeln!(out, ";");
barriers.record_expr(value);
}
Node::Store {
buffer,
index,
value,
} => {
if barriers.needs_before_store(buffer) {
emit_synchronization_barriers(out, indent);
}
if let Some(b) = program.buffer(buffer) {
if b.access == BufferAccess::Workgroup {
let _ = write_indent(out, indent);
let _ = write!(out, "if (");
emit_expr_string(out, index, program, ctx)?;
let _ = writeln!(out, " < {count}u) {{", count = b.count);
ctx.indent += 1;
let _ = ctx.write_pad(out);
let _ = write!(out, "{buffer}[");
emit_expr_string(out, index, program, ctx)?;
let _ = write!(out, "] = ");
emit_expr_string(out, value, program, ctx)?;
let _ = writeln!(out, ";");
ctx.indent -= 1;
let _ = write_indent(out, indent);
let _ = writeln!(out, "}}");
} else {
let _ = write_indent(out, indent);
let _ = write!(out, "_vyre_store_{buffer}(");
emit_expr_string(out, index, program, ctx)?;
let _ = write!(out, ", ");
emit_expr_string(out, value, program, ctx)?;
let _ = writeln!(out, ");");
}
} else {
return Err(Error::lowering(format!(
"store references unknown buffer `{buffer}`. Fix: declare the buffer before lowering or reject the program in validation."
)));
}
barriers.record_expr(index);
barriers.record_expr(value);
}
Node::If {
cond,
then,
otherwise,
} => {
barriers.record_expr(cond);
let _ = write_indent(out, indent);
let _ = write!(out, "if (");
emit_expr_string(out, cond, program, ctx)?;
let _ = writeln!(out, " != 0u) {{");
ctx.indent += 1;
let outer_vars = ctx.vars.child();
ctx.vars = outer_vars.child();
let mut then_barriers = barriers.clone();
emit_nodes_with_barriers(out, ctx, then, program, &mut then_barriers)?;
ctx.vars = outer_vars.child();
ctx.indent -= 1;
let mut otherwise_barriers = barriers.clone();
if otherwise.is_empty() {
let _ = write_indent(out, indent);
let _ = writeln!(out, "}}");
} else {
let _ = write_indent(out, indent);
let _ = writeln!(out, "}} else {{");
ctx.indent += 1;
ctx.vars = outer_vars.child();
emit_nodes_with_barriers(out, ctx, otherwise, program, &mut otherwise_barriers)?;
ctx.indent -= 1;
let _ = write_indent(out, indent);
let _ = writeln!(out, "}}");
}
ctx.vars = outer_vars;
barriers.replace_with_union(then_barriers, otherwise_barriers);
}
Node::Loop {
var,
from,
to,
body,
} => {
barriers.record_expr(from);
barriers.record_expr(to);
let safe_var = LowerCtx::safe_var(var);
let _ = write_indent(out, indent);
let _ = write!(out, "for (var {safe_var} = ");
emit_expr_string(out, from, program, ctx)?;
let _ = write!(out, "; {safe_var} < ");
emit_expr_string(out, to, program, ctx)?;
let _ = writeln!(out, "; {safe_var} = {safe_var} + 1u) {{");
ctx.indent += 1;
let outer_vars = ctx.vars.child();
ctx.vars = outer_vars.child();
ctx.vars.insert(var.as_str(), DataType::U32);
let mut body_barriers = barriers.clone();
emit_nodes_with_barriers(out, ctx, body, program, &mut body_barriers)?;
ctx.vars = outer_vars;
ctx.indent -= 1;
let _ = write_indent(out, indent);
let _ = writeln!(out, "}}");
barriers.merge(body_barriers);
}
Node::Return => {
let _ = write_indent(out, indent);
let _ = writeln!(out, "return;");
}
Node::Block(nodes) => {
let _ = write_indent(out, indent);
let _ = writeln!(out, "{{");
ctx.indent += 1;
let outer_vars = ctx.vars.child();
ctx.vars = outer_vars.child();
emit_nodes_with_barriers(out, ctx, nodes, program, barriers)?;
ctx.vars = outer_vars;
ctx.indent -= 1;
let _ = write_indent(out, indent);
let _ = writeln!(out, "}}");
}
Node::Barrier => {
emit_synchronization_barriers(out, indent);
barriers.clear();
}
}
ensure_output_within_limit(out)?;
Ok(())
}
#[inline]
pub fn emit_synchronization_barriers(out: &mut String, indent: usize) {
let _ = write_indent(out, indent);
let _ = writeln!(out, "storageBarrier();");
let _ = write_indent(out, indent);
let _ = writeln!(out, "workgroupBarrier();");
}
#[derive(Clone, Default)]
pub struct BarrierState {
pending_atomic_buffers: HashSet<String>,
}
impl BarrierState {
#[inline]
pub(crate) fn needs_before_store(&mut self, buffer: &str) -> bool {
self.pending_atomic_buffers.remove(buffer)
}
#[inline]
pub(crate) fn record_expr(&mut self, expr: &crate::ir::Expr) {
collect_atomic_buffers(expr, &mut self.pending_atomic_buffers);
}
#[inline]
pub(crate) fn replace_with_union(&mut self, left: Self, right: Self) {
self.pending_atomic_buffers = left
.pending_atomic_buffers
.union(&right.pending_atomic_buffers)
.cloned()
.collect();
}
#[inline]
pub(crate) fn merge(&mut self, other: Self) {
self.pending_atomic_buffers
.extend(other.pending_atomic_buffers);
}
#[inline]
pub(crate) fn clear(&mut self) {
self.pending_atomic_buffers.clear();
}
}
#[inline]
pub fn collect_atomic_buffers(expr: &crate::ir::Expr, buffers: &mut HashSet<String>) {
use crate::ir::model::expr::Expr;
match expr {
Expr::Atomic {
buffer,
index,
expected,
value,
..
} => {
buffers.insert(buffer.to_string());
collect_atomic_buffers(index, buffers);
if let Some(expected) = expected {
collect_atomic_buffers(expected, buffers);
}
collect_atomic_buffers(value, buffers);
}
Expr::Load { index, .. } => collect_atomic_buffers(index, buffers),
Expr::BinOp { left, right, .. } => {
collect_atomic_buffers(left, buffers);
collect_atomic_buffers(right, buffers);
}
Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
collect_atomic_buffers(operand, buffers);
}
Expr::Call { args, .. } => {
for arg in args {
collect_atomic_buffers(arg, buffers);
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
collect_atomic_buffers(cond, buffers);
collect_atomic_buffers(true_val, buffers);
collect_atomic_buffers(false_val, buffers);
}
Expr::Fma { a, b, c } => {
collect_atomic_buffers(a, buffers);
collect_atomic_buffers(b, buffers);
collect_atomic_buffers(c, buffers);
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. } => {}
}
}