hugr_core/extension/
simple_op.rs

1//! A trait that enum for op definitions that gathers up some shared functionality.
2
3use std::sync::Weak;
4
5use strum::IntoEnumIterator;
6
7use crate::ops::{ExtensionOp, OpName, OpNameRef};
8use crate::{Extension, ops::OpType, types::TypeArg};
9
10use super::{ExtensionBuildError, ExtensionId, OpDef, SignatureError, op_def::SignatureFunc};
11use delegate::delegate;
12use thiserror::Error;
13
14/// Error loading operation.
15#[derive(Debug, Error, PartialEq, Clone)]
16#[error("{0}")]
17#[allow(missing_docs)]
18#[non_exhaustive]
19pub enum OpLoadError {
20    #[error("Op with name {0} is not a member of this set.")]
21    NotMember(String),
22    #[error("Type args invalid: {0}.")]
23    InvalidArgs(#[from] SignatureError),
24    #[error("OpDef belongs to extension {0}, expected {1}.")]
25    WrongExtension(ExtensionId, ExtensionId),
26}
27
28/// Traits implemented by types which can add themselves to [`Extension`]s as
29/// [`OpDef`]s or load themselves from an [`OpDef`].
30///
31/// Particularly useful with C-style enums that implement [`strum::IntoEnumIterator`],
32/// as then all definitions can be added to an extension at once.
33///
34/// [`MakeExtensionOp`] has a blanket impl for types that impl [`MakeOpDef`].
35pub trait MakeOpDef {
36    /// The [`OpDef::name`] which will be used when `Self`  is added to an [Extension]
37    /// or when `Self` is loaded from an [`OpDef`].
38    ///
39    /// This identifer must be unique within the extension with which the
40    /// [`OpDef`] is registered. An [`ExtensionOp`] instantiating this [`OpDef`] will
41    /// report `self.opdef_id()` as its [`ExtensionOp::unqualified_id`].
42    ///
43    /// [`MakeExtensionOp::op_id`] must match this function.
44    fn opdef_id(&self) -> OpName;
45
46    /// Try to load one of the operations of this set from an [`OpDef`].
47    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
48    where
49        Self: Sized;
50
51    /// The ID of the extension this operation is defined in.
52    fn extension(&self) -> ExtensionId;
53
54    /// Returns a weak reference to the extension this operation is defined in.
55    fn extension_ref(&self) -> Weak<Extension>;
56
57    /// Compute the signature of the operation while the extension definition is being built.
58    ///
59    /// Requires a [`Weak`] reference to the extension defining the operation.
60    /// This method is intended to be used inside the closure passed to [`Extension::new_arc`],
61    /// and it is normally internally called by [`MakeOpDef::add_to_extension`].
62    fn init_signature(&self, extension_ref: &Weak<Extension>) -> SignatureFunc;
63
64    /// Return the signature (polymorphic function type) of the operation.
65    fn signature(&self) -> SignatureFunc {
66        self.init_signature(&self.extension_ref())
67    }
68
69    /// Description of the operation. By default, the same as `self.opdef_id()`.
70    fn description(&self) -> String {
71        self.opdef_id().to_string()
72    }
73
74    /// Edit the opdef before finalising. By default does nothing.
75    fn post_opdef(&self, _def: &mut OpDef) {}
76
77    /// Add an operation implemented as an [`MakeOpDef`], which can provide the data
78    /// required to define an [`OpDef`], to an extension.
79    ///
80    /// Requires a [`Weak`] reference to the extension defining the operation.
81    /// This method is intended to be used inside the closure passed to [`Extension::new_arc`].
82    fn add_to_extension(
83        &self,
84        extension: &mut Extension,
85        extension_ref: &Weak<Extension>,
86    ) -> Result<(), ExtensionBuildError> {
87        let def = extension.add_op(
88            self.opdef_id(),
89            self.description(),
90            self.init_signature(extension_ref),
91            extension_ref,
92        )?;
93
94        self.post_opdef(def);
95
96        Ok(())
97    }
98
99    /// Load all variants of an enum of op definitions in to an extension as op defs.
100    /// See [`strum::IntoEnumIterator`].
101    ///
102    /// Requires a [`Weak`] reference to the extension defining the operation.
103    /// This method is intended to be used inside the closure passed to [`Extension::new_arc`].
104    fn load_all_ops(
105        extension: &mut Extension,
106        extension_ref: &Weak<Extension>,
107    ) -> Result<(), ExtensionBuildError>
108    where
109        Self: IntoEnumIterator,
110    {
111        for op in Self::iter() {
112            op.add_to_extension(extension, extension_ref)?;
113        }
114        Ok(())
115    }
116
117    /// If the definition can be loaded from a string, load from an [`ExtensionOp`].
118    fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
119    where
120        Self: Sized + std::str::FromStr,
121    {
122        Self::from_extension_op(ext_op)
123    }
124}
125
126/// [`MakeOpDef`] with an associate concrete Op type which can be instantiated with type arguments.
127pub trait HasConcrete: MakeOpDef {
128    /// Associated concrete type.
129    type Concrete: MakeExtensionOp;
130
131    /// Instantiate the operation with type arguments.
132    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError>;
133}
134
135/// [`MakeExtensionOp`] with an associated [`HasConcrete`].
136pub trait HasDef: MakeExtensionOp {
137    /// Associated [`HasConcrete`] type.
138    type Def: HasConcrete<Concrete = Self> + std::str::FromStr;
139
140    /// Load the operation from a [`ExtensionOp`].
141    fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
142    where
143        Self: Sized,
144    {
145        Self::from_extension_op(ext_op)
146    }
147}
148
149/// Traits implemented by types which can be loaded from [`ExtensionOp`]s,
150/// i.e. concrete instances of [`OpDef`]s, with defined type arguments.
151pub trait MakeExtensionOp {
152    /// The [`OpDef::name`] of [`ExtensionOp`]s from which `Self` can be loaded.
153    ///
154    /// This identifer must be unique within the extension with which the
155    /// [`OpDef`] is registered. An [`ExtensionOp`] instantiating this [`OpDef`] will
156    /// report `self.opdef_id()` as its [`ExtensionOp::unqualified_id`].
157    fn op_id(&self) -> OpName;
158
159    /// Try to load one of the operations of this set from an [`OpDef`].
160    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
161    where
162        Self: Sized;
163    /// Try to instantiate a variant from an [`OpType`]. Default behaviour assumes
164    /// an [`ExtensionOp`] and loads from the name.
165    #[must_use]
166    fn from_optype(op: &OpType) -> Option<Self>
167    where
168        Self: Sized,
169    {
170        let ext: &ExtensionOp = op.as_extension_op()?;
171        Self::from_extension_op(ext).ok()
172    }
173
174    /// Any type args which define this operation.
175    fn type_args(&self) -> Vec<TypeArg>;
176
177    /// Given the ID of the extension this operation is defined in, and a
178    /// registry containing that extension, return a [`RegisteredOp`].
179    fn to_registered(
180        self,
181        extension_id: ExtensionId,
182        extension: Weak<Extension>,
183    ) -> RegisteredOp<Self>
184    where
185        Self: Sized,
186    {
187        RegisteredOp {
188            extension_id,
189            extension,
190            op: self,
191        }
192    }
193}
194
195/// Blanket implementation for non-polymorphic operations - [`OpDef`]s with no type parameters.
196impl<T: MakeOpDef> MakeExtensionOp for T {
197    fn op_id(&self) -> OpName {
198        self.opdef_id()
199    }
200
201    #[inline]
202    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
203    where
204        Self: Sized,
205    {
206        Self::from_def(ext_op.def())
207    }
208
209    #[inline]
210    fn type_args(&self) -> Vec<TypeArg> {
211        vec![]
212    }
213}
214
215/// Load an [`MakeOpDef`] from its name.
216/// See [`strum::EnumString`].
217pub fn try_from_name<T>(name: &OpNameRef, def_extension: &ExtensionId) -> Result<T, OpLoadError>
218where
219    T: std::str::FromStr + MakeOpDef,
220{
221    let op = T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string()))?;
222    let expected_extension = op.extension();
223    if def_extension != &expected_extension {
224        return Err(OpLoadError::WrongExtension(
225            def_extension.clone(),
226            expected_extension,
227        ));
228    }
229
230    Ok(op)
231}
232
233/// Wrap an [`MakeExtensionOp`] with an extension registry to allow type computation.
234/// Generate from [`MakeExtensionOp::to_registered`]
235#[derive(Clone, Debug)]
236pub struct RegisteredOp<T> {
237    /// The name of the extension these ops belong to.
238    pub extension_id: ExtensionId,
239    /// A registry of all extensions, used for type computation.
240    extension: Weak<Extension>,
241    /// The inner [`MakeExtensionOp`]
242    op: T,
243}
244
245impl<T> RegisteredOp<T> {
246    /// Extract the inner wrapped value
247    pub fn to_inner(self) -> T {
248        self.op
249    }
250}
251
252impl<T: MakeExtensionOp> RegisteredOp<T> {
253    /// Generate an [`OpType`].
254    pub fn to_extension_op(&self) -> Option<ExtensionOp> {
255        ExtensionOp::new(
256            self.extension.upgrade()?.get_op(&self.op_id())?.clone(),
257            self.type_args(),
258        )
259        .ok()
260    }
261
262    delegate! {
263        to self.op {
264            /// Name of the operation - derived from strum serialization.
265            pub fn op_id(&self) -> OpName;
266            /// Any type args which define this operation. Default is no type arguments.
267            pub fn type_args(&self) -> Vec<TypeArg>;
268        }
269    }
270}
271
272/// Trait for operations that can self report the extension ID they belong to
273/// and the registry required to compute their types.
274/// Allows conversion to [`ExtensionOp`]
275pub trait MakeRegisteredOp: MakeExtensionOp {
276    /// The ID of the extension this op belongs to.
277    fn extension_id(&self) -> ExtensionId;
278    /// A reference to the [Extension] which defines this operation.
279    fn extension_ref(&self) -> Weak<Extension>;
280
281    /// Convert this operation in to an [`ExtensionOp`]. Returns None if the type
282    /// cannot be computed.
283    fn to_extension_op(self) -> Option<ExtensionOp>
284    where
285        Self: Sized,
286    {
287        let registered: RegisteredOp<_> = self.into();
288        registered.to_extension_op()
289    }
290}
291
292impl<T: MakeRegisteredOp> From<T> for RegisteredOp<T> {
293    fn from(ext_op: T) -> Self {
294        let extension_id = ext_op.extension_id();
295        let extension = ext_op.extension_ref();
296        ext_op.to_registered(extension_id, extension)
297    }
298}
299
300impl<T: MakeRegisteredOp> From<T> for OpType {
301    /// Convert
302    fn from(ext_op: T) -> Self {
303        ext_op.to_extension_op().unwrap().into()
304    }
305}
306
307#[cfg(test)]
308mod test {
309    use std::sync::{Arc, LazyLock};
310
311    use crate::{
312        const_extension_ids, type_row,
313        types::{Signature, Term},
314    };
315
316    use super::*;
317    use strum::{EnumIter, EnumString, IntoStaticStr};
318
319    #[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
320    enum DummyEnum {
321        Dumb,
322    }
323
324    impl MakeOpDef for DummyEnum {
325        fn opdef_id(&self) -> OpName {
326            <&'static str>::from(self).into()
327        }
328
329        fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
330            Signature::new_endo(type_row![]).into()
331        }
332
333        fn extension_ref(&self) -> Weak<Extension> {
334            Arc::downgrade(&EXT)
335        }
336
337        fn from_def(_op_def: &OpDef) -> Result<Self, OpLoadError> {
338            Ok(Self::Dumb)
339        }
340
341        fn extension(&self) -> ExtensionId {
342            EXT_ID.clone()
343        }
344    }
345
346    impl HasConcrete for DummyEnum {
347        type Concrete = Self;
348
349        fn instantiate(&self, _type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
350            if _type_args.is_empty() {
351                Ok(self.clone())
352            } else {
353                Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
354            }
355        }
356    }
357    const_extension_ids! {
358        const EXT_ID: ExtensionId = "DummyExt";
359    }
360
361    static EXT: LazyLock<Arc<Extension>> = LazyLock::new(|| {
362        Extension::new_test_arc(EXT_ID.clone(), |ext, extension_ref| {
363            DummyEnum::Dumb
364                .add_to_extension(ext, extension_ref)
365                .unwrap();
366        })
367    });
368
369    impl MakeRegisteredOp for DummyEnum {
370        fn extension_id(&self) -> ExtensionId {
371            EXT_ID.clone()
372        }
373
374        fn extension_ref(&self) -> Weak<Extension> {
375            Arc::downgrade(&EXT)
376        }
377    }
378
379    #[test]
380    fn test_dummy_enum() {
381        let o = DummyEnum::Dumb;
382
383        assert_eq!(
384            DummyEnum::from_def(EXT.get_op(&o.opdef_id()).unwrap()).unwrap(),
385            o
386        );
387
388        assert_eq!(
389            DummyEnum::from_optype(&o.clone().to_extension_op().unwrap().into()).unwrap(),
390            o
391        );
392        let registered: RegisteredOp<_> = o.clone().into();
393        assert_eq!(registered.to_inner(), o);
394
395        assert_eq!(o.instantiate(&[]), Ok(o.clone()));
396        assert_eq!(
397            o.instantiate(&[Term::from(1u64)]),
398            Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
399        );
400    }
401}