use std::collections::HashMap;
use thiserror::Error;
use vyre::ir::{BufferDecl, DataType, Expr as IrExpr, Node, Program};
use super::lex::tokens::{ANDAND, EQ, GE, GT, LE, LT, MINUS, NE, OROR, PERCENT, PLUS, SLASH, STAR};
use super::parse::{Expr, Module, Stmt, Type};
use super::sema::{BindingId, Resolution};
#[derive(Debug, Clone, Error)]
pub enum RustLowerError {
#[error("Rust lowering needs at least one function to use as the entry kernel")]
NoEntryFunction,
#[error(
"Rust to Vyre IR lowering does not support {0} yet; not emitting a miscompiled Program"
)]
Unsupported(String),
}
pub fn lower(module: &Module, resolution: &Resolution) -> Result<Program, RustLowerError> {
lower_entry(module, resolution, LowerMode::Scalar)
}
pub fn lower_batched(
module: &Module,
resolution: &Resolution,
lane_count: u32,
) -> Result<Program, RustLowerError> {
if lane_count == 0 {
return Err(RustLowerError::Unsupported(
"batched Rust lowering with zero lanes".to_string(),
));
}
lower_entry(module, resolution, LowerMode::Batched { lane_count })
}
#[derive(Clone, Copy)]
enum LowerMode {
Scalar,
Batched { lane_count: u32 },
}
impl LowerMode {
fn buffer_count(self) -> u32 {
match self {
Self::Scalar => 1,
Self::Batched { lane_count } => lane_count,
}
}
fn workgroup_size(self) -> [u32; 3] {
match self {
Self::Scalar => [1, 1, 1],
Self::Batched { .. } => [256, 1, 1],
}
}
fn lane_index(self) -> IrExpr {
match self {
Self::Scalar => IrExpr::u32(0),
Self::Batched { .. } => IrExpr::var(BATCH_LANE_VAR),
}
}
}
const BATCH_LANE_VAR: &str = "__rust_lane";
fn lower_entry(
module: &Module,
resolution: &Resolution,
mode: LowerMode,
) -> Result<Program, RustLowerError> {
let entry_index = module
.functions
.len()
.checked_sub(1)
.ok_or(RustLowerError::NoEntryFunction)?;
let func = &module.functions[entry_index];
let def_to_id: HashMap<u32, BindingId> = resolution
.bindings
.iter()
.enumerate()
.map(|(id, b)| (b.def_offset, id))
.collect();
let mut buffers = Vec::with_capacity(func.params.len() + 1);
let mut entry_nodes = Vec::new();
for (i, (offset, ty)) in func.params.iter().enumerate() {
let dtype = scalar_dtype(ty)?;
let buf = format!("p{i}");
buffers.push(BufferDecl::read(&buf, i as u32, dtype).with_count(mode.buffer_count()));
let binding = def_to_id
.get(offset)
.copied()
.ok_or_else(|| RustLowerError::Unsupported("unresolved parameter".to_string()))?;
entry_nodes.push(Node::let_bind(
format!("v{binding}"),
IrExpr::load(buf, mode.lane_index()),
));
}
let out_dtype = scalar_dtype(&func.ret)?;
buffers.push(
BufferDecl::output("out", func.params.len() as u32, out_dtype)
.with_count(mode.buffer_count()),
);
let ctx = LowerCtx {
module,
resolution,
def_to_id: &def_to_id,
output_index: mode.lane_index(),
};
entry_nodes.extend(ctx.lower_stmts(&func.body, Subst::Local(None))?);
let entry_nodes = match mode {
LowerMode::Scalar => entry_nodes,
LowerMode::Batched { lane_count } => vec![
Node::let_bind(BATCH_LANE_VAR, IrExpr::gid_x()),
Node::if_then(
IrExpr::lt(IrExpr::var(BATCH_LANE_VAR), IrExpr::u32(lane_count)),
entry_nodes,
),
],
};
Ok(Program::wrapped(
buffers,
mode.workgroup_size(),
entry_nodes,
))
}
fn scalar_dtype(ty: &Type) -> Result<DataType, RustLowerError> {
match ty {
Type::I32 => Ok(DataType::I32),
Type::Bool => Ok(DataType::Bool),
Type::Unit => Err(RustLowerError::Unsupported(
"unit-typed parameter or return".to_string(),
)),
Type::Ref { inner, .. } => scalar_dtype(inner),
}
}
struct LowerCtx<'a> {
module: &'a Module,
resolution: &'a Resolution,
def_to_id: &'a HashMap<u32, BindingId>,
output_index: IrExpr,
}
#[derive(Clone, Copy)]
enum Subst<'a> {
Local(Option<&'a HashMap<BindingId, IrExpr>>),
Inline(&'a HashMap<BindingId, IrExpr>),
}
impl LowerCtx<'_> {
fn lower_stmts(&self, stmts: &[Stmt], subst: Subst<'_>) -> Result<Vec<Node>, RustLowerError> {
let mut nodes = Vec::new();
for stmt in stmts {
match stmt {
Stmt::Let { name, init, .. } => {
let binding = self.def_to_id.get(name).copied().ok_or_else(|| {
RustLowerError::Unsupported("unresolved let binding".to_string())
})?;
nodes.push(Node::let_bind(
format!("v{binding}"),
self.lower_value(init, subst)?,
));
}
Stmt::Return(Some(expr)) => {
nodes.push(Node::store(
"out",
self.output_index.clone(),
self.lower_value(expr, subst)?,
));
return Ok(nodes);
}
Stmt::Return(None) => return Ok(nodes),
Stmt::Assign { name, value } => {
let binding = self.resolution.uses.get(name).copied().ok_or_else(|| {
RustLowerError::Unsupported("unresolved assignment target".to_string())
})?;
nodes.push(Node::assign(
format!("v{binding}"),
self.lower_value(value, subst)?,
));
}
Stmt::Expr(Expr::If {
cond,
then_block,
else_block,
}) => {
let then_nodes = self.lower_stmts(block_stmts(then_block), subst)?;
let else_nodes = match else_block {
Some(block) => self.lower_stmts(block_stmts(block), subst)?,
None => Vec::new(),
};
nodes.push(Node::if_then_else(
self.lower_value(cond, subst)?,
then_nodes,
else_nodes,
));
let then_div = stmts_diverge(block_stmts(then_block));
let else_div = else_block
.as_ref()
.is_some_and(|b| stmts_diverge(block_stmts(b)));
if then_div && else_div {
return Ok(nodes);
}
}
Stmt::While { cond, body } => {
nodes.extend(self.lower_while(cond, body, subst)?);
}
Stmt::For {
name,
start,
end,
body,
} => {
nodes.extend(self.lower_for_range(*name, start, end, body, subst)?);
}
Stmt::Expr(_) => {}
}
}
Ok(nodes)
}
fn counted_loop_trip(lo: &IrExpr, hi: &IrExpr) -> IrExpr {
let span_u32 = IrExpr::sub(
IrExpr::cast(DataType::U32, hi.clone()),
IrExpr::cast(DataType::U32, lo.clone()),
);
IrExpr::select(IrExpr::gt(hi.clone(), lo.clone()), span_u32, IrExpr::u32(0))
}
fn counted_loop_induction(lo: &IrExpr, loop_var: &str) -> IrExpr {
IrExpr::add(
lo.clone(),
IrExpr::cast(DataType::I32, IrExpr::var(loop_var.to_string())),
)
}
fn lower_while(
&self,
cond: &Expr,
body: &[Stmt],
subst: Subst<'_>,
) -> Result<Vec<Node>, RustLowerError> {
let bad = || {
RustLowerError::Unsupported(
"while loop that is not a canonical `while i < BOUND { ...; i = i + 1; }` counting loop"
.to_string(),
)
};
let (i_off, bound) = match cond {
Expr::Binary { op, lhs, rhs } if *op == LT => match lhs.as_ref() {
Expr::Var(off) => (*off, rhs.as_ref()),
_ => return Err(bad()),
},
_ => return Err(bad()),
};
let b_i = self.resolution.uses.get(&i_off).copied().ok_or_else(bad)?;
let Some((last, init_stmts)) = body.split_last() else {
return Err(bad());
};
let inc_ok = matches!(last, Stmt::Assign { name, value }
if self.resolution.uses.get(name).copied() == Some(b_i)
&& matches!(value, Expr::Binary { op, lhs, rhs }
if *op == PLUS
&& matches!(lhs.as_ref(), Expr::Var(o) if self.resolution.uses.get(o).copied() == Some(b_i))
&& matches!(rhs.as_ref(), Expr::LiteralInt(_, 1))));
if !inc_ok {
return Err(bad());
}
if stmts_assign_binding(init_stmts, b_i, self.resolution) {
return Err(bad());
}
for v in expr_var_bindings(bound, self.resolution) {
if stmts_assign_binding(body, v, self.resolution) {
return Err(bad());
}
}
let loop_var = format!("v{b_i}__w");
let mut inner: HashMap<BindingId, IrExpr> = match subst {
Subst::Local(Some(m)) => m.clone(),
Subst::Local(None) => HashMap::new(),
Subst::Inline(m) => m.clone(),
};
inner.insert(
b_i,
Self::counted_loop_induction(&IrExpr::var(format!("v{b_i}")), &loop_var),
);
let inner_subst = match subst {
Subst::Inline(_) => Subst::Inline(&inner),
Subst::Local(_) => Subst::Local(Some(&inner)),
};
let from_i32 = IrExpr::var(format!("v{b_i}"));
let to_i32 = self.lower_value(bound, subst)?;
let trip = Self::counted_loop_trip(&from_i32, &to_i32);
let from = IrExpr::u32(0);
let to = trip;
let loop_body = self.lower_stmts(init_stmts, inner_subst)?;
let post = IrExpr::select(
IrExpr::gt(to_i32.clone(), from_i32.clone()),
to_i32,
from_i32,
);
Ok(vec![
Node::loop_for(loop_var, from, to, loop_body),
Node::assign(format!("v{b_i}"), post),
])
}
fn lower_for_range(
&self,
name: u32,
start: &Expr,
end: &Expr,
body: &[Stmt],
subst: Subst<'_>,
) -> Result<Vec<Node>, RustLowerError> {
let b_i = self.def_to_id.get(&name).copied().ok_or_else(|| {
RustLowerError::Unsupported("unresolved for-loop binding".to_string())
})?;
let start_name = format!("v{b_i}__for_start");
let end_name = format!("v{b_i}__for_end");
let loop_var = format!("v{b_i}__for");
let start_i32 = IrExpr::var(start_name.clone());
let end_i32 = IrExpr::var(end_name.clone());
let trip = Self::counted_loop_trip(&start_i32, &end_i32);
let mut inner: HashMap<BindingId, IrExpr> = match subst {
Subst::Local(Some(m)) => m.clone(),
Subst::Local(None) => HashMap::new(),
Subst::Inline(m) => m.clone(),
};
inner.insert(b_i, Self::counted_loop_induction(&start_i32, &loop_var));
let inner_subst = match subst {
Subst::Inline(_) => Subst::Inline(&inner),
Subst::Local(_) => Subst::Local(Some(&inner)),
};
let loop_body = self.lower_stmts(body, inner_subst)?;
Ok(vec![
Node::let_bind(start_name, self.lower_value(start, subst)?),
Node::let_bind(end_name, self.lower_value(end, subst)?),
Node::loop_for(loop_var, IrExpr::u32(0), trip, loop_body),
])
}
fn lower_value(&self, expr: &Expr, subst: Subst<'_>) -> Result<IrExpr, RustLowerError> {
match expr {
Expr::LiteralInt(_, value) => Ok(IrExpr::i32(*value as i32)),
Expr::LiteralBool(_, value) => Ok(IrExpr::bool(*value)),
Expr::Var(offset) => {
let binding = self.resolution.uses.get(offset).copied().ok_or_else(|| {
RustLowerError::Unsupported("unresolved variable use".to_string())
})?;
match subst {
Subst::Local(Some(map)) => Ok(map
.get(&binding)
.cloned()
.unwrap_or_else(|| IrExpr::var(format!("v{binding}")))),
Subst::Local(None) => Ok(IrExpr::var(format!("v{binding}"))),
Subst::Inline(map) => map.get(&binding).cloned().ok_or_else(|| {
RustLowerError::Unsupported("callee variable not substituted".to_string())
}),
}
}
Expr::Binary { op, lhs, rhs } => {
let l = self.lower_value(lhs, subst)?;
let r = self.lower_value(rhs, subst)?;
Ok(match *op {
PLUS => IrExpr::add(l, r),
MINUS => IrExpr::sub(l, r),
STAR => IrExpr::mul(l, r),
SLASH => IrExpr::div(l, r),
PERCENT => IrExpr::cast(DataType::I32, IrExpr::rem(l, r)),
EQ => IrExpr::eq(l, r),
NE => IrExpr::ne(l, r),
LT => IrExpr::lt(l, r),
GT => IrExpr::gt(l, r),
LE => IrExpr::le(l, r),
GE => IrExpr::ge(l, r),
ANDAND => IrExpr::and(l, r),
OROR => IrExpr::or(l, r),
other => {
return Err(RustLowerError::Unsupported(format!(
"binary operator {other}"
)))
}
})
}
Expr::Call { name, args } => self.lower_call(name, args, subst),
Expr::Borrow { expr, .. } => self.lower_value(expr, subst),
Expr::Deref(inner) => self.lower_value(inner, subst),
Expr::Not(inner) => Ok(IrExpr::not(self.lower_value(inner, subst)?)),
Expr::Neg(inner) => Ok(IrExpr::sub(IrExpr::i32(0), self.lower_value(inner, subst)?)),
Expr::Block(_) | Expr::If { .. } => Err(RustLowerError::Unsupported(
"block/if used as a value".to_string(),
)),
}
}
fn lower_call(
&self,
name: &u32,
args: &[Expr],
caller_subst: Subst<'_>,
) -> Result<IrExpr, RustLowerError> {
let callee_index = self
.resolution
.calls
.get(name)
.copied()
.ok_or_else(|| RustLowerError::Unsupported("unresolved call".to_string()))?;
let callee = &self.module.functions[callee_index];
if args.len() != callee.params.len() {
return Err(RustLowerError::Unsupported(
"call arity mismatch".to_string(),
));
}
let mut subst: HashMap<BindingId, IrExpr> = HashMap::new();
for (i, (offset, _)) in callee.params.iter().enumerate() {
let binding = self.def_to_id.get(offset).copied().ok_or_else(|| {
RustLowerError::Unsupported("unresolved callee parameter".to_string())
})?;
subst.insert(binding, self.lower_value(&args[i], caller_subst)?);
}
for stmt in &callee.body {
match stmt {
Stmt::Let {
name: offset, init, ..
} => {
let value = self.lower_value(init, Subst::Inline(&subst))?;
let binding = self.def_to_id.get(offset).copied().ok_or_else(|| {
RustLowerError::Unsupported("unresolved callee binding".to_string())
})?;
subst.insert(binding, value);
}
Stmt::Return(Some(expr)) => return self.lower_value(expr, Subst::Inline(&subst)),
_ => {
return Err(RustLowerError::Unsupported(
"call to a callee with control flow or no terminal return".to_string(),
))
}
}
}
Err(RustLowerError::Unsupported(
"call to a callee with no return".to_string(),
))
}
}
fn stmts_assign_binding(stmts: &[Stmt], b: BindingId, res: &Resolution) -> bool {
stmts.iter().any(|s| match s {
Stmt::Assign { name, .. } => res.uses.get(name).copied() == Some(b),
Stmt::Expr(Expr::If {
then_block,
else_block,
..
}) => {
stmts_assign_binding(block_stmts(then_block), b, res)
|| else_block
.as_ref()
.is_some_and(|e| stmts_assign_binding(block_stmts(e), b, res))
}
Stmt::While { body, .. } => stmts_assign_binding(body, b, res),
Stmt::For { body, .. } => stmts_assign_binding(body, b, res),
_ => false,
})
}
fn expr_var_bindings(expr: &Expr, res: &Resolution) -> Vec<BindingId> {
let mut out = Vec::new();
collect_var_bindings(expr, res, &mut out);
out
}
fn collect_var_bindings(expr: &Expr, res: &Resolution, out: &mut Vec<BindingId>) {
match expr {
Expr::Var(off) => {
if let Some(&id) = res.uses.get(off) {
out.push(id);
}
}
Expr::Binary { lhs, rhs, .. } => {
collect_var_bindings(lhs, res, out);
collect_var_bindings(rhs, res, out);
}
Expr::Borrow { expr, .. } => collect_var_bindings(expr, res, out),
Expr::Deref(inner) => collect_var_bindings(inner, res, out),
Expr::Not(inner) => collect_var_bindings(inner, res, out),
Expr::Neg(inner) => collect_var_bindings(inner, res, out),
Expr::Call { args, .. } => {
for a in args {
collect_var_bindings(a, res, out);
}
}
_ => {}
}
}
fn block_stmts(expr: &Expr) -> &[Stmt] {
match expr {
Expr::Block(stmts) => stmts,
_ => &[],
}
}
fn stmts_diverge(stmts: &[Stmt]) -> bool {
stmts.iter().any(|stmt| match stmt {
Stmt::Return(_) => true,
Stmt::Expr(Expr::If {
then_block,
else_block: Some(else_block),
..
}) => stmts_diverge(block_stmts(then_block)) && stmts_diverge(block_stmts(else_block)),
_ => false,
})
}