use std::collections::{HashMap, HashSet, VecDeque};
use hugr_core::core::HugrNode;
use hugr_core::hugr::internal::PortgraphNodeMap;
use hugr_core::hugr::{HugrError, hugrmut::HugrMut};
use hugr_core::ops::{OpTag, OpTrait};
use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort};
use itertools::Itertools;
use petgraph::visit::Walker;
use crate::passes::composable::WithScope;
use crate::passes::{ComposablePass, PassScope};
#[derive(Debug, Default, Clone)]
pub struct RedundantOrderEdgesPass {
scope: PassScope,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, derive_more::AddAssign)]
pub struct RedundantOrderEdgesResult {
pub edges_removed: usize,
}
impl RedundantOrderEdgesPass {
pub fn run_on_df_region<H: HugrMut>(
&self,
hugr: &mut H,
parent: H::Node,
region_candidates: &mut VecDeque<H::Node>,
) -> Result<RedundantOrderEdgesResult, HugrError> {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct PredecessorOrderEdges<N: HugrNode> {
from_node: N,
from_port: OutgoingPort,
to_node: N,
to_port: IncomingPort,
}
let mut predecessor_order_edges: HashMap<H::Node, HashSet<PredecessorOrderEdges<H::Node>>> =
HashMap::new();
let mut to_remove = Vec::new();
let (region, node_map) = hugr.region_portgraph(parent);
let postorder = petgraph::visit::Topo::new(®ion);
for pg_child in postorder.iter(®ion) {
let child = node_map.from_portgraph(pg_child);
let op = hugr.get_optype(child);
if self.scope.recursive() && hugr.first_child(child).is_some() {
region_candidates.push_back(child);
}
let predecessor_edges = predecessor_order_edges.remove(&child).unwrap_or_default();
let removable_edges: HashSet<PredecessorOrderEdges<H::Node>> = predecessor_edges
.iter()
.filter(|edge| edge.to_node == child)
.copied()
.collect();
let predecessor_edges: HashSet<PredecessorOrderEdges<H::Node>> = predecessor_edges
.difference(&removable_edges)
.copied()
.collect();
let new_edges = match op.other_output_port() {
Some(out_order_port) => hugr
.linked_inputs(child, out_order_port)
.filter(|(to_node, _)| {
hugr.get_parent(*to_node) == Some(parent)
&& hugr.first_child(*to_node).is_none()
})
.map(|(to_node, to_port)| PredecessorOrderEdges {
from_node: child,
from_port: out_order_port,
to_node,
to_port,
})
.collect_vec(),
None => vec![],
};
for out_port in op.value_output_ports().chain(op.static_output_port()) {
for (to_node, _) in hugr.linked_inputs(child, out_port) {
if hugr.get_parent(to_node) != Some(parent) {
continue;
}
let neigh_predecessor_order_edges =
predecessor_order_edges.entry(to_node).or_default();
neigh_predecessor_order_edges.extend(predecessor_edges.clone());
neigh_predecessor_order_edges.extend(new_edges.clone());
}
}
if let Some(out_port) = op.other_output_port() {
for (to_node, _) in hugr.linked_inputs(child, out_port) {
if hugr.get_parent(to_node) != Some(parent) {
continue;
}
let neigh_predecessor_order_edges =
predecessor_order_edges.entry(to_node).or_default();
neigh_predecessor_order_edges.extend(predecessor_edges.clone());
}
}
to_remove.extend(removable_edges);
}
drop(region);
let edges_removed = to_remove.len();
for edge in to_remove {
hugr.disconnect_edge(edge.from_node, edge.from_port, edge.to_node, edge.to_port);
}
Ok(RedundantOrderEdgesResult { edges_removed })
}
}
impl<H: HugrMut<Node = Node>> ComposablePass<H> for RedundantOrderEdgesPass {
type Error = HugrError;
type Result = RedundantOrderEdgesResult;
fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
let mut region_candidates = VecDeque::from_iter(self.scope.root(hugr));
let mut result = RedundantOrderEdgesResult::default();
while let Some(region) = region_candidates.pop_front() {
let op = hugr.get_optype(region);
if OpTag::DataflowParent.is_superset(op.tag()) {
result += self.run_on_df_region(hugr, region, &mut region_candidates)?;
} else if self.scope.recursive() {
region_candidates.extend(hugr.children(region));
}
}
Ok(result)
}
}
impl WithScope for RedundantOrderEdgesPass {
fn with_scope(mut self, scope: impl Into<PassScope>) -> Self {
self.scope = scope.into();
self
}
}
#[cfg(test)]
mod tests {
use hugr_core::builder::{Dataflow, DataflowHugr, FunctionBuilder, SubContainer};
use hugr_core::extension::prelude::{Noop, bool_t};
use hugr_core::ops::handle::NodeHandle;
use hugr_core::types::Signature;
use super::*;
#[test]
fn test_redundant_order_edges() {
let mut hugr = FunctionBuilder::new("f", Signature::new_endo([bool_t()])).unwrap();
let op = Noop::new(bool_t());
let [input, output] = hugr.io();
let [b1] = hugr.input_wires_arr();
let noop1 = hugr.add_dataflow_op(Noop::new(bool_t()), [b1]).unwrap();
let noop2 = hugr
.add_dataflow_op(op.clone(), [noop1.out_wire(0)])
.unwrap();
let noop3 = hugr
.add_dataflow_op(op.clone(), [noop2.out_wire(0)])
.unwrap();
let noop4 = hugr.add_dataflow_op(op.clone(), [b1]).unwrap();
let noop5 = hugr
.add_dataflow_op(op.clone(), [noop4.out_wire(0)])
.unwrap();
let nested_op = hugr
.dfg_builder(Signature::new(vec![bool_t()], vec![]), [noop5.out_wire(0)])
.unwrap()
.finish_sub_container()
.unwrap();
hugr.set_order(&input, &noop2);
hugr.set_order(&noop1, &output);
hugr.set_order(&noop4, &noop3);
hugr.set_order(&noop5, &noop2);
hugr.set_order(&noop3, &nested_op.node());
let mut hugr = hugr.finish_hugr_with_outputs([noop5.out_wire(0)]).unwrap();
let result = RedundantOrderEdgesPass::default().run(&mut hugr).unwrap();
assert_eq!(result.edges_removed, 2);
let order_in = IncomingPort::from(1);
let order_out = OutgoingPort::from(1);
assert_eq!(hugr.single_linked_input(input, order_out), None);
assert_eq!(
hugr.single_linked_input(noop1.node(), order_out),
Some((output, order_in))
);
assert_eq!(hugr.single_linked_input(noop4.node(), order_out), None);
assert_eq!(
hugr.single_linked_input(noop5.node(), order_out),
Some((noop2.node(), order_in))
);
assert_eq!(
hugr.single_linked_input(noop3.node(), order_out),
Some((nested_op.node(), order_in))
);
}
}