use std::collections::VecDeque;
use hugr_core::builder::{DFGBuilder, Dataflow, DataflowHugr};
use hugr_core::extension::prelude::{MakeTuple, UnpackTuple};
use hugr_core::hugr::SimpleReplacementError;
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::hugr::views::SiblingSubgraph;
use hugr_core::hugr::views::sibling_subgraph::TopoConvexChecker;
use hugr_core::ops::{OpTrait, OpType};
use hugr_core::types::Type;
use hugr_core::{HugrView, Node, PortIndex, SimpleReplacement};
use itertools::Itertools;
use crate::passes::composable::WithScope;
use crate::passes::{ComposablePass, PassScope};
#[derive(Debug, Clone, Default)]
pub struct UntuplePass {
scope: PassScope,
}
#[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)]
#[non_exhaustive]
pub enum UntupleError {
RewriteError(SimpleReplacementError),
}
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct UntupleResult {
pub rewrites_applied: usize,
}
impl UntuplePass {
pub fn all_rewrites<H: HugrView<Node = Node>>(
&self,
hugr: &H,
) -> Vec<SimpleReplacement<H::Node>> {
let Some(parent) = self.scope.root(hugr) else {
return vec![];
};
find_rewrites(hugr, parent, self.scope.recursive())
}
}
fn find_rewrites<H: HugrView>(
hugr: &H,
parent: H::Node,
recursive: bool,
) -> Vec<SimpleReplacement<H::Node>> {
let mut res = Vec::new();
let mut children_queue = VecDeque::new();
children_queue.push_back(parent);
while let Some(parent) = children_queue.pop_front() {
let mut convex_checker: Option<TopoConvexChecker<H>> = None;
for node in hugr.children(parent) {
let op = hugr.get_optype(node);
if let Some(rw) = make_rewrite(hugr, &mut convex_checker, node, op) {
res.push(rw);
}
if recursive && op.is_container() {
children_queue.push_back(node);
}
}
}
res
}
impl<H: HugrMut<Node = Node>> ComposablePass<H> for UntuplePass {
type Error = UntupleError;
type Result = UntupleResult;
fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
let rewrites = self.all_rewrites(hugr);
let rewrites_applied = rewrites.len();
for rewrite in rewrites {
hugr.apply_patch(rewrite)?;
}
Ok(UntupleResult { rewrites_applied })
}
}
impl WithScope for UntuplePass {
fn with_scope(mut self, scope: impl Into<PassScope>) -> Self {
self.scope = scope.into();
self
}
}
fn is_make_tuple(optype: &OpType) -> bool {
optype.cast::<MakeTuple>().is_some()
}
fn is_unpack_tuple(optype: &OpType) -> bool {
optype.cast::<UnpackTuple>().is_some()
}
fn make_rewrite<'h, T: HugrView>(
hugr: &'h T,
convex_checker: &mut Option<TopoConvexChecker<'h, T>>,
node: T::Node,
op: &OpType,
) -> Option<SimpleReplacement<T::Node>> {
if !is_make_tuple(op) {
return None;
}
let has_order_edges = |node: T::Node| -> bool {
let op = hugr.get_optype(node);
let has_input_order = op
.other_input_port()
.and_then(|p| hugr.linked_outputs(node, p).next())
.is_some();
let has_output_order = op
.other_output_port()
.and_then(|p| hugr.linked_inputs(node, p).next())
.is_some();
has_input_order || has_output_order
};
if has_order_edges(node) {
return None;
}
let tuple_types = op.dataflow_signature().unwrap().input_types().to_vec();
let node_parent = hugr.get_parent(node);
let links = hugr
.linked_inputs(node, 0)
.map(|(neigh, _)| neigh)
.collect_vec();
let unpack_nodes = links
.iter()
.filter(|&&neigh| hugr.get_parent(neigh) == node_parent)
.filter(|&&neigh| is_unpack_tuple(hugr.get_optype(neigh)))
.filter(|&&neigh| !has_order_edges(neigh))
.copied()
.collect_vec();
if unpack_nodes.is_empty() && !links.is_empty() {
return None;
}
let num_other_outputs = links.len() - unpack_nodes.len();
Some(remove_pack_unpack(
hugr,
convex_checker,
&tuple_types,
node,
unpack_nodes,
num_other_outputs,
))
}
fn remove_pack_unpack<'h, T: HugrView>(
hugr: &'h T,
convex_checker: &mut Option<TopoConvexChecker<'h, T>>,
tuple_types: &[Type],
pack_node: T::Node,
unpack_nodes: Vec<T::Node>,
num_other_outputs: usize,
) -> SimpleReplacement<T::Node> {
let parent = hugr.get_parent(pack_node).expect("pack_node has no parent");
let checker = convex_checker.get_or_insert_with(|| TopoConvexChecker::new(hugr, parent));
let mut nodes = unpack_nodes.clone();
nodes.push(pack_node);
let subcirc = SiblingSubgraph::try_from_nodes_with_checker(nodes, hugr, checker).unwrap();
let subcirc_signature = subcirc.signature(hugr);
let mut replacement = DFGBuilder::new(subcirc_signature).unwrap();
let mut replacement_outputs =
Vec::with_capacity(unpack_nodes.len() * tuple_types.len() + num_other_outputs);
let replacement_inputs = replacement.input_wires().collect_vec();
for unpack_node in unpack_nodes {
for out_port in hugr.node_outputs(unpack_node) {
if hugr.is_linked(unpack_node, out_port) {
let input = replacement_inputs[out_port.index()];
replacement_outputs.push(input);
}
}
}
if num_other_outputs > 0 {
let op = MakeTuple::new(tuple_types.to_vec().into());
let [tuple] = replacement
.add_dataflow_op(op, replacement.input_wires())
.unwrap()
.outputs_arr();
replacement_outputs.extend(std::iter::repeat_n(tuple, num_other_outputs));
}
let replacement = replacement
.finish_hugr_with_outputs(replacement_outputs)
.unwrap_or_else(|e| {
panic!("Failed to create replacement for removing tuple pack/unpack operations. {e}")
});
subcirc
.create_simple_replacement(hugr, replacement)
.unwrap_or_else(|e| {
panic!("Failed to create rewrite for removing tuple pack/unpack operations. {e}")
})
}
#[cfg(test)]
mod test {
use super::*;
use crate::passes::composable::WithScope;
use hugr_core::Hugr;
use hugr_core::builder::FunctionBuilder;
use hugr_core::extension::prelude::{UnpackTuple, bool_t, qb_t};
use hugr_core::ops::handle::NodeHandle;
use hugr_core::std_extensions::arithmetic::float_types::float64_type;
use hugr_core::types::Signature;
use rstest::{fixture, rstest};
#[fixture]
fn unused_pack() -> Hugr {
let mut h = DFGBuilder::new(Signature::new(vec![bool_t(), bool_t()], vec![])).unwrap();
let mut inps = h.input_wires();
let b1 = inps.next().unwrap();
let b2 = inps.next().unwrap();
let _tuple = h.make_tuple([b1, b2]).unwrap();
h.finish_hugr_with_outputs([]).unwrap()
}
#[fixture]
fn simple_pack_unpack() -> Hugr {
let mut h = DFGBuilder::new(Signature::new_endo([qb_t(), bool_t()])).unwrap();
let mut inps = h.input_wires();
let qb1 = inps.next().unwrap();
let b2 = inps.next().unwrap();
let tuple = h.make_tuple([qb1, b2]).unwrap();
let op = UnpackTuple::new(vec![qb_t(), bool_t()].into());
let [qb1, b2] = h.add_dataflow_op(op, [tuple]).unwrap().outputs_arr();
h.finish_hugr_with_outputs([qb1, b2]).unwrap()
}
#[fixture]
fn ordered_pack_unpack() -> Hugr {
let mut h = DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()])).unwrap();
let mut inps = h.input_wires();
let qb1 = inps.next().unwrap();
let b2 = inps.next().unwrap();
let tuple = h.make_tuple([qb1, b2]).unwrap();
h.set_order(&h.input(), &tuple.node());
let op = UnpackTuple::new(vec![qb_t(), bool_t()].into());
let untuple = h.add_dataflow_op(op, [tuple]).unwrap();
let [qb1, b2] = untuple.outputs_arr();
h.set_order(&tuple.node(), &untuple.node());
h.set_order(&untuple.node(), &h.output());
h.finish_hugr_with_outputs([qb1, b2]).unwrap()
}
#[fixture]
fn outgoing_ordered_pack_unpack() -> Hugr {
let mut h = DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()])).unwrap();
let mut inps = h.input_wires();
let qb1 = inps.next().unwrap();
let b2 = inps.next().unwrap();
let tuple = h.make_tuple([qb1, b2]).unwrap();
let op = UnpackTuple::new(vec![qb_t(), bool_t()].into());
let untuple = h.add_dataflow_op(op, [tuple]).unwrap();
let [qb1, b2] = untuple.outputs_arr();
h.set_order(&tuple.node(), &h.output());
h.finish_hugr_with_outputs([qb1, b2]).unwrap()
}
#[fixture]
fn incoming_ordered_pack_unpack() -> Hugr {
let mut h = DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()])).unwrap();
let mut inps = h.input_wires();
let qb1 = inps.next().unwrap();
let b2 = inps.next().unwrap();
let tuple = h.make_tuple([qb1, b2]).unwrap();
let op = UnpackTuple::new(vec![qb_t(), bool_t()].into());
let untuple = h.add_dataflow_op(op, [tuple]).unwrap();
let [qb1, b2] = untuple.outputs_arr();
h.set_order(&h.input(), &untuple.node());
h.finish_hugr_with_outputs([qb1, b2]).unwrap()
}
#[fixture]
fn multi_unpack() -> Hugr {
let mut h = DFGBuilder::new(Signature::new(
vec![bool_t(), bool_t()],
vec![bool_t(), bool_t(), bool_t(), bool_t()],
))
.unwrap();
let mut inps = h.input_wires();
let b1 = inps.next().unwrap();
let b2 = inps.next().unwrap();
let tuple = h.make_tuple([b1, b2]).unwrap();
let op = UnpackTuple::new(vec![bool_t(), bool_t()].into());
let [b1, b2] = h.add_dataflow_op(op, [tuple]).unwrap().outputs_arr();
let op = UnpackTuple::new(vec![bool_t(), bool_t()].into());
let [b3, b4] = h.add_dataflow_op(op, [tuple]).unwrap().outputs_arr();
h.finish_hugr_with_outputs([b1, b2, b3, b4]).unwrap()
}
#[fixture]
fn partial_unpack() -> Hugr {
let mut h = DFGBuilder::new(Signature::new(
vec![bool_t(), bool_t()],
vec![
bool_t(),
bool_t(),
Type::new_tuple(vec![bool_t(), bool_t()]),
],
))
.unwrap();
let mut inps = h.input_wires();
let b1 = inps.next().unwrap();
let b2 = inps.next().unwrap();
let tuple = h.make_tuple([b1, b2]).unwrap();
let op = UnpackTuple::new(vec![bool_t(), bool_t()].into());
let [b1, b2] = h.add_dataflow_op(op, [tuple]).unwrap().outputs_arr();
h.finish_hugr_with_outputs([b1, b2, tuple]).unwrap()
}
#[fixture]
fn unpack_discard_first() -> Hugr {
let mut h = FunctionBuilder::new(
"test",
Signature::new(vec![bool_t(), float64_type()], vec![float64_type()]),
)
.unwrap();
let [b, f] = h.input_wires_arr();
let tuple = h.make_tuple([b, f]).unwrap();
let op = UnpackTuple::new(vec![bool_t(), float64_type()].into());
let [_b, f] = h.add_dataflow_op(op, [tuple]).unwrap().outputs_arr();
h.finish_hugr_with_outputs([f]).unwrap()
}
#[rstest]
#[case::unused(unused_pack(), 1, 2)]
#[case::simple(simple_pack_unpack(), 1, 2)]
#[case::multi(multi_unpack(), 1, 2)]
#[case::partial(partial_unpack(), 1, 3)]
#[case::unpack_discard_first(unpack_discard_first(), 1, 2)]
#[case::ordered(ordered_pack_unpack(), 0, 4)]
#[case::outgoing_ordered(outgoing_ordered_pack_unpack(), 0, 4)]
#[case::incoming_ordered(incoming_ordered_pack_unpack(), 0, 4)]
fn test_pack_unpack(
#[case] mut hugr: Hugr,
#[case] expected_rewrites: usize,
#[case] remaining_nodes: usize,
) {
let parent = hugr.entrypoint();
let pass = UntuplePass::default().with_scope(PassScope::EntrypointFlat);
let res = pass.run(&mut hugr).unwrap_or_else(|e| panic!("{e}"));
assert_eq!(res.rewrites_applied, expected_rewrites);
assert_eq!(hugr.children(parent).count(), remaining_nodes);
}
}