vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::model::node::Node;
use crate::ir::model::program::Program;
use crate::ir::model::types::{BufferAccess, DataType};
use crate::lower::wgsl::analysis::expr_type;
use crate::lower::wgsl::expr::emit_expr_string;
use crate::lower::wgsl::impl_lowerctx::write_indent;
use crate::lower::wgsl::{ensure_output_within_limit, Error, LowerCtx};
use std::collections::HashSet;
use std::fmt::Write as _;

/// Emit a sequence of nodes, skipping any that are unreachable after a `Return`.
#[inline]
pub(crate) fn emit_nodes<'a>(
    out: &mut String,
    ctx: &mut LowerCtx<'a>,
    nodes: &'a [Node],
    program: &'a Program,
) -> Result<(), Error> {
    let mut barriers = BarrierState::default();
    emit_nodes_with_barriers(out, ctx, nodes, program, &mut barriers)
}

#[inline]
pub fn emit_nodes_with_barriers<'a>(
    out: &mut String,
    ctx: &mut LowerCtx<'a>,
    nodes: &'a [Node],
    program: &'a Program,
    barriers: &mut BarrierState,
) -> Result<(), Error> {
    let mut returned = false;
    for node in nodes {
        if returned {
            continue;
        }
        emit_node(out, ctx, node, program, barriers)?;
        if matches!(node, Node::Return) {
            returned = true;
        }
    }
    Ok(())
}

#[inline]
pub fn emit_node<'a>(
    out: &mut String,
    ctx: &mut LowerCtx<'a>,
    node: &'a Node,
    program: &'a Program,
    barriers: &mut BarrierState,
) -> Result<(), Error> {
    let indent = ctx.indent;
    match node {
        Node::Let { name, value } => {
            if let Some(ty) = expr_type(value, &ctx.buffer_map, &ctx.vars) {
                ctx.vars.insert(name.as_str(), ty);
            }
            let safe_name = LowerCtx::safe_var(name);
            let _ = write_indent(out, indent);
            let _ = write!(out, "var {safe_name} = ");
            emit_expr_string(out, value, program, ctx)?;
            let _ = writeln!(out, ";");
            barriers.record_expr(value);
        }
        Node::Assign { name, value } => {
            let safe_name = LowerCtx::safe_var(name);
            let _ = write_indent(out, indent);
            let _ = write!(out, "{safe_name} = ");
            emit_expr_string(out, value, program, ctx)?;
            let _ = writeln!(out, ";");
            barriers.record_expr(value);
        }
        Node::Store {
            buffer,
            index,
            value,
        } => {
            if barriers.needs_before_store(buffer) {
                emit_synchronization_barriers(out, indent);
            }
            if let Some(b) = program.buffer(buffer) {
                if b.access == BufferAccess::Workgroup {
                    let _ = write_indent(out, indent);
                    let _ = write!(out, "if (");
                    emit_expr_string(out, index, program, ctx)?;
                    let _ = writeln!(out, " < {count}u) {{", count = b.count);
                    ctx.indent += 1;
                    let _ = ctx.write_pad(out);
                    let _ = write!(out, "{buffer}[");
                    emit_expr_string(out, index, program, ctx)?;
                    let _ = write!(out, "] = ");
                    emit_expr_string(out, value, program, ctx)?;
                    let _ = writeln!(out, ";");
                    ctx.indent -= 1;
                    let _ = write_indent(out, indent);
                    let _ = writeln!(out, "}}");
                } else {
                    let _ = write_indent(out, indent);
                    let _ = write!(out, "_vyre_store_{buffer}(");
                    emit_expr_string(out, index, program, ctx)?;
                    let _ = write!(out, ", ");
                    emit_expr_string(out, value, program, ctx)?;
                    let _ = writeln!(out, ");");
                }
            } else {
                return Err(Error::lowering(format!(
                    "store references unknown buffer `{buffer}`. Fix: declare the buffer before lowering or reject the program in validation."
                )));
            }
            barriers.record_expr(index);
            barriers.record_expr(value);
        }
        Node::If {
            cond,
            then,
            otherwise,
        } => {
            barriers.record_expr(cond);
            let _ = write_indent(out, indent);
            let _ = write!(out, "if (");
            emit_expr_string(out, cond, program, ctx)?;
            let _ = writeln!(out, " != 0u) {{");
            ctx.indent += 1;
            let outer_vars = ctx.vars.child();
            ctx.vars = outer_vars.child();
            let mut then_barriers = barriers.clone();
            emit_nodes_with_barriers(out, ctx, then, program, &mut then_barriers)?;
            ctx.vars = outer_vars.child();
            ctx.indent -= 1;
            let mut otherwise_barriers = barriers.clone();
            if otherwise.is_empty() {
                let _ = write_indent(out, indent);
                let _ = writeln!(out, "}}");
            } else {
                let _ = write_indent(out, indent);
                let _ = writeln!(out, "}} else {{");
                ctx.indent += 1;
                ctx.vars = outer_vars.child();
                emit_nodes_with_barriers(out, ctx, otherwise, program, &mut otherwise_barriers)?;
                ctx.indent -= 1;
                let _ = write_indent(out, indent);
                let _ = writeln!(out, "}}");
            }
            ctx.vars = outer_vars;
            barriers.replace_with_union(then_barriers, otherwise_barriers);
        }
        Node::Loop {
            var,
            from,
            to,
            body,
        } => {
            barriers.record_expr(from);
            barriers.record_expr(to);
            let safe_var = LowerCtx::safe_var(var);
            let _ = write_indent(out, indent);
            let _ = write!(out, "for (var {safe_var} = ");
            emit_expr_string(out, from, program, ctx)?;
            let _ = write!(out, "; {safe_var} < ");
            emit_expr_string(out, to, program, ctx)?;
            let _ = writeln!(out, "; {safe_var} = {safe_var} + 1u) {{");
            ctx.indent += 1;
            let outer_vars = ctx.vars.child();
            ctx.vars = outer_vars.child();
            ctx.vars.insert(var.as_str(), DataType::U32);
            let mut body_barriers = barriers.clone();
            emit_nodes_with_barriers(out, ctx, body, program, &mut body_barriers)?;
            ctx.vars = outer_vars;
            ctx.indent -= 1;
            let _ = write_indent(out, indent);
            let _ = writeln!(out, "}}");
            barriers.merge(body_barriers);
        }
        Node::Return => {
            let _ = write_indent(out, indent);
            let _ = writeln!(out, "return;");
        }
        Node::Block(nodes) => {
            let _ = write_indent(out, indent);
            let _ = writeln!(out, "{{");
            ctx.indent += 1;
            let outer_vars = ctx.vars.child();
            ctx.vars = outer_vars.child();
            emit_nodes_with_barriers(out, ctx, nodes, program, barriers)?;
            ctx.vars = outer_vars;
            ctx.indent -= 1;
            let _ = write_indent(out, indent);
            let _ = writeln!(out, "}}");
        }
        Node::Barrier => {
            emit_synchronization_barriers(out, indent);
            barriers.clear();
        }
    }
    ensure_output_within_limit(out)?;
    Ok(())
}

#[inline]
pub fn emit_synchronization_barriers(out: &mut String, indent: usize) {
    let _ = write_indent(out, indent);
    let _ = writeln!(out, "storageBarrier();");
    let _ = write_indent(out, indent);
    let _ = writeln!(out, "workgroupBarrier();");
}

#[derive(Clone, Default)]
pub struct BarrierState {
    pending_atomic_buffers: HashSet<String>,
}

impl BarrierState {
    #[inline]
    pub(crate) fn needs_before_store(&mut self, buffer: &str) -> bool {
        self.pending_atomic_buffers.remove(buffer)
    }

    #[inline]
    pub(crate) fn record_expr(&mut self, expr: &crate::ir::Expr) {
        collect_atomic_buffers(expr, &mut self.pending_atomic_buffers);
    }

    #[inline]
    pub(crate) fn replace_with_union(&mut self, left: Self, right: Self) {
        self.pending_atomic_buffers = left
            .pending_atomic_buffers
            .union(&right.pending_atomic_buffers)
            .cloned()
            .collect();
    }

    #[inline]
    pub(crate) fn merge(&mut self, other: Self) {
        self.pending_atomic_buffers
            .extend(other.pending_atomic_buffers);
    }

    #[inline]
    pub(crate) fn clear(&mut self) {
        self.pending_atomic_buffers.clear();
    }
}

#[inline]
pub fn collect_atomic_buffers(expr: &crate::ir::Expr, buffers: &mut HashSet<String>) {
    use crate::ir::model::expr::Expr;

    match expr {
        Expr::Atomic {
            buffer,
            index,
            expected,
            value,
            ..
        } => {
            buffers.insert(buffer.to_string());
            collect_atomic_buffers(index, buffers);
            if let Some(expected) = expected {
                collect_atomic_buffers(expected, buffers);
            }
            collect_atomic_buffers(value, buffers);
        }
        Expr::Load { index, .. } => collect_atomic_buffers(index, buffers),
        Expr::BinOp { left, right, .. } => {
            collect_atomic_buffers(left, buffers);
            collect_atomic_buffers(right, buffers);
        }
        Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
            collect_atomic_buffers(operand, buffers);
        }
        Expr::Call { args, .. } => {
            for arg in args {
                collect_atomic_buffers(arg, buffers);
            }
        }
        Expr::Select {
            cond,
            true_val,
            false_val,
        } => {
            collect_atomic_buffers(cond, buffers);
            collect_atomic_buffers(true_val, buffers);
            collect_atomic_buffers(false_val, buffers);
        }
        Expr::Fma { a, b, c } => {
            collect_atomic_buffers(a, buffers);
            collect_atomic_buffers(b, buffers);
            collect_atomic_buffers(c, buffers);
        }
        Expr::LitU32(_)
        | Expr::LitI32(_)
        | Expr::LitF32(_)
        | Expr::LitBool(_)
        | Expr::Var(_)
        | Expr::BufLen { .. }
        | Expr::InvocationId { .. }
        | Expr::WorkgroupId { .. }
        | Expr::LocalId { .. } => {}
    }
}