use std::collections::{HashMap, HashSet};
use rlx_ir::*;
pub fn promote_params_to_inputs(graph: &Graph, names: &[&str]) -> Graph {
let name_set: HashSet<&str> = names.iter().copied().collect();
let mut out = Graph::new(graph.name.clone());
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for node in graph.nodes() {
let new_id = match &node.op {
Op::Param { name } if name_set.contains(name.as_str()) => {
out.input(name.clone(), node.shape.clone())
}
_ => {
let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
out.add_node(node.op.clone(), new_inputs, node.shape.clone())
}
};
id_map.insert(node.id, new_id);
}
let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|o| id_map[o]).collect();
out.set_outputs(new_outputs);
out
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::op::BinaryOp;
#[test]
fn promote_swaps_listed_param_only() {
let s = Shape::new(&[1], DType::F32);
let mut g = Graph::new("t");
let x = g.input("x", s.clone());
let w = g.param("w", s.clone());
let b = g.param("b", s.clone());
let xw = g.binary(BinaryOp::Mul, x, w, s.clone());
let y = g.binary(BinaryOp::Add, xw, b, s.clone());
g.set_outputs(vec![y]);
let g2 = promote_params_to_inputs(&g, &["w"]);
let mut input_names: Vec<String> = Vec::new();
let mut param_names: Vec<String> = Vec::new();
for n in g2.nodes() {
match &n.op {
Op::Input { name } => input_names.push(name.clone()),
Op::Param { name } => param_names.push(name.clone()),
_ => {}
}
}
input_names.sort();
param_names.sort();
assert_eq!(input_names, vec!["w".to_string(), "x".to_string()]);
assert_eq!(param_names, vec!["b".to_string()]);
assert_eq!(g2.outputs.len(), 1);
assert_eq!(g2.nodes().len(), g.nodes().len());
}
#[test]
fn promote_silently_ignores_unknown_names() {
let s = Shape::new(&[1], DType::F32);
let mut g = Graph::new("t");
let _ = g.input("x", s.clone());
let p = g.param("p", s.clone());
g.set_outputs(vec![p]);
let g2 = promote_params_to_inputs(&g, &["missing", "p"]);
let promoted = g2
.nodes()
.iter()
.filter(|n| matches!(&n.op, Op::Input { name } if name == "p"))
.count();
assert_eq!(promoted, 1);
}
}