use std::sync::{Arc, Weak};
use crate::extension::bool::bool_type;
use crate::extension::rotation::rotation_type;
use crate::extension::sympy::SympyOpDef;
use crate::extension::{TKET2_EXTENSION, TKET2_EXTENSION_ID as EXTENSION_ID};
use hugr::ops::custom::ExtensionOp;
use hugr::types::Type;
use hugr::{
extension::{
prelude::{bool_t, option_type, qb_t},
simple_op::{try_from_name, MakeOpDef, MakeRegisteredOp},
ExtensionId, OpDef, SignatureFunc,
},
ops::OpType,
type_row,
types::Signature,
};
use derive_more::{Display, Error};
use serde::{Deserialize, Serialize};
use smol_str::ToSmolStr;
use strum::{EnumIter, EnumString, IntoStaticStr};
#[derive(
Clone,
Copy,
Debug,
Serialize,
Deserialize,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
EnumIter,
IntoStaticStr,
EnumString,
)]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum Tk2Op {
H,
CX,
CY,
CZ,
CRz,
T,
Tdg,
S,
Sdg,
X,
Y,
Z,
Rx,
Ry,
Rz,
Toffoli,
Measure,
MeasureFree,
QAlloc,
TryQAlloc,
QFree,
Reset,
V,
Vdg,
}
impl Tk2Op {
pub fn exposed_name(&self) -> smol_str::SmolStr {
<Tk2Op as Into<OpType>>::into(*self).to_smolstr()
}
pub fn into_extension_op(self) -> ExtensionOp {
<Self as MakeRegisteredOp>::to_extension_op(self)
.expect("Failed to convert to extension op.")
}
}
pub fn op_matches(op: &OpType, tk2op: Tk2Op) -> bool {
op.to_string() == tk2op.exposed_name()
}
#[derive(
Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd, EnumString,
)]
#[allow(missing_docs)]
pub enum Pauli {
I,
X,
Y,
Z,
}
#[derive(Display, Debug, Error, PartialEq, Clone)]
#[display("{} is not a Tk2Op.", op)]
pub struct NotTk2Op {
pub op: OpType,
}
impl Pauli {
pub fn commutes_with(&self, other: Self) -> bool {
*self == Pauli::I || other == Pauli::I || *self == other
}
}
impl MakeOpDef for Tk2Op {
fn opdef_id(&self) -> hugr::ops::OpName {
<&'static str>::from(self).into()
}
fn init_signature(&self, _extension_ref: &std::sync::Weak<hugr::Extension>) -> SignatureFunc {
use Tk2Op::*;
match self {
H | T | S | V | X | Y | Z | Tdg | Sdg | Vdg | Reset => Signature::new_endo(qb_t()),
CX | CZ | CY => Signature::new_endo(vec![qb_t(); 2]),
Toffoli => Signature::new_endo(vec![qb_t(); 3]),
Measure => Signature::new(qb_t(), vec![qb_t(), bool_t()]),
MeasureFree => Signature::new(qb_t(), bool_type()),
Rz | Rx | Ry => Signature::new(vec![qb_t(), rotation_type()], qb_t()),
CRz => Signature::new(vec![qb_t(), qb_t(), rotation_type()], vec![qb_t(); 2]),
QAlloc => Signature::new(type_row![], qb_t()),
TryQAlloc => Signature::new(type_row![], Type::from(option_type(qb_t()))),
QFree => Signature::new(qb_t(), type_row![]),
}
.into()
}
fn extension(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}
fn post_opdef(&self, def: &mut OpDef) {
def.add_misc(
"commutation",
serde_json::to_value(self.qubit_commutation()).unwrap(),
);
}
fn from_def(op_def: &OpDef) -> Result<Self, hugr::extension::simple_op::OpLoadError> {
try_from_name(op_def.name(), op_def.extension_id())
}
fn extension_ref(&self) -> Weak<hugr::Extension> {
Arc::downgrade(&TKET2_EXTENSION)
}
}
impl MakeRegisteredOp for Tk2Op {
fn extension_id(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}
fn extension_ref(&self) -> Weak<hugr::Extension> {
Arc::<hugr::Extension>::downgrade(&TKET2_EXTENSION)
}
}
impl Tk2Op {
pub(crate) fn qubit_commutation(&self) -> Vec<(usize, Pauli)> {
use Tk2Op::*;
match self {
X | V | Vdg | Rx => vec![(0, Pauli::X)],
Y => vec![(0, Pauli::Y)],
T | Z | S | Tdg | Sdg | Rz | Measure => vec![(0, Pauli::Z)],
CX => vec![(0, Pauli::Z), (1, Pauli::X)],
CZ => vec![(0, Pauli::Z), (1, Pauli::Z)],
_ => vec![],
}
}
pub fn is_quantum(&self) -> bool {
use Tk2Op::*;
match self {
H | CX | T | S | V | X | Y | Z | Tdg | Sdg | Vdg | Rz | Rx | Toffoli | Ry | CZ | CY
| CRz => true,
Measure | MeasureFree | QAlloc | TryQAlloc | QFree | Reset => false,
}
}
}
pub fn symbolic_constant_op(arg: String) -> OpType {
SympyOpDef.with_expr(arg).into()
}
#[cfg(test)]
pub(crate) mod test {
use std::str::FromStr;
use std::sync::Arc;
use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::prelude::{option_type, qb_t};
use hugr::extension::simple_op::{MakeExtensionOp, MakeOpDef};
use hugr::extension::{prelude::UnwrapBuilder as _, OpDef};
use hugr::types::Signature;
use hugr::{type_row, CircuitUnit, HugrView};
use itertools::Itertools;
use rstest::{fixture, rstest};
use strum::IntoEnumIterator;
use super::Tk2Op;
use crate::circuit::Circuit;
use crate::extension::bool::bool_type;
use crate::extension::{TKET2_EXTENSION as EXTENSION, TKET2_EXTENSION_ID as EXTENSION_ID};
use crate::utils::build_simple_circuit;
use crate::Pauli;
fn get_opdef(op: Tk2Op) -> Option<&'static Arc<OpDef>> {
EXTENSION.get_op(&op.op_id())
}
#[test]
fn create_extension() {
assert_eq!(EXTENSION.name(), &EXTENSION_ID);
for o in Tk2Op::iter() {
assert_eq!(Tk2Op::from_def(get_opdef(o).unwrap()), Ok(o));
}
}
#[fixture]
pub(crate) fn t2_bell_circuit() -> Circuit {
let h = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
Ok(())
});
h.unwrap()
}
#[rstest]
fn check_t2_bell(t2_bell_circuit: Circuit) {
assert_eq!(t2_bell_circuit.commands().count(), 2);
}
#[test]
fn ancilla_circ() {
let h = build_simple_circuit(1, |circ| {
let empty: [CircuitUnit; 0] = []; let ancilla = circ.append_with_outputs(Tk2Op::QAlloc, empty)?[0];
let ancilla = circ.append_with_outputs(Tk2Op::Reset, [ancilla])?[0];
let ancilla = circ.append_with_outputs(
Tk2Op::CX,
[CircuitUnit::Linear(0), CircuitUnit::Wire(ancilla)],
)?[0];
let ancilla = circ.append_with_outputs(Tk2Op::Measure, [ancilla])?[0];
circ.append_and_consume(Tk2Op::QFree, [ancilla])?;
Ok(())
})
.unwrap();
assert_eq!(h.commands().count(), 5);
}
#[test]
fn try_qalloc_measure_free() {
let mut b = DFGBuilder::new(Signature::new(type_row![], bool_type())).unwrap();
let try_q = b.add_dataflow_op(Tk2Op::TryQAlloc, []).unwrap().out_wire(0);
let [q] = b.build_unwrap_sum(1, option_type(qb_t()), try_q).unwrap();
let measured = b
.add_dataflow_op(Tk2Op::MeasureFree, [q])
.unwrap()
.out_wire(0);
let h = b.finish_hugr_with_outputs([measured]).unwrap();
let top_ops = h
.children(h.entrypoint())
.map(|n| h.get_optype(n))
.collect_vec();
assert_eq!(top_ops.len(), 5);
assert_eq!(
Tk2Op::from_op(top_ops[2].as_extension_op().unwrap()).unwrap(),
Tk2Op::TryQAlloc
);
assert!(top_ops[3].is_conditional());
assert_eq!(
Tk2Op::from_op(top_ops[4].as_extension_op().unwrap()).unwrap(),
Tk2Op::MeasureFree
);
}
#[test]
fn tk2op_properties() {
for op in Tk2Op::iter() {
assert!(op.exposed_name().starts_with(&EXTENSION_ID.to_string()));
let ext_op = op.into_extension_op();
assert_eq!(ext_op.args(), &[]);
assert_eq!(ext_op.def().extension_id(), &EXTENSION_ID);
let name = ext_op.def().name();
assert_eq!(Tk2Op::from_str(name), Ok(op));
}
assert!(Tk2Op::H.is_quantum());
assert!(!Tk2Op::Measure.is_quantum());
for (op, pauli) in [
(Tk2Op::X, Pauli::X),
(Tk2Op::Y, Pauli::Y),
(Tk2Op::Z, Pauli::Z),
]
.iter()
{
assert_eq!(op.qubit_commutation(), &[(0, *pauli)]);
}
}
}