use std::collections::{HashMap, HashSet, VecDeque};
use hugr::HugrView;
use itertools::Itertools;
use petgraph::algo::tarjan_scc;
use hugr_core::hugr::{hugrmut::HugrMut, patch::inline_call::InlineCall};
use hugr_core::module_graph::{ModuleGraph, StaticNode};
use crate::metadata::InlineAnnotation;
use crate::passes::{ComposablePass, PassScope, WithScope};
#[derive(Clone, Debug, thiserror::Error, PartialEq)]
#[non_exhaustive]
pub enum InlineFuncsError {}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum InlineFuncsHeuristic {
MaxSize(usize),
All,
}
impl InlineFuncsHeuristic {
fn should_inline<H: HugrView>(&self, func: H::Node, hugr: &H) -> bool {
match self {
InlineFuncsHeuristic::MaxSize(size) => hugr.descendants(func).count() <= *size,
InlineFuncsHeuristic::All => true,
}
}
}
impl Default for InlineFuncsHeuristic {
fn default() -> Self {
Self::MaxSize(64)
}
}
#[derive(Debug, Default, Clone)]
pub struct InlineFunctionsPass {
heuristic: InlineFuncsHeuristic,
scope: PassScope,
}
impl InlineFunctionsPass {
pub fn with_heuristic(mut self, heuristic: InlineFuncsHeuristic) -> Self {
self.heuristic = heuristic;
self
}
}
impl<H: HugrMut> ComposablePass<H> for InlineFunctionsPass {
type Error = InlineFuncsError;
type Result = ();
fn run(&self, h: &mut H) -> Result<(), Self::Error> {
let mut should_inline_cache: HashMap<H::Node, bool> = HashMap::new();
inline_acyclic_scoped(h, self.scope.clone(), |h, call| {
let Some(func) = h.static_source(call) else {
return false;
};
*should_inline_cache.entry(func).or_insert_with(|| {
match h.get_metadata::<InlineAnnotation>(func) {
Some(InlineAnnotation::Never) => false,
Some(InlineAnnotation::BestEffort) => true,
None => self.heuristic.should_inline(func, h),
}
})
})
}
}
impl WithScope for InlineFunctionsPass {
fn with_scope(mut self, scope: impl Into<PassScope>) -> Self {
self.scope = scope.into();
self
}
}
#[deprecated(
since = "0.18.1",
note = "Use `inline_acyclic_scoped` with an appropriate `PassScope` instead. For module hugrs, use `PassScope::Global(Preserve::Entrypoint)`."
)]
pub fn inline_acyclic<H: HugrMut>(
h: &mut H,
call_predicate: impl FnMut(&H, H::Node) -> bool,
) -> Result<(), InlineFuncsError> {
inline_acyclic_scoped(h, PassScope::EntrypointRecursive, call_predicate)
}
pub fn inline_acyclic_scoped<H: HugrMut>(
h: &mut H,
scope: impl Into<PassScope>,
mut call_predicate: impl FnMut(&H, H::Node) -> bool,
) -> Result<(), InlineFuncsError> {
let scope: PassScope = scope.into();
let Some(scope_root) = scope.root(h) else {
return Ok(());
};
let cg = ModuleGraph::new(&*h);
let g = cg.graph();
let all_funcs_in_cycles = tarjan_scc(g)
.into_iter()
.flat_map(|mut ns| {
if let Ok(n) = ns.iter().exactly_one()
&& g.edges_connecting(*n, *n).next().is_none()
{
ns.clear(); }
ns.into_iter().map(|n| {
let StaticNode::FuncDefn(fd) = g.node_weight(n).unwrap() else {
panic!("Expected only FuncDefns in sccs")
};
*fd
})
})
.collect::<HashSet<_>>();
let target_funcs: HashSet<H::Node> = h
.children(h.module_root())
.filter(|n| h.get_optype(*n).is_func_defn() && !all_funcs_in_cycles.contains(n))
.collect();
let mut q = VecDeque::from([scope_root]);
while let Some(n) = q.pop_front() {
if h.get_optype(n).is_call()
&& let Some(t) = h.static_source(n)
&& target_funcs.contains(&t)
&& call_predicate(h, n)
{
h.apply_patch(InlineCall::new(n)).unwrap();
}
if scope.recursive() {
q.extend(h.children(n));
}
}
Ok(())
}
#[cfg(test)]
mod test {
use std::collections::HashSet;
use itertools::Itertools;
use rstest::rstest;
use hugr_core::HugrView;
use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
use hugr_core::core::HugrNode;
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::module_graph::{ModuleGraph, StaticNode};
use hugr_core::ops::OpType;
use hugr_core::{Hugr, extension::prelude::qb_t, types::Signature};
use super::{InlineFunctionsPass, inline_acyclic_scoped};
use crate::metadata::InlineAnnotation;
use crate::passes::composable::test::run_validating;
use crate::passes::inline_funcs::InlineFuncsHeuristic;
use crate::passes::{PassScope, composable::Preserve};
fn make_test_hugr() -> Hugr {
let sig = || Signature::new_endo([qb_t()]);
let mut mb = ModuleBuilder::new();
let x = mb.declare("x", sig().into()).unwrap();
let a = {
let mut fb = mb.define_function("a", sig()).unwrap();
let ins = fb.input_wires();
let res = fb.call(&x, &[], ins).unwrap();
fb.finish_with_outputs(res.outputs()).unwrap()
};
let c = {
let fb = mb.define_function("c", sig()).unwrap();
let ins = fb.input_wires();
fb.finish_with_outputs(ins).unwrap()
};
let b = {
let mut fb = mb.define_function("b", sig()).unwrap();
let ins = fb.input_wires();
let res = fb.call(c.handle(), &[], ins).unwrap().outputs();
fb.finish_with_outputs(res).unwrap()
};
let f = mb.declare("f", sig().into()).unwrap();
let g = {
let mut fb = mb.define_function("g", sig()).unwrap();
let ins = fb.input_wires();
let c1 = fb.call(&f, &[], ins).unwrap();
let c2 = fb.call(b.handle(), &[], c1.outputs()).unwrap();
fb.finish_with_outputs(c2.outputs()).unwrap()
};
let _f = {
let mut fb = mb.define_declaration(&f).unwrap();
let ins = fb.input_wires();
let c1 = fb.call(g.handle(), &[], ins).unwrap();
let c2 = fb.call(a.handle(), &[], c1.outputs()).unwrap();
fb.finish_with_outputs(c2.outputs()).unwrap()
};
mb.finish_hugr().unwrap()
}
fn find_func<H: HugrView>(h: &H, name: &str) -> H::Node {
h.children(h.module_root())
.find(|n| {
h.get_optype(*n)
.as_func_defn()
.is_some_and(|fd| fd.func_name() == name)
})
.unwrap()
}
#[rstest]
#[case(["a", "b", "c"], ["a", "b", "c"], [vec!["g", "x"], vec!["f"], vec!["x"], vec![], vec![]])]
#[case(["a", "b"], ["a", "b"], [vec!["g", "x"], vec!["f", "c"], vec!["x"], vec!["c"], vec![]])]
#[case(["c"], ["c"], [vec!["g", "a"], vec!("f", "b"), vec!["x"], vec![], vec![]])]
fn test_inline(
#[case] req: impl IntoIterator<Item = &'static str>,
#[case] check_not_called: impl IntoIterator<Item = &'static str>,
#[case] calls_fgabc: [Vec<&'static str>; 5],
) {
let mut h = make_test_hugr();
let target_funcs = req
.into_iter()
.map(|name| find_func(&h, name))
.collect::<HashSet<_>>();
inline_acyclic_scoped(
&mut h,
PassScope::Global(Preserve::Entrypoint),
|h, call| {
let tgt = h.static_source(call).unwrap();
assert!(["a", "b", "c"].contains(&func_name(h, tgt).as_str()));
target_funcs.contains(&tgt)
},
)
.unwrap();
let cg = ModuleGraph::new(&h);
for fname in check_not_called {
let fnode = find_func(&h, fname);
let fnode = cg.node_index(fnode).unwrap();
assert_eq!(
None,
cg.graph()
.edges_directed(fnode, petgraph::Direction::Incoming)
.next()
);
}
for (fname, tgts) in ["f", "g", "a", "b", "c"].into_iter().zip_eq(calls_fgabc) {
let fnode = find_func(&h, fname);
assert_eq!(
outgoing_calls(&cg, fnode)
.into_iter()
.map(|n| func_name(&h, n).as_str())
.collect::<HashSet<_>>(),
HashSet::from_iter(tgts),
"Calls from {fname}"
);
}
}
fn outgoing_calls<N: HugrNode>(cg: &ModuleGraph<N>, src: N) -> Vec<N> {
cg.out_edges(src).map(|(_, tgt)| func_node(tgt)).collect()
}
#[test]
fn test_filter_caller() {
let mut h = make_test_hugr();
let [g, b, c] = ["g", "b", "c"].map(|n| find_func(&h, n));
inline_acyclic_scoped(
&mut h,
PassScope::Global(Preserve::Entrypoint),
|h, mut call| {
loop {
if call == g {
return true;
};
let Some(parent) = h.get_parent(call) else {
return false;
};
call = parent;
}
},
)
.unwrap();
let cg = ModuleGraph::new(&h);
assert_eq!(outgoing_calls(&cg, g), [find_func(&h, "f")]);
assert_eq!(outgoing_calls(&cg, b), [c]);
}
fn func_node<N: Copy>(cgn: &StaticNode<N>) -> N {
match cgn {
StaticNode::FuncDecl(n) | StaticNode::FuncDefn(n) => *n,
_ => panic!(),
}
}
fn func_name<H: HugrView>(h: &H, n: H::Node) -> &String {
match h.get_optype(n) {
OpType::FuncDecl(fd) => fd.func_name(),
OpType::FuncDefn(fd) => fd.func_name(),
_ => panic!(),
}
}
#[rstest]
#[case::size_zero(InlineFuncsHeuristic::MaxSize(0), vec!["f", "b"])]
#[case::size_unlimited(InlineFuncsHeuristic::MaxSize(usize::MAX), vec!["f"])]
#[case::all(InlineFuncsHeuristic::All, vec!["f"])]
fn inline_functions_pass_heuristic(
#[case] heuristic: InlineFuncsHeuristic,
#[case] g_targets: Vec<&'static str>,
) {
let mut h = make_test_hugr();
run_validating(
InlineFunctionsPass::default().with_heuristic(heuristic),
&mut h,
)
.unwrap();
let cg = ModuleGraph::new(&h);
let g = find_func(&h, "g");
assert_eq!(
outgoing_calls(&cg, g)
.into_iter()
.map(|n| func_name(&h, n).as_str())
.collect::<HashSet<_>>(),
HashSet::from_iter(g_targets),
);
}
#[rstest]
fn inline_functions_pass_hints() {
let g_targets = vec!["f", "c"];
let mut h = make_test_hugr();
let b = find_func(&h, "b");
let c = find_func(&h, "c");
let f = find_func(&h, "f");
h.set_metadata::<InlineAnnotation>(b, InlineAnnotation::BestEffort);
h.set_metadata::<InlineAnnotation>(c, InlineAnnotation::Never);
h.set_metadata::<InlineAnnotation>(f, InlineAnnotation::BestEffort);
run_validating(
InlineFunctionsPass::default().with_heuristic(InlineFuncsHeuristic::MaxSize(0)),
&mut h,
)
.unwrap();
let cg = ModuleGraph::new(&h);
let g = find_func(&h, "g");
assert_eq!(
outgoing_calls(&cg, g)
.into_iter()
.map(|n| func_name(&h, n).as_str())
.collect::<HashSet<_>>(),
HashSet::from_iter(g_targets),
);
}
}