use hugr_core::ops::{OpType, ValidateOp};
use hugr_core::{HugrView, Visibility};
use itertools::{Either, Itertools};
#[derive(Debug, Clone, PartialEq, Eq, derive_more::From, Hash, derive_more::Display)]
#[non_exhaustive]
pub enum PassScope {
EntrypointFlat,
EntrypointRecursive,
Global(#[from] Preserve),
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Hash, derive_more::Display)]
pub enum Preserve {
All,
#[default]
Public,
Entrypoint,
}
impl Default for PassScope {
fn default() -> Self {
Self::Global(Preserve::default())
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum InScope {
Yes,
PreserveInterface,
No,
}
impl PassScope {
pub fn root<'a, H: HugrView>(&'a self, hugr: &'a H) -> Option<H::Node> {
let ep = hugr.entrypoint();
match self {
Self::EntrypointFlat | Self::EntrypointRecursive => {
(ep != hugr.module_root()).then_some(ep)
}
Self::Global(_) => Some(hugr.module_root()),
}
}
pub fn preserve_interface<'a, H: HugrView>(
&'a self,
hugr: &'a H,
) -> impl Iterator<Item = H::Node> + 'a {
self.root(hugr).into_iter().flat_map(move |r| {
let ep = hugr.entrypoint();
[r, ep]
.into_iter()
.unique()
.chain(hugr.children(hugr.module_root()).filter(move |n| {
if *n == ep {
return false; };
match self {
Self::Global(Preserve::All) => return true,
Self::Global(Preserve::Public) => (), Self::Global(Preserve::Entrypoint) if ep == hugr.module_root() => (), _ => return false,
};
let vis = match hugr.get_optype(*n) {
OpType::FuncDecl(fd) => fd.visibility(),
OpType::FuncDefn(fd) => fd.visibility(),
_ => return false,
};
vis == &Visibility::Public
}))
})
}
pub fn regions<'a, H: HugrView>(&'a self, hugr: &'a H) -> impl Iterator<Item = H::Node> {
self.root(hugr).into_iter().flat_map(|r| {
if self.recursive() {
let mut iter = hugr.descendants(r);
if r == hugr.module_root() {
assert_eq!(iter.next(), Some(hugr.module_root())); }
Either::Left(iter.filter(|n| {
hugr.get_optype(*n)
.validity_flags::<H::Node>()
.requires_children
}))
} else {
assert_ne!(r, hugr.module_root());
Either::Right(std::iter::once(r))
}
})
}
pub fn in_scope<H: HugrView>(&self, hugr: &H, node: H::Node) -> InScope {
let Some(r) = self.root(hugr) else {
return InScope::No;
};
'in_subtree: {
if r != hugr.module_root() {
let mut anc = Some(node);
while let Some(n) = anc {
if n == r {
break 'in_subtree;
};
anc = hugr.get_parent(n);
}
return InScope::No;
}
}
if self.preserve_interface(hugr).contains(&node) {
InScope::PreserveInterface
} else {
InScope::Yes
}
}
pub fn recursive(&self) -> bool {
!matches!(self, Self::EntrypointFlat)
}
}
#[cfg(test)]
mod test {
use std::collections::HashSet;
use hugr_core::hugr::hugrmut::HugrMut;
use rstest::{fixture, rstest};
use hugr_core::builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer};
use hugr_core::ops::Value;
use hugr_core::ops::handle::NodeHandle;
use hugr_core::std_extensions::arithmetic::int_types::ConstInt;
use hugr_core::types::Signature;
use hugr_core::{Hugr, Node};
use super::*;
#[derive(Debug)]
struct TestHugr {
hugr: Hugr,
module_root: Node,
public_func: Node,
public_func_nested: Node,
private_func: Node,
public_func_decl: Node,
private_func_decl: Node,
const_def: Node,
}
#[fixture]
fn th() -> TestHugr {
let mut h = ModuleBuilder::new();
let module_root = h.container_node();
let (public_func, public_func_nested) = {
let mut pub_f = h
.define_function_vis(
"public_func",
Signature::new_endo(vec![]),
Visibility::Public,
)
.unwrap();
let public_func_nested = {
let pub_f_nested = pub_f.dfg_builder(Signature::new_endo(vec![]), []).unwrap();
pub_f_nested.finish_sub_container().unwrap().node()
};
(
pub_f.finish_sub_container().unwrap().node(),
public_func_nested,
)
};
let private_func = {
let priv_f = h
.define_function_vis(
"private_func",
Signature::new_endo(vec![]),
Visibility::Private,
)
.unwrap();
priv_f.finish_sub_container().unwrap().node()
};
let public_func_decl = h
.declare_vis(
"public_func_decl",
Signature::new_endo(vec![]).into(),
Visibility::Public,
)
.unwrap()
.node();
let private_func_decl = h
.declare_vis(
"private_func_decl",
Signature::new_endo(vec![]).into(),
Visibility::Private,
)
.unwrap()
.node();
let const_def = h
.add_constant(Value::from(ConstInt::new_u(5, 7).unwrap()))
.node();
TestHugr {
hugr: h.finish_hugr().unwrap(),
module_root,
public_func,
public_func_nested,
private_func,
public_func_decl,
private_func_decl,
const_def,
}
}
#[rstest]
#[case(PassScope::EntrypointFlat, false)]
#[case(PassScope::EntrypointRecursive, true)]
fn scope_entrypoint(mut th: TestHugr, #[case] scope: PassScope, #[case] recursive: bool) {
assert_eq!(scope.recursive(), recursive);
th.hugr.set_entrypoint(th.module_root);
assert_eq!(scope.root(&th.hugr), None);
assert_eq!(scope.regions(&th.hugr).next(), None);
for n in th.hugr.nodes() {
assert_eq!(scope.in_scope(&th.hugr, n), InScope::No);
}
th.hugr.set_entrypoint(th.public_func);
assert_eq!(scope.root(&th.hugr), Some(th.public_func));
let expected_regions = match recursive {
true => vec![th.public_func, th.public_func_nested],
false => vec![th.public_func],
};
assert_eq!(scope.regions(&th.hugr).collect_vec(), expected_regions);
assert_eq!(scope.in_scope(&th.hugr, th.module_root), InScope::No);
assert_eq!(
scope.in_scope(&th.hugr, th.public_func),
InScope::PreserveInterface
);
assert_eq!(
scope.in_scope(&th.hugr, th.public_func_nested),
InScope::Yes
);
for n in [
th.module_root,
th.private_func,
th.public_func_decl,
th.private_func_decl,
th.const_def,
] {
assert_eq!(scope.in_scope(&th.hugr, n), InScope::No);
}
th.hugr.set_entrypoint(th.public_func_nested);
assert_eq!(scope.root(&th.hugr), Some(th.public_func_nested));
assert_eq!(
scope.regions(&th.hugr).collect_vec(),
[th.public_func_nested]
);
for n in [
th.module_root,
th.public_func,
th.private_func,
th.public_func_decl,
th.private_func_decl,
th.const_def,
] {
assert_eq!(scope.in_scope(&th.hugr, n), InScope::No)
}
assert_eq!(
scope.in_scope(&th.hugr, th.public_func_nested),
InScope::PreserveInterface
);
}
#[rstest]
fn preserve_all(th: TestHugr) {
let preserve = [
th.public_func,
th.private_func,
th.public_func_decl,
th.private_func_decl,
th.const_def,
];
check_preserve(&th, Preserve::All, preserve)
}
fn check_preserve(
th: &TestHugr,
preserve: Preserve,
expected_chs: impl IntoIterator<Item = Node>,
) {
let scope = PassScope::Global(preserve);
assert!(scope.recursive());
let expected_chs = expected_chs.into_iter().collect::<HashSet<_>>();
assert_eq!(scope.root(&th.hugr), Some(th.module_root));
assert_eq!(
scope.regions(&th.hugr).collect_vec(),
[th.public_func, th.private_func, th.public_func_nested]
);
assert_eq!(
scope.in_scope(&th.hugr, th.module_root),
InScope::PreserveInterface
);
for n in [
th.public_func,
th.private_func,
th.public_func_decl,
th.public_func_nested,
th.private_func_decl,
th.const_def,
] {
let expected = if expected_chs.contains(&n) {
InScope::PreserveInterface
} else {
InScope::Yes
};
assert_eq!(
scope.in_scope(&th.hugr, n),
expected,
"{:?} among {:?}",
n,
th
);
}
let mut preserve = expected_chs;
preserve.insert(th.module_root);
assert_eq!(preserve, scope.preserve_interface(&th.hugr).collect());
}
#[rstest]
fn preserve_public(th: TestHugr) {
let preserve = [th.public_func, th.public_func_decl];
check_preserve(&th, Preserve::Public, preserve)
}
#[rstest]
fn preserve_entrypoint(mut th: TestHugr) {
th.hugr.set_entrypoint(th.hugr.module_root());
let preserve = [th.public_func, th.public_func_decl];
check_preserve(&th, Preserve::Entrypoint, preserve);
th.hugr.set_entrypoint(th.public_func_nested);
let preserve = [th.public_func_nested];
check_preserve(&th, Preserve::Entrypoint, preserve)
}
}