use std::collections::{HashSet, VecDeque};
use itertools::Itertools;
use petgraph::algo::tarjan_scc;
use hugr_core::hugr::{hugrmut::HugrMut, patch::inline_call::InlineCall};
use hugr_core::module_graph::{ModuleGraph, StaticNode};
#[derive(Clone, Debug, thiserror::Error, PartialEq)]
#[non_exhaustive]
pub enum InlineFuncsError {}
pub fn inline_acyclic<H: HugrMut>(
h: &mut H,
call_predicate: impl Fn(&H, H::Node) -> bool,
) -> Result<(), InlineFuncsError> {
let cg = ModuleGraph::new(&*h);
let g = cg.graph();
let all_funcs_in_cycles = tarjan_scc(g)
.into_iter()
.flat_map(|mut ns| {
if let Ok(n) = ns.iter().exactly_one()
&& g.edges_connecting(*n, *n).next().is_none()
{
ns.clear(); }
ns.into_iter().map(|n| {
let StaticNode::FuncDefn(fd) = g.node_weight(n).unwrap() else {
panic!("Expected only FuncDefns in sccs")
};
*fd
})
})
.collect::<HashSet<_>>();
let target_funcs: HashSet<H::Node> = h
.children(h.module_root())
.filter(|n| h.get_optype(*n).is_func_defn() && !all_funcs_in_cycles.contains(n))
.collect();
let mut q = VecDeque::from([h.entrypoint()]);
while let Some(n) = q.pop_front() {
if h.get_optype(n).is_call()
&& let Some(t) = h.static_source(n)
&& target_funcs.contains(&t)
&& call_predicate(h, n)
{
h.apply_patch(InlineCall::new(n)).unwrap();
}
q.extend(h.children(n));
}
Ok(())
}
#[cfg(test)]
mod test {
use std::collections::HashSet;
use itertools::Itertools;
use rstest::rstest;
use hugr_core::HugrView;
use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
use hugr_core::core::HugrNode;
use hugr_core::module_graph::{ModuleGraph, StaticNode};
use hugr_core::ops::OpType;
use hugr_core::{Hugr, extension::prelude::qb_t, types::Signature};
use super::inline_acyclic;
fn make_test_hugr() -> Hugr {
let sig = || Signature::new_endo([qb_t()]);
let mut mb = ModuleBuilder::new();
let x = mb.declare("x", sig().into()).unwrap();
let a = {
let mut fb = mb.define_function("a", sig()).unwrap();
let ins = fb.input_wires();
let res = fb.call(&x, &[], ins).unwrap();
fb.finish_with_outputs(res.outputs()).unwrap()
};
let c = {
let fb = mb.define_function("c", sig()).unwrap();
let ins = fb.input_wires();
fb.finish_with_outputs(ins).unwrap()
};
let b = {
let mut fb = mb.define_function("b", sig()).unwrap();
let ins = fb.input_wires();
let res = fb.call(c.handle(), &[], ins).unwrap().outputs();
fb.finish_with_outputs(res).unwrap()
};
let f = mb.declare("f", sig().into()).unwrap();
let g = {
let mut fb = mb.define_function("g", sig()).unwrap();
let ins = fb.input_wires();
let c1 = fb.call(&f, &[], ins).unwrap();
let c2 = fb.call(b.handle(), &[], c1.outputs()).unwrap();
fb.finish_with_outputs(c2.outputs()).unwrap()
};
let _f = {
let mut fb = mb.define_declaration(&f).unwrap();
let ins = fb.input_wires();
let c1 = fb.call(g.handle(), &[], ins).unwrap();
let c2 = fb.call(a.handle(), &[], c1.outputs()).unwrap();
fb.finish_with_outputs(c2.outputs()).unwrap()
};
mb.finish_hugr().unwrap()
}
fn find_func<H: HugrView>(h: &H, name: &str) -> H::Node {
h.children(h.module_root())
.find(|n| {
h.get_optype(*n)
.as_func_defn()
.is_some_and(|fd| fd.func_name() == name)
})
.unwrap()
}
#[rstest]
#[case(["a", "b", "c"], ["a", "b", "c"], [vec!["g", "x"], vec!["f"], vec!["x"], vec![], vec![]])]
#[case(["a", "b"], ["a", "b"], [vec!["g", "x"], vec!["f", "c"], vec!["x"], vec!["c"], vec![]])]
#[case(["c"], ["c"], [vec!["g", "a"], vec!("f", "b"), vec!["x"], vec![], vec![]])]
fn test_inline(
#[case] req: impl IntoIterator<Item = &'static str>,
#[case] check_not_called: impl IntoIterator<Item = &'static str>,
#[case] calls_fgabc: [Vec<&'static str>; 5],
) {
let mut h = make_test_hugr();
let target_funcs = req
.into_iter()
.map(|name| find_func(&h, name))
.collect::<HashSet<_>>();
inline_acyclic(&mut h, |h, call| {
let tgt = h.static_source(call).unwrap();
assert!(["a", "b", "c"].contains(&func_name(h, tgt).as_str()));
target_funcs.contains(&tgt)
})
.unwrap();
let cg = ModuleGraph::new(&h);
for fname in check_not_called {
let fnode = find_func(&h, fname);
let fnode = cg.node_index(fnode).unwrap();
assert_eq!(
None,
cg.graph()
.edges_directed(fnode, petgraph::Direction::Incoming)
.next()
);
}
for (fname, tgts) in ["f", "g", "a", "b", "c"].into_iter().zip_eq(calls_fgabc) {
let fnode = find_func(&h, fname);
assert_eq!(
outgoing_calls(&cg, fnode)
.into_iter()
.map(|n| func_name(&h, n).as_str())
.collect::<HashSet<_>>(),
HashSet::from_iter(tgts),
"Calls from {fname}"
);
}
}
fn outgoing_calls<N: HugrNode>(cg: &ModuleGraph<N>, src: N) -> Vec<N> {
cg.out_edges(src).map(|(_, tgt)| func_node(tgt)).collect()
}
#[test]
fn test_filter_caller() {
let mut h = make_test_hugr();
let [g, b, c] = ["g", "b", "c"].map(|n| find_func(&h, n));
inline_acyclic(&mut h, |h, mut call| {
loop {
if call == g {
return true;
};
let Some(parent) = h.get_parent(call) else {
return false;
};
call = parent;
}
})
.unwrap();
let cg = ModuleGraph::new(&h);
assert_eq!(outgoing_calls(&cg, g), [find_func(&h, "f")]);
assert_eq!(outgoing_calls(&cg, b), [c]);
}
fn func_node<N: Copy>(cgn: &StaticNode<N>) -> N {
match cgn {
StaticNode::FuncDecl(n) | StaticNode::FuncDefn(n) => *n,
_ => panic!(),
}
}
fn func_name<H: HugrView>(h: &H, n: H::Node) -> &String {
match h.get_optype(n) {
OpType::FuncDecl(fd) => fd.func_name(),
OpType::FuncDefn(fd) => fd.func_name(),
_ => panic!(),
}
}
}