use rlx_ir::shape::Dim;
use rlx_ir::{Graph, Node, NodeId, Op, Shape};
use std::collections::HashMap;
use rlx_fusion::pass::Pass;
pub struct LegalizeBroadcast;
impl LegalizeBroadcast {
pub fn new() -> Self {
Self
}
}
impl Default for LegalizeBroadcast {
fn default() -> Self {
Self::new()
}
}
impl Pass for LegalizeBroadcast {
fn name(&self) -> &str {
"legalize_broadcast"
}
fn run(&self, graph: Graph) -> Graph {
run(graph)
}
}
pub fn run(graph: Graph) -> Graph {
run_with_remap(graph).0
}
pub fn run_with_remap(graph: Graph) -> (Graph, HashMap<NodeId, NodeId>) {
let mut out = Graph::new(&graph.name);
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for node in graph.nodes() {
let new_id = legalize_node(node, &graph, &id_map, &mut out);
id_map.insert(node.id, new_id);
}
let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|id| id_map[id]).collect();
out.set_outputs(new_outputs);
(out, id_map)
}
fn legalize_node(
node: &Node,
fwd_graph: &Graph,
id_map: &HashMap<NodeId, NodeId>,
out: &mut Graph,
) -> NodeId {
let new_inputs: Vec<NodeId> = node.inputs.iter().map(|id| id_map[id]).collect();
if matches!(node.op, Op::Binary(_)) && node.inputs.len() == 2 {
let out_shape = &node.shape;
let lhs_shape = fwd_graph.node(node.inputs[0]).shape.clone();
let rhs_shape = fwd_graph.node(node.inputs[1]).shape.clone();
let lhs_id = maybe_expand(new_inputs[0], &lhs_shape, out_shape, out);
let rhs_id = maybe_expand(new_inputs[1], &rhs_shape, out_shape, out);
return out.add_node(node.op.clone(), vec![lhs_id, rhs_id], node.shape.clone());
}
out.add_node(node.op.clone(), new_inputs, node.shape.clone())
}
fn maybe_expand(id: NodeId, src: &Shape, target: &Shape, out: &mut Graph) -> NodeId {
if shape_eq(src, target) {
return id;
}
if is_scalar(src) {
return id;
}
if is_clean_trailing_broadcast(src, target) {
return id;
}
let target_dims_i64: Vec<i64> = target
.dims()
.iter()
.map(|d| match d {
Dim::Static(n) => *n as i64,
Dim::Dynamic(_) => -1,
})
.collect();
out.add_node(
Op::Expand {
target_shape: target_dims_i64,
},
vec![id],
target.clone(),
)
}
fn shape_eq(a: &Shape, b: &Shape) -> bool {
a.dims() == b.dims() && a.dtype() == b.dtype()
}
fn is_scalar(s: &Shape) -> bool {
let n: usize = s
.dims()
.iter()
.filter_map(|d| match d {
Dim::Static(n) => Some(*n),
_ => None,
})
.product();
n == 1
}
fn is_clean_trailing_broadcast(src: &Shape, target: &Shape) -> bool {
let s_dims = src.dims();
let t_dims = target.dims();
if s_dims.len() > t_dims.len() {
return false;
}
let off = t_dims.len() - s_dims.len();
for i in 0..s_dims.len() {
match (s_dims[i], t_dims[off + i]) {
(Dim::Static(a), Dim::Static(b)) if a == b => {}
(Dim::Dynamic(a), Dim::Dynamic(b)) if a == b => {}
_ => return false,
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::op::*;
use rlx_ir::*;
#[test]
fn passthrough_for_equal_shapes() {
let f = DType::F32;
let mut g = Graph::new("eq");
let a = g.input("a", Shape::new(&[4, 8], f));
let b = g.input("b", Shape::new(&[4, 8], f));
let c = g.binary(BinaryOp::Add, a, b, Shape::new(&[4, 8], f));
g.set_outputs(vec![c]);
let n_before = g.len();
let g2 = run(g);
assert_eq!(g2.len(), n_before, "no Expand inserted for equal shapes");
}
#[test]
fn passthrough_for_trailing_bias() {
let f = DType::F32;
let mut g = Graph::new("trail");
let a = g.input("a", Shape::new(&[4, 8], f));
let b = g.input("b", Shape::new(&[8], f));
let c = g.binary(BinaryOp::Add, a, b, Shape::new(&[4, 8], f));
g.set_outputs(vec![c]);
let n_before = g.len();
let g2 = run(g);
assert_eq!(
g2.len(),
n_before,
"no Expand inserted for trailing-bias broadcast"
);
}
#[test]
fn inserts_expand_for_channel_broadcast() {
let f = DType::F32;
let mut g = Graph::new("chan");
let a = g.input("a", Shape::new(&[1, 2, 3, 3], f));
let b = g.input("b", Shape::new(&[1, 2, 1, 1], f));
let c = g.binary(BinaryOp::Add, a, b, Shape::new(&[1, 2, 3, 3], f));
g.set_outputs(vec![c]);
let n_before = g.len();
let g2 = run(g);
assert!(
g2.len() > n_before,
"Expand should be inserted for [1,2,1,1] → [1,2,3,3]"
);
let has_expand = g2.nodes().iter().any(|n| matches!(n.op, Op::Expand { .. }));
assert!(has_expand);
}
}