use derive_more::{Display, Error, From};
use hugr::{HugrView, Node};
use itertools::Itertools;
use crate::serialize::pytket::OpConvertError;
use crate::Circuit;
use super::find_tuple_unpack_rewrites;
pub fn lower_to_pytket<T: HugrView<Node = Node>>(
circ: &Circuit<T>,
) -> Result<Circuit, PytketLoweringError> {
let mut circ = circ
.extract_dfg()
.map_err(|_| PytketLoweringError::NonLocalOperations)?;
let rewrites = find_tuple_unpack_rewrites(&circ).collect_vec();
for rewrite in rewrites {
rewrite.apply(&mut circ).unwrap();
}
Ok(circ)
}
#[derive(Debug, Display, Error, From)]
#[non_exhaustive]
pub enum PytketLoweringError {
#[display("operation conversion error: {_0}")]
#[from]
OpConversionError(OpConvertError),
#[display("Non-local operations found. Function calls are not supported.")]
NonLocalOperations,
}
#[cfg(test)]
mod test {
use crate::Tk2Op;
use super::*;
use hugr::builder::{CFGBuilder, Dataflow, HugrBuilder};
use hugr::extension::prelude::{qb_t, MakeTuple, UnpackTuple};
use hugr::hugr::hugrmut::HugrMut;
use hugr::ops::handle::NodeHandle;
use hugr::ops::{OpTag, OpTrait, OpType, Tag};
use hugr::types::{Signature, TypeRow};
use hugr::HugrView;
use rstest::{fixture, rstest};
#[fixture]
fn guppy_like_circuit() -> Circuit {
fn build() -> Result<Circuit, hugr::builder::BuildError> {
let two_qbs = TypeRow::from(vec![qb_t(), qb_t()]);
let circ_signature = Signature::new_endo(two_qbs.clone());
let mut cfg = CFGBuilder::new(circ_signature)?;
let circ = {
let mut dfg = cfg.simple_entry_builder(two_qbs.clone(), 1)?;
let [q1, q2] = dfg.input_wires_arr();
let [q1] = dfg.add_dataflow_op(Tk2Op::H, [q1])?.outputs_arr();
let [q1, q2] = dfg.add_dataflow_op(Tk2Op::CX, [q1, q2])?.outputs_arr();
let [tup] = dfg
.add_dataflow_op(MakeTuple::new(two_qbs.clone()), [q1, q2])?
.outputs_arr();
let [q1, q2] = dfg
.add_dataflow_op(UnpackTuple::new(two_qbs), [tup])?
.outputs_arr();
let [branch] = dfg
.add_dataflow_op(Tag::new(0, vec![TypeRow::new()]), [])?
.outputs_arr();
dfg.finish_with_outputs(branch, [q1, q2])?
};
cfg.branch(&circ, 0, &cfg.exit_block())?;
let mut hugr = cfg.finish_hugr()?;
hugr.set_entrypoint(circ.node());
Ok(Circuit::new(hugr))
}
build().unwrap()
}
#[rstest]
#[case::guppy_like_circuit(guppy_like_circuit())]
fn test_pytket_lowering(#[case] circ: Circuit) {
use cool_asserts::assert_matches;
let lowered_circ = lower_to_pytket(&circ).unwrap();
lowered_circ.hugr().validate().unwrap();
let parent_tag = lowered_circ.hugr().entrypoint_optype().tag();
assert!(OpTag::DataflowParent.is_superset(parent_tag));
assert_matches!(
lowered_circ.hugr().get_optype(lowered_circ.input_node()),
OpType::Input(_)
);
assert_matches!(
lowered_circ.hugr().get_optype(lowered_circ.output_node()),
OpType::Output(_)
);
assert_eq!(lowered_circ.num_operations(), circ.num_operations());
let original_sig = circ.circuit_signature();
let lowered_sig = lowered_circ.circuit_signature();
assert_eq!(lowered_sig.input(), original_sig.input());
let output_count_diff =
original_sig.output().len() as isize - lowered_sig.output().len() as isize;
assert!(
output_count_diff == 0 || output_count_diff == 1,
"Output count mismatch. Original: {original_sig}, Lowered: {lowered_sig}"
);
assert_eq!(
lowered_sig.output()[..],
original_sig.output()[output_count_diff as usize..]
);
let output_sig = lowered_circ
.hugr()
.signature(lowered_circ.output_node())
.unwrap();
assert_eq!(lowered_sig.output(), output_sig.input());
}
}