use std::sync::{Arc, Weak};
use hugr::{
Extension, Wire,
builder::{BuildError, Dataflow},
extension::{
ExtensionBuildError, ExtensionId, OpDef, SignatureFunc, TypeDef, Version,
prelude::bool_t,
simple_op::{MakeOpDef, MakeRegisteredOp, try_from_name},
},
types::{CustomType, Signature, Type, TypeBound},
};
use lazy_static::lazy_static;
use smol_str::SmolStr;
use strum::{EnumIter, EnumString, IntoStaticStr};
pub const MEASUREMENT_EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("tket.measurement");
pub const MEASUREMENT_EXTENSION_VERSION: Version = Version::new(0, 1, 0);
lazy_static! {
pub static ref MEASUREMENT_EXTENSION: Arc<Extension> = {
Extension::new_arc(
MEASUREMENT_EXTENSION_ID,
MEASUREMENT_EXTENSION_VERSION,
|ext, ext_ref| {
let _ = add_measurement_type_def(ext, ext_ref.clone()).unwrap();
MeasurementOp::load_all_ops(ext, ext_ref).unwrap();
},
)
};
pub static ref MEASUREMENT_TYPE_ID: SmolStr = SmolStr::new_inline("Measurement");
}
fn add_measurement_type_def(
ext: &mut Extension,
extension_ref: Weak<Extension>,
) -> Result<&TypeDef, ExtensionBuildError> {
ext.add_type(
MEASUREMENT_TYPE_ID.to_owned(),
vec![],
"A copyable type representing the result of a measurement operation".into(),
TypeBound::Copyable.into(),
&extension_ref,
)
}
pub fn measurement_custom_type() -> CustomType {
MEASUREMENT_EXTENSION
.get_type(&MEASUREMENT_TYPE_ID)
.unwrap()
.instantiate([])
.unwrap()
}
pub fn measurement_type() -> Type {
measurement_custom_type().into()
}
#[derive(
Clone,
Copy,
Debug,
serde::Serialize,
serde::Deserialize,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
EnumIter,
IntoStaticStr,
EnumString,
)]
#[non_exhaustive]
pub enum MeasurementOp {
Read,
}
impl MakeOpDef for MeasurementOp {
fn opdef_id(&self) -> hugr::ops::OpName {
<&'static str>::from(self).into()
}
fn init_signature(&self, extension_ref: &Weak<Extension>) -> SignatureFunc {
let measurement_type = Type::new_extension(CustomType::new(
MEASUREMENT_TYPE_ID.to_owned(),
vec![],
MEASUREMENT_EXTENSION_ID,
MEASUREMENT_EXTENSION_VERSION,
TypeBound::Copyable,
extension_ref,
));
match self {
MeasurementOp::Read => Signature::new([measurement_type], [bool_t()]).into(),
}
}
fn from_def(op_def: &OpDef) -> Result<Self, hugr::extension::simple_op::OpLoadError> {
try_from_name(op_def.name(), op_def.extension_id())
}
fn extension(&self) -> ExtensionId {
MEASUREMENT_EXTENSION_ID
}
fn description(&self) -> String {
match self {
MeasurementOp::Read => "Consumes a measurement, converting it into a bool.".into(),
}
}
fn extension_ref(&self) -> Weak<Extension> {
Arc::downgrade(&MEASUREMENT_EXTENSION)
}
}
impl MakeRegisteredOp for MeasurementOp {
fn extension_id(&self) -> ExtensionId {
MEASUREMENT_EXTENSION_ID
}
fn extension_ref(&self) -> Arc<Extension> {
MEASUREMENT_EXTENSION.clone()
}
}
pub trait MeasurementOpBuilder: Dataflow {
fn add_measurement_read(&mut self, measurement: Wire) -> Result<[Wire; 1], BuildError> {
Ok(self
.add_dataflow_op(MeasurementOp::Read, [measurement])?
.outputs_arr())
}
}
impl<D: Dataflow> MeasurementOpBuilder for D {}
#[cfg(test)]
mod test {
use std::sync::Arc;
use hugr::HugrView;
use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::OpDef;
use hugr::types::Signature;
use strum::IntoEnumIterator;
use super::*;
fn get_opdef(op: MeasurementOp) -> Option<&'static Arc<OpDef>> {
MEASUREMENT_EXTENSION.get_op(&op.opdef_id())
}
#[test]
fn create_extension() {
assert_eq!(MEASUREMENT_EXTENSION.name(), &MEASUREMENT_EXTENSION_ID);
for op in MeasurementOp::iter() {
assert_eq!(MeasurementOp::from_def(get_opdef(op).unwrap()), Ok(op));
}
}
#[test]
fn measurement_ops_validate() {
let mut builder =
DFGBuilder::new(Signature::new(vec![measurement_type()], vec![bool_t()])).unwrap();
let [msmt] = builder.input_wires_arr();
let [bit] = builder.add_measurement_read(msmt).unwrap();
let hugr = builder.finish_hugr_with_outputs([bit]).unwrap();
hugr.validate().unwrap();
}
}