use itertools::{Either, Itertools};
use std::{
collections::{HashMap, VecDeque},
iter, mem,
};
pub mod array_modify;
pub mod call_modify;
pub mod dfg_modify;
pub mod global_phase_modify;
pub mod tket_op_modify;
use super::{CombinedModifier, ModifierFlags};
use crate::passes::utils::unpack_container::TypeUnpacker;
use crate::{TketOp, extension::global_phase::GlobalPhase, modifier::Modifier};
use global_phase_modify::delete_phase;
use hugr::{
HugrView, IncomingPort, Node, OutgoingPort, Port, PortIndex, Wire,
builder::{BuildError, CFGBuilder, Container, Dataflow, SubContainer},
core::HugrNode,
extension::{prelude::qb_t, simple_op::MakeExtensionOp},
hugr::hugrmut::HugrMut,
ops::{CFG, Const, OpType},
std_extensions::collections::array::array_type,
type_row,
types::{EdgeKind, FuncTypeBase, Signature, Type},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct DirWire<N = Node>(N, Port);
impl<N: HugrNode> std::fmt::Display for DirWire<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let dir = match self.1.as_directed() {
Either::Left(_) => "In",
Either::Right(_) => "Out",
};
write!(f, "DirWire({}, {}({}))", self.0, dir, self.1.index())
}
}
impl<N> DirWire<N> {
pub fn new(node: N, port: Port) -> Self {
DirWire(node, port)
}
pub fn reverse(self) -> Self {
let index = self.1.index();
let port = match self.1.as_directed() {
Either::Left(_in) => OutgoingPort::from(index).into(),
Either::Right(_out) => IncomingPort::from(index).into(),
};
DirWire::new(self.0, port)
}
}
impl<N: HugrNode> From<Wire<N>> for DirWire<N> {
fn from(wire: Wire<N>) -> Self {
DirWire(wire.node(), wire.source().into())
}
}
impl<N: HugrNode> From<(N, OutgoingPort)> for DirWire<N> {
fn from((node, port): (N, OutgoingPort)) -> Self {
DirWire(node, port.into())
}
}
impl<N: HugrNode> From<(N, IncomingPort)> for DirWire<N> {
fn from((node, port): (N, IncomingPort)) -> Self {
DirWire(node, port.into())
}
}
impl<N: HugrNode> TryFrom<DirWire<N>> for Wire<N> {
type Error = hugr::hugr::HugrError;
fn try_from(value: DirWire<N>) -> Result<Self, Self::Error> {
let out_port = value.1.as_outgoing()?;
Ok(Wire::new(value.0, out_port))
}
}
impl<N: HugrNode> TryFrom<DirWire<N>> for (N, IncomingPort) {
type Error = hugr::hugr::HugrError;
fn try_from(value: DirWire<N>) -> Result<Self, Self::Error> {
let in_port = value.1.as_incoming()?;
Ok((value.0, in_port))
}
}
fn connect<N>(
new_dfg: &mut impl Container,
w1: &DirWire<Node>,
w2: &DirWire<Node>,
) -> Result<(), ModifierResolverErrors<N>> {
let (n_o, p_o, n_i, p_i) = match (w1.1.as_directed(), w2.1.as_directed()) {
(Either::Right(p_o), Either::Left(p_i)) => (w1.0, p_o, w2.0, p_i),
(Either::Left(p_i), Either::Right(p_o)) => (w2.0, p_o, w1.0, p_i),
_ => {
return Err(ModifierResolverErrors::unreachable(format!(
"Cannot connect the wires with the same direction: {} -> {}",
w1, w2
)));
}
};
new_dfg.hugr_mut().connect(n_o, p_o, n_i, p_i);
Ok(())
}
fn connect_by_num(
new_dfg: &mut impl Dataflow,
dw: &DirWire<Node>,
node: Node,
num: usize,
) -> DirWire<Node> {
let dw_node = dw.0;
match dw.1.as_directed() {
Either::Left(incoming) => {
new_dfg.hugr_mut().connect(node, num, dw_node, incoming);
(node, IncomingPort::from(num)).into()
}
Either::Right(outgoing) => {
new_dfg.hugr_mut().connect(dw_node, outgoing, node, num);
(node, OutgoingPort::from(num)).into()
}
}
}
trait PortExt {
fn shift(self, offset: usize) -> Self;
}
impl PortExt for Port {
fn shift(self, offset: usize) -> Self {
Port::new(self.direction(), self.index() + offset)
}
}
impl PortExt for IncomingPort {
fn shift(self, offset: usize) -> Self {
IncomingPort::from(self.index() + offset)
}
}
impl PortExt for OutgoingPort {
fn shift(self, offset: usize) -> Self {
OutgoingPort::from(self.index() + offset)
}
}
impl<N> PortExt for DirWire<N> {
fn shift(self, offset: usize) -> Self {
DirWire(self.0, self.1.shift(offset))
}
}
pub struct PortVector<N = Node> {
incoming: Vec<DirWire<N>>,
outgoing: Vec<DirWire<N>>,
}
impl<N: HugrNode> PortVector<N> {
fn from_single_node(
n: N,
inputs: impl Iterator<Item = usize>,
outputs: impl Iterator<Item = usize>,
) -> Self {
let incoming = inputs.map(|p| (n, IncomingPort::from(p)).into()).collect();
let outgoing = outputs.map(|p| (n, OutgoingPort::from(p)).into()).collect();
PortVector { incoming, outgoing }
}
fn port_vector_rev(
n: N,
inputs: impl Iterator<Item = usize>,
outputs: impl Iterator<Item = usize>,
iter: impl Iterator<Item = usize>,
) -> Self {
let iter = iter.collect::<Vec<_>>();
let incoming = inputs
.map(|p| {
if iter.contains(&p) {
(n, OutgoingPort::from(p)).into()
} else {
(n, IncomingPort::from(p)).into()
}
})
.collect();
let outgoing = outputs
.map(|p| {
if iter.contains(&p) {
(n, IncomingPort::from(p)).into()
} else {
(n, OutgoingPort::from(p)).into()
}
})
.collect();
PortVector { incoming, outgoing }
}
}
pub struct ModifierResolver<N = Node> {
modifiers: CombinedModifier,
corresp_map: HashMap<DirWire<N>, Vec<DirWire>>,
controls: Vec<Wire>,
worklist: VecDeque<N>,
call_map: HashMap<N, (Node, IncomingPort)>,
qubit_finder: TypeUnpacker,
}
impl<N> ModifierResolver<N> {
pub fn new() -> Self {
ModifierResolver {
modifiers: CombinedModifier::default(),
corresp_map: HashMap::default(),
controls: Vec::default(),
worklist: VecDeque::default(),
call_map: HashMap::default(),
qubit_finder: TypeUnpacker::for_qubits(),
}
}
}
impl<N> Default for ModifierResolver<N> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, derive_more::Error, derive_more::Display)]
pub enum ModifierError<N = Node> {
#[display("Node to modify {_0} expected to be a modifier but actually {_1}")]
NotModifier(N, OpType),
#[display("No caller of the modified function exists for node {_0}")]
#[error(ignore)]
NoCaller(N),
#[display("No caller of the modified function exists for node {_0}")]
#[error(ignore)]
NoTarget(N),
#[display("Node {_0} is not the first modifier in a chain. It is called by {_0}")]
NotInitialModifier(N, OpType),
#[display("Modifier cannot be applied to the node {_0} of type {_1}")]
ModifierNotApplicable(N, OpType),
}
impl<N> ModifierError<N> {
fn node(self) -> N {
match self {
ModifierError::NotModifier(n, _)
| ModifierError::NoCaller(n)
| ModifierError::NoTarget(n)
| ModifierError::NotInitialModifier(n, _)
| ModifierError::ModifierNotApplicable(n, _) => n,
}
}
}
#[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)]
pub enum ModifierResolverErrors<N = Node> {
#[display("{_0}")]
#[from]
ModifierError(ModifierError<N>),
#[display("{_0}")]
#[from]
BuildError(BuildError),
#[display("Unreachable error: {msg}")]
Unreachable {
msg: String,
},
#[display("Modifier {node} applied to the node {msg} cannot be modified")]
UnResolvable {
node: N,
msg: String,
optype: OpType,
},
#[display("Modification by {_0:?} is not defined for the node {_1}")]
Unimplemented(Modifier, OpType),
}
impl<N> ModifierResolverErrors<N> {
pub fn unreachable(msg: impl Into<String>) -> Self {
Self::Unreachable { msg: msg.into() }
}
pub fn unresolvable(node: N, msg: impl Into<String>, optype: OpType) -> Self {
Self::UnResolvable {
node,
msg: msg.into(),
optype,
}
}
}
impl<N: HugrNode> ModifierResolver<N> {
fn modifiers_mut(&mut self) -> &mut CombinedModifier {
&mut self.modifiers
}
fn modifiers(&self) -> &CombinedModifier {
&self.modifiers
}
fn control_num(&self) -> usize {
self.modifiers.control
}
fn controls(&mut self) -> &mut Vec<Wire> {
&mut self.controls
}
fn controls_ref(&self) -> &Vec<Wire> {
&self.controls
}
fn worklist(&mut self) -> &mut VecDeque<N> {
&mut self.worklist
}
fn corresp_map(&mut self) -> &mut HashMap<DirWire<N>, Vec<DirWire>> {
&mut self.corresp_map
}
fn call_map(&mut self) -> &mut HashMap<N, (Node, IncomingPort)> {
&mut self.call_map
}
fn with_worklist<T>(&mut self, worklist: VecDeque<N>, f: impl FnOnce(&mut Self) -> T) -> T {
let worklist = mem::replace(self.worklist(), worklist);
let r = f(self);
*self.worklist() = worklist;
r
}
fn with_modifiers<T>(
&mut self,
modifiers: CombinedModifier,
f: impl FnOnce(&mut Self) -> T,
) -> T {
let modifiers = mem::replace(self.modifiers_mut(), modifiers);
let r = f(self);
*self.modifiers_mut() = modifiers;
r
}
fn with_ancilla<T>(
&mut self,
wire: &mut Wire<Node>,
ancilla: &mut Vec<Wire<Node>>,
f: impl FnOnce(&mut Self, &mut Vec<Wire<Node>>) -> T,
) -> T {
ancilla.push(*wire);
let r = f(self, ancilla);
*wire = ancilla.pop().unwrap();
r
}
fn pop_control(&mut self) -> Option<Wire<Node>> {
if let Some(c) = self.controls().pop() {
self.modifiers.control -= 1;
Some(c)
} else {
None
}
}
fn push_control(&mut self, c: Wire<Node>) {
self.controls().push(c);
self.modifiers.control += 1;
}
fn map_insert(
&mut self,
old: DirWire<N>,
new: DirWire,
) -> Result<(), ModifierResolverErrors<N>> {
self.corresp_map()
.insert(old, vec![new])
.map_or(Ok(()), |former| {
Err(ModifierResolverErrors::unreachable(format!(
"Wire already registered for node {}. Former [{},...], Latter {}.",
old.0, former[0], new
)))
})
}
fn map_insert_none(&mut self, old: DirWire<N>) -> Result<(), ModifierResolverErrors<N>> {
self.corresp_map().entry(old).or_default();
Ok(())
}
fn map_get(&self, key: &DirWire<N>) -> Result<&Vec<DirWire>, ModifierResolverErrors<N>> {
self.corresp_map
.get(key)
.ok_or(ModifierResolverErrors::unreachable(format!(
"No correspondence for the wire: {}",
key
)))
}
fn forget_node(
&mut self,
h: &impl HugrView<Node = N>,
n: N,
) -> Result<(), ModifierResolverErrors<N>> {
for port in h.all_node_ports(n) {
let dw = DirWire(n, port);
self.map_insert_none(dw)?;
}
Ok(())
}
fn add_node_no_modification(
&mut self,
h: &impl HugrMut<Node = N>,
old_n: N,
op: impl Into<OpType>,
new_dfg: &mut impl Container,
) -> Result<Node, ModifierResolverErrors<N>> {
let node = new_dfg.add_child_node(op);
for port in h.all_node_ports(old_n) {
self.map_insert(DirWire(old_n, port), DirWire(node, port))?;
}
Ok(node)
}
fn port_vector_dagger(
&self,
n: Node,
inputs: impl Iterator<Item = usize>,
outputs: impl Iterator<Item = usize>,
iter: impl Iterator<Item = usize>,
) -> PortVector<Node> {
if self.modifiers.dagger {
PortVector::port_vector_rev(n, inputs, outputs, iter)
} else {
PortVector::from_single_node(n, inputs, outputs)
}
}
fn add_edge_from_pv(
&mut self,
h: &impl HugrMut<Node = N>,
n: N,
pv: PortVector<Node>,
) -> Result<(), ModifierResolverErrors<N>> {
let PortVector { incoming, outgoing } = pv;
for (old_in, new) in (0..h.num_inputs(n)).map(IncomingPort::from).zip(incoming) {
self.map_insert((n, old_in).into(), new)?
}
for (old_out, new) in (0..h.num_outputs(n)).map(OutgoingPort::from).zip(outgoing) {
self.map_insert((n, old_out).into(), new)?
}
Ok(())
}
fn add_node_control(&mut self, new_dfg: &mut impl Container, op: impl Into<OpType>) -> Node {
let node = new_dfg.add_child_node(op);
for (i, ctrl) in self.controls().iter_mut().enumerate() {
new_dfg
.hugr_mut()
.connect(ctrl.node(), ctrl.source(), node, i);
*ctrl = Wire::new(node, i);
}
node
}
pub fn connect_all(
&mut self,
h: &impl HugrView<Node = N>,
new_dfg: &mut impl Container,
parent: N,
) -> Result<(), ModifierResolverErrors<N>> {
for out_node in h.children(parent) {
for out_port in h.node_outputs(out_node) {
if let Some(EdgeKind::StateOrder) = h.get_optype(out_node).port_kind(out_port) {
continue;
}
for (in_node, in_port) in h.linked_inputs(out_node, out_port) {
for a in self.map_get(&(in_node, in_port).into())? {
for b in self.map_get(&(out_node, out_port).into())? {
connect(new_dfg, a, b)?
}
}
}
}
}
Ok(())
}
}
impl<N: HugrNode> ModifierResolver<N> {
fn verify(&self, h: &impl HugrView<Node = N>, n: N) -> Result<(), ModifierError<N>> {
let optype = h.get_optype(n);
if Modifier::from_optype(optype).is_none() {
return Err(ModifierError::NotModifier(n, optype.clone()));
}
let Ok((caller, _)) = h.linked_inputs(n, 0).exactly_one() else {
return Err(ModifierError::NoCaller(n));
};
let optype = h.get_optype(caller);
if Modifier::from_optype(optype).is_some() {
return Err(ModifierError::NotInitialModifier(caller, optype.clone()));
}
Ok(())
}
fn try_rewrite(
&mut self,
h: &mut impl HugrMut<Node = N>,
n: N,
) -> Result<(), ModifierResolverErrors<N>> {
self.verify(h, n)?;
let modified_fn_loader: Vec<(_, Vec<_>)> = h
.node_outputs(n)
.map(|p| (p, h.linked_inputs(n, p).collect()))
.collect();
let modifiers = CombinedModifier::default();
let new_load = self.with_modifiers(modifiers, |this| {
this.apply_modifier_chain_to_loaded_fn(h, n)
})?;
for (out_port, inputs) in modified_fn_loader {
for (recv, recv_port) in inputs {
h.disconnect(recv, recv_port);
h.connect(new_load, out_port, recv, recv_port);
}
}
Ok(())
}
pub fn modify_signature(&self, signature: &mut Signature, flatten: bool) {
let FuncTypeBase { input, output } = signature;
if flatten {
let n = self.control_num();
input.to_mut().splice(0..0, iter::repeat_n(qb_t(), n));
output.to_mut().splice(0..0, iter::repeat_n(qb_t(), n));
} else {
for ctrls in &self.modifiers.accum_ctrl {
let n = *ctrls as u64;
input.to_mut().insert(0, array_type(n, qb_t()));
output.to_mut().insert(0, array_type(n, qb_t()));
}
}
}
fn modify_op(
&mut self,
h: &mut impl HugrMut<Node = N>,
n: N,
new_dfg: &mut impl Dataflow,
) -> Result<(), ModifierResolverErrors<N>> {
let optype = &h.get_optype(n).clone();
match optype {
OpType::Input(_) | OpType::Output(_) => {}
OpType::CFG(cfg) => self.modify_cfg(h, n, cfg, new_dfg)?,
OpType::DFG(dfg) => self.modify_dfg(h, n, dfg, new_dfg)?,
OpType::TailLoop(tail_loop) => self.modify_tail_loop(h, n, tail_loop, new_dfg)?,
OpType::Conditional(conditional) => {
self.modify_conditional(h, n, conditional, new_dfg)?
}
OpType::Call(_) => self.modify_call(h, n, optype, new_dfg)?,
OpType::CallIndirect(indir_call) => {
self.modify_indirect_call(h, n, indir_call, new_dfg)?
}
OpType::LoadFunction(load) => self.modify_load_function(h, n, load, new_dfg)?,
OpType::ExtensionOp(_) => {
self.modify_extension_op(h, n, optype, new_dfg)?;
}
OpType::Const(constant) => {
self.modify_constant(n, constant, new_dfg)?;
}
OpType::LoadConstant(_) | OpType::OpaqueOp(_) | OpType::Tag(_) => {
self.add_node_no_modification(h, n, optype.clone(), new_dfg)?;
}
OpType::FuncDefn(_) | OpType::FuncDecl(_) | OpType::Module(_) => {
return Err(ModifierResolverErrors::unreachable(format!(
"Invalid node found inside modified function (OpType = {})",
optype.clone()
)));
}
OpType::Case(_) => {
return Err(ModifierResolverErrors::unreachable(
"Case cannot be directly modified.".to_string(),
));
}
OpType::AliasDecl(_)
| OpType::AliasDefn(_)
| OpType::ExitBlock(_)
| OpType::DataflowBlock(_) => {
return Err(ModifierResolverErrors::unresolvable(
n,
"Unmodifiable node found".to_string(),
optype.clone(),
));
}
_ => {
return Err(ModifierResolverErrors::unresolvable(
n,
"Unknown operation".to_string(),
optype.clone(),
));
}
}
Ok(())
}
fn wire_node_inout<'a>(
&mut self,
n: N,
node: Node,
(inputs, outputs): (
impl Iterator<Item = &'a Type>,
impl Iterator<Item = &'a Type>,
),
(input_offset, output_offset, new_offset): (usize, usize, usize),
) -> Result<(), ModifierResolverErrors<N>> {
self.wire_inout(
(n, n),
(node, node),
(inputs, outputs),
(input_offset, output_offset, new_offset),
)
}
fn wire_inout<'a>(
&mut self,
(old_in, old_out): (N, N),
(new_in, new_out): (Node, Node),
(mut inputs, mut outputs): (
impl Iterator<Item = &'a Type>,
impl Iterator<Item = &'a Type>,
),
(input_offset, output_offset, new_offset): (usize, usize, usize),
) -> Result<(), ModifierResolverErrors<N>> {
let mut old_in_wire = (old_in, IncomingPort::from(input_offset)).into();
let mut old_out_wire = (old_out, OutgoingPort::from(output_offset)).into();
let mut new_in_wire = (new_in, IncomingPort::from(input_offset + new_offset)).into();
let mut new_out_wire = (new_out, OutgoingPort::from(output_offset + new_offset)).into();
let mut in_ty = inputs.next();
let mut out_ty = outputs.next();
loop {
while let Some(ty) = in_ty {
if self.qubit_finder.contains_element_type(ty) {
break;
}
self.map_insert(old_in_wire, new_in_wire)?;
old_in_wire = old_in_wire.shift(1);
new_in_wire = new_in_wire.shift(1);
in_ty = inputs.next();
}
while let Some(ty) = out_ty {
if self.qubit_finder.contains_element_type(ty) {
break;
}
self.map_insert(old_out_wire, new_out_wire)?;
old_out_wire = old_out_wire.shift(1);
new_out_wire = new_out_wire.shift(1);
out_ty = outputs.next();
}
while let Some(ty) = in_ty {
if !self.qubit_finder.contains_element_type(ty) {
break;
}
let new_in = if !self.modifiers.dagger {
let new_in = new_in_wire;
new_in_wire = new_in_wire.shift(1);
new_in
} else {
let new_in = new_out_wire;
new_out_wire = new_out_wire.shift(1);
new_in
};
self.map_insert(old_in_wire, new_in)?;
old_in_wire = old_in_wire.shift(1);
in_ty = inputs.next();
}
while let Some(ty) = out_ty {
if !self.qubit_finder.contains_element_type(ty) {
break;
}
let new_out = if !self.modifiers.dagger {
let new_out = new_out_wire;
new_out_wire = new_out_wire.shift(1);
new_out
} else {
let new_out = new_in_wire;
new_in_wire = new_in_wire.shift(1);
new_out
};
self.map_insert(old_out_wire, new_out)?;
old_out_wire = old_out_wire.shift(1);
out_ty = outputs.next();
}
if in_ty.is_none() && out_ty.is_none() {
break;
}
}
Ok(())
}
fn _wire_others(
&mut self,
n: N,
n_optype: &OpType,
node: Node,
node_optype: &OpType,
) -> Result<(), ModifierResolverErrors<N>> {
if let (Some(old), Some(new)) =
(n_optype.other_input_port(), node_optype.other_input_port())
{
self.map_insert((n, old).into(), (node, new).into())?;
}
if let (Some(old), Some(new)) = (
n_optype.other_output_port(),
node_optype.other_output_port(),
) {
self.map_insert((n, old).into(), (node, new).into())?;
}
Ok(())
}
fn modify_constant(
&mut self,
n: N,
constant: &Const,
new_dfg: &mut impl Container,
) -> Result<(), ModifierResolverErrors<N>> {
let output = new_dfg.add_child_node(constant.clone());
self.map_insert(Wire::new(n, 0).into(), Wire::new(output, 0).into())
}
fn modify_dataflow_op(
&mut self,
h: &impl HugrMut<Node = N>,
n: N,
optype: &OpType,
new_dfg: &mut impl Container,
) -> Result<(), ModifierResolverErrors<N>> {
let node = new_dfg.add_child_node(optype.clone());
let signature = h.signature(n).unwrap();
let inputs = signature.input.iter();
let outputs = signature.output.iter();
self.wire_node_inout(n, node, (inputs, outputs), (0, 0, 0))?;
Ok(())
}
fn modify_extension_op(
&mut self,
h: &impl HugrMut<Node = N>,
n: N,
optype: &OpType,
new_dfg: &mut impl Dataflow,
) -> Result<(), ModifierResolverErrors<N>> {
if self.controls().len() != self.control_num() {
return Err(ModifierResolverErrors::unreachable(
"Control qubits are not set correctly.".to_string(),
));
}
if let Some(op) = TketOp::from_optype(optype) {
let pv = self.modify_tket_op(n, op, new_dfg, &mut vec![])?;
self.add_edge_from_pv(h, n, pv)
} else if GlobalPhase::from_optype(optype).is_some() {
let inputs = self.modify_global_phase(n, new_dfg, &mut vec![])?;
self.corresp_map().insert(
(n, IncomingPort::from(0)).into(),
inputs.into_iter().map(Into::into).collect(),
);
Ok(())
} else if Modifier::from_optype(optype).is_some() {
self.forget_node(h, n)
} else if self.modify_array_op(h, n, optype, new_dfg)?
|| self.try_array_convert(h, n, optype, new_dfg)?
{
Ok(())
} else {
self.modify_dataflow_op(h, n, optype, new_dfg)
}
}
fn modify_cfg(
&mut self,
h: &mut impl HugrMut<Node = N>,
n: N,
cfg: &CFG,
new_dfg: &mut impl Container,
) -> Result<(), ModifierResolverErrors<N>> {
let children: Vec<N> = h
.children(n)
.filter(|child| h.get_optype(*child).is_dataflow_block())
.collect();
if children.len() != 1 {
return Err(ModifierResolverErrors::unresolvable(
n,
"CFG with more than one node found.".to_string(),
cfg.clone().into(),
));
}
let old_bb = children[0];
let mut signature = cfg.signature.clone();
self.modify_signature(&mut signature, true);
let mut new_cfg = CFGBuilder::new(signature.clone())?;
let mut new_bb = new_cfg.entry_builder([type_row![]], signature.output.clone())?;
self.modify_dfg_body(h, old_bb, &mut new_bb)?;
let bb_id = new_bb.finish_sub_container()?;
new_cfg.branch(&bb_id, 0, &new_cfg.exit_block())?;
let new = self.insert_sub_dfg(new_dfg, new_cfg)?;
for (i, c) in self.controls().iter_mut().enumerate() {
new_dfg.hugr_mut().connect(c.node(), c.source(), new, i);
*c = Wire::new(new, i);
}
let offset = self.control_num();
self.wire_node_inout(
n,
new,
(signature.input.iter(), signature.output.iter()),
(0, 0, offset),
)?;
Ok(())
}
}
pub fn resolve_modifier_with_entrypoints(
h: &mut impl HugrMut<Node = Node>,
entry_points: impl IntoIterator<Item = Node>,
) -> Result<(), ModifierResolverErrors<Node>> {
use ModifierResolverErrors::*;
let entry_points: VecDeque<_> = entry_points.into_iter().collect();
let mut resolver = ModifierResolver::new();
let mut worklist = entry_points.clone();
let mut visited = vec![];
while let Some(node) = worklist.pop_front() {
if !h.contains_node(node) || visited.contains(&node) {
continue;
}
worklist.extend(h.children(node).filter(|n| !visited.contains(n)));
worklist.extend(h.all_neighbours(node).filter(|n| !visited.contains(n)));
visited.push(node);
if let Err(e) = resolver.try_rewrite(h, node) {
if !matches!(e, ModifierError(_)) {
return Err(e);
}
}
}
let mut deletelist = entry_points.clone();
let mut visited = vec![];
while let Some(node) = deletelist.pop_front() {
deletelist.extend(h.children(node).filter(|n| !visited.contains(n)));
deletelist.extend(h.all_neighbours(node).filter(|n| !visited.contains(n)));
visited.push(node);
if h.contains_node(node) {
let optype = h.get_optype(node);
if Modifier::from_optype(optype).is_some() {
let mut l = vec![node];
while let Some(n) = l.pop() {
l.extend(h.output_neighbours(n));
h.remove_node(n);
}
}
}
}
delete_phase(h, entry_points)?;
h.validate()
.map_err(|e| ModifierResolverErrors::BuildError(e.into()))?;
Ok(())
}
#[cfg(test)]
mod tests {
use cool_asserts::assert_matches;
use hugr::{
Hugr,
builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder},
ops::{CallIndirect, ExtensionOp, handle::FuncID},
std_extensions::collections::array::ArrayOpBuilder,
types::Term,
};
use crate::{
TketOp,
extension::modifier::{CONTROL_OP_ID, DAGGER_OP_ID, MODIFIER_EXTENSION},
metadata,
};
use super::*;
pub(crate) trait SetUnitary {
fn set_unitary(&mut self);
}
impl<T: Container> SetUnitary for T {
fn set_unitary(&mut self) {
let node = self.container_node();
self.hugr_mut().set_metadata::<metadata::Unitary>(node, 7);
}
}
pub(crate) fn test_modifier_resolver(
t_num: usize,
c_num: u64,
foo: impl FnOnce(&mut ModuleBuilder<Hugr>, usize) -> FuncID<true>,
dagger: bool,
) {
let mut module = ModuleBuilder::new();
let call_sig = Signature::new_endo(
[array_type(c_num, qb_t())]
.into_iter()
.chain(iter::repeat_n(qb_t(), t_num))
.collect::<Vec<_>>(),
);
let main_sig = Signature::new(
type_row![],
vec![array_type(c_num, qb_t())]
.into_iter()
.chain(iter::repeat_n(qb_t(), t_num))
.collect::<Vec<_>>(),
);
let dagger_op: ExtensionOp = {
MODIFIER_EXTENSION
.instantiate_extension_op(
&DAGGER_OP_ID,
[
iter::repeat_n(qb_t().into(), t_num)
.collect::<Vec<_>>()
.into(),
vec![].into(),
],
)
.unwrap()
};
let control_op: ExtensionOp = {
MODIFIER_EXTENSION
.instantiate_extension_op(
&CONTROL_OP_ID,
[
Term::BoundedNat(c_num),
iter::repeat_n(qb_t().into(), t_num)
.collect::<Vec<_>>()
.into(),
vec![].into(),
],
)
.unwrap()
};
let foo = foo(&mut module, t_num);
let _main = {
let mut func = module.define_function("main", main_sig).unwrap();
let mut call = func.load_func(&foo, &[]).unwrap();
if dagger {
call = func
.add_dataflow_op(dagger_op, vec![call])
.unwrap()
.out_wire(0);
}
call = func
.add_dataflow_op(control_op, vec![call])
.unwrap()
.out_wire(0);
let mut controls = Vec::new();
for _ in 0..c_num {
controls.push(
func.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0),
);
}
let mut targ = Vec::new();
for _ in 0..t_num {
targ.push(
func.add_dataflow_op(TketOp::QAlloc, vec![])
.unwrap()
.out_wire(0),
)
}
let control_arr = func.add_new_array(qb_t(), controls).unwrap();
let fn_outs = func
.add_dataflow_op(
CallIndirect {
signature: call_sig,
},
[call, control_arr].into_iter().chain(targ),
)
.unwrap()
.outputs();
func.finish_with_outputs(fn_outs).unwrap()
};
let mut h = module.finish_hugr().unwrap();
assert_matches!(h.validate(), Ok(()));
let entrypoint = h.entrypoint();
resolve_modifier_with_entrypoints(&mut h, [entrypoint]).unwrap();
assert_matches!(h.validate(), Ok(()));
}
}