use rlx_ir::shape::Dim;
use rlx_ir::{Graph, NodeId, Op, Shape};
pub fn legalize_multi_axis_reduce(g: Graph) -> Graph {
use std::collections::HashMap;
let any_multi = g
.nodes()
.iter()
.any(|n| matches!(&n.op, Op::Reduce { axes, .. } if axes.len() > 1));
if !any_multi {
return g;
}
let mut out = Graph::new(g.name.clone());
let mut remap: HashMap<NodeId, NodeId> = HashMap::new();
for node in g.nodes() {
let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| remap[i]).collect();
let final_new_id = match &node.op {
Op::Reduce { op, axes, keep_dim } if axes.len() > 1 => {
let mut cur = new_inputs[0];
let mut shape = out.node(cur).shape.clone();
let dtype = shape.dtype();
let mut sorted = axes.clone();
sorted.sort_unstable_by(|a, b| b.cmp(a));
for &ax in &sorted {
let mut dims: Vec<Dim> = shape.dims().to_vec();
dims[ax] = Dim::Static(1);
let step_shape = Shape::from_dims(&dims, dtype);
cur = out.add_node(
Op::Reduce {
op: *op,
axes: vec![ax],
keep_dim: true,
},
vec![cur],
step_shape,
);
shape = out.node(cur).shape.clone();
}
if !*keep_dim {
let final_shape = node.shape.clone();
let new_shape_dims: Vec<i64> = final_shape
.dims()
.iter()
.map(|d| match d {
Dim::Static(n) => *n as i64,
Dim::Dynamic(_) => -1,
})
.collect();
cur = out.add_node(
Op::Reshape {
new_shape: new_shape_dims,
},
vec![cur],
final_shape,
);
}
cur
}
_ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
};
remap.insert(node.id, final_new_id);
}
let new_outputs: Vec<NodeId> = g.outputs.iter().map(|id| remap[id]).collect();
out.set_outputs(new_outputs);
out
}