hugr_core/extension/
simple_op.rs

1//! A trait that enum for op definitions that gathers up some shared functionality.
2
3use std::sync::Weak;
4
5use strum::IntoEnumIterator;
6
7use crate::ops::{ExtensionOp, OpName, OpNameRef};
8use crate::{
9    ops::{NamedOp, OpType},
10    types::TypeArg,
11    Extension,
12};
13
14use super::{op_def::SignatureFunc, ExtensionBuildError, ExtensionId, OpDef, SignatureError};
15use delegate::delegate;
16use thiserror::Error;
17
18/// Error loading operation.
19#[derive(Debug, Error, PartialEq, Clone)]
20#[error("{0}")]
21#[allow(missing_docs)]
22#[non_exhaustive]
23pub enum OpLoadError {
24    #[error("Op with name {0} is not a member of this set.")]
25    NotMember(String),
26    #[error("Type args invalid: {0}.")]
27    InvalidArgs(#[from] SignatureError),
28    #[error("OpDef belongs to extension {0}, expected {1}.")]
29    WrongExtension(ExtensionId, ExtensionId),
30}
31
32impl<T> NamedOp for T
33where
34    for<'a> &'a T: Into<&'static str>,
35{
36    fn name(&self) -> OpName {
37        let s = self.into();
38        s.into()
39    }
40}
41
42/// Traits implemented by types which can add themselves to [`Extension`]s as
43/// [`OpDef`]s or load themselves from an [`OpDef`].
44///
45/// Particularly useful with C-style enums that implement [strum::IntoEnumIterator],
46/// as then all definitions can be added to an extension at once.
47pub trait MakeOpDef: NamedOp {
48    /// Try to load one of the operations of this set from an [OpDef].
49    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
50    where
51        Self: Sized;
52
53    /// The ID of the extension this operation is defined in.
54    fn extension(&self) -> ExtensionId;
55
56    /// Returns a weak reference to the extension this operation is defined in.
57    fn extension_ref(&self) -> Weak<Extension>;
58
59    /// Compute the signature of the operation while the extension definition is being built.
60    ///
61    /// Requires a [`Weak`] reference to the extension defining the operation.
62    /// This method is intended to be used inside the closure passed to [`Extension::new_arc`],
63    /// and it is normally internally called by [`MakeOpDef::add_to_extension`].
64    fn init_signature(&self, extension_ref: &Weak<Extension>) -> SignatureFunc;
65
66    /// Return the signature (polymorphic function type) of the operation.
67    fn signature(&self) -> SignatureFunc {
68        self.init_signature(&self.extension_ref())
69    }
70
71    /// Description of the operation. By default, the same as `self.name()`.
72    fn description(&self) -> String {
73        self.name().to_string()
74    }
75
76    /// Edit the opdef before finalising. By default does nothing.
77    fn post_opdef(&self, _def: &mut OpDef) {}
78
79    /// Add an operation implemented as an [MakeOpDef], which can provide the data
80    /// required to define an [OpDef], to an extension.
81    ///
82    /// Requires a [`Weak`] reference to the extension defining the operation.
83    /// This method is intended to be used inside the closure passed to [`Extension::new_arc`].
84    fn add_to_extension(
85        &self,
86        extension: &mut Extension,
87        extension_ref: &Weak<Extension>,
88    ) -> Result<(), ExtensionBuildError> {
89        let def = extension.add_op(
90            self.name(),
91            self.description(),
92            self.init_signature(extension_ref),
93            extension_ref,
94        )?;
95
96        self.post_opdef(def);
97
98        Ok(())
99    }
100
101    /// Load all variants of an enum of op definitions in to an extension as op defs.
102    /// See [strum::IntoEnumIterator].
103    ///
104    /// Requires a [`Weak`] reference to the extension defining the operation.
105    /// This method is intended to be used inside the closure passed to [`Extension::new_arc`].
106    fn load_all_ops(
107        extension: &mut Extension,
108        extension_ref: &Weak<Extension>,
109    ) -> Result<(), ExtensionBuildError>
110    where
111        Self: IntoEnumIterator,
112    {
113        for op in Self::iter() {
114            op.add_to_extension(extension, extension_ref)?;
115        }
116        Ok(())
117    }
118
119    /// If the definition can be loaded from a string, load from an [ExtensionOp].
120    fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
121    where
122        Self: Sized + std::str::FromStr,
123    {
124        Self::from_extension_op(ext_op)
125    }
126}
127
128/// [MakeOpDef] with an associate concrete Op type which can be instantiated with type arguments.
129pub trait HasConcrete: MakeOpDef {
130    /// Associated concrete type.
131    type Concrete: MakeExtensionOp;
132
133    /// Instantiate the operation with type arguments.
134    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError>;
135}
136
137/// [MakeExtensionOp] with an associated [HasConcrete].
138pub trait HasDef: MakeExtensionOp {
139    /// Associated [HasConcrete] type.
140    type Def: HasConcrete<Concrete = Self> + std::str::FromStr;
141
142    /// Load the operation from a [ExtensionOp].
143    fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
144    where
145        Self: Sized,
146    {
147        Self::from_extension_op(ext_op)
148    }
149}
150
151/// Traits implemented by types which can be loaded from [`ExtensionOp`]s,
152/// i.e. concrete instances of [`OpDef`]s, with defined type arguments.
153pub trait MakeExtensionOp: NamedOp {
154    /// Try to load one of the operations of this set from an [OpDef].
155    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
156    where
157        Self: Sized;
158    /// Try to instantiate a variant from an [OpType]. Default behaviour assumes
159    /// an [ExtensionOp] and loads from the name.
160    fn from_optype(op: &OpType) -> Option<Self>
161    where
162        Self: Sized,
163    {
164        let ext: &ExtensionOp = op.as_extension_op()?;
165        Self::from_extension_op(ext).ok()
166    }
167
168    /// Any type args which define this operation.
169    fn type_args(&self) -> Vec<TypeArg>;
170
171    /// Given the ID of the extension this operation is defined in, and a
172    /// registry containing that extension, return a [RegisteredOp].
173    fn to_registered(
174        self,
175        extension_id: ExtensionId,
176        extension: Weak<Extension>,
177    ) -> RegisteredOp<Self>
178    where
179        Self: Sized,
180    {
181        RegisteredOp {
182            extension_id,
183            extension,
184            op: self,
185        }
186    }
187}
188
189/// Blanket implementation for non-polymorphic operations - [OpDef]s with no type parameters.
190impl<T: MakeOpDef> MakeExtensionOp for T {
191    #[inline]
192    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
193    where
194        Self: Sized,
195    {
196        Self::from_def(ext_op.def())
197    }
198
199    #[inline]
200    fn type_args(&self) -> Vec<TypeArg> {
201        vec![]
202    }
203}
204
205/// Load an [MakeOpDef] from its name.
206/// See [strum::EnumString].
207pub fn try_from_name<T>(name: &OpNameRef, def_extension: &ExtensionId) -> Result<T, OpLoadError>
208where
209    T: std::str::FromStr + MakeOpDef,
210{
211    let op = T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string()))?;
212    let expected_extension = op.extension();
213    if def_extension != &expected_extension {
214        return Err(OpLoadError::WrongExtension(
215            def_extension.clone(),
216            expected_extension,
217        ));
218    }
219
220    Ok(op)
221}
222
223/// Wrap an [MakeExtensionOp] with an extension registry to allow type computation.
224/// Generate from [MakeExtensionOp::to_registered]
225#[derive(Clone, Debug)]
226pub struct RegisteredOp<T> {
227    /// The name of the extension these ops belong to.
228    pub extension_id: ExtensionId,
229    /// A registry of all extensions, used for type computation.
230    extension: Weak<Extension>,
231    /// The inner [MakeExtensionOp]
232    op: T,
233}
234
235impl<T> RegisteredOp<T> {
236    /// Extract the inner wrapped value
237    pub fn to_inner(self) -> T {
238        self.op
239    }
240}
241
242impl<T: MakeExtensionOp> RegisteredOp<T> {
243    /// Generate an [OpType].
244    pub fn to_extension_op(&self) -> Option<ExtensionOp> {
245        ExtensionOp::new(
246            self.extension.upgrade()?.get_op(&self.name())?.clone(),
247            self.type_args(),
248        )
249        .ok()
250    }
251
252    delegate! {
253        to self.op {
254            /// Name of the operation - derived from strum serialization.
255            pub fn name(&self) -> OpName;
256            /// Any type args which define this operation. Default is no type arguments.
257            pub fn type_args(&self) -> Vec<TypeArg>;
258        }
259    }
260}
261
262/// Trait for operations that can self report the extension ID they belong to
263/// and the registry required to compute their types.
264/// Allows conversion to [`ExtensionOp`]
265pub trait MakeRegisteredOp: MakeExtensionOp {
266    /// The ID of the extension this op belongs to.
267    fn extension_id(&self) -> ExtensionId;
268    /// A reference to the [Extension] which defines this operation.
269    fn extension_ref(&self) -> Weak<Extension>;
270
271    /// Convert this operation in to an [ExtensionOp]. Returns None if the type
272    /// cannot be computed.
273    fn to_extension_op(self) -> Option<ExtensionOp>
274    where
275        Self: Sized,
276    {
277        let registered: RegisteredOp<_> = self.into();
278        registered.to_extension_op()
279    }
280}
281
282impl<T: MakeRegisteredOp> From<T> for RegisteredOp<T> {
283    fn from(ext_op: T) -> Self {
284        let extension_id = ext_op.extension_id();
285        let extension = ext_op.extension_ref();
286        ext_op.to_registered(extension_id, extension)
287    }
288}
289
290impl<T: MakeRegisteredOp> From<T> for OpType {
291    /// Convert
292    fn from(ext_op: T) -> Self {
293        ext_op.to_extension_op().unwrap().into()
294    }
295}
296
297#[cfg(test)]
298mod test {
299    use std::sync::Arc;
300
301    use crate::{const_extension_ids, type_row, types::Signature};
302
303    use super::*;
304    use lazy_static::lazy_static;
305    use strum::{EnumIter, EnumString, IntoStaticStr};
306
307    #[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
308    enum DummyEnum {
309        Dumb,
310    }
311
312    impl MakeOpDef for DummyEnum {
313        fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
314            Signature::new_endo(type_row![]).into()
315        }
316
317        fn extension_ref(&self) -> Weak<Extension> {
318            Arc::downgrade(&EXT)
319        }
320
321        fn from_def(_op_def: &OpDef) -> Result<Self, OpLoadError> {
322            Ok(Self::Dumb)
323        }
324
325        fn extension(&self) -> ExtensionId {
326            EXT_ID.to_owned()
327        }
328    }
329
330    impl HasConcrete for DummyEnum {
331        type Concrete = Self;
332
333        fn instantiate(&self, _type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
334            if _type_args.is_empty() {
335                Ok(self.clone())
336            } else {
337                Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
338            }
339        }
340    }
341    const_extension_ids! {
342        const EXT_ID: ExtensionId = "DummyExt";
343    }
344
345    lazy_static! {
346        static ref EXT: Arc<Extension> = {
347            Extension::new_test_arc(EXT_ID.clone(), |ext, extension_ref| {
348                DummyEnum::Dumb
349                    .add_to_extension(ext, extension_ref)
350                    .unwrap();
351            })
352        };
353    }
354    impl MakeRegisteredOp for DummyEnum {
355        fn extension_id(&self) -> ExtensionId {
356            EXT_ID.to_owned()
357        }
358
359        fn extension_ref(&self) -> Weak<Extension> {
360            Arc::downgrade(&EXT)
361        }
362    }
363
364    #[test]
365    fn test_dummy_enum() {
366        let o = DummyEnum::Dumb;
367
368        assert_eq!(
369            DummyEnum::from_def(EXT.get_op(&o.name()).unwrap()).unwrap(),
370            o
371        );
372
373        assert_eq!(
374            DummyEnum::from_optype(&o.clone().to_extension_op().unwrap().into()).unwrap(),
375            o
376        );
377        let registered: RegisteredOp<_> = o.clone().into();
378        assert_eq!(registered.to_inner(), o);
379
380        assert_eq!(o.instantiate(&[]), Ok(o.clone()));
381        assert_eq!(
382            o.instantiate(&[TypeArg::BoundedNat { n: 1 }]),
383            Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
384        );
385    }
386}