use fxhash::FxHashSet;
use itertools::{Either, Itertools};
use std::{
collections::{HashMap, HashSet, VecDeque},
iter, mem,
};
pub mod array_modify;
pub mod call_modify;
pub mod dfg_modify;
pub mod global_phase_modify;
pub mod tket_op_modify;
use super::{CombinedModifier, ModifierFlags};
use crate::passes::utils::unpack_container::TypeUnpacker;
use crate::passes::{InScope, PassScope};
use crate::{TketOp, extension::global_phase::GlobalPhase, modifier::Modifier};
use global_phase_modify::delete_phase;
use hugr::{
HugrView, IncomingPort, Node, OutgoingPort, Port, PortIndex, Wire,
builder::{BuildError, CFGBuilder, Container, Dataflow, SubContainer},
core::HugrNode,
extension::{prelude::qb_t, simple_op::MakeExtensionOp},
hugr::hugrmut::HugrMut,
ops::{CFG, Const, OpType},
std_extensions::collections::array::array_type,
types::{EdgeKind, FuncTypeBase, Signature, Type},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct DirWire<N = Node>(N, Port);
impl<N: HugrNode> std::fmt::Display for DirWire<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let dir = match self.1.as_directed() {
Either::Left(_) => "In",
Either::Right(_) => "Out",
};
write!(f, "DirWire({}, {}({}))", self.0, dir, self.1.index())
}
}
impl<N> DirWire<N> {
fn new(node: N, port: Port) -> Self {
DirWire(node, port)
}
pub(crate) fn reverse(self) -> Self {
let index = self.1.index();
let port = match self.1.as_directed() {
Either::Left(_in) => OutgoingPort::from(index).into(),
Either::Right(_out) => IncomingPort::from(index).into(),
};
DirWire::new(self.0, port)
}
}
impl<N: HugrNode> From<Wire<N>> for DirWire<N> {
fn from(wire: Wire<N>) -> Self {
DirWire(wire.node(), wire.source().into())
}
}
impl<N: HugrNode> From<(N, OutgoingPort)> for DirWire<N> {
fn from((node, port): (N, OutgoingPort)) -> Self {
DirWire(node, port.into())
}
}
impl<N: HugrNode> From<(N, IncomingPort)> for DirWire<N> {
fn from((node, port): (N, IncomingPort)) -> Self {
DirWire(node, port.into())
}
}
impl<N: HugrNode> TryFrom<DirWire<N>> for Wire<N> {
type Error = hugr::hugr::HugrError;
fn try_from(value: DirWire<N>) -> Result<Self, Self::Error> {
let out_port = value.1.as_outgoing()?;
Ok(Wire::new(value.0, out_port))
}
}
impl<N: HugrNode> TryFrom<DirWire<N>> for (N, IncomingPort) {
type Error = hugr::hugr::HugrError;
fn try_from(value: DirWire<N>) -> Result<Self, Self::Error> {
let in_port = value.1.as_incoming()?;
Ok((value.0, in_port))
}
}
fn connect<N>(
new_dfg: &mut impl Container,
w1: &DirWire<Node>,
w2: &DirWire<Node>,
) -> Result<(), ModifierResolverErrors<N>> {
let (n_o, p_o, n_i, p_i) = match (w1.1.as_directed(), w2.1.as_directed()) {
(Either::Right(p_o), Either::Left(p_i)) => (w1.0, p_o, w2.0, p_i),
(Either::Left(p_i), Either::Right(p_o)) => (w2.0, p_o, w1.0, p_i),
_ => {
return Err(ModifierResolverErrors::unreachable(format!(
"Cannot connect the wires with the same direction: {} -> {}",
w1, w2
)));
}
};
new_dfg.hugr_mut().connect(n_o, p_o, n_i, p_i);
Ok(())
}
fn connect_by_num(
new_dfg: &mut impl Dataflow,
dw: &DirWire<Node>,
node: Node,
num: usize,
) -> DirWire<Node> {
let dw_node = dw.0;
match dw.1.as_directed() {
Either::Left(incoming) => {
new_dfg.hugr_mut().connect(node, num, dw_node, incoming);
(node, IncomingPort::from(num)).into()
}
Either::Right(outgoing) => {
new_dfg.hugr_mut().connect(dw_node, outgoing, node, num);
(node, OutgoingPort::from(num)).into()
}
}
}
trait PortExt {
fn shift(self, offset: usize) -> Self;
}
impl PortExt for Port {
fn shift(self, offset: usize) -> Self {
Port::new(self.direction(), self.index() + offset)
}
}
impl PortExt for IncomingPort {
fn shift(self, offset: usize) -> Self {
IncomingPort::from(self.index() + offset)
}
}
impl PortExt for OutgoingPort {
fn shift(self, offset: usize) -> Self {
OutgoingPort::from(self.index() + offset)
}
}
impl<N> PortExt for DirWire<N> {
fn shift(self, offset: usize) -> Self {
DirWire(self.0, self.1.shift(offset))
}
}
pub struct PortVector<N = Node> {
incoming: Vec<DirWire<N>>,
outgoing: Vec<DirWire<N>>,
}
impl<N: HugrNode> PortVector<N> {
fn from_single_node(
n: N,
inputs: impl Iterator<Item = usize>,
outputs: impl Iterator<Item = usize>,
) -> Self {
let incoming = inputs.map(|p| (n, IncomingPort::from(p)).into()).collect();
let outgoing = outputs.map(|p| (n, OutgoingPort::from(p)).into()).collect();
PortVector { incoming, outgoing }
}
fn port_vector_rev(
n: N,
inputs: impl Iterator<Item = usize>,
outputs: impl Iterator<Item = usize>,
iter: impl Iterator<Item = usize>,
) -> Self {
let iter = iter.collect::<Vec<_>>();
let incoming = inputs
.map(|p| {
if iter.contains(&p) {
(n, OutgoingPort::from(p)).into()
} else {
(n, IncomingPort::from(p)).into()
}
})
.collect();
let outgoing = outputs
.map(|p| {
if iter.contains(&p) {
(n, IncomingPort::from(p)).into()
} else {
(n, OutgoingPort::from(p)).into()
}
})
.collect();
PortVector { incoming, outgoing }
}
}
pub struct ModifierResolver<N = Node> {
modifiers: CombinedModifier,
corresp_map: HashMap<DirWire<N>, Vec<DirWire>>,
controls: Vec<Wire>,
worklist: VecDeque<N>,
call_map: HashMap<N, Vec<(Node, IncomingPort)>>,
modified_functions: HashSet<N>,
qubit_finder: TypeUnpacker,
}
impl<N> ModifierResolver<N> {
fn new() -> Self {
ModifierResolver {
modifiers: CombinedModifier::default(),
corresp_map: HashMap::default(),
controls: Vec::default(),
worklist: VecDeque::default(),
call_map: HashMap::default(),
modified_functions: HashSet::default(),
qubit_finder: TypeUnpacker::for_qubits(),
}
}
}
impl<N> Default for ModifierResolver<N> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, derive_more::Error, derive_more::Display)]
pub enum ModifierError<N = Node> {
#[display("Node to modify {_0} expected to be a modifier but actually {_1}")]
NotModifier(N, OpType),
#[display("No caller of the modified function exists for node {_0}")]
#[error(ignore)]
NoCaller(N),
#[display("No caller of the modified function exists for node {_0}")]
#[error(ignore)]
NoTarget(N),
#[display("Node {_0} is not the first modifier in a chain. It is called by {_0}")]
NotInitialModifier(N, OpType),
#[display("Modifier cannot be applied to the node {_0} of type {_1}")]
ModifierNotApplicable(N, OpType),
}
impl<N> ModifierError<N> {
fn node(self) -> N {
match self {
ModifierError::NotModifier(n, _)
| ModifierError::NoCaller(n)
| ModifierError::NoTarget(n)
| ModifierError::NotInitialModifier(n, _)
| ModifierError::ModifierNotApplicable(n, _) => n,
}
}
}
#[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)]
pub enum ModifierResolverErrors<N = Node> {
#[display("{_0}")]
#[from]
ModifierError(ModifierError<N>),
#[display("{_0}")]
#[from]
BuildError(BuildError),
#[display("Unreachable error: {msg}")]
Unreachable {
msg: String,
},
#[display("Modifier {node} applied to the node {msg} cannot be modified")]
UnResolvable {
node: N,
msg: String,
optype: OpType,
},
#[display("Modification by {_0:?} is not defined for the node {_1}")]
Unimplemented(Modifier, OpType),
}
impl<N> ModifierResolverErrors<N> {
fn unreachable(msg: impl Into<String>) -> Self {
Self::Unreachable { msg: msg.into() }
}
fn unresolvable(node: N, msg: impl Into<String>, optype: OpType) -> Self {
Self::UnResolvable {
node,
msg: msg.into(),
optype,
}
}
}
impl<N: HugrNode> ModifierResolver<N> {
fn modifiers_mut(&mut self) -> &mut CombinedModifier {
&mut self.modifiers
}
fn modifiers(&self) -> &CombinedModifier {
&self.modifiers
}
fn control_num(&self) -> usize {
self.modifiers.control
}
fn controls(&mut self) -> &mut Vec<Wire> {
&mut self.controls
}
fn controls_ref(&self) -> &Vec<Wire> {
&self.controls
}
fn worklist(&mut self) -> &mut VecDeque<N> {
&mut self.worklist
}
fn corresp_map(&mut self) -> &mut HashMap<DirWire<N>, Vec<DirWire>> {
&mut self.corresp_map
}
fn call_map(&mut self) -> &mut HashMap<N, Vec<(Node, IncomingPort)>> {
&mut self.call_map
}
fn call_map_insert(&mut self, source: N, target: (Node, IncomingPort)) {
self.call_map().entry(source).or_default().push(target);
}
fn with_worklist<T>(&mut self, worklist: VecDeque<N>, f: impl FnOnce(&mut Self) -> T) -> T {
let worklist = mem::replace(self.worklist(), worklist);
let r = f(self);
*self.worklist() = worklist;
r
}
fn with_modifiers<T>(
&mut self,
modifiers: CombinedModifier,
f: impl FnOnce(&mut Self) -> T,
) -> T {
let modifiers = mem::replace(self.modifiers_mut(), modifiers);
let r = f(self);
*self.modifiers_mut() = modifiers;
r
}
fn with_ancilla<T>(
&mut self,
wire: &mut Wire<Node>,
ancilla: &mut Vec<Wire<Node>>,
f: impl FnOnce(&mut Self, &mut Vec<Wire<Node>>) -> T,
) -> T {
ancilla.push(*wire);
let r = f(self, ancilla);
*wire = ancilla.pop().unwrap();
r
}
fn pop_control(&mut self) -> Option<Wire<Node>> {
if let Some(c) = self.controls().pop() {
self.modifiers.control -= 1;
Some(c)
} else {
None
}
}
fn push_control(&mut self, c: Wire<Node>) {
self.controls().push(c);
self.modifiers.control += 1;
}
fn map_insert(
&mut self,
old: DirWire<N>,
new: DirWire,
) -> Result<(), ModifierResolverErrors<N>> {
match self.corresp_map().entry(old) {
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(vec![new]);
Ok(())
}
std::collections::hash_map::Entry::Occupied(mut entry) if entry.get().is_empty() => {
entry.insert(vec![new]);
Ok(())
}
std::collections::hash_map::Entry::Occupied(entry) => {
let former = entry.get();
Err(ModifierResolverErrors::unreachable(format!(
"Wire already registered for node {}. Former [{},...], Latter {}.",
old.0, former[0], new
)))
}
}
}
fn map_insert_none(&mut self, old: DirWire<N>) -> Result<(), ModifierResolverErrors<N>> {
self.corresp_map().entry(old).or_default();
Ok(())
}
fn map_get(&self, key: &DirWire<N>) -> Result<&Vec<DirWire>, ModifierResolverErrors<N>> {
self.corresp_map
.get(key)
.ok_or(ModifierResolverErrors::unreachable(format!(
"No correspondence for the wire: {}",
key
)))
}
fn forget_node(
&mut self,
h: &impl HugrView<Node = N>,
n: N,
) -> Result<(), ModifierResolverErrors<N>> {
for port in h.all_node_ports(n) {
let dw = DirWire(n, port);
self.map_insert_none(dw)?;
}
Ok(())
}
fn add_node_no_modification(
&mut self,
h: &impl HugrMut<Node = N>,
old_n: N,
op: impl Into<OpType>,
new_dfg: &mut impl Container,
) -> Result<Node, ModifierResolverErrors<N>> {
let node = new_dfg.add_child_node(op);
for port in h.all_node_ports(old_n) {
self.map_insert(DirWire(old_n, port), DirWire(node, port))?;
}
Ok(node)
}
fn port_vector_dagger(
&self,
n: Node,
inputs: impl Iterator<Item = usize>,
outputs: impl Iterator<Item = usize>,
iter: impl Iterator<Item = usize>,
) -> PortVector<Node> {
if self.modifiers.dagger {
PortVector::port_vector_rev(n, inputs, outputs, iter)
} else {
PortVector::from_single_node(n, inputs, outputs)
}
}
fn add_edge_from_pv(
&mut self,
h: &impl HugrMut<Node = N>,
n: N,
pv: PortVector<Node>,
) -> Result<(), ModifierResolverErrors<N>> {
let PortVector { incoming, outgoing } = pv;
for (old_in, new) in (0..h.num_inputs(n)).map(IncomingPort::from).zip(incoming) {
self.map_insert((n, old_in).into(), new)?
}
for (old_out, new) in (0..h.num_outputs(n)).map(OutgoingPort::from).zip(outgoing) {
self.map_insert((n, old_out).into(), new)?
}
Ok(())
}
fn add_node_control(&mut self, new_dfg: &mut impl Container, op: impl Into<OpType>) -> Node {
let node = new_dfg.add_child_node(op);
for (i, ctrl) in self.controls().iter_mut().enumerate() {
new_dfg
.hugr_mut()
.connect(ctrl.node(), ctrl.source(), node, i);
*ctrl = Wire::new(node, i);
}
node
}
fn connect_all(
&mut self,
h: &impl HugrView<Node = N>,
new_dfg: &mut impl Container,
parent: N,
) -> Result<(), ModifierResolverErrors<N>> {
for out_node in h.children(parent) {
for out_port in h.node_outputs(out_node) {
if let Some(EdgeKind::StateOrder) = h.get_optype(out_node).port_kind(out_port) {
continue;
}
for (in_node, in_port) in h.linked_inputs(out_node, out_port) {
for w1 in self.map_get(&(in_node, in_port).into())? {
for w2 in self.map_get(&(out_node, out_port).into())? {
connect(new_dfg, w1, w2)?
}
}
}
}
}
Ok(())
}
}
impl<N: HugrNode> ModifierResolver<N> {
fn verify(&self, h: &impl HugrView<Node = N>, n: N) -> Result<(), ModifierError<N>> {
let optype = h.get_optype(n);
if Modifier::from_optype(optype).is_none() {
return Err(ModifierError::NotModifier(n, optype.clone()));
}
let Ok((caller, _)) = h.linked_inputs(n, 0).exactly_one() else {
return Err(ModifierError::NoCaller(n));
};
let optype = h.get_optype(caller);
if Modifier::from_optype(optype).is_some() {
return Err(ModifierError::NotInitialModifier(caller, optype.clone()));
}
Ok(())
}
fn try_rewrite(
&mut self,
hugr: &mut impl HugrMut<Node = N>,
modifier_node: N,
) -> Result<(), ModifierResolverErrors<N>> {
self.verify(hugr, modifier_node)?;
let modified_fn_loader: Vec<(_, Vec<_>)> = hugr
.node_outputs(modifier_node)
.map(|p| (p, hugr.linked_inputs(modifier_node, p).collect()))
.collect();
let modifiers = CombinedModifier::default();
let new_load = self.with_modifiers(modifiers, |this| {
this.apply_modifier_chain_to_loaded_fn(hugr, modifier_node)
})?;
for (out_port, inputs) in modified_fn_loader {
for (recv, recv_port) in inputs {
hugr.disconnect(recv, recv_port);
hugr.connect(new_load, out_port, recv, recv_port);
}
}
Ok(())
}
fn modify_signature(&self, signature: &mut Signature, flatten: bool) {
let FuncTypeBase { input, output } = signature;
if flatten {
let n = self.control_num();
input.to_mut().splice(0..0, iter::repeat_n(qb_t(), n));
output.to_mut().splice(0..0, iter::repeat_n(qb_t(), n));
} else {
let control_types = self
.modifiers
.accum_ctrl
.iter()
.map(|ctrls| array_type(*ctrls as u64, qb_t()))
.collect::<Vec<_>>();
input.to_mut().splice(0..0, control_types.iter().cloned());
output.to_mut().splice(0..0, control_types);
}
}
fn modify_op(
&mut self,
h: &mut impl HugrMut<Node = N>,
target_node: N,
new_dfg: &mut impl Dataflow,
) -> Result<(), ModifierResolverErrors<N>> {
let optype = &h.get_optype(target_node).clone();
match optype {
OpType::Input(_) | OpType::Output(_) => {}
OpType::CFG(cfg) => self.modify_cfg(h, target_node, cfg, new_dfg)?,
OpType::DFG(dfg) => self.modify_dfg(h, target_node, dfg, new_dfg)?,
OpType::TailLoop(tail_loop) => {
self.modify_tail_loop(h, target_node, tail_loop, new_dfg)?
}
OpType::Conditional(conditional) => {
self.modify_conditional(h, target_node, conditional, new_dfg)?
}
OpType::Call(_) => self.modify_call(h, target_node, optype, new_dfg)?,
OpType::CallIndirect(indir_call) => {
self.modify_indirect_call(h, target_node, indir_call, new_dfg)?
}
OpType::LoadFunction(load) => {
self.modify_load_function(h, target_node, load, new_dfg)?
}
OpType::ExtensionOp(_) => {
self.modify_extension_op(h, target_node, optype, new_dfg)?;
}
OpType::Const(constant) => {
self.modify_constant(target_node, constant, new_dfg)?;
}
OpType::LoadConstant(_) | OpType::OpaqueOp(_) | OpType::Tag(_) => {
self.add_node_no_modification(h, target_node, optype.clone(), new_dfg)?;
}
OpType::FuncDefn(_) | OpType::FuncDecl(_) | OpType::Module(_) => {
return Err(ModifierResolverErrors::unreachable(format!(
"Invalid node found inside modified function (OpType = {})",
optype.clone()
)));
}
OpType::Case(_) => {
return Err(ModifierResolverErrors::unreachable(
"Case cannot be directly modified.".to_string(),
));
}
OpType::AliasDecl(_)
| OpType::AliasDefn(_)
| OpType::ExitBlock(_)
| OpType::DataflowBlock(_) => {
return Err(ModifierResolverErrors::unresolvable(
target_node,
"Unmodifiable node found".to_string(),
optype.clone(),
));
}
_ => {
return Err(ModifierResolverErrors::unresolvable(
target_node,
"Unknown operation".to_string(),
optype.clone(),
));
}
}
Ok(())
}
fn wire_node_inout<'a>(
&mut self,
old_node: N,
new_node: Node,
(inputs, outputs): (
impl Iterator<Item = &'a Type>,
impl Iterator<Item = &'a Type>,
),
(input_offset, output_offset, new_offset): (usize, usize, usize),
) -> Result<(), ModifierResolverErrors<N>> {
self.wire_inout(
(old_node, old_node),
(new_node, new_node),
(inputs, outputs),
(input_offset, output_offset, new_offset),
)
}
fn wire_inout<'a>(
&mut self,
(old_in, old_out): (N, N),
(new_in, new_out): (Node, Node),
(mut inputs, mut outputs): (
impl Iterator<Item = &'a Type>,
impl Iterator<Item = &'a Type>,
),
(input_offset, output_offset, new_offset): (usize, usize, usize),
) -> Result<(), ModifierResolverErrors<N>> {
let mut old_in_wire = (old_in, IncomingPort::from(input_offset)).into();
let mut old_out_wire = (old_out, OutgoingPort::from(output_offset)).into();
let mut new_in_wire = (new_in, IncomingPort::from(input_offset + new_offset)).into();
let mut new_out_wire = (new_out, OutgoingPort::from(output_offset + new_offset)).into();
let mut in_ty = inputs.next();
let mut out_ty = outputs.next();
loop {
while let Some(ty) = in_ty {
if self.qubit_finder.contains_element_type(ty) {
break;
}
self.map_insert(old_in_wire, new_in_wire)?;
old_in_wire = old_in_wire.shift(1);
new_in_wire = new_in_wire.shift(1);
in_ty = inputs.next();
}
while let Some(ty) = out_ty {
if self.qubit_finder.contains_element_type(ty) {
break;
}
self.map_insert(old_out_wire, new_out_wire)?;
old_out_wire = old_out_wire.shift(1);
new_out_wire = new_out_wire.shift(1);
out_ty = outputs.next();
}
while let Some(ty) = in_ty {
if !self.qubit_finder.contains_element_type(ty) {
break;
}
let new_in = if !self.modifiers.dagger {
let new_in = new_in_wire;
new_in_wire = new_in_wire.shift(1);
new_in
} else {
let new_in = new_out_wire;
new_out_wire = new_out_wire.shift(1);
new_in
};
self.map_insert(old_in_wire, new_in)?;
old_in_wire = old_in_wire.shift(1);
in_ty = inputs.next();
}
while let Some(ty) = out_ty {
if !self.qubit_finder.contains_element_type(ty) {
break;
}
let new_out = if !self.modifiers.dagger {
let new_out = new_out_wire;
new_out_wire = new_out_wire.shift(1);
new_out
} else {
let new_out = new_in_wire;
new_in_wire = new_in_wire.shift(1);
new_out
};
self.map_insert(old_out_wire, new_out)?;
old_out_wire = old_out_wire.shift(1);
out_ty = outputs.next();
}
if in_ty.is_none() && out_ty.is_none() {
break;
}
}
Ok(())
}
fn _wire_others(
&mut self,
n: N,
n_optype: &OpType,
node: Node,
node_optype: &OpType,
) -> Result<(), ModifierResolverErrors<N>> {
if let (Some(old), Some(new)) =
(n_optype.other_input_port(), node_optype.other_input_port())
{
self.map_insert((n, old).into(), (node, new).into())?;
}
if let (Some(old), Some(new)) = (
n_optype.other_output_port(),
node_optype.other_output_port(),
) {
self.map_insert((n, old).into(), (node, new).into())?;
}
Ok(())
}
fn modify_constant(
&mut self,
n: N,
constant: &Const,
new_dfg: &mut impl Container,
) -> Result<(), ModifierResolverErrors<N>> {
let output = new_dfg.add_child_node(constant.clone());
self.map_insert(Wire::new(n, 0).into(), Wire::new(output, 0).into())
}
fn modify_dataflow_op(
&mut self,
h: &impl HugrMut<Node = N>,
n: N,
optype: &OpType,
new_dfg: &mut impl Container,
) -> Result<(), ModifierResolverErrors<N>> {
let node = new_dfg.add_child_node(optype.clone());
let signature = h.signature(n).unwrap();
let inputs = signature.input.iter();
let outputs = signature.output.iter();
self.wire_node_inout(n, node, (inputs, outputs), (0, 0, 0))?;
Ok(())
}
fn modify_extension_op(
&mut self,
h: &impl HugrMut<Node = N>,
op_node: N,
optype: &OpType,
new_dfg: &mut impl Dataflow,
) -> Result<(), ModifierResolverErrors<N>> {
if self.controls().len() != self.control_num() {
return Err(ModifierResolverErrors::unreachable(
"Control qubits are not set correctly.".to_string(),
));
}
if let Some(tket_op) = TketOp::from_optype(optype) {
let pv = self.modify_tket_op(op_node, tket_op, new_dfg, &mut vec![])?;
self.add_edge_from_pv(h, op_node, pv)
} else if GlobalPhase::from_optype(optype).is_some() {
let inputs = self.modify_global_phase(op_node, new_dfg, &mut vec![])?;
self.corresp_map().insert(
(op_node, IncomingPort::from(0)).into(),
inputs.into_iter().map(Into::into).collect(),
);
Ok(())
} else if Modifier::from_optype(optype).is_some() {
self.forget_node(h, op_node)
} else if self.modify_array_op(h, op_node, optype, new_dfg)?
|| self.try_array_convert(h, op_node, optype, new_dfg)?
{
Ok(())
} else {
self.modify_dataflow_op(h, op_node, optype, new_dfg)
}
}
fn cfg_control_types(&self, mut row: hugr::types::TypeRow) -> hugr::types::TypeRow {
let control_num = self.control_num();
if control_num == 0 {
return row;
}
let types = row.to_mut();
types.reserve(control_num);
types.extend(iter::repeat_n(qb_t(), control_num));
row
}
fn modify_cfg(
&mut self,
h: &mut impl HugrMut<Node = N>,
cfg_node: N,
cfg: &CFG,
new_dfg: &mut impl Container,
) -> Result<(), ModifierResolverErrors<N>> {
let children: Vec<N> = h
.children(cfg_node)
.filter(|child| h.get_optype(*child).is_dataflow_block())
.collect();
if children.len() != 1 && self.modifiers().dagger {
return Err(ModifierResolverErrors::unresolvable(
cfg_node,
"CFG with more than one node cannot be daggered.".to_string(),
cfg.clone().into(),
));
}
let signature = Signature::new(
self.cfg_control_types(cfg.signature.input.clone()),
self.cfg_control_types(cfg.signature.output.clone()),
);
let mut new_cfg = CFGBuilder::new(signature)?;
let mut bb_map = HashMap::new();
for (i, old_bb) in children.iter().copied().enumerate() {
let OpType::DataflowBlock(old_block) = h.get_optype(old_bb).clone() else {
return Err(ModifierResolverErrors::unreachable(
"Non-basic-block node found while modifying CFG.".to_string(),
));
};
let input = self.cfg_control_types(old_block.inputs.clone());
let other_outputs = self.cfg_control_types(old_block.other_outputs.clone());
let mut new_bb = if i == 0 {
new_cfg.entry_builder(old_block.sum_rows.clone(), other_outputs)?
} else {
new_cfg.block_builder(input, old_block.sum_rows.clone(), other_outputs)?
};
self.modify_dfg_body(h, old_bb, &mut new_bb)?;
let new_bb_id = new_bb.finish_sub_container()?;
bb_map.insert(old_bb, new_bb_id);
}
for old_bb in children.iter().copied() {
let OpType::DataflowBlock(old_block) = h.get_optype(old_bb) else {
return Err(ModifierResolverErrors::unreachable(
"Non-basic-block node found while connecting CFG branches.".to_string(),
));
};
let new_bb = bb_map.get(&old_bb).ok_or_else(|| {
ModifierResolverErrors::unreachable("Missing modified basic block.".to_string())
})?;
for branch in 0..old_block.sum_rows.len() {
let (successor, _) = h
.linked_inputs(old_bb, OutgoingPort::from(branch))
.exactly_one()
.map_err(|_| {
ModifierResolverErrors::unreachable(format!(
"Expected one successor for CFG block branch {branch}."
))
})?;
let new_successor = if let Some(successor) = bb_map.get(&successor) {
*successor
} else if matches!(h.get_optype(successor), OpType::ExitBlock(_)) {
new_cfg.exit_block()
} else {
return Err(ModifierResolverErrors::unreachable(
"CFG branch successor is neither a basic block nor the exit block."
.to_string(),
));
};
new_cfg.branch(new_bb, branch, &new_successor)?;
}
}
let new_node = self.insert_sub_dfg(new_dfg, new_cfg)?;
self.wire_node_inout(
cfg_node,
new_node,
(cfg.signature.input.iter(), cfg.signature.output.iter()),
(0, 0, 0),
)?;
let input_offset = cfg.signature.input.len();
let output_offset = cfg.signature.output.len();
for (i, c) in self.controls().iter_mut().enumerate() {
new_dfg
.hugr_mut()
.connect(c.node(), c.source(), new_node, input_offset + i);
*c = Wire::new(new_node, OutgoingPort::from(output_offset + i));
}
Ok(())
}
}
fn module_child_containing<N: HugrNode>(h: &impl HugrView<Node = N>, node: N) -> Option<N> {
let mut child = node;
while let Some(parent) = h.get_parent(child) {
if parent == h.module_root() {
return Some(child);
}
child = parent;
}
None
}
fn has_static_use_outside_candidates<N: HugrNode>(
h: &impl HugrView<Node = N>,
func: N,
candidates: &HashSet<N>,
) -> bool {
let Some(mut targets) = h.static_targets(func) else {
return true;
};
targets.any(|(target, _)| {
module_child_containing(h, target)
.is_none_or(|target_owner| !candidates.contains(&target_owner))
})
}
fn candidate_static_dependencies<N: HugrNode>(
h: &impl HugrView<Node = N>,
func: N,
candidates: &HashSet<N>,
) -> Vec<N> {
h.descendants(func)
.filter_map(|node| h.static_source(node))
.filter(|target| candidates.contains(target))
.collect_vec()
}
fn remove_unused_modified_functions<N: HugrNode>(
h: &mut impl HugrMut<Node = N>,
modified_functions: &HashSet<N>,
scope: &PassScope,
) {
let mut candidates = modified_functions
.iter()
.copied()
.filter(|func| {
h.contains_node(*func)
&& h.get_optype(*func).as_func_defn().is_some()
&& scope.in_scope(h, *func) == InScope::Yes
})
.collect::<HashSet<_>>();
if let Some(entrypoint_owner) = module_child_containing(h, h.entrypoint()) {
candidates.remove(&entrypoint_owner);
}
let mut live = candidates
.iter()
.copied()
.filter(|func| has_static_use_outside_candidates(h, *func, &candidates))
.collect::<HashSet<_>>();
let mut worklist = live.iter().copied().collect::<VecDeque<_>>();
while let Some(func) = worklist.pop_front() {
for dependency in candidate_static_dependencies(h, func, &candidates) {
if live.insert(dependency) {
worklist.push_back(dependency);
}
}
}
let unused = candidates.difference(&live).copied().collect_vec();
for func in unused {
if h.contains_node(func) {
h.remove_subtree(func);
}
}
}
pub fn resolve_modifier_with_entrypoints(
h: &mut impl HugrMut<Node = Node>,
entry_points: impl IntoIterator<Item = Node>,
) -> Result<(), ModifierResolverErrors<Node>> {
resolve_modifier_with_entrypoints_and_scope(h, entry_points, &PassScope::default())
}
pub fn resolve_modifier_with_entrypoints_and_scope(
h: &mut impl HugrMut<Node = Node>,
entry_points: impl IntoIterator<Item = Node>,
scope: &PassScope,
) -> Result<(), ModifierResolverErrors<Node>> {
use ModifierResolverErrors::*;
let entry_points: VecDeque<_> = entry_points.into_iter().collect();
let mut resolver = ModifierResolver::new();
let mut worklist = entry_points.clone();
let mut visited = FxHashSet::default();
while let Some(node) = worklist.pop_front() {
if !h.contains_node(node) || visited.contains(&node) {
continue;
}
worklist.extend(h.children(node).filter(|n| !visited.contains(n)));
worklist.extend(h.all_neighbours(node).filter(|n| !visited.contains(n)));
visited.insert(node);
if let Err(e) = resolver.try_rewrite(h, node) {
if !matches!(e, ModifierError(_)) {
return Err(e);
}
}
}
let mut deletelist = entry_points.clone();
let mut visited = FxHashSet::default();
while let Some(node) = deletelist.pop_front() {
deletelist.extend(h.children(node).filter(|n| !visited.contains(n)));
deletelist.extend(h.all_neighbours(node).filter(|n| !visited.contains(n)));
visited.insert(node);
if h.contains_node(node) {
let optype = h.get_optype(node);
if Modifier::from_optype(optype).is_some() {
let mut l = vec![node];
while let Some(n) = l.pop() {
l.extend(h.output_neighbours(n));
h.remove_node(n);
}
}
}
}
delete_phase(h, entry_points)?;
remove_unused_modified_functions(h, &resolver.modified_functions, scope);
h.validate()
.map_err(|e| ModifierResolverErrors::BuildError(e.into()))?;
Ok(())
}
#[cfg(test)]
mod tests {
use std::{fs, io::BufReader, path::Path};
use cool_asserts::assert_matches;
use hugr::{
Hugr,
builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder},
ops::{
CallIndirect, ExtensionOp,
handle::{FuncID, NodeHandle},
},
std_extensions::collections::array::ArrayOpBuilder,
type_row,
types::Term,
};
use hugr_core::Visibility;
use crate::{
TketOp,
extension::modifier::{CONTROL_OP_ID, DAGGER_OP_ID, MODIFIER_EXTENSION},
metadata,
passes::composable::Preserve,
};
use super::*;
pub(crate) trait SetUnitary {
fn set_unitary(&mut self);
}
impl<T: Container> SetUnitary for T {
fn set_unitary(&mut self) {
let node = self.container_node();
self.hugr_mut()
.set_metadata::<metadata::UnitaryFlags>(node, 7);
}
}
pub(crate) fn test_modifier_resolver(
target_num: usize,
ctrl_num: u64,
foo: impl FnOnce(&mut ModuleBuilder<Hugr>, usize) -> FuncID<true>,
dagger: bool,
) {
let _ = resolved_modifier_test_hugr(target_num, ctrl_num, foo, dagger);
}
pub(crate) fn resolved_modifier_test_hugr(
target_num: usize,
ctrl_num: u64,
foo: impl FnOnce(&mut ModuleBuilder<Hugr>, usize) -> FuncID<true>,
dagger: bool,
) -> Hugr {
let (mut h, foo_node) = modifier_test_hugr(target_num, ctrl_num, foo, dagger);
let entrypoint = h.entrypoint();
resolve_modifier_with_entrypoints(&mut h, [entrypoint]).unwrap();
assert!(!h.contains_node(foo_node));
assert!(
h.nodes()
.all(|node| Modifier::from_optype(h.get_optype(node)).is_none())
);
assert_matches!(h.validate(), Ok(()));
h
}
pub(crate) fn modifier_test_hugr(
target_num: usize,
ctrl_num: u64,
foo: impl FnOnce(&mut ModuleBuilder<Hugr>, usize) -> FuncID<true>,
dagger: bool,
) -> (Hugr, Node) {
let mut module = ModuleBuilder::new();
let call_sig = Signature::new_endo(
[array_type(ctrl_num, qb_t())]
.into_iter()
.chain(iter::repeat_n(qb_t(), target_num))
.collect::<Vec<_>>(),
);
let main_sig = Signature::new(
type_row![],
vec![array_type(ctrl_num, qb_t())]
.into_iter()
.chain(iter::repeat_n(qb_t(), target_num))
.collect::<Vec<_>>(),
);
let dagger_op: ExtensionOp = {
MODIFIER_EXTENSION
.instantiate_extension_op(
&DAGGER_OP_ID,
[
iter::repeat_n(qb_t().into(), target_num)
.collect::<Vec<_>>()
.into(),
vec![].into(),
],
)
.unwrap()
};
let control_op: ExtensionOp = {
MODIFIER_EXTENSION
.instantiate_extension_op(
&CONTROL_OP_ID,
[
Term::BoundedNat(ctrl_num),
iter::repeat_n(qb_t().into(), target_num)
.collect::<Vec<_>>()
.into(),
vec![].into(),
],
)
.unwrap()
};
let foo = foo(&mut module, target_num);
let foo_node = foo.node();
let _main = {
let mut func = module.define_function("main", main_sig).unwrap();
let mut call = func.load_func(&foo, &[]).unwrap();
if dagger {
call = func
.add_dataflow_op(dagger_op, vec![call])
.unwrap()
.out_wire(0);
}
call = func
.add_dataflow_op(control_op, vec![call])
.unwrap()
.out_wire(0);
let mut controls = Vec::new();
for _ in 0..ctrl_num {
controls.push(
func.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0),
);
}
let mut targ = Vec::new();
for _ in 0..target_num {
targ.push(
func.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0),
)
}
let control_arr = func.add_new_array(qb_t(), controls).unwrap();
let fn_outs = func
.add_dataflow_op(
CallIndirect {
signature: call_sig,
},
[call, control_arr].into_iter().chain(targ),
)
.unwrap()
.outputs();
func.finish_with_outputs(fn_outs).unwrap()
};
let h = module.finish_hugr().unwrap();
assert_matches!(h.validate(), Ok(()));
(h, foo_node)
}
#[test]
fn shared_loaded_function_is_not_removed() {
let mut module = ModuleBuilder::new();
let foo_sig = Signature::new_endo(vec![qb_t()]);
let foo = {
let mut func = module.define_function("foo", foo_sig.clone()).unwrap();
func.set_unitary();
let mut inputs: Vec<Wire> = func.input_wires().collect();
inputs[0] = func
.add_dataflow_op(TketOp::X, vec![inputs[0]])
.unwrap()
.out_wire(0);
func.finish_with_outputs(inputs).unwrap()
};
let foo_node = foo.node();
let ctrl_num = 1;
let controlled_sig = Signature::new_endo(vec![array_type(ctrl_num, qb_t()), qb_t()]);
let main_sig = Signature::new(
type_row![],
vec![array_type(ctrl_num, qb_t()), qb_t(), qb_t()],
);
let control_op: ExtensionOp = MODIFIER_EXTENSION
.instantiate_extension_op(
&CONTROL_OP_ID,
[
Term::BoundedNat(ctrl_num),
vec![qb_t().into()].into(),
vec![].into(),
],
)
.unwrap();
let shared_load_node = {
let mut func = module.define_function("main", main_sig).unwrap();
let loaded = func.load_func(foo.handle(), &[]).unwrap();
let shared_load_node = loaded.node();
let modified_fn = func
.add_dataflow_op(control_op, vec![loaded])
.unwrap()
.out_wire(0);
let control = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let controlled_target = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let direct_target = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let control_arr = func.add_new_array(qb_t(), [control]).unwrap();
let [control_arr, controlled_target] = func
.add_dataflow_op(
CallIndirect {
signature: controlled_sig,
},
[modified_fn, control_arr, controlled_target],
)
.unwrap()
.outputs_arr();
let direct_target = func
.add_dataflow_op(CallIndirect { signature: foo_sig }, [loaded, direct_target])
.unwrap()
.out_wire(0);
func.finish_with_outputs([control_arr, controlled_target, direct_target])
.unwrap();
shared_load_node
};
let mut h = module.finish_hugr().unwrap();
assert_matches!(h.validate(), Ok(()));
let entrypoint = h.entrypoint();
resolve_modifier_with_entrypoints(&mut h, [entrypoint]).unwrap();
assert!(h.contains_node(shared_load_node));
assert!(h.contains_node(foo_node));
assert_matches!(h.validate(), Ok(()));
}
#[test]
fn unused_unmodified_function_is_preserved() {
let mut module = ModuleBuilder::new();
let foo_sig = Signature::new_endo(vec![qb_t()]);
let foo = {
let mut func = module.define_function("foo", foo_sig.clone()).unwrap();
func.set_unitary();
let mut inputs: Vec<Wire> = func.input_wires().collect();
inputs[0] = func
.add_dataflow_op(TketOp::X, vec![inputs[0]])
.unwrap()
.out_wire(0);
func.finish_with_outputs(inputs).unwrap()
};
let foo_node = foo.node();
let unused = {
let func = module
.define_function("unused", Signature::new_endo(vec![qb_t()]))
.unwrap();
let inputs = func.input_wires();
func.finish_with_outputs(inputs).unwrap()
};
let unused_node = unused.node();
let ctrl_num = 1;
let controlled_sig = Signature::new_endo(vec![array_type(ctrl_num, qb_t()), qb_t()]);
let main_sig = Signature::new(type_row![], vec![array_type(ctrl_num, qb_t()), qb_t()]);
let control_op: ExtensionOp = MODIFIER_EXTENSION
.instantiate_extension_op(
&CONTROL_OP_ID,
[
Term::BoundedNat(ctrl_num),
vec![qb_t().into()].into(),
vec![].into(),
],
)
.unwrap();
{
let mut func = module.define_function("main", main_sig).unwrap();
let loaded = func.load_func(foo.handle(), &[]).unwrap();
let modified_fn = func
.add_dataflow_op(control_op, vec![loaded])
.unwrap()
.out_wire(0);
let control = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let target = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let control_arr = func.add_new_array(qb_t(), [control]).unwrap();
let outputs = func
.add_dataflow_op(
CallIndirect {
signature: controlled_sig,
},
[modified_fn, control_arr, target],
)
.unwrap()
.outputs();
func.finish_with_outputs(outputs).unwrap();
}
let mut h = module.finish_hugr().unwrap();
assert_matches!(h.validate(), Ok(()));
let entrypoint = h.entrypoint();
resolve_modifier_with_entrypoints(&mut h, [entrypoint]).unwrap();
assert!(!h.contains_node(foo_node));
assert!(h.contains_node(unused_node));
assert_matches!(h.validate(), Ok(()));
}
#[test]
fn modified_public_function_is_not_removed_after_passes() {
let mut module = ModuleBuilder::new();
let foo_sig = Signature::new_endo(vec![qb_t()]);
let foo = {
let mut func = module
.define_function_vis("foo", foo_sig, Visibility::Public)
.unwrap();
func.set_unitary();
let mut inputs: Vec<Wire> = func.input_wires().collect();
inputs[0] = func
.add_dataflow_op(TketOp::X, vec![inputs[0]])
.unwrap()
.out_wire(0);
func.finish_with_outputs(inputs).unwrap()
};
let foo_node = foo.node();
let ctrl_num = 1;
let controlled_sig = Signature::new_endo(vec![array_type(ctrl_num, qb_t()), qb_t()]);
let main_sig = Signature::new(type_row![], vec![array_type(ctrl_num, qb_t()), qb_t()]);
let control_op: ExtensionOp = MODIFIER_EXTENSION
.instantiate_extension_op(
&CONTROL_OP_ID,
[
Term::BoundedNat(ctrl_num),
vec![qb_t().into()].into(),
vec![].into(),
],
)
.unwrap();
{
let mut func = module.define_function("main", main_sig).unwrap();
let loaded = func.load_func(foo.handle(), &[]).unwrap();
let modified_fn = func
.add_dataflow_op(control_op, vec![loaded])
.unwrap()
.out_wire(0);
let control = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let target = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let control_arr = func.add_new_array(qb_t(), [control]).unwrap();
let outputs = func
.add_dataflow_op(
CallIndirect {
signature: controlled_sig,
},
[modified_fn, control_arr, target],
)
.unwrap()
.outputs();
func.finish_with_outputs(outputs).unwrap();
}
let mut h = module.finish_hugr().unwrap();
assert_matches!(h.validate(), Ok(()));
let entrypoint = h.entrypoint();
resolve_modifier_with_entrypoints(&mut h, [entrypoint]).unwrap();
assert!(h.contains_node(foo_node));
assert_matches!(h.validate(), Ok(()));
}
#[test]
fn modified_public_function_is_removed_when_not_preserved_by_scope() {
let mut module = ModuleBuilder::new();
let foo_sig = Signature::new_endo(vec![qb_t()]);
let foo = {
let mut func = module
.define_function_vis("foo", foo_sig, Visibility::Public)
.unwrap();
func.set_unitary();
let mut inputs: Vec<Wire> = func.input_wires().collect();
inputs[0] = func
.add_dataflow_op(TketOp::X, vec![inputs[0]])
.unwrap()
.out_wire(0);
func.finish_with_outputs(inputs).unwrap()
};
let foo_node = foo.node();
let ctrl_num = 1;
let controlled_sig = Signature::new_endo(vec![array_type(ctrl_num, qb_t()), qb_t()]);
let main_sig = Signature::new(type_row![], vec![array_type(ctrl_num, qb_t()), qb_t()]);
let control_op: ExtensionOp = MODIFIER_EXTENSION
.instantiate_extension_op(
&CONTROL_OP_ID,
[
Term::BoundedNat(ctrl_num),
vec![qb_t().into()].into(),
vec![].into(),
],
)
.unwrap();
let main_node = {
let mut func = module.define_function("main", main_sig).unwrap();
let loaded = func.load_func(foo.handle(), &[]).unwrap();
let modified_fn = func
.add_dataflow_op(control_op, vec![loaded])
.unwrap()
.out_wire(0);
let control = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let target = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let control_arr = func.add_new_array(qb_t(), [control]).unwrap();
let outputs = func
.add_dataflow_op(
CallIndirect {
signature: controlled_sig,
},
[modified_fn, control_arr, target],
)
.unwrap()
.outputs();
func.finish_with_outputs(outputs).unwrap().node()
};
let mut h = module.finish_hugr().unwrap();
h.set_entrypoint(main_node);
assert_matches!(h.validate(), Ok(()));
let scope = PassScope::Global(Preserve::Entrypoint);
let root = scope.root(&h).unwrap();
resolve_modifier_with_entrypoints_and_scope(&mut h, [root], &scope).unwrap();
assert!(!h.contains_node(foo_node));
assert_matches!(h.validate(), Ok(()));
}
#[test]
fn modified_dependency_is_preserved_when_original_caller_is_live() {
let mut module = ModuleBuilder::new();
let foo_sig = Signature::new_endo(vec![qb_t()]);
let foo = {
let mut func = module.define_function("foo", foo_sig.clone()).unwrap();
func.set_unitary();
let mut inputs: Vec<Wire> = func.input_wires().collect();
inputs[0] = func
.add_dataflow_op(TketOp::X, vec![inputs[0]])
.unwrap()
.out_wire(0);
func.finish_with_outputs(inputs).unwrap()
};
let foo_node = foo.node();
let bar = {
let mut func = module.define_function("bar", foo_sig.clone()).unwrap();
func.set_unitary();
let call = func.call(foo.handle(), &[], func.input_wires()).unwrap();
func.finish_with_outputs(call.outputs()).unwrap()
};
let bar_node = bar.node();
let ctrl_num = 1;
let controlled_sig = Signature::new_endo(vec![array_type(ctrl_num, qb_t()), qb_t()]);
let main_sig = Signature::new(
type_row![],
vec![array_type(ctrl_num, qb_t()), qb_t(), qb_t()],
);
let control_op: ExtensionOp = MODIFIER_EXTENSION
.instantiate_extension_op(
&CONTROL_OP_ID,
[
Term::BoundedNat(ctrl_num),
vec![qb_t().into()].into(),
vec![].into(),
],
)
.unwrap();
{
let mut func = module.define_function("main", main_sig).unwrap();
let loaded = func.load_func(bar.handle(), &[]).unwrap();
let modified_fn = func
.add_dataflow_op(control_op, vec![loaded])
.unwrap()
.out_wire(0);
let control = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let controlled_target = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let direct_target = func
.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0);
let control_arr = func.add_new_array(qb_t(), [control]).unwrap();
let [control_arr, controlled_target] = func
.add_dataflow_op(
CallIndirect {
signature: controlled_sig,
},
[modified_fn, control_arr, controlled_target],
)
.unwrap()
.outputs_arr();
let direct_target = func
.call(bar.handle(), &[], [direct_target])
.unwrap()
.out_wire(0);
func.finish_with_outputs([control_arr, controlled_target, direct_target])
.unwrap();
}
let mut h = module.finish_hugr().unwrap();
assert_matches!(h.validate(), Ok(()));
let entrypoint = h.entrypoint();
resolve_modifier_with_entrypoints(&mut h, [entrypoint]).unwrap();
assert!(h.contains_node(bar_node));
assert!(h.contains_node(foo_node));
assert_matches!(h.validate(), Ok(()));
}
fn load_guppy_example(file: impl AsRef<Path>) -> std::io::Result<Hugr> {
let reader = fs::File::open(file)?;
let reader = BufReader::new(reader);
Ok(Hugr::load(reader, None).unwrap())
}
fn test_resolve(h: &mut Hugr) {
assert_matches!(h.validate(), Ok(()));
let entrypoint = h.entrypoint();
resolve_modifier_with_entrypoints(h, [entrypoint]).unwrap();
assert_matches!(h.validate(), Ok(()));
}
#[rstest::rstest]
#[case::multiple_functions_in_ctrl_dagger(
"../test_files/modifier_examples/multiple_functions_in_ctrl_dagger.hugr"
)]
#[case::guppy_modifiers("../test_files/guppy_examples/modifiers.hugr")]
#[case::assign_in_dagger("../test_files/modifier_examples/assign_in_dagger.hugr")]
#[case::classical_array_op("../test_files/modifier_examples/classical_array_op.hugr")]
#[case::classical_function1("../test_files/modifier_examples/classical_function1.hugr")]
#[case::classical_function2("../test_files/modifier_examples/classical_function2.hugr")]
#[case::classical_function3("../test_files/modifier_examples/classical_function3.hugr")]
#[case::ctrl_on_cfg("../test_files/modifier_examples/ctrl_on_cfg.hugr")]
#[case::multiple_gates2_in_ctrl("../test_files/modifier_examples/multiple_gates2_in_ctrl.hugr")]
#[case::subscript_in_ctrl("../test_files/modifier_examples/subscript_in_ctrl.hugr")]
#[case::subscript_in_dagger("../test_files/modifier_examples/subscript_in_dagger.hugr")]
#[case::subscript_as_controller("../test_files/modifier_examples/subscript_as_controller.hugr")]
#[case::complex_modifier_stress("../test_files/modifier_examples/complex_modifier_stress.hugr")]
#[case::ctrl_array_controller("../test_files/modifier_examples/ctrl_array_controller.hugr")]
#[case::call1_in_ctrl("../test_files/modifier_examples/call1_in_ctrl.hugr")]
#[case::call2_in_ctrl("../test_files/modifier_examples/call2_in_ctrl.hugr")]
#[case::multiple_gates1_in_ctrl("../test_files/modifier_examples/multiple_gates1_in_ctrl.hugr")]
#[case::gate_in_ctrl("../test_files/modifier_examples/gate_in_ctrl.hugr")]
#[case::call_in_dagger("../test_files/modifier_examples/call_in_dagger.hugr")]
#[case::multiple_functions_in_dagger(
"../test_files/modifier_examples/multiple_functions_in_dagger.hugr"
)]
#[case::multiple_gates1_in_dagger(
"../test_files/modifier_examples/multiple_gates1_in_dagger.hugr"
)]
#[case::multiple_gates2_in_dagger(
"../test_files/modifier_examples/multiple_gates2_in_dagger.hugr"
)]
#[case::multiple_gates3_in_dagger(
"../test_files/modifier_examples/multiple_gates3_in_dagger.hugr"
)]
#[case::double_modifier("../test_files/modifier_examples/double_modifier.hugr")]
#[case::modify_array("../test_files/modifier_examples/modify_array.hugr")]
#[case::multiple_dagger("../test_files/modifier_examples/multiple_dagger.hugr")]
#[case::nested_ctrl_dagger1("../test_files/modifier_examples/nested_ctrl_dagger1.hugr")]
#[case::nested_multiple_ctrl1("../test_files/modifier_examples/nested_multiple_ctrl1.hugr")]
#[case::swap_in_dagger("../test_files/modifier_examples/swap_in_dagger.hugr")]
#[case::subscript_in_dagger_ctrl(
"../test_files/modifier_examples/subscript_in_dagger_ctrl.hugr"
)]
#[cfg_attr(miri, ignore)] fn test_examples(#[case] example: &str) {
let mut h = load_guppy_example(example).unwrap();
test_resolve(&mut h);
}
}