use super::Rewrite;
use crate::ops::handle::{DfgID, NodeHandle};
use crate::{IncomingPort, Node, OutgoingPort, PortIndex};
pub struct InlineDFG(pub DfgID);
#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
#[non_exhaustive]
pub enum InlineDFGError {
#[error("Node {0} was not a DFG")]
NotDFG(Node),
#[error("Node did not have a parent into which to inline")]
NoParent,
}
impl Rewrite for InlineDFG {
type ApplyResult = [Node; 3];
type Error = InlineDFGError;
const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, h: &impl crate::HugrView) -> Result<(), Self::Error> {
let n = self.0.node();
if h.get_optype(n).as_dfg().is_none() {
return Err(InlineDFGError::NotDFG(n));
};
if h.get_parent(n).is_none() {
return Err(InlineDFGError::NoParent);
};
Ok(())
}
fn apply(self, h: &mut impl crate::hugr::HugrMut) -> Result<Self::ApplyResult, Self::Error> {
self.verify(h)?;
let n = self.0.node();
let (oth_in, oth_out) = {
let dfg_ty = h.get_optype(n);
(
dfg_ty.other_input_port().unwrap(),
dfg_ty.other_output_port().unwrap(),
)
};
let parent = h.get_parent(n).unwrap();
let [input, output] = h.get_io(n).unwrap();
for ch in h.children(n).skip(2).collect::<Vec<_>>().into_iter() {
h.set_parent(ch, parent);
}
for (src_n, src_p) in h.linked_outputs(n, oth_in).collect::<Vec<_>>() {
debug_assert_eq!(Some(src_p), h.get_optype(src_n).other_output_port());
for tgt_n in h.output_neighbours(input).collect::<Vec<_>>() {
h.add_other_edge(src_n, tgt_n);
}
}
let input_ord_succs = h
.linked_inputs(input, h.get_optype(input).other_output_port().unwrap())
.collect::<Vec<_>>();
for inp in h.node_inputs(n).collect::<Vec<_>>() {
if inp == oth_in {
continue;
};
let (src_n, src_p) = h.single_linked_output(n, inp).unwrap();
h.disconnect(n, inp); let outp = OutgoingPort::from(inp.index());
let targets = h.linked_inputs(input, outp).collect::<Vec<_>>();
h.disconnect(input, outp);
for (tgt_n, tgt_p) in targets {
h.connect(src_n, src_p, tgt_n, tgt_p);
}
for (tgt, _) in input_ord_succs.iter() {
h.add_other_edge(src_n, *tgt);
}
}
for (tgt_n, tgt_p) in h.linked_inputs(n, oth_out).collect::<Vec<_>>() {
debug_assert_eq!(Some(tgt_p), h.get_optype(tgt_n).other_input_port());
for src_n in h.input_neighbours(output).collect::<Vec<_>>() {
h.add_other_edge(src_n, tgt_n);
}
}
let output_ord_preds = h
.linked_outputs(output, h.get_optype(output).other_input_port().unwrap())
.collect::<Vec<_>>();
for outport in h.node_outputs(n).collect::<Vec<_>>() {
if outport == oth_out {
continue;
};
let inpp = IncomingPort::from(outport.index());
let (src_n, src_p) = h.single_linked_output(output, inpp).unwrap();
h.disconnect(output, inpp);
for (tgt_n, tgt_p) in h.linked_inputs(n, outport).collect::<Vec<_>>() {
h.connect(src_n, src_p, tgt_n, tgt_p);
for (src, _) in output_ord_preds.iter() {
h.add_other_edge(*src, tgt_n);
}
}
h.disconnect(n, outport);
}
h.remove_node(input);
h.remove_node(output);
assert!(h.children(n).next().is_none());
h.remove_node(n);
Ok([n, input, output])
}
fn invalidation_set(&self) -> impl Iterator<Item = Node> {
[self.0.node()].into_iter()
}
}
#[cfg(test)]
mod test {
use std::collections::HashSet;
use rstest::rstest;
use crate::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
};
use crate::extension::prelude::QB_T;
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
use crate::hugr::rewrite::inline_dfg::InlineDFGError;
use crate::hugr::HugrMut;
use crate::ops::handle::{DfgID, NodeHandle};
use crate::ops::{Lift, OpType, Value};
use crate::std_extensions::arithmetic::float_types;
use crate::std_extensions::arithmetic::int_ops::{self, IntOpDef};
use crate::std_extensions::arithmetic::int_types::{self, ConstInt};
use crate::types::FunctionType;
use crate::utils::test_quantum_extension;
use crate::{type_row, Direction, HugrView, Node, Port};
use crate::{Hugr, Wire};
use super::InlineDFG;
fn find_dfgs(h: &impl HugrView) -> Vec<Node> {
h.nodes()
.filter(|n| h.get_optype(*n).as_dfg().is_some())
.collect()
}
fn extension_ops(h: &impl HugrView) -> Vec<Node> {
h.nodes()
.filter(|n| matches!(h.get_optype(*n), OpType::CustomOp(_)))
.collect()
}
#[rstest]
#[case(true)]
#[case(false)]
fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box<dyn std::error::Error>> {
let delta = ExtensionSet::from_iter([int_ops::EXTENSION_ID, int_types::EXTENSION_ID]);
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
int_ops::EXTENSION.to_owned(),
int_types::EXTENSION.to_owned(),
])
.unwrap();
let int_ty = &int_types::INT_TYPES[6];
let mut outer = DFGBuilder::new(
FunctionType::new(vec![int_ty.clone(); 2], vec![int_ty.clone()])
.with_extension_delta(delta.clone()),
)?;
let [a, b] = outer.input_wires_arr();
fn make_const<T: AsMut<Hugr> + AsRef<Hugr>>(
d: &mut DFGBuilder<T>,
) -> Result<Wire, Box<dyn std::error::Error>> {
let int_ty = &int_types::INT_TYPES[6];
let cst = Value::extension(ConstInt::new_u(6, 15)?);
let c1 = d.add_load_const(cst);
let [lifted] = d
.add_dataflow_op(
Lift {
type_row: vec![int_ty.clone()].into(),
new_extension: int_ops::EXTENSION_ID,
},
[c1],
)?
.outputs_arr();
Ok(lifted)
}
let c1 = nonlocal.then(|| make_const(&mut outer));
let inner = {
let mut inner = outer.dfg_builder(
FunctionType::new_endo(vec![int_ty.clone()]).with_extension_delta(delta),
None,
[a],
)?;
let [a] = inner.input_wires_arr();
let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
inner.finish_with_outputs(a1.outputs())?
};
let [a1] = inner.outputs_arr();
let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?;
let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs(), ®)?;
assert_eq!(
outer.children(inner.node()).len(),
if nonlocal { 3 } else { 6 }
); assert_eq!(find_dfgs(&outer), vec![outer.root(), inner.node()]);
let [add, sub] = extension_ops(&outer).try_into().unwrap();
assert_eq!(
outer.get_parent(outer.get_parent(add).unwrap()),
outer.get_parent(sub)
);
assert_eq!(outer.nodes().len(), 11); {
let mut h = outer.clone();
assert_eq!(
h.apply_rewrite(InlineDFG(DfgID::from(h.root()))),
Err(InlineDFGError::NoParent)
);
assert_eq!(h, outer); }
outer.apply_rewrite(InlineDFG(*inner.handle()))?;
outer.validate(®)?;
assert_eq!(outer.nodes().len(), 8);
assert_eq!(find_dfgs(&outer), vec![outer.root()]);
let [add, sub] = extension_ops(&outer).try_into().unwrap();
assert_eq!(outer.get_parent(add), Some(outer.root()));
assert_eq!(outer.get_parent(sub), Some(outer.root()));
assert_eq!(
outer.node_connections(add, sub).collect::<Vec<_>>().len(),
1
);
Ok(())
}
#[test]
fn permutation() -> Result<(), Box<dyn std::error::Error>> {
let mut h = DFGBuilder::new(
FunctionType::new_endo(type_row![QB_T, QB_T])
.with_extension_delta(test_quantum_extension::EXTENSION_ID),
)?;
let [p, q] = h.input_wires_arr();
let [p_h] = h
.add_dataflow_op(test_quantum_extension::h_gate(), [p])?
.outputs_arr();
let swap = {
let swap = h.dfg_builder(
FunctionType::new_endo(type_row![QB_T, QB_T]),
None,
[p_h, q],
)?;
let [a, b] = swap.input_wires_arr();
swap.finish_with_outputs([b, a])?
};
let [q, p] = swap.outputs_arr();
let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?;
let reg = ExtensionRegistry::try_new([
test_quantum_extension::EXTENSION.to_owned(),
PRELUDE.to_owned(),
float_types::EXTENSION.to_owned(),
])
.unwrap();
let mut h = h.finish_hugr_with_outputs(cx.outputs(), ®)?;
assert_eq!(find_dfgs(&h), vec![h.root(), swap.node()]);
assert_eq!(h.nodes().len(), 8); assert_eq!(
h.node_connections(p_h.node(), swap.node())
.collect::<Vec<_>>(),
vec![[
Port::new(Direction::Outgoing, 0),
Port::new(Direction::Incoming, 0)
]]
);
assert_eq!(
h.node_connections(swap.node(), cx.node())
.collect::<Vec<_>>(),
vec![
[
Port::new(Direction::Outgoing, 0),
Port::new(Direction::Incoming, 0)
],
[
Port::new(Direction::Outgoing, 1),
Port::new(Direction::Incoming, 1)
]
]
);
h.apply_rewrite(InlineDFG(*swap.handle()))?;
assert_eq!(find_dfgs(&h), vec![h.root()]);
assert_eq!(h.nodes().len(), 5); let mut ops = extension_ops(&h);
ops.sort_by_key(|n| h.num_outputs(*n)); let [h_gate, cx] = ops.try_into().unwrap();
assert_eq!(
h.node_connections(h_gate, cx).collect::<Vec<_>>(),
vec![[
Port::new(Direction::Outgoing, 0),
Port::new(Direction::Incoming, 1)
]]
);
Ok(())
}
#[test]
fn order_edges() -> Result<(), Box<dyn std::error::Error>> {
let reg = ExtensionRegistry::try_new([
test_quantum_extension::EXTENSION.to_owned(),
float_types::EXTENSION.to_owned(),
PRELUDE.to_owned(),
])
.unwrap();
let mut outer = DFGBuilder::new(
FunctionType::new_endo(type_row![QB_T, QB_T])
.with_extension_delta(float_types::EXTENSION_ID),
)?;
let [a, b] = outer.input_wires_arr();
let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?;
let mut inner = outer.dfg_builder(
FunctionType::new_endo(type_row![QB_T]).with_extension_delta(float_types::EXTENSION_ID),
None,
h_b.outputs(),
)?;
let [i] = inner.input_wires_arr();
let f = inner.add_load_value(float_types::ConstF64::new(1.0));
inner.add_other_wire(inner.input().node(), f.node());
let r = inner.add_dataflow_op(test_quantum_extension::rz_f64(), [i, f])?;
let [m, b] = inner
.add_dataflow_op(test_quantum_extension::measure(), r.outputs())?
.outputs_arr();
let mut if_n = inner.conditional_builder(
([type_row![], type_row![]], b),
[],
type_row![],
ExtensionSet::new(),
)?;
if_n.case_builder(0)?.finish_with_outputs([])?;
if_n.case_builder(1)?.finish_with_outputs([])?;
let if_n = if_n.finish_sub_container()?;
inner.add_other_wire(if_n.node(), inner.output().node());
let inner = inner.finish_with_outputs([m])?;
outer.add_other_wire(h_a.node(), inner.node());
let h_a2 = outer.add_dataflow_op(test_quantum_extension::h_gate(), h_a.outputs())?;
outer.add_other_wire(inner.node(), h_a2.node());
let cx = outer.add_dataflow_op(
test_quantum_extension::cx_gate(),
h_a2.outputs().chain(inner.outputs()),
)?;
let mut outer = outer.finish_hugr_with_outputs(cx.outputs(), ®)?;
outer.apply_rewrite(InlineDFG(*inner.handle()))?;
outer.validate(®)?;
let order_neighbours = |n, d| {
let p = outer.get_optype(n).other_port(d).unwrap();
outer
.linked_ports(n, p)
.map(|(n, _)| n)
.collect::<HashSet<_>>()
};
assert_eq!(
order_neighbours(h_a.node(), Direction::Outgoing),
HashSet::from([r.node(), f.node()])
);
assert_eq!(
order_neighbours(f.node(), Direction::Incoming),
HashSet::from([h_a.node(), h_b.node()])
);
assert_eq!(
order_neighbours(h_a2.node(), Direction::Incoming),
HashSet::from([m.node(), if_n.node()])
);
assert_eq!(
order_neighbours(if_n.node(), Direction::Outgoing),
HashSet::from([h_a2.node(), cx.node()])
);
Ok(())
}
}