use super::CalleeExpander;
use crate::error::Result;
use crate::ir::{AtomicOp, BinOp, Expr, Node, UnOp};
impl CalleeExpander<'_> {
#[inline]
pub(crate) fn nodes(&mut self, nodes: &[Node]) -> Result<Vec<Node>> {
let mut out = Vec::new();
for node in nodes {
out.extend(self.node(node)?);
}
Ok(out)
}
#[inline]
pub(crate) fn node(&mut self, node: &Node) -> Result<Vec<Node>> {
match node {
Node::Let { name, value } => self.bind(name, value),
Node::Assign { name, value } => self.assign(name, value),
Node::Store {
buffer,
index,
value,
} => self.store(buffer, index, value),
Node::If {
cond,
then,
otherwise,
} => self.branch(cond, then, otherwise),
Node::Loop {
var,
from,
to,
body,
} => self.loop_for(var, from, to, body),
Node::Return => Ok(vec![Node::Return]),
Node::Block(nodes) => Ok(vec![Node::Block(self.nodes(nodes)?)]),
Node::Barrier => Ok(vec![Node::Barrier]),
}
}
#[inline]
pub(crate) fn bind(&mut self, name: &str, value: &Expr) -> Result<Vec<Node>> {
let renamed = self.rename_decl(name);
let (mut prefix, value) = self.expr(value)?;
prefix.push(Node::let_bind(&renamed, value));
Ok(prefix)
}
#[inline]
pub(crate) fn assign(&mut self, name: &str, value: &Expr) -> Result<Vec<Node>> {
let (mut prefix, value) = self.expr(value)?;
prefix.push(Node::assign(&self.rename_use(name), value));
Ok(prefix)
}
#[inline]
pub(crate) fn store(&mut self, buffer: &str, index: &Expr, value: &Expr) -> Result<Vec<Node>> {
let (mut prefix, index) = self.expr(index)?;
let (value_prefix, value) = self.expr(value)?;
prefix.extend(value_prefix);
if self.output_name == buffer {
self.saw_output = true;
prefix.push(Node::assign(&self.result_name, value));
} else {
prefix.push(Node::store(buffer, index, value));
}
Ok(prefix)
}
#[inline]
pub(crate) fn branch(
&mut self,
cond: &Expr,
then: &[Node],
otherwise: &[Node],
) -> Result<Vec<Node>> {
let (mut prefix, cond) = self.expr(cond)?;
prefix.push(Node::if_then_else(
cond,
self.nodes(then)?,
self.nodes(otherwise)?,
));
Ok(prefix)
}
#[inline]
pub(crate) fn loop_for(
&mut self,
var: &str,
from: &Expr,
to: &Expr,
body: &[Node],
) -> Result<Vec<Node>> {
let renamed = self.rename_decl(var);
let (mut prefix, from) = self.expr(from)?;
let (to_prefix, to) = self.expr(to)?;
prefix.extend(to_prefix);
prefix.push(Node::loop_for(&renamed, from, to, self.nodes(body)?));
Ok(prefix)
}
#[inline]
pub(crate) fn expr(&mut self, expr: &Expr) -> Result<(Vec<Node>, Expr)> {
match expr {
Expr::Var(name) => Ok((Vec::new(), Expr::var(&self.rename_use(name)))),
Expr::Load { buffer, index } => self.load(buffer, index),
Expr::BufLen { buffer } if self.output_name == buffer.as_str() => {
Ok((Vec::new(), Expr::u32(1)))
}
Expr::BufLen { buffer } if self.input_args.contains_key(buffer.as_str()) => {
Ok((Vec::new(), Expr::u32(1)))
}
Expr::Call { .. } => {
let renamed = self.rename_expr_vars(expr);
self.ctx.inline_expr(&renamed)
}
Expr::InvocationId { .. } | Expr::WorkgroupId { .. } | Expr::LocalId { .. } => {
Ok((Vec::new(), Expr::u32(0)))
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::BufLen { .. } => Ok((Vec::new(), expr.clone())),
Expr::BinOp { op, left, right } => self.binop(op.clone(), left, right),
Expr::UnOp { op, operand } => self.unop(op.clone(), operand),
Expr::Fma { a, b, c } => {
let (mut prefix, a) = self.expr(a)?;
let (b_prefix, b) = self.expr(b)?;
let (c_prefix, c) = self.expr(c)?;
prefix.extend(b_prefix);
prefix.extend(c_prefix);
Ok((
prefix,
Expr::Fma {
a: Box::new(a),
b: Box::new(b),
c: Box::new(c),
},
))
}
Expr::Select {
cond,
true_val,
false_val,
} => self.select(cond, true_val, false_val),
Expr::Cast { target, value } => {
let (prefix, value) = self.expr(value)?;
Ok((
prefix,
Expr::Cast {
target: target.clone(),
value: Box::new(value),
},
))
}
Expr::Atomic {
op,
buffer,
index,
expected,
value,
} => self.atomic(op.clone(), buffer, index, expected.as_deref(), value),
}
}
#[inline]
pub(crate) fn load(&mut self, buffer: &str, index: &Expr) -> Result<(Vec<Node>, Expr)> {
if let Some(arg) = self.input_args.get(buffer) {
return Ok((Vec::new(), arg.clone()));
}
let (prefix, index) = self.expr(index)?;
Ok((
prefix,
Expr::Load {
buffer: buffer.into(),
index: Box::new(index),
},
))
}
#[inline]
pub(crate) fn binop(
&mut self,
op: BinOp,
left: &Expr,
right: &Expr,
) -> Result<(Vec<Node>, Expr)> {
let (mut prefix, left) = self.expr(left)?;
let (right_prefix, right) = self.expr(right)?;
prefix.extend(right_prefix);
Ok((
prefix,
Expr::BinOp {
op,
left: Box::new(left),
right: Box::new(right),
},
))
}
#[inline]
pub(crate) fn unop(&mut self, op: UnOp, operand: &Expr) -> Result<(Vec<Node>, Expr)> {
let (prefix, operand) = self.expr(operand)?;
Ok((
prefix,
Expr::UnOp {
op,
operand: Box::new(operand),
},
))
}
#[inline]
pub(crate) fn select(
&mut self,
cond: &Expr,
true_val: &Expr,
false_val: &Expr,
) -> Result<(Vec<Node>, Expr)> {
let (mut prefix, cond) = self.expr(cond)?;
let (true_prefix, true_val) = self.expr(true_val)?;
let (false_prefix, false_val) = self.expr(false_val)?;
prefix.extend(true_prefix);
prefix.extend(false_prefix);
Ok((
prefix,
Expr::Select {
cond: Box::new(cond),
true_val: Box::new(true_val),
false_val: Box::new(false_val),
},
))
}
#[inline]
pub(crate) fn atomic(
&mut self,
op: AtomicOp,
buffer: &str,
index: &Expr,
expected: Option<&Expr>,
value: &Expr,
) -> Result<(Vec<Node>, Expr)> {
let (mut prefix, index) = self.expr(index)?;
let (expected_prefix, expected) = match expected {
Some(expected) => {
let (prefix, expected) = self.expr(expected)?;
(prefix, Some(Box::new(expected)))
}
None => (Vec::new(), None),
};
let (value_prefix, value) = self.expr(value)?;
prefix.extend(expected_prefix);
prefix.extend(value_prefix);
Ok((
prefix,
Expr::Atomic {
op,
buffer: buffer.into(),
index: Box::new(index),
expected,
value: Box::new(value),
},
))
}
#[inline]
pub(crate) fn rename_decl(&mut self, name: &str) -> String {
let renamed = format!("{}{name}", self.prefix);
self.vars.insert(name.to_string(), renamed.clone());
renamed
}
#[inline]
pub(crate) fn rename_use(&self, name: &str) -> String {
self.vars
.get(name)
.cloned()
.unwrap_or_else(|| name.to_string())
}
#[inline]
pub(crate) fn rename_expr_vars(&self, expr: &Expr) -> Expr {
enum Frame<'a> {
Enter(&'a Expr),
Load {
buffer: &'a str,
},
Bin {
op: BinOp,
},
Un {
op: UnOp,
},
Call {
op_id: &'a str,
args: usize,
},
Fma,
Select,
Cast {
target: crate::ir::DataType,
},
Atomic {
op: AtomicOp,
buffer: &'a str,
has_expected: bool,
},
}
let mut frames = vec![Frame::Enter(expr)];
let mut values: Vec<Expr> = Vec::new();
while let Some(frame) = frames.pop() {
match frame {
Frame::Enter(expr) => match expr {
Expr::Var(name) => values.push(Expr::var(&self.rename_use(name))),
Expr::Load { buffer, index } => {
frames.push(Frame::Load { buffer });
frames.push(Frame::Enter(index));
}
Expr::BinOp { op, left, right } => {
frames.push(Frame::Bin { op: op.clone() });
frames.push(Frame::Enter(right));
frames.push(Frame::Enter(left));
}
Expr::UnOp { op, operand } => {
frames.push(Frame::Un { op: op.clone() });
frames.push(Frame::Enter(operand));
}
Expr::Call { op_id, args } => {
frames.push(Frame::Call {
op_id,
args: args.len(),
});
for arg in args.iter().rev() {
frames.push(Frame::Enter(arg));
}
}
Expr::Fma { a, b, c } => {
frames.push(Frame::Fma);
frames.push(Frame::Enter(c));
frames.push(Frame::Enter(b));
frames.push(Frame::Enter(a));
}
Expr::Select {
cond,
true_val,
false_val,
} => {
frames.push(Frame::Select);
frames.push(Frame::Enter(false_val));
frames.push(Frame::Enter(true_val));
frames.push(Frame::Enter(cond));
}
Expr::Cast { target, value } => {
frames.push(Frame::Cast {
target: target.clone(),
});
frames.push(Frame::Enter(value));
}
Expr::Atomic {
op,
buffer,
index,
expected,
value,
} => {
frames.push(Frame::Atomic {
op: op.clone(),
buffer,
has_expected: expected.is_some(),
});
frames.push(Frame::Enter(value));
if let Some(expected) = expected.as_deref() {
frames.push(Frame::Enter(expected));
}
frames.push(Frame::Enter(index));
}
Expr::InvocationId { .. } | Expr::WorkgroupId { .. } | Expr::LocalId { .. } => {
values.push(Expr::u32(0));
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::BufLen { .. } => values.push(expr.clone()),
},
Frame::Load { buffer } => {
let index = values.pop().expect("load index must be built");
values.push(Expr::Load {
buffer: buffer.into(),
index: Box::new(index),
});
}
Frame::Bin { op } => {
let right = values.pop().expect("binop right must be built");
let left = values.pop().expect("binop left must be built");
values.push(Expr::BinOp {
op,
left: Box::new(left),
right: Box::new(right),
});
}
Frame::Un { op } => {
let operand = values.pop().expect("unop operand must be built");
values.push(Expr::UnOp {
op,
operand: Box::new(operand),
});
}
Frame::Call { op_id, args } => {
let split_at = values
.len()
.checked_sub(args)
.expect("call args must be built");
let args = values.split_off(split_at);
values.push(Expr::Call {
op_id: op_id.to_string(),
args,
});
}
Frame::Fma => {
let c = values.pop().expect("fma c must be built");
let b = values.pop().expect("fma b must be built");
let a = values.pop().expect("fma a must be built");
values.push(Expr::Fma {
a: Box::new(a),
b: Box::new(b),
c: Box::new(c),
});
}
Frame::Select => {
let false_val = values.pop().expect("select false value must be built");
let true_val = values.pop().expect("select true value must be built");
let cond = values.pop().expect("select condition must be built");
values.push(Expr::Select {
cond: Box::new(cond),
true_val: Box::new(true_val),
false_val: Box::new(false_val),
});
}
Frame::Cast { target } => {
let value = values.pop().expect("cast value must be built");
values.push(Expr::Cast {
target,
value: Box::new(value),
});
}
Frame::Atomic {
op,
buffer,
has_expected,
} => {
let value = values.pop().expect("atomic value must be built");
let expected = if has_expected {
Some(Box::new(
values.pop().expect("atomic expected must be built"),
))
} else {
None
};
let index = values.pop().expect("atomic index must be built");
values.push(Expr::Atomic {
op,
buffer: buffer.into(),
index: Box::new(index),
expected,
value: Box::new(value),
});
}
}
}
values
.pop()
.expect("expression rename must produce a value")
}
}