use crate::ir::{Expr, Ident, Node};
use std::sync::Arc;
pub(super) fn substitute_nodes(nodes: &[Node], var: &Ident, replacement: &Expr) -> Vec<Node> {
nodes
.iter()
.map(|node| substitute_node(node, var, replacement))
.collect()
}
pub(super) fn substitute_node(node: &Node, var: &Ident, replacement: &Expr) -> Node {
match node {
Node::Let { name, value } => Node::let_bind(name, substitute_expr(value, var, replacement)),
Node::Assign { name, value } => {
Node::assign(name, substitute_expr(value, var, replacement))
}
Node::Store {
buffer,
index,
value,
} => Node::store(
buffer,
substitute_expr(index, var, replacement),
substitute_expr(value, var, replacement),
),
Node::If {
cond,
then,
otherwise,
} => Node::if_then_else(
substitute_expr(cond, var, replacement),
substitute_nodes(then, var, replacement),
substitute_nodes(otherwise, var, replacement),
),
Node::Loop {
var: inner,
from,
to,
body,
} => {
let from = substitute_expr(from, var, replacement);
let to = substitute_expr(to, var, replacement);
let body = if inner == var {
body.clone()
} else {
substitute_nodes(body, var, replacement)
};
Node::loop_for(inner, from, to, body)
}
Node::Block(body) => Node::block(substitute_nodes(body, var, replacement)),
Node::Region {
generator,
source_region,
body,
} => Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(substitute_nodes(body, var, replacement)),
},
Node::AsyncLoad {
source,
destination,
offset,
size,
tag,
} => Node::AsyncLoad {
source: source.clone(),
destination: destination.clone(),
offset: Box::new(substitute_expr(offset, var, replacement)),
size: Box::new(substitute_expr(size, var, replacement)),
tag: tag.clone(),
},
Node::AsyncStore {
source,
destination,
offset,
size,
tag,
} => Node::AsyncStore {
source: source.clone(),
destination: destination.clone(),
offset: Box::new(substitute_expr(offset, var, replacement)),
size: Box::new(substitute_expr(size, var, replacement)),
tag: tag.clone(),
},
Node::Trap { address, tag } => Node::Trap {
address: Box::new(substitute_expr(address, var, replacement)),
tag: tag.clone(),
},
Node::IndirectDispatch { .. }
| Node::AllReduce { .. }
| Node::AllGather { .. }
| Node::ReduceScatter { .. }
| Node::Broadcast { .. }
| Node::AsyncWait { .. }
| Node::Resume { .. }
| Node::Return
| Node::Barrier { .. }
| Node::Opaque(_) => node.clone(),
}
}
fn substitute_expr(expr: &Expr, var: &Ident, replacement: &Expr) -> Expr {
match expr {
Expr::Var(name) if name == var => replacement.clone(),
Expr::Load { buffer, index } => {
Expr::load(buffer, substitute_expr(index, var, replacement))
}
Expr::BinOp { op, left, right } => Expr::BinOp {
op: *op,
left: Box::new(substitute_expr(left, var, replacement)),
right: Box::new(substitute_expr(right, var, replacement)),
},
Expr::UnOp { op, operand } => Expr::UnOp {
op: op.clone(),
operand: Box::new(substitute_expr(operand, var, replacement)),
},
Expr::Call { op_id, args } => Expr::call(
op_id,
args.iter()
.map(|arg| substitute_expr(arg, var, replacement))
.collect(),
),
Expr::Select {
cond,
true_val,
false_val,
} => Expr::select(
substitute_expr(cond, var, replacement),
substitute_expr(true_val, var, replacement),
substitute_expr(false_val, var, replacement),
),
Expr::Cast { target, value } => {
Expr::cast(target.clone(), substitute_expr(value, var, replacement))
}
Expr::Fma { a, b, c } => Expr::fma(
substitute_expr(a, var, replacement),
substitute_expr(b, var, replacement),
substitute_expr(c, var, replacement),
),
Expr::Atomic {
op,
buffer,
index,
expected,
value,
ordering,
} => Expr::Atomic {
op: *op,
buffer: buffer.clone(),
index: Box::new(substitute_expr(index, var, replacement)),
expected: expected
.as_ref()
.map(|expr| Box::new(substitute_expr(expr, var, replacement))),
value: Box::new(substitute_expr(value, var, replacement)),
ordering: *ordering,
},
Expr::SubgroupBallot { cond } => {
Expr::subgroup_ballot(substitute_expr(cond, var, replacement))
}
Expr::SubgroupShuffle { value, lane } => Expr::subgroup_shuffle(
substitute_expr(value, var, replacement),
substitute_expr(lane, var, replacement),
),
Expr::SubgroupAdd { value } => Expr::subgroup_add(substitute_expr(value, var, replacement)),
_ => expr.clone(),
}
}