use std::collections::HashMap;
use thiserror::Error;
use vyre::ir::{BufferDecl, DataType, Expr as IrExpr, Node, Program};
use super::lex::tokens::{ANDAND, EQ, GE, GT, LE, LT, MINUS, NE, OROR, PERCENT, PLUS, SLASH, STAR};
use super::parse::{Expr, Module, Stmt, Type};
use super::sema::{BindingId, Resolution};
#[derive(Debug, Clone, Error)]
pub enum RustLowerError {
#[error("Rust lowering needs at least one function to use as the entry kernel")]
NoEntryFunction,
#[error("Rust to Vyre IR lowering does not support {0} yet; not emitting a miscompiled Program")]
Unsupported(String),
}
pub fn lower(module: &Module, resolution: &Resolution) -> Result<Program, RustLowerError> {
let entry_index = module
.functions
.len()
.checked_sub(1)
.ok_or(RustLowerError::NoEntryFunction)?;
let func = &module.functions[entry_index];
let def_to_id: HashMap<u32, BindingId> = resolution
.bindings
.iter()
.enumerate()
.map(|(id, b)| (b.def_offset, id))
.collect();
let mut buffers = Vec::with_capacity(func.params.len() + 1);
let mut entry_nodes = Vec::new();
for (i, (offset, ty)) in func.params.iter().enumerate() {
let dtype = scalar_dtype(ty)?;
let buf = format!("p{i}");
buffers.push(BufferDecl::read(&buf, i as u32, dtype).with_count(1));
let binding = def_to_id
.get(offset)
.copied()
.ok_or_else(|| RustLowerError::Unsupported("unresolved parameter".to_string()))?;
entry_nodes.push(Node::let_bind(
format!("v{binding}"),
IrExpr::load(buf, IrExpr::u32(0)),
));
}
let out_dtype = scalar_dtype(&func.ret)?;
buffers.push(BufferDecl::output("out", func.params.len() as u32, out_dtype).with_count(1));
let ctx = LowerCtx { module, resolution, def_to_id: &def_to_id };
entry_nodes.extend(ctx.lower_stmts(&func.body, None)?);
Ok(Program::wrapped(buffers, [1, 1, 1], entry_nodes))
}
fn scalar_dtype(ty: &Type) -> Result<DataType, RustLowerError> {
match ty {
Type::I32 => Ok(DataType::I32),
Type::Bool => Ok(DataType::Bool),
Type::Unit => Err(RustLowerError::Unsupported("unit-typed parameter or return".to_string())),
Type::Ref { inner, .. } => scalar_dtype(inner),
}
}
struct LowerCtx<'a> {
module: &'a Module,
resolution: &'a Resolution,
def_to_id: &'a HashMap<u32, BindingId>,
}
impl LowerCtx<'_> {
fn lower_stmts(
&self,
stmts: &[Stmt],
subst: Option<&HashMap<BindingId, IrExpr>>,
) -> Result<Vec<Node>, RustLowerError> {
let mut nodes = Vec::new();
for stmt in stmts {
match stmt {
Stmt::Let { name, init, .. } => {
let binding = self.def_to_id.get(name).copied().ok_or_else(|| {
RustLowerError::Unsupported("unresolved let binding".to_string())
})?;
nodes.push(Node::let_bind(format!("v{binding}"), self.lower_value(init, subst)?));
}
Stmt::Return(Some(expr)) => {
nodes.push(Node::store("out", IrExpr::u32(0), self.lower_value(expr, subst)?));
return Ok(nodes);
}
Stmt::Return(None) => return Ok(nodes),
Stmt::Assign { name, value } => {
let binding = self.resolution.uses.get(name).copied().ok_or_else(|| {
RustLowerError::Unsupported("unresolved assignment target".to_string())
})?;
nodes.push(Node::assign(format!("v{binding}"), self.lower_value(value, subst)?));
}
Stmt::Expr(Expr::If { cond, then_block, else_block }) => {
let then_nodes = self.lower_stmts(block_stmts(then_block), subst)?;
let else_nodes = match else_block {
Some(block) => self.lower_stmts(block_stmts(block), subst)?,
None => Vec::new(),
};
nodes.push(Node::if_then_else(self.lower_value(cond, subst)?, then_nodes, else_nodes));
let then_div = stmts_diverge(block_stmts(then_block));
let else_div = else_block.as_ref().is_some_and(|b| stmts_diverge(block_stmts(b)));
if then_div && else_div {
return Ok(nodes);
}
}
Stmt::While { cond, body } => {
nodes.extend(self.lower_while(cond, body, subst)?);
}
Stmt::Expr(_) => {}
}
}
Ok(nodes)
}
fn lower_while(
&self,
cond: &Expr,
body: &[Stmt],
subst: Option<&HashMap<BindingId, IrExpr>>,
) -> Result<Vec<Node>, RustLowerError> {
let bad = || {
RustLowerError::Unsupported(
"while loop that is not a canonical `while i < BOUND { ...; i = i + 1; }` counting loop"
.to_string(),
)
};
let (i_off, bound) = match cond {
Expr::Binary { op, lhs, rhs } if *op == LT => match lhs.as_ref() {
Expr::Var(off) => (*off, rhs.as_ref()),
_ => return Err(bad()),
},
_ => return Err(bad()),
};
let b_i = self.resolution.uses.get(&i_off).copied().ok_or_else(bad)?;
let Some((last, init_stmts)) = body.split_last() else {
return Err(bad());
};
let inc_ok = matches!(last, Stmt::Assign { name, value }
if self.resolution.uses.get(name).copied() == Some(b_i)
&& matches!(value, Expr::Binary { op, lhs, rhs }
if *op == PLUS
&& matches!(lhs.as_ref(), Expr::Var(o) if self.resolution.uses.get(o).copied() == Some(b_i))
&& matches!(rhs.as_ref(), Expr::LiteralInt(_, 1))));
if !inc_ok {
return Err(bad());
}
if stmts_assign_binding(init_stmts, b_i, self.resolution) {
return Err(bad());
}
for v in expr_var_bindings(bound, self.resolution) {
if stmts_assign_binding(body, v, self.resolution) {
return Err(bad());
}
}
let loop_var = format!("v{b_i}__w");
let mut inner = subst.cloned().unwrap_or_default();
inner.insert(b_i, IrExpr::var(loop_var.clone()));
let from = IrExpr::var(format!("v{b_i}"));
let to = self.lower_value(bound, subst)?;
let loop_body = self.lower_stmts(init_stmts, Some(&inner))?;
Ok(vec![
Node::loop_for(loop_var, from, to.clone(), loop_body),
Node::assign(format!("v{b_i}"), to),
])
}
fn lower_value(
&self,
expr: &Expr,
subst: Option<&HashMap<BindingId, IrExpr>>,
) -> Result<IrExpr, RustLowerError> {
match expr {
Expr::LiteralInt(_, value) => Ok(IrExpr::i32(*value as i32)),
Expr::LiteralBool(_, value) => Ok(IrExpr::bool(*value)),
Expr::Var(offset) => {
let binding = self.resolution.uses.get(offset).copied().ok_or_else(|| {
RustLowerError::Unsupported("unresolved variable use".to_string())
})?;
match subst {
Some(map) => map.get(&binding).cloned().ok_or_else(|| {
RustLowerError::Unsupported("callee variable not substituted".to_string())
}),
None => Ok(IrExpr::var(format!("v{binding}"))),
}
}
Expr::Binary { op, lhs, rhs } => {
let l = self.lower_value(lhs, subst)?;
let r = self.lower_value(rhs, subst)?;
Ok(match *op {
PLUS => IrExpr::add(l, r),
MINUS => IrExpr::sub(l, r),
STAR => IrExpr::mul(l, r),
SLASH => IrExpr::div(l, r),
PERCENT => IrExpr::cast(DataType::I32, IrExpr::rem(l, r)),
EQ => IrExpr::eq(l, r),
NE => IrExpr::ne(l, r),
LT => IrExpr::lt(l, r),
GT => IrExpr::gt(l, r),
LE => IrExpr::le(l, r),
GE => IrExpr::ge(l, r),
ANDAND => IrExpr::and(l, r),
OROR => IrExpr::or(l, r),
other => return Err(RustLowerError::Unsupported(format!("binary operator {other}"))),
})
}
Expr::Call { name, args } => self.lower_call(name, args, subst),
Expr::Borrow { expr, .. } => self.lower_value(expr, subst),
Expr::Deref(inner) => self.lower_value(inner, subst),
Expr::Not(inner) => Ok(IrExpr::not(self.lower_value(inner, subst)?)),
Expr::Block(_) | Expr::If { .. } => {
Err(RustLowerError::Unsupported("block/if used as a value".to_string()))
}
}
}
fn lower_call(
&self,
name: &u32,
args: &[Expr],
caller_subst: Option<&HashMap<BindingId, IrExpr>>,
) -> Result<IrExpr, RustLowerError> {
let callee_index = self
.resolution
.calls
.get(name)
.copied()
.ok_or_else(|| RustLowerError::Unsupported("unresolved call".to_string()))?;
let callee = &self.module.functions[callee_index];
if args.len() != callee.params.len() {
return Err(RustLowerError::Unsupported("call arity mismatch".to_string()));
}
let mut subst: HashMap<BindingId, IrExpr> = HashMap::new();
for (i, (offset, _)) in callee.params.iter().enumerate() {
let binding = self.def_to_id.get(offset).copied().ok_or_else(|| {
RustLowerError::Unsupported("unresolved callee parameter".to_string())
})?;
subst.insert(binding, self.lower_value(&args[i], caller_subst)?);
}
for stmt in &callee.body {
match stmt {
Stmt::Let { name: offset, init, .. } => {
let value = self.lower_value(init, Some(&subst))?;
let binding = self.def_to_id.get(offset).copied().ok_or_else(|| {
RustLowerError::Unsupported("unresolved callee binding".to_string())
})?;
subst.insert(binding, value);
}
Stmt::Return(Some(expr)) => return self.lower_value(expr, Some(&subst)),
_ => {
return Err(RustLowerError::Unsupported(
"call to a callee with control flow or no terminal return".to_string(),
))
}
}
}
Err(RustLowerError::Unsupported("call to a callee with no return".to_string()))
}
}
fn stmts_assign_binding(stmts: &[Stmt], b: BindingId, res: &Resolution) -> bool {
stmts.iter().any(|s| match s {
Stmt::Assign { name, .. } => res.uses.get(name).copied() == Some(b),
Stmt::Expr(Expr::If { then_block, else_block, .. }) => {
stmts_assign_binding(block_stmts(then_block), b, res)
|| else_block.as_ref().is_some_and(|e| stmts_assign_binding(block_stmts(e), b, res))
}
Stmt::While { body, .. } => stmts_assign_binding(body, b, res),
_ => false,
})
}
fn expr_var_bindings(expr: &Expr, res: &Resolution) -> Vec<BindingId> {
let mut out = Vec::new();
collect_var_bindings(expr, res, &mut out);
out
}
fn collect_var_bindings(expr: &Expr, res: &Resolution, out: &mut Vec<BindingId>) {
match expr {
Expr::Var(off) => {
if let Some(&id) = res.uses.get(off) {
out.push(id);
}
}
Expr::Binary { lhs, rhs, .. } => {
collect_var_bindings(lhs, res, out);
collect_var_bindings(rhs, res, out);
}
Expr::Borrow { expr, .. } => collect_var_bindings(expr, res, out),
Expr::Deref(inner) => collect_var_bindings(inner, res, out),
Expr::Not(inner) => collect_var_bindings(inner, res, out),
Expr::Call { args, .. } => {
for a in args {
collect_var_bindings(a, res, out);
}
}
_ => {}
}
}
fn block_stmts(expr: &Expr) -> &[Stmt] {
match expr {
Expr::Block(stmts) => stmts,
_ => &[],
}
}
fn stmts_diverge(stmts: &[Stmt]) -> bool {
stmts.iter().any(|stmt| match stmt {
Stmt::Return(_) => true,
Stmt::Expr(Expr::If { then_block, else_block: Some(else_block), .. }) => {
stmts_diverge(block_stmts(then_block)) && stmts_diverge(block_stmts(else_block))
}
_ => false,
})
}