use core::panic;
use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::prelude::{MakeTuple, TupleOpDef};
use hugr::extension::simple_op::MakeExtensionOp;
use hugr::ops::{OpTrait, OpType};
use hugr::types::Type;
use hugr::{HugrView, Node};
use itertools::Itertools;
use crate::circuit::Command;
use crate::rewrite::{CircuitRewrite, Subcircuit};
use crate::Circuit;
pub fn find_tuple_unpack_rewrites(
circ: &Circuit<impl HugrView<Node = Node>>,
) -> impl Iterator<Item = CircuitRewrite> + '_ {
circ.commands().filter_map(|cmd| make_rewrite(circ, cmd))
}
fn is_make_tuple(optype: &OpType) -> bool {
optype.to_string() == format!("prelude.{}", TupleOpDef::MakeTuple.op_id())
}
fn is_unpack_tuple(optype: &OpType) -> bool {
optype.to_string() == format!("prelude.{}", TupleOpDef::UnpackTuple.op_id())
}
fn make_rewrite<T: HugrView<Node = Node>>(
circ: &Circuit<T>,
cmd: Command<T>,
) -> Option<CircuitRewrite> {
let cmd_optype = cmd.optype();
let tuple_node = cmd.node();
if !is_make_tuple(cmd_optype) {
return None;
}
let tuple_types = cmd_optype
.dataflow_signature()
.unwrap()
.input_types()
.to_vec();
let Ok((_, wire)) = cmd.output_wires().exactly_one() else {
panic!("MakeTuple at node {tuple_node} should have a single output wire.");
};
let links = circ
.hugr()
.linked_inputs(tuple_node, wire.source())
.map(|(neigh, _)| neigh)
.collect_vec();
if links.is_empty() {
return None;
}
let unpack_nodes = links
.iter()
.filter(|&&neigh| is_unpack_tuple(circ.hugr().get_optype(neigh)))
.copied()
.collect_vec();
if unpack_nodes.is_empty() {
return None;
}
let num_other_outputs = links.len() - unpack_nodes.len();
Some(remove_pack_unpack(
circ,
&tuple_types,
tuple_node,
unpack_nodes,
num_other_outputs,
))
}
fn remove_pack_unpack<T: HugrView<Node = Node>>(
circ: &Circuit<T>,
tuple_types: &[Type],
pack_node: Node,
unpack_nodes: Vec<Node>,
num_other_outputs: usize,
) -> CircuitRewrite {
let num_unpack_outputs = tuple_types.len() * unpack_nodes.len();
let mut nodes = unpack_nodes;
nodes.push(pack_node);
let subcirc = Subcircuit::try_from_nodes(nodes, circ).unwrap();
let subcirc_signature = subcirc.signature(circ);
debug_assert!(
itertools::equal(
subcirc_signature.output().iter(),
tuple_types
.iter()
.cycle()
.take(num_unpack_outputs)
.chain(itertools::repeat_n(
&Type::new_tuple(tuple_types.to_vec()),
num_other_outputs
))
),
"Unpacked tuple values must come before tupled values"
);
let mut replacement = DFGBuilder::new(subcirc_signature).unwrap();
let mut outputs = Vec::with_capacity(num_unpack_outputs + num_other_outputs);
outputs.extend(replacement.input_wires().cycle().take(num_unpack_outputs));
if num_other_outputs > 0 {
let op = MakeTuple::new(tuple_types.to_vec().into());
let [tuple] = replacement
.add_dataflow_op(op, replacement.input_wires())
.unwrap()
.outputs_arr();
outputs.extend(std::iter::repeat_n(tuple, num_other_outputs))
}
let replacement = replacement
.finish_hugr_with_outputs(outputs)
.unwrap_or_else(|e| {
panic!("Failed to create replacement for removing tuple pack/unpack operations. {e}")
})
.into();
subcirc
.create_rewrite(circ, replacement)
.unwrap_or_else(|e| {
panic!("Failed to create rewrite for removing tuple pack/unpack operations. {e}")
})
}
#[cfg(test)]
mod test {
use super::*;
use hugr::extension::prelude::{bool_t, qb_t, UnpackTuple};
use hugr::types::Signature;
use rstest::{fixture, rstest};
#[fixture]
fn simple_pack_unpack() -> Circuit {
let mut h = DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()])).unwrap();
let mut inps = h.input_wires();
let qb1 = inps.next().unwrap();
let b2 = inps.next().unwrap();
let tuple = h.make_tuple([qb1, b2]).unwrap();
let op = UnpackTuple::new(vec![qb_t(), bool_t()].into());
let [qb1, b2] = h.add_dataflow_op(op, [tuple]).unwrap().outputs_arr();
h.finish_hugr_with_outputs([qb1, b2]).unwrap().into()
}
#[fixture]
fn multi_unpack() -> Circuit {
let mut h = DFGBuilder::new(Signature::new(
vec![bool_t(), bool_t()],
vec![bool_t(), bool_t(), bool_t(), bool_t()],
))
.unwrap();
let mut inps = h.input_wires();
let b1 = inps.next().unwrap();
let b2 = inps.next().unwrap();
let tuple = h.make_tuple([b1, b2]).unwrap();
let op = UnpackTuple::new(vec![bool_t(), bool_t()].into());
let [b1, b2] = h.add_dataflow_op(op, [tuple]).unwrap().outputs_arr();
let op = UnpackTuple::new(vec![bool_t(), bool_t()].into());
let [b3, b4] = h.add_dataflow_op(op, [tuple]).unwrap().outputs_arr();
h.finish_hugr_with_outputs([b1, b2, b3, b4]).unwrap().into()
}
#[fixture]
fn partial_unpack() -> Circuit {
let mut h = DFGBuilder::new(Signature::new(
vec![bool_t(), bool_t()],
vec![
bool_t(),
bool_t(),
Type::new_tuple(vec![bool_t(), bool_t()]),
],
))
.unwrap();
let mut inps = h.input_wires();
let b1 = inps.next().unwrap();
let b2 = inps.next().unwrap();
let tuple = h.make_tuple([b1, b2]).unwrap();
let op = UnpackTuple::new(vec![bool_t(), bool_t()].into());
let [b1, b2] = h.add_dataflow_op(op, [tuple]).unwrap().outputs_arr();
h.finish_hugr_with_outputs([b1, b2, tuple]).unwrap().into()
}
#[rstest]
#[case::simple(simple_pack_unpack(), 1, 0)]
#[case::multi(multi_unpack(), 1, 0)]
#[case::partial(partial_unpack(), 1, 1)]
fn test_pack_unpack(
#[case] mut circ: Circuit,
#[case] expected_rewrites: usize,
#[case] remaining_commands: usize,
) -> Result<(), Box<dyn std::error::Error>> {
let mut num_rewrites = 0;
loop {
let Some(rewrite) = find_tuple_unpack_rewrites(&circ).next() else {
break;
};
num_rewrites += 1;
rewrite.apply(&mut circ)?;
}
assert_eq!(num_rewrites, expected_rewrites);
assert_eq!(circ.commands().count(), remaining_commands);
Ok(())
}
}