vyre-foundation 0.4.1

Foundation layer: IR, type system, memory model, wire format. Zero application semantics. Part of the vyre GPU compiler.
Documentation
use super::super::CalleeExpander;
use crate::error::Result;
use crate::ir::{AtomicOp, BinOp, Expr, Ident, UnOp};
use crate::memory_model::MemoryOrdering;

enum Frame<'a> {
    Enter(&'a Expr),
    Load {
        buffer: &'a Ident,
    },
    Bin {
        op: BinOp,
    },
    Un {
        op: UnOp,
    },
    Call {
        op_id: &'a str,
        args: usize,
    },
    Fma,
    Select,
    Cast {
        target: crate::ir::DataType,
    },
    Atomic {
        op: AtomicOp,
        buffer: &'a Ident,
        has_expected: bool,
        ordering: MemoryOrdering,
    },
}

impl CalleeExpander<'_> {
    #[inline]
    pub(crate) fn rename_decl(&mut self, name: &Ident) -> String {
        let renamed = format!("{}{name}", self.prefix);
        self.vars.insert(Ident::from(name), renamed.clone());
        renamed
    }

    #[inline]
    pub(crate) fn rename_use(&self, name: &Ident) -> String {
        self.vars
            .get(name)
            .cloned()
            .unwrap_or_else(|| name.to_string())
    }

    #[inline]
    pub(crate) fn rename_expr_vars(&self, expr: &Expr) -> Result<Expr> {
        let mut frames = vec![Frame::Enter(expr)];
        let mut values: Vec<Expr> = Vec::new();
        while let Some(frame) = frames.pop() {
            match frame {
                Frame::Enter(expr) => self.enter_expr_frame(expr, &mut frames, &mut values),
                Frame::Load { buffer } => push_load(buffer, &mut values)?,
                Frame::Bin { op } => push_bin(op, &mut values)?,
                Frame::Un { op } => push_un(op, &mut values)?,
                Frame::Call { op_id, args } => push_call(op_id, args, &mut values)?,
                Frame::Fma => push_fma(&mut values)?,
                Frame::Select => push_select(&mut values)?,
                Frame::Cast { target } => push_cast(target, &mut values)?,
                Frame::Atomic {
                    op,
                    buffer,
                    has_expected,
                    ordering,
                } => push_atomic(op, buffer, has_expected, ordering, &mut values)?,
            }
        }
        values.pop().ok_or_else(|| {
            crate::error::Error::lowering(
                "IR inline expansion: expression rename produced no value. Fix: ensure the input Expr is well-formed.",
            )
        })
    }

    fn enter_expr_frame<'a>(
        &self,
        expr: &'a Expr,
        frames: &mut Vec<Frame<'a>>,
        values: &mut Vec<Expr>,
    ) {
        match expr {
            Expr::Var(name) => values.push(Expr::var(self.rename_use(name))),
            Expr::Load { buffer, index } => {
                frames.push(Frame::Load { buffer });
                frames.push(Frame::Enter(index));
            }
            Expr::BinOp { op, left, right } => {
                frames.push(Frame::Bin { op: *op });
                frames.push(Frame::Enter(right));
                frames.push(Frame::Enter(left));
            }
            Expr::UnOp { op, operand } => {
                frames.push(Frame::Un { op: op.clone() });
                frames.push(Frame::Enter(operand));
            }
            Expr::Call { op_id, args } => {
                frames.push(Frame::Call {
                    op_id,
                    args: args.len(),
                });
                for arg in args.iter().rev() {
                    frames.push(Frame::Enter(arg));
                }
            }
            Expr::Fma { a, b, c } => {
                frames.push(Frame::Fma);
                frames.push(Frame::Enter(c));
                frames.push(Frame::Enter(b));
                frames.push(Frame::Enter(a));
            }
            Expr::Select {
                cond,
                true_val,
                false_val,
            } => {
                frames.push(Frame::Select);
                frames.push(Frame::Enter(false_val));
                frames.push(Frame::Enter(true_val));
                frames.push(Frame::Enter(cond));
            }
            Expr::Cast { target, value } => {
                frames.push(Frame::Cast {
                    target: target.clone(),
                });
                frames.push(Frame::Enter(value));
            }
            Expr::Atomic {
                op,
                buffer,
                index,
                expected,
                value,
                ordering,
            } => {
                frames.push(Frame::Atomic {
                    op: *op,
                    buffer,
                    has_expected: expected.is_some(),
                    ordering: *ordering,
                });
                frames.push(Frame::Enter(value));
                if let Some(expected) = expected.as_deref() {
                    frames.push(Frame::Enter(expected));
                }
                frames.push(Frame::Enter(index));
            }
            Expr::InvocationId { .. }
            | Expr::WorkgroupId { .. }
            | Expr::LocalId { .. }
            | Expr::SubgroupLocalId
            | Expr::SubgroupSize => {
                values.push(Expr::u32(0));
            }
            Expr::LitU32(_)
            | Expr::LitI32(_)
            | Expr::LitF32(_)
            | Expr::LitBool(_)
            | Expr::BufLen { .. }
            | Expr::SubgroupBallot { .. }
            | Expr::SubgroupShuffle { .. }
            | Expr::SubgroupAdd { .. } => values.push(expr.clone()),
            Expr::Opaque(_) => values.push(expr.clone()),
        }
    }
}

fn missing(what: &str) -> crate::error::Error {
    crate::error::Error::lowering(format!(
        "IR inline expansion: {what} missing from expression stack. Fix: ensure the input Expr is well-formed."
    ))
}

fn push_load(buffer: &Ident, values: &mut Vec<Expr>) -> Result<()> {
    let index = values.pop().ok_or_else(|| missing("load index"))?;
    values.push(Expr::Load {
        buffer: buffer.into(),
        index: Box::new(index),
    });
    Ok(())
}

fn push_bin(op: BinOp, values: &mut Vec<Expr>) -> Result<()> {
    let right = values.pop().ok_or_else(|| missing("binop right operand"))?;
    let left = values.pop().ok_or_else(|| missing("binop left operand"))?;
    values.push(Expr::BinOp {
        op,
        left: Box::new(left),
        right: Box::new(right),
    });
    Ok(())
}

fn push_un(op: UnOp, values: &mut Vec<Expr>) -> Result<()> {
    let operand = values.pop().ok_or_else(|| missing("unop operand"))?;
    values.push(Expr::UnOp {
        op,
        operand: Box::new(operand),
    });
    Ok(())
}

fn push_call(op_id: &str, args: usize, values: &mut Vec<Expr>) -> Result<()> {
    let split_at = values
        .len()
        .checked_sub(args)
        .ok_or_else(|| missing("call argument count mismatch"))?;
    let args = values.split_off(split_at);
    values.push(Expr::call(op_id, args));
    Ok(())
}

fn push_fma(values: &mut Vec<Expr>) -> Result<()> {
    let c = values.pop().ok_or_else(|| missing("fma operand c"))?;
    let b = values.pop().ok_or_else(|| missing("fma operand b"))?;
    let a = values.pop().ok_or_else(|| missing("fma operand a"))?;
    values.push(Expr::Fma {
        a: Box::new(a),
        b: Box::new(b),
        c: Box::new(c),
    });
    Ok(())
}

fn push_select(values: &mut Vec<Expr>) -> Result<()> {
    let false_val = values.pop().ok_or_else(|| missing("select false branch"))?;
    let true_val = values.pop().ok_or_else(|| missing("select true branch"))?;
    let cond = values.pop().ok_or_else(|| missing("select condition"))?;
    values.push(Expr::Select {
        cond: Box::new(cond),
        true_val: Box::new(true_val),
        false_val: Box::new(false_val),
    });
    Ok(())
}

fn push_cast(target: crate::ir::DataType, values: &mut Vec<Expr>) -> Result<()> {
    let value = values.pop().ok_or_else(|| missing("cast value"))?;
    values.push(Expr::Cast {
        target,
        value: Box::new(value),
    });
    Ok(())
}

fn push_atomic(
    op: AtomicOp,
    buffer: &Ident,
    has_expected: bool,
    ordering: MemoryOrdering,
    values: &mut Vec<Expr>,
) -> Result<()> {
    let value = values.pop().ok_or_else(|| missing("atomic value"))?;
    let expected = if has_expected {
        Some(Box::new(
            values
                .pop()
                .ok_or_else(|| missing("atomic expected value"))?,
        ))
    } else {
        None
    };
    let index = values.pop().ok_or_else(|| missing("atomic index"))?;
    values.push(Expr::Atomic {
        op,
        buffer: buffer.into(),
        index: Box::new(index),
        expected,
        value: Box::new(value),
        ordering,
    });
    Ok(())
}