vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use super::{expr_has_effect, CseCtx, ExprKey};
use crate::ir::{Expr, Node};

impl CseCtx {
    #[inline]
    pub(crate) fn child(&self) -> Self {
        Self {
            values: self.values.clone(),
        }
    }

    #[inline]
    pub(crate) fn clear_observed_state(&mut self) {
        self.values.clear();
    }

    #[inline]
    pub(crate) fn nodes(&mut self, nodes: &[Node]) -> Vec<Node> {
        nodes.iter().map(|node| self.node(node)).collect()
    }

    #[inline]
    pub(crate) fn node(&mut self, node: &Node) -> Node {
        match node {
            Node::Let { name, value } => {
                let value = self.expr(value);
                if expr_has_effect(&value) {
                    self.clear_observed_state();
                    return Node::let_bind(name, value);
                }

                let key = ExprKey::from_expr(&value);
                // Do not CSE-alias literals through variables: `let state = 0u`
                // must not record `LitU32(0) → "state"` because `state` may be
                // reassigned later while the literal stays constant.
                if matches!(value, Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitBool(_)) {
                    return Node::let_bind(name, value);
                }
                let canonical = self
                    .values
                    .get(&key)
                    .map_or_else(|| value.clone(), |existing| Expr::var(existing));
                self.values.entry(key).or_insert_with(|| name.clone());
                Node::let_bind(name, canonical)
            }
            Node::Assign { name, value } => {
                let value = self.expr(value);
                self.clear_observed_state();
                Node::assign(name, value)
            }
            Node::Store {
                buffer,
                index,
                value,
            } => {
                let index = self.expr(index);
                let value = self.expr(value);
                self.clear_observed_state();
                Node::store(buffer, index, value)
            }
            Node::If {
                cond,
                then,
                otherwise,
            } => {
                let cond = self.expr(cond);
                let mut then_ctx = self.child();
                let then = then_ctx.nodes(then);
                let mut otherwise_ctx = self.child();
                let otherwise = otherwise_ctx.nodes(otherwise);
                self.clear_observed_state();
                Node::if_then_else(cond, then, otherwise)
            }
            Node::Loop {
                var,
                from,
                to,
                body,
            } => {
                let from = self.expr(from);
                let to = self.expr(to);
                let mut body_ctx = CseCtx::default();
                let body = body_ctx.nodes(body);
                self.clear_observed_state();
                Node::loop_for(var, from, to, body)
            }
            Node::Return => Node::Return,
            Node::Block(nodes) => {
                let mut block_ctx = self.child();
                let nodes = block_ctx.nodes(nodes);
                self.clear_observed_state();
                Node::block(nodes)
            }
            Node::Barrier => {
                self.clear_observed_state();
                Node::Barrier
            }
        }
    }

    #[inline]
    pub(crate) fn expr(&mut self, expr: &Expr) -> Expr {
        let rewritten = match expr {
            Expr::Load { buffer, index } => Expr::Load {
                buffer: buffer.clone(),
                index: Box::new(self.expr(index)),
            },
            Expr::BinOp { op, left, right } => Expr::BinOp {
                op: op.clone(),
                left: Box::new(self.expr(left)),
                right: Box::new(self.expr(right)),
            },
            Expr::UnOp { op, operand } => Expr::UnOp {
                op: op.clone(),
                operand: Box::new(self.expr(operand)),
            },
            Expr::Fma { a, b, c } => Expr::Fma {
                a: Box::new(self.expr(a)),
                b: Box::new(self.expr(b)),
                c: Box::new(self.expr(c)),
            },
            Expr::Call { op_id, args } => Expr::Call {
                op_id: op_id.clone(),
                args: args.iter().map(|arg| self.expr(arg)).collect(),
            },
            Expr::Select {
                cond,
                true_val,
                false_val,
            } => Expr::Select {
                cond: Box::new(self.expr(cond)),
                true_val: Box::new(self.expr(true_val)),
                false_val: Box::new(self.expr(false_val)),
            },
            Expr::Cast { target, value } => Expr::Cast {
                target: target.clone(),
                value: Box::new(self.expr(value)),
            },
            Expr::Atomic {
                op,
                buffer,
                index,
                expected,
                value,
            } => {
                let index = self.expr(index);
                let expected = expected.as_deref().map(|expr| Box::new(self.expr(expr)));
                let value = self.expr(value);
                self.clear_observed_state();
                Expr::Atomic {
                    op: op.clone(),
                    buffer: buffer.clone(),
                    index: Box::new(index),
                    expected,
                    value: Box::new(value),
                }
            }
            Expr::LitU32(_)
            | Expr::LitI32(_)
            | Expr::LitF32(_)
            | Expr::LitBool(_)
            | Expr::Var(_)
            | Expr::BufLen { .. }
            | Expr::InvocationId { .. }
            | Expr::WorkgroupId { .. }
            | Expr::LocalId { .. } => expr.clone(),
        };

        if matches!(rewritten, Expr::Var(_)) || expr_has_effect(&rewritten) {
            return rewritten;
        }

        let key = ExprKey::from_expr(&rewritten);
        // Never replace a literal with a variable reference — literals are
        // already minimal, and the variable may be mutated later (e.g. loop
        // counters or state accumulators). Substituting `0u` with `var state`
        // when `state` was initially `0u` is unsound if `state` is reassigned.
        if matches!(
            rewritten,
            Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitBool(_)
        ) {
            return rewritten;
        }
        self.values
            .get(&key)
            .map_or(rewritten, |existing| Expr::var(existing))
    }
}