use std::collections::HashMap;
use rlx_ir::*;
pub fn inline_into(
target: &mut Graph,
source: &Graph,
input_bindings: &HashMap<String, NodeId>,
param_bindings: Option<&HashMap<String, NodeId>>,
) -> Result<Vec<NodeId>, String> {
let mut id_map: HashMap<NodeId, NodeId> = HashMap::with_capacity(source.nodes().len());
for node in source.nodes() {
let new_id = match &node.op {
Op::Input { name } => *input_bindings.get(name).ok_or_else(|| {
format!("inline_into: source Op::Input '{name}' not in input_bindings")
})?,
Op::Param { name } => match param_bindings {
Some(pm) => *pm.get(name).ok_or_else(|| {
format!("inline_into: source Op::Param '{name}' not in param_bindings")
})?,
None => target.param(name.clone(), node.shape.clone()),
},
Op::Constant { data } => target.add_node(
Op::Constant { data: data.clone() },
vec![],
node.shape.clone(),
),
_ => {
let new_inputs: Vec<NodeId> = node
.inputs
.iter()
.map(|i| {
*id_map.get(i).expect(
"inline_into: input NodeId not yet mapped — \
source graph isn't in topo order?",
)
})
.collect();
target.add_node(node.op.clone(), new_inputs, node.shape.clone())
}
};
id_map.insert(node.id, new_id);
}
Ok(source
.outputs
.iter()
.map(|o| *id_map.get(o).expect("output NodeId missing from map"))
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::op::BinaryOp;
#[test]
fn inline_replaces_inputs_with_target_nodes() {
let s = Shape::new(&[1], DType::F32);
let mut src = Graph::new("src");
let x = src.input("x", s.clone());
let two = src.add_node(
Op::Constant {
data: 2.0_f32.to_le_bytes().to_vec(),
},
vec![],
s.clone(),
);
let y = src.binary(BinaryOp::Mul, x, two, s.clone());
src.set_outputs(vec![y]);
let mut tgt = Graph::new("tgt");
let five = tgt.add_node(
Op::Constant {
data: 5.0_f32.to_le_bytes().to_vec(),
},
vec![],
s.clone(),
);
let mut bindings: HashMap<String, NodeId> = HashMap::new();
bindings.insert("x".to_string(), five);
let outs = inline_into(&mut tgt, &src, &bindings, None).expect("inline");
assert_eq!(outs.len(), 1);
let out_node = tgt.node(outs[0]);
assert!(matches!(out_node.op, Op::Binary(BinaryOp::Mul)));
assert_eq!(out_node.inputs[0], five);
}
#[test]
fn inline_errors_on_missing_input_binding() {
let s = Shape::new(&[1], DType::F32);
let mut src = Graph::new("src");
let x = src.input("x", s.clone());
src.set_outputs(vec![x]);
let mut tgt = Graph::new("tgt");
let bindings: HashMap<String, NodeId> = HashMap::new(); let result = inline_into(&mut tgt, &src, &bindings, None);
assert!(result.is_err());
assert!(result.unwrap_err().contains("'x'"));
}
}