use hugr_core::hugr::internal::HugrInternals;
use hugr_core::{HugrView, Node, hugr::hugrmut::HugrMut, ops::OpType};
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;
use crate::passes::composable::WithScope;
use crate::passes::{ComposablePass, PassScope};
#[derive(Clone)]
pub struct DeadCodeElimPass<H: HugrView> {
entry_points: Vec<H::Node>,
scope: Option<PassScope>,
preserve_callback: Arc<PreserveCallback<H>>,
}
impl<H: HugrView + 'static> Default for DeadCodeElimPass<H> {
fn default() -> Self {
Self {
entry_points: Default::default(),
scope: None,
preserve_callback: Arc::new(PreserveNode::default_for),
}
}
}
impl<H: HugrView> Debug for DeadCodeElimPass<H> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
#[expect(unused)] #[derive(Debug)]
struct DCEDebug<'a, N> {
entry_points: &'a Vec<N>,
scope: &'a Option<PassScope>,
}
Debug::fmt(
&DCEDebug {
entry_points: &self.entry_points,
scope: &self.scope,
},
f,
)
}
}
pub type PreserveCallback<H> = dyn Fn(&H, <H as HugrInternals>::Node) -> PreserveNode;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum PreserveNode {
MustKeep,
CanRemoveIgnoringChildren,
DeferToChildren,
}
impl PreserveNode {
pub fn default_for<H: HugrView>(h: &H, n: H::Node) -> PreserveNode {
match h.get_optype(n) {
OpType::CFG(_) | OpType::TailLoop(_) | OpType::Call(_) => PreserveNode::MustKeep,
_ => Self::DeferToChildren,
}
}
}
#[derive(Clone, Debug, thiserror::Error, PartialEq)]
#[non_exhaustive]
pub enum DeadCodeElimError<N: Display = Node> {
#[error("Node {_0} does not exist in the Hugr")]
NodeNotFound(N),
}
impl<H: HugrView> DeadCodeElimPass<H> {
pub fn set_preserve_callback(mut self, cb: Arc<PreserveCallback<H>>) -> Self {
self.preserve_callback = cb;
self
}
pub fn with_entry_points(mut self, entry_points: impl IntoIterator<Item = H::Node>) -> Self {
self.entry_points.extend(entry_points);
self
}
fn find_needed_nodes(&self, h: &H) -> Result<HashSet<H::Node>, DeadCodeElimError<H::Node>> {
let mut must_preserve = HashMap::new();
let mut needed = HashSet::new();
let mut q = VecDeque::from_iter(self.entry_points.iter().copied());
match &self.scope {
None => q.push_back(h.entrypoint()),
Some(scope) => q.extend(scope.preserve_interface(h)),
};
while let Some(n) = q.pop_front() {
if !h.contains_node(n) {
return Err(DeadCodeElimError::NodeNotFound(n));
}
if !needed.insert(n) {
continue;
}
q.extend(h.get_parent(n));
for (i, ch) in h.children(n).enumerate() {
if self.must_preserve(h, &mut must_preserve, ch)
|| match h.get_optype(ch) {
OpType::Case(_) | OpType::ExitBlock(_)
| OpType::AliasDecl(_) | OpType::AliasDefn(_)
| OpType::Input(_) | OpType::Output(_) => true,
OpType::DataflowBlock(_) => h.get_optype(n).is_cfg() && i == 0,
_ => false,
}
{
q.push_back(ch);
}
}
if matches!(
h.get_optype(n),
OpType::DataflowBlock(_) | OpType::ExitBlock(_)
) {
q.extend(h.output_neighbours(n))
} else {
q.extend(h.input_neighbours(n));
}
if let Some(sig) = h.signature(n) {
for op in sig.output_ports() {
if !sig.out_port_type(op).unwrap().copyable() {
q.extend(h.linked_inputs(n, op).map(|(n, _inp)| n))
}
}
}
}
Ok(needed)
}
fn must_preserve(&self, h: &H, cache: &mut HashMap<H::Node, bool>, n: H::Node) -> bool {
if let Some(res) = cache.get(&n) {
return *res;
}
let res = match self.preserve_callback.as_ref()(h, n) {
PreserveNode::MustKeep => true,
PreserveNode::CanRemoveIgnoringChildren => false,
PreserveNode::DeferToChildren => {
h.children(n).any(|ch| self.must_preserve(h, cache, ch))
}
};
cache.insert(n, res);
res
}
}
impl<H: HugrMut> ComposablePass<H> for DeadCodeElimPass<H> {
type Error = DeadCodeElimError<H::Node>;
type Result = ();
fn run(&self, hugr: &mut H) -> Result<(), Self::Error> {
let root = match &self.scope {
None => hugr.entrypoint(),
Some(scope) => match scope.root(hugr) {
Some(root) => root,
None => return Ok(()),
},
};
let needed = self.find_needed_nodes(&*hugr)?;
let remove = hugr
.descendants(root)
.filter(|n| !needed.contains(n))
.collect::<Vec<_>>();
for n in remove {
hugr.remove_node(n);
}
Ok(())
}
}
impl<H: HugrMut> WithScope for DeadCodeElimPass<H> {
fn with_scope(mut self, scope: impl Into<PassScope>) -> Self {
self.scope = Some(scope.into());
self
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use hugr_core::builder::{
CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
HugrBuilder, endo_sig, inout_sig,
};
use hugr_core::extension::prelude::{ConstUsize, bool_t, qb_t, usize_t};
use hugr_core::extension::{ExtensionId, Version};
use hugr_core::ops::{ExtensionOp, OpType};
use hugr_core::ops::{OpTag, OpTrait, handle::NodeHandle};
use hugr_core::types::Signature;
use hugr_core::{Extension, Hugr};
use hugr_core::{HugrView, ops::Value, type_row};
use itertools::Itertools;
use crate::passes::ComposablePass;
use super::{DeadCodeElimPass, PreserveNode};
#[test]
fn test_cfg_callback() {
let mut cb = CFGBuilder::new(Signature::new_endo(type_row![])).unwrap();
let cst_unused = cb.add_constant(Value::from(ConstUsize::new(3)));
let cst_used_in_dfg = cb.add_constant(Value::from(ConstUsize::new(5)));
let cst_used = cb.add_constant(Value::unary_unit_sum());
let mut block = cb.entry_builder([type_row![]], type_row![]).unwrap();
let mut dfg_unused = block
.dfg_builder(Signature::new(type_row![], [usize_t()]), [])
.unwrap();
let lc_unused = dfg_unused.load_const(&cst_unused);
let lc1 = dfg_unused.load_const(&cst_used_in_dfg);
let dfg_unused = dfg_unused.finish_with_outputs([lc1]).unwrap().node();
let pred = block.load_const(&cst_used);
let block = block.finish_with_outputs(pred, []).unwrap();
let exit = cb.exit_block();
cb.branch(&block, 0, &exit).unwrap();
let orig = cb.finish_hugr().unwrap();
for dce in [
DeadCodeElimPass::<Hugr>::default(),
DeadCodeElimPass::default().set_preserve_callback(Arc::new(move |h, n| {
if n == dfg_unused || h.get_optype(n).is_const() {
PreserveNode::CanRemoveIgnoringChildren
} else {
PreserveNode::MustKeep
}
})),
] {
let mut h = orig.clone();
dce.run(&mut h).unwrap();
assert_eq!(
h.children(h.entrypoint()).collect_vec(),
[block.node(), exit.node(), cst_used.node()]
);
assert_eq!(
h.children(block.node())
.map(|n| h.get_optype(n).tag())
.collect_vec(),
[OpTag::Input, OpTag::Output, OpTag::LoadConst]
);
}
fn keep_if(b: bool) -> PreserveNode {
if b {
PreserveNode::MustKeep
} else {
PreserveNode::DeferToChildren
}
}
for dce in [
DeadCodeElimPass::<Hugr>::default()
.set_preserve_callback(Arc::new(|_, _| PreserveNode::MustKeep)),
DeadCodeElimPass::default()
.set_preserve_callback(Arc::new(move |_, n| keep_if(n == lc_unused.node()))),
] {
let mut h = orig.clone();
dce.run(&mut h).unwrap();
assert_eq!(orig, h);
}
for dce in [
DeadCodeElimPass::<Hugr>::default()
.set_preserve_callback(Arc::new(move |_, n| keep_if(n == dfg_unused))),
DeadCodeElimPass::default()
.set_preserve_callback(Arc::new(move |_, n| keep_if(n == lc1.node()))),
] {
let mut h = orig.clone();
dce.run(&mut h).unwrap();
assert_eq!(
h.children(h.entrypoint()).collect_vec(),
[
block.node(),
exit.node(),
cst_used_in_dfg.node(),
cst_used.node()
]
);
assert_eq!(
h.children(block.node()).skip(2).collect_vec(),
[dfg_unused, pred.node()]
);
assert_eq!(
h.children(dfg_unused.node())
.map(|n| h.get_optype(n).tag())
.collect_vec(),
[OpTag::Input, OpTag::Output, OpTag::LoadConst]
);
}
{
let cst_unused = cst_unused.node();
let mut h = orig.clone();
DeadCodeElimPass::<Hugr>::default()
.set_preserve_callback(Arc::new(move |_, n| keep_if(n == cst_unused)))
.run(&mut h)
.unwrap();
assert_eq!(
h.children(h.entrypoint()).collect_vec(),
[block.node(), exit.node(), cst_unused, cst_used.node()]
);
assert_eq!(
h.children(block.node())
.map(|n| h.get_optype(n).tag())
.collect_vec(),
[OpTag::Input, OpTag::Output, OpTag::LoadConst]
);
}
}
#[test]
fn preserve_linear() {
let test_ext = Extension::new_arc(
ExtensionId::new_unchecked("test_qext"),
Version::new(0, 0, 0),
|e, w| {
e.add_op("new".into(), "".into(), inout_sig(vec![], [qb_t()]), w)
.unwrap();
e.add_op("gate".into(), "".into(), endo_sig([qb_t()]), w)
.unwrap();
e.add_op(
"measure".into(),
"".into(),
inout_sig([qb_t()], [bool_t()]),
w,
)
.unwrap();
e.add_op("not".into(), "".into(), endo_sig([bool_t()]), w)
.unwrap();
},
);
let [new, gate, measure, not] = ["new", "gate", "measure", "not"]
.map(|n| ExtensionOp::new(test_ext.get_op(n).unwrap().clone(), []).unwrap());
let mut dfb = DFGBuilder::new(endo_sig([qb_t()])).unwrap();
let qn = dfb.add_dataflow_op(new.clone(), []).unwrap().outputs();
let [_] = dfb
.add_dataflow_op(measure.clone(), qn)
.unwrap()
.outputs_arr();
let [q_in] = dfb.input_wires_arr();
let [h_in] = dfb
.add_dataflow_op(gate.clone(), [q_in])
.unwrap()
.outputs_arr();
let [b] = dfb.add_dataflow_op(measure, [h_in]).unwrap().outputs_arr();
dfb.add_dataflow_op(not, [b]).unwrap();
let q = dfb.add_dataflow_op(new, []).unwrap().outputs();
let outs = dfb.add_dataflow_op(gate, q).unwrap().outputs();
let mut h = dfb.finish_hugr_with_outputs(outs).unwrap();
DeadCodeElimPass::default().run(&mut h).unwrap();
h.validate().unwrap();
let ext_ops = h
.nodes()
.filter_map(|n| h.get_optype(n).as_extension_op())
.map(ExtensionOp::unqualified_id);
assert_eq!(
ext_ops.sorted().collect_vec(),
["gate", "gate", "measure", "new"]
);
}
#[test]
fn remove_unreachable_bb() {
let mut cb = CFGBuilder::new(Signature::new_endo(type_row![])).unwrap();
let cst_unused = cb.add_constant(Value::from(ConstUsize::new(3)));
let b1_pred = cb.add_constant(Value::unary_unit_sum());
let b2_pred = cb.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
let mut entry = cb.entry_builder([type_row![]], type_row![]).unwrap();
let pred1 = entry.load_const(&b1_pred);
let entry = entry.finish_with_outputs(pred1, []).unwrap();
let mut block_reachable = cb
.simple_block_builder(Signature::new(type_row![], type_row![]), 1)
.unwrap();
let pred2 = block_reachable.load_const(&b1_pred);
let block_reachable = block_reachable.finish_with_outputs(pred2, []).unwrap();
let mut block_unreachable = cb
.simple_block_builder(Signature::new(type_row![], type_row![]), 2)
.unwrap();
let _ = block_unreachable.load_const(&cst_unused);
let pred3 = block_unreachable.load_const(&b2_pred);
let block_unreachable = block_unreachable.finish_with_outputs(pred3, []).unwrap();
let exit = cb.exit_block();
cb.branch(&entry, 0, &block_reachable).unwrap();
cb.branch(&block_reachable, 0, &exit).unwrap();
cb.branch(&block_unreachable, 0, &exit).unwrap();
cb.branch(&block_unreachable, 1, &block_unreachable)
.unwrap();
let mut h = cb.finish_hugr().unwrap();
h.validate().unwrap();
let num_nodes_before = h.nodes().count();
let cfg_node = h.entrypoint();
let num_cfg_children_before: usize = h
.children(cfg_node)
.filter(|child| matches!(h.get_optype(*child), OpType::DataflowBlock(_)))
.count();
DeadCodeElimPass::default().run(&mut h).unwrap();
h.validate().unwrap();
let num_nodes_after = h.nodes().count();
assert_eq!(num_nodes_before - num_nodes_after, 7);
assert!(!h.contains_node(block_unreachable.node()));
assert!(h.get_optype(cfg_node).is_cfg());
let num_cfg_children_after: usize = h
.children(cfg_node)
.filter(|child| matches!(h.get_optype(*child), OpType::DataflowBlock(_)))
.count();
assert_eq!(num_cfg_children_after, num_cfg_children_before - 1);
let exit_preds = h.input_neighbours(exit.node()).collect_vec();
assert_eq!(exit_preds.len(), 1);
}
}