use hugr::IncomingPort;
use hugr::Wire;
use hugr::builder::{BuildError, Dataflow};
use hugr::extension::fold_out_row;
use hugr::extension::prelude::const_some;
use hugr::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
use hugr::extension::{ExtensionId, Version, prelude::option_type};
use hugr::ops::Value;
use hugr::ops::constant::{CustomConst, TryHash, downcast_equal_consts};
use hugr::std_extensions::arithmetic::float_types::{ConstF64, float64_type};
use hugr::utils::sorted_consts;
use hugr::{
Extension,
types::{ConstTypeError, CustomType, Signature, Type, TypeBound},
};
use smol_str::SmolStr;
use std::f64::consts::PI;
use std::sync::{Arc, Weak};
use strum::{EnumIter, EnumString, IntoStaticStr};
use lazy_static::lazy_static;
pub const ROTATION_EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("tket.rotation");
pub const ROTATION_EXTENSION_VERSION: Version = Version::new(0, 2, 0);
lazy_static! {
pub static ref ROTATION_EXTENSION: Arc<Extension> = {
Extension::new_arc(ROTATION_EXTENSION_ID, ROTATION_EXTENSION_VERSION, |e, extension_ref| {
add_to_extension(e, extension_ref);
}
)};
}
pub const ROTATION_TYPE_ID: SmolStr = SmolStr::new_inline("rotation");
pub fn rotation_custom_type(extension_ref: &Weak<Extension>) -> CustomType {
CustomType::new(
ROTATION_TYPE_ID,
[],
ROTATION_EXTENSION_ID,
TypeBound::Copyable,
extension_ref,
)
}
pub fn rotation_type() -> Type {
rotation_custom_type(&Arc::downgrade(&ROTATION_EXTENSION)).into()
}
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct ConstRotation {
half_turns: f64,
}
impl ConstRotation {
pub const PI: Self = Self::new_unchecked(1.0);
pub const TAU: Self = Self::new_unchecked(2.0);
pub const PI_2: Self = Self::new_unchecked(0.5);
pub const PI_4: Self = Self::new_unchecked(0.25);
const fn new_unchecked(half_turns: f64) -> Self {
Self { half_turns }
}
pub fn new(half_turns: f64) -> Result<Self, ConstTypeError> {
if half_turns.is_nan() || half_turns.is_infinite() {
return Err(ConstTypeError::CustomCheckFail(
hugr::types::CustomCheckFailure::Message(format!(
"Invalid rotation value {half_turns}."
)),
));
}
Ok(Self { half_turns })
}
pub fn to_radians(&self) -> f64 {
self.half_turns * PI
}
pub fn from_radians(theta: f64) -> Result<Self, ConstTypeError> {
Self::new(theta / PI)
}
pub fn half_turns(&self) -> f64 {
self.half_turns
}
}
impl TryHash for ConstRotation {}
#[typetag::serde]
impl CustomConst for ConstRotation {
fn name(&self) -> SmolStr {
format!("a(π*{})", self.half_turns).into()
}
fn get_type(&self) -> Type {
rotation_type()
}
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
downcast_equal_consts(self, other)
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
#[expect(non_camel_case_types)]
#[non_exhaustive]
pub enum RotationOp {
from_halfturns,
from_halfturns_unchecked,
to_halfturns,
radd,
}
impl MakeOpDef for RotationOp {
fn opdef_id(&self) -> hugr::ops::OpName {
<&'static str>::from(self).into()
}
fn from_def(
op_def: &hugr::extension::OpDef,
) -> Result<Self, hugr::extension::simple_op::OpLoadError>
where
Self: Sized,
{
hugr::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
}
fn init_signature(&self, extension_ref: &Weak<Extension>) -> hugr::extension::SignatureFunc {
let rotation_type = Type::new_extension(rotation_custom_type(extension_ref));
match self {
RotationOp::from_halfturns => Signature::new(
[float64_type()],
[Type::from(option_type([rotation_type.clone()]))],
),
RotationOp::from_halfturns_unchecked => {
Signature::new([float64_type()], [rotation_type.clone()])
}
RotationOp::to_halfturns => Signature::new([rotation_type.clone()], [float64_type()]),
RotationOp::radd => Signature::new(
vec![rotation_type.clone(), rotation_type.clone()],
[rotation_type],
),
}
.into()
}
fn description(&self) -> String {
match self {
RotationOp::from_halfturns => {
"Construct rotation from number of half-turns (would be multiples of PI in radians). Returns None if the float is non-finite."
}
RotationOp::from_halfturns_unchecked => {
"Construct rotation from number of half-turns (would be multiples of PI in radians). Panics if the float is non-finite."
}
RotationOp::to_halfturns => {
"Convert rotation to number of half-turns (would be multiples of PI in radians)."
}
RotationOp::radd => "Add two angles together (experimental).",
}
.to_owned()
}
fn extension(&self) -> hugr::extension::ExtensionId {
ROTATION_EXTENSION_ID
}
fn extension_ref(&self) -> Weak<Extension> {
Arc::downgrade(&ROTATION_EXTENSION)
}
fn post_opdef(&self, def: &mut hugr::extension::OpDef) {
match self {
RotationOp::radd => {
def.set_constant_folder(|consts: &[(IncomingPort, Value)]| {
let [a, b]: [&Value; 2] = sorted_consts(consts).try_into().ok()?;
let a_rot = a.get_custom_value::<ConstRotation>()?;
let b_rot = b.get_custom_value::<ConstRotation>()?;
let sum = ConstRotation::new(a_rot.half_turns() + b_rot.half_turns()).ok()?;
fold_out_row([Value::extension(sum)])
});
}
RotationOp::from_halfturns_unchecked => {
def.set_constant_folder(|consts: &[(IncomingPort, Value)]| {
let (_, v) = consts.first()?;
let f = v.get_custom_value::<ConstF64>()?;
let rot = ConstRotation::new(f.value()).ok()?;
fold_out_row([Value::extension(rot)])
});
}
RotationOp::from_halfturns => {
def.set_constant_folder(|consts: &[(IncomingPort, Value)]| {
let (_, v) = consts.first()?;
let f = v.get_custom_value::<ConstF64>()?;
let rot = ConstRotation::new(f.value()).ok()?;
let option_vale = const_some(Value::extension(rot));
fold_out_row([option_vale])
});
}
RotationOp::to_halfturns => {
def.set_constant_folder(|consts: &[(IncomingPort, Value)]| {
let (_, v) = consts.first()?;
let rot = v.get_custom_value::<ConstRotation>()?;
fold_out_row([Value::extension(ConstF64::new(rot.half_turns()))])
});
}
}
}
}
impl MakeRegisteredOp for RotationOp {
fn extension_id(&self) -> hugr::extension::ExtensionId {
ROTATION_EXTENSION_ID
}
fn extension_ref(&self) -> Arc<Extension> {
ROTATION_EXTENSION.clone()
}
}
pub(super) fn add_to_extension(extension: &mut Extension, extension_ref: &Weak<Extension>) {
extension
.add_type(
ROTATION_TYPE_ID,
vec![],
"rotation type expressed as number of half turns".to_owned(),
TypeBound::Copyable.into(),
extension_ref,
)
.unwrap();
RotationOp::load_all_ops(extension, extension_ref).expect("add fail");
}
pub trait RotationOpBuilder: Dataflow {
fn add_from_halfturns(&mut self, turns: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(RotationOp::from_halfturns, [turns])?
.out_wire(0))
}
fn add_from_halfturns_unchecked(&mut self, turns: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(RotationOp::from_halfturns_unchecked, [turns])?
.out_wire(0))
}
fn add_to_halfturns(&mut self, rotation: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(RotationOp::to_halfturns, [rotation])?
.out_wire(0))
}
}
impl<D: Dataflow> RotationOpBuilder for D {}
#[cfg(test)]
mod test {
use hugr::{
builder::{DFGBuilder, DataflowHugr},
ops::OpType,
};
use strum::IntoEnumIterator;
use super::*;
#[test]
fn test_rotation_consts() {
let const_57 = ConstRotation::new(5.7).unwrap();
let const_01 = ConstRotation::new(0.1).unwrap();
let const_256 = ConstRotation::new(256.0).unwrap();
assert_ne!(const_57, const_01);
assert_ne!(const_57, const_256);
assert_eq!(const_57, ConstRotation::new(5.7).unwrap());
assert_eq!(const_57.get_type(), rotation_type());
assert!(matches!(
ConstRotation::new(f64::INFINITY),
Err(ConstTypeError::CustomCheckFail(_))
));
assert!(matches!(
ConstRotation::new(f64::NAN),
Err(ConstTypeError::CustomCheckFail(_))
));
let const_af1 = ConstRotation::from_radians(0.75 * PI).unwrap();
assert_eq!(const_af1.half_turns(), 0.75);
assert!(const_57.equal_consts(&ConstRotation::new(5.7).unwrap()));
assert_ne!(const_57, const_01);
assert_eq!(const_256.name(), "a(π*256)");
}
#[test]
fn test_ops() {
let ops = RotationOp::iter().collect::<Vec<_>>();
for op in ops {
let optype: OpType = op.into();
assert_eq!(optype.cast(), Some(op));
}
}
#[test]
fn test_builder() {
let mut builder = DFGBuilder::new(Signature::new(
[rotation_type()],
[Type::from(option_type([rotation_type()])), rotation_type()],
))
.unwrap();
let [rotation] = builder.input_wires_arr();
let turns = builder.add_to_halfturns(rotation).unwrap();
let mb_rotation = builder.add_from_halfturns(turns).unwrap();
let unwrapped_rotation = builder.add_from_halfturns_unchecked(turns).unwrap();
let _hugr = builder
.finish_hugr_with_outputs([mb_rotation, unwrapped_rotation])
.unwrap();
}
#[rstest::rstest]
fn const_rotation_statics(
#[values(
ConstRotation::TAU,
ConstRotation::PI,
ConstRotation::PI_2,
ConstRotation::PI_4
)]
konst: ConstRotation,
) {
assert_eq!(ConstRotation::new(konst.half_turns()), Ok(konst));
}
fn do_fold(op: RotationOp, consts: Vec<(IncomingPort, Value)>) -> Option<Vec<Value>> {
let ext_op = MakeRegisteredOp::to_extension_op(op).unwrap();
ext_op
.constant_fold(&consts)
.map(|r| r.into_iter().map(|(_, v)| v).collect())
}
#[rstest::rstest]
#[case(0.25, 0.5, 0.75)]
#[case(1.0, 0.0, 1.0)]
#[case(0.0, 0.0, 0.0)]
#[case(0.5, 0.5, 1.0)]
fn test_radd_fold(#[case] a: f64, #[case] b: f64, #[case] expected: f64) {
let consts = vec![
(
0usize.into(),
Value::extension(ConstRotation::new(a).unwrap()),
),
(
1usize.into(),
Value::extension(ConstRotation::new(b).unwrap()),
),
];
let result = do_fold(RotationOp::radd, consts).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(
result[0]
.get_custom_value::<ConstRotation>()
.unwrap()
.half_turns(),
expected
);
}
#[test]
fn test_radd_no_fold() {
let consts = vec![(
0usize.into(),
Value::extension(ConstRotation::new(0.25).unwrap()),
)];
assert!(do_fold(RotationOp::radd, consts).is_none());
}
#[rstest::rstest]
#[case(0.25)]
#[case(1.0)]
#[case(0.0)]
fn test_to_halfturns_fold(#[case] val: f64) {
let consts = vec![(
0usize.into(),
Value::extension(ConstRotation::new(val).unwrap()),
)];
let result_checked = do_fold(RotationOp::to_halfturns, consts).unwrap();
assert_eq!(result_checked.len(), 1);
assert_eq!(
result_checked[0]
.get_custom_value::<ConstF64>()
.unwrap()
.value(),
val
);
}
#[test]
fn test_to_halfturns_no_fold() {
let consts: Vec<(IncomingPort, Value)> = vec![];
assert!(do_fold(RotationOp::to_halfturns, consts).is_none());
}
#[rstest::rstest]
#[case(0.5)]
#[case(1.0)]
#[case(0.0)]
fn test_from_halfturns_fold(#[case] val: f64) {
let consts = vec![(0usize.into(), Value::extension(ConstF64::new(val)))];
let result_checked = do_fold(RotationOp::from_halfturns, consts.clone()).unwrap();
let result_unchecked = do_fold(RotationOp::from_halfturns_unchecked, consts).unwrap();
assert_eq!(result_checked.len(), 1);
assert_eq!(result_unchecked.len(), 1);
assert_eq!(
result_checked[0],
const_some(Value::extension(ConstRotation::new(val).unwrap()))
);
assert_eq!(
result_unchecked[0]
.get_custom_value::<ConstRotation>()
.unwrap()
.half_turns(),
val
);
}
#[test]
fn test_from_halfturns_no_fold() {
let consts: Vec<(IncomingPort, Value)> = vec![];
assert!(do_fold(RotationOp::from_halfturns, consts.clone()).is_none());
assert!(do_fold(RotationOp::from_halfturns_unchecked, consts).is_none());
}
}